mod cost;
mod transform;
use super::eval::{sample, GrammarOracle, SampleConfig};
use crate::grammar::Grammar;
use transform::{apply_candidate, enumerate_candidates, Candidate, CandidateKind};
const DEFAULT_PRECISION_BUDGET: f64 = 0.0;
const DEFAULT_SAMPLE_BUDGET: usize = 256;
const DEFAULT_MAX_ITERATIONS: usize = 64;
const COST_EPSILON: f64 = 1e-9;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Mdl {
pub grammar_bits: f64,
pub data_bits: f64,
}
impl Mdl {
#[must_use]
pub fn total(self) -> f64 {
self.grammar_bits + self.data_bits
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct MinimizeOptions {
pub precision_budget: f64,
pub sample_budget: usize,
pub max_iterations: usize,
}
impl Default for MinimizeOptions {
fn default() -> Self {
Self {
precision_budget: DEFAULT_PRECISION_BUDGET,
sample_budget: DEFAULT_SAMPLE_BUDGET,
max_iterations: DEFAULT_MAX_ITERATIONS,
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct MinimizeReport {
pub merges_applied: usize,
pub inlines_applied: usize,
pub factorings_applied: usize,
pub prunes_applied: usize,
pub candidates_rejected_by_mdl: usize,
pub candidates_rejected_by_gate: usize,
}
#[derive(Clone, Debug, PartialEq)]
pub struct MinimizeResult {
pub grammar: Grammar,
pub before: Mdl,
pub after: Mdl,
pub report: MinimizeReport,
}
#[must_use]
pub fn mdl_cost(grammar: &Grammar, examples: &[String]) -> Mdl {
cost::mdl_cost(grammar, examples)
}
#[must_use]
pub fn minimize(grammar: &Grammar, examples: &[String], opts: MinimizeOptions) -> MinimizeResult {
let before = mdl_cost(grammar, examples);
let mut current = grammar.clone();
let mut current_cost = before;
let mut report = MinimizeReport::default();
for _ in 0..opts.max_iterations {
let candidates = enumerate_candidates(¤t);
if candidates.is_empty() {
break;
}
let mut best: Option<ScoredCandidate> = None;
for (order, candidate) in candidates.into_iter().enumerate() {
let trial = apply_candidate(¤t, candidate);
if trial == current {
report.candidates_rejected_by_mdl =
report.candidates_rejected_by_mdl.saturating_add(1);
continue;
}
let trial_cost = mdl_cost(&trial, examples);
let delta = trial_cost.total() - current_cost.total();
if delta >= -COST_EPSILON {
report.candidates_rejected_by_mdl =
report.candidates_rejected_by_mdl.saturating_add(1);
continue;
}
if !passes_gate(grammar, &trial, examples, opts) {
report.candidates_rejected_by_gate =
report.candidates_rejected_by_gate.saturating_add(1);
continue;
}
let scored = ScoredCandidate {
order,
candidate,
grammar: trial,
cost: trial_cost,
delta,
};
if best
.as_ref()
.map_or(true, |best| scored.is_better_than(best))
{
best = Some(scored);
}
}
let Some(best) = best else {
break;
};
record_acceptance(&mut report, best.candidate.kind());
current = best.grammar;
current_cost = best.cost;
}
MinimizeResult {
grammar: current,
before,
after: current_cost,
report,
}
}
#[derive(Clone, Debug, PartialEq)]
struct ScoredCandidate {
order: usize,
candidate: Candidate,
grammar: Grammar,
cost: Mdl,
delta: f64,
}
impl ScoredCandidate {
fn is_better_than(&self, other: &Self) -> bool {
self.delta < other.delta - COST_EPSILON
|| ((self.delta - other.delta).abs() <= COST_EPSILON && self.order < other.order)
}
}
fn record_acceptance(report: &mut MinimizeReport, kind: CandidateKind) {
match kind {
CandidateKind::Merge => {
report.merges_applied = report.merges_applied.saturating_add(1);
}
CandidateKind::Inline => {
report.inlines_applied = report.inlines_applied.saturating_add(1);
}
CandidateKind::Factor => {
report.factorings_applied = report.factorings_applied.saturating_add(1);
}
CandidateKind::Prune => {
report.prunes_applied = report.prunes_applied.saturating_add(1);
}
}
}
fn passes_gate(
baseline: &Grammar,
trial: &Grammar,
examples: &[String],
opts: MinimizeOptions,
) -> bool {
let trial_oracle = GrammarOracle::new(trial);
if !examples.iter().all(|example| trial_oracle.accepts(example)) {
return false;
}
if opts.sample_budget == 0 {
return true;
}
let config = SampleConfig {
count: opts.sample_budget.max(1),
max_depth: opts.max_iterations.max(1),
..SampleConfig::default()
};
let Ok(samples) = sample(trial, &config) else {
return false;
};
if samples.is_empty() {
return false;
}
let baseline_oracle = GrammarOracle::new(baseline);
let accepted = samples
.iter()
.filter(|sample| baseline_oracle.accepts(sample))
.count();
let precision = ratio(accepted, samples.len());
let budget = if opts.precision_budget.is_finite() {
opts.precision_budget.max(0.0)
} else {
DEFAULT_PRECISION_BUDGET
};
precision + budget + COST_EPSILON >= 1.0
}
fn ratio(numerator: usize, denominator: usize) -> f64 {
debug_assert!(denominator > 0);
usize_to_f64(numerator) / usize_to_f64(denominator)
}
#[allow(clippy::cast_precision_loss)]
const fn usize_to_f64(value: usize) -> f64 {
value as f64
}