Skip to main content

alkahest_cas/simplify/
egraph.rs

1/// E-graph based simplifier using egglog.
2///
3/// Enabled only when the `egraph` feature is active.  Falls back to the
4/// rule-based engine otherwise; see `simplify_egraph` for the stable
5/// public entry point that is always available.
6///
7/// # Encoding strategy
8///
9/// Alkahest uses n-ary `Add`/`Mul`, but egglog works with fixed-arity
10/// constructors.  We left-fold n-ary sums/products into binary trees for
11/// submission, then *flatten* the extracted binary tree back to n-ary on
12/// the way out (see `parse_egglog_term`).  Commutativity is handled at
13/// construction time (children sorted by ExprId); associativity is not
14/// added as a rule to avoid AC explosion — the phased schedule plus the
15/// flattening round-trip is sufficient for practical inputs.
16///
17/// # Schedule (RW-2)
18///
19/// The iteration counts and node/iteration limits are taken from
20/// [`EgraphConfig`], allowing callers to trade completeness for bounded
21/// run time on large inputs.
22#[cfg(feature = "egraph")]
23mod backend {
24    use crate::kernel::{ExprData, ExprId, ExprPool};
25    use std::collections::HashMap;
26
27    // -----------------------------------------------------------------------
28    // 1. Serialise ExprId → egglog expression string (binary left-fold)
29    // -----------------------------------------------------------------------
30
31    pub(super) fn expr_to_egglog(expr: ExprId, pool: &ExprPool) -> String {
32        enum Node {
33            Num(i64),
34            Var(String),
35            Add(Vec<ExprId>),
36            Mul(Vec<ExprId>),
37            Pow(ExprId, ExprId),
38            Func(String, ExprId),
39            Unsupported,
40        }
41
42        let node = pool.with(expr, |data| match data {
43            ExprData::Integer(n) => {
44                let v =
45                    n.0.to_i64()
46                        .unwrap_or(if n.0 > 0 { i64::MAX } else { i64::MIN });
47                Node::Num(v)
48            }
49            ExprData::Rational(_) | ExprData::Float(_) => Node::Unsupported,
50            ExprData::Symbol { name, .. } => Node::Var(name.clone()),
51            ExprData::Add(args) => Node::Add(args.clone()),
52            ExprData::Mul(args) => Node::Mul(args.clone()),
53            ExprData::Pow { base, exp } => Node::Pow(*base, *exp),
54            ExprData::Func { name, args } if args.len() == 1 => Node::Func(name.clone(), args[0]),
55            ExprData::Func { .. } => Node::Unsupported,
56            ExprData::Piecewise { .. }
57            | ExprData::Predicate { .. }
58            | ExprData::Forall { .. }
59            | ExprData::Exists { .. }
60            | ExprData::BigO(_) => Node::Unsupported,
61        });
62
63        match node {
64            Node::Num(n) => format!("(Num {n})"),
65            Node::Var(name) => format!("(Var \"{name}\")"),
66            Node::Add(args) => {
67                // Binary left-fold; the parser flattens this back to n-ary.
68                let mut it = args.into_iter();
69                let first = it.next().expect(
70                    "Add node must have at least one argument — ExprPool invariant violated",
71                );
72                let init = expr_to_egglog(first, pool);
73                it.fold(init, |acc, id| {
74                    format!("(Add {acc} {})", expr_to_egglog(id, pool))
75                })
76            }
77            Node::Mul(args) => {
78                let mut it = args.into_iter();
79                let first = it.next().expect(
80                    "Mul node must have at least one argument — ExprPool invariant violated",
81                );
82                let init = expr_to_egglog(first, pool);
83                it.fold(init, |acc, id| {
84                    format!("(Mul {acc} {})", expr_to_egglog(id, pool))
85                })
86            }
87            Node::Pow(base, exp) => format!(
88                "(Pow {} {})",
89                expr_to_egglog(base, pool),
90                expr_to_egglog(exp, pool)
91            ),
92            Node::Func(name, arg) => {
93                let inner = expr_to_egglog(arg, pool);
94                match name.as_str() {
95                    "sin" => format!("(Sin {inner})"),
96                    "cos" => format!("(Cos {inner})"),
97                    "exp" => format!("(Exp {inner})"),
98                    "log" => format!("(Log {inner})"),
99                    "sqrt" => format!("(Sqrt {inner})"),
100                    _ => format!("(Var \"{name}_{inner}\")"),
101                }
102            }
103            Node::Unsupported => "(Num 0)".to_string(),
104        }
105    }
106
107    // -----------------------------------------------------------------------
108    // 2. Build the complete egglog program  (RW-2: uses EgraphConfig)
109    // -----------------------------------------------------------------------
110
111    /// Count unique nodes in the expression DAG.
112    ///
113    /// Used to enforce `EgraphConfig::node_limit` before handing the expression
114    /// to egglog, preventing OOM on pathological inputs.
115    fn count_dag_nodes(expr: ExprId, pool: &ExprPool) -> usize {
116        let mut visited = std::collections::HashSet::new();
117        count_dag_nodes_rec(expr, pool, &mut visited);
118        visited.len()
119    }
120
121    fn count_dag_nodes_rec(
122        expr: ExprId,
123        pool: &ExprPool,
124        visited: &mut std::collections::HashSet<ExprId>,
125    ) {
126        if !visited.insert(expr) {
127            return;
128        }
129        match pool.get(expr) {
130            ExprData::Add(args) | ExprData::Mul(args) => {
131                for &a in &args {
132                    count_dag_nodes_rec(a, pool, visited);
133                }
134            }
135            ExprData::Pow { base, exp } => {
136                count_dag_nodes_rec(base, pool, visited);
137                count_dag_nodes_rec(exp, pool, visited);
138            }
139            ExprData::Func { args, .. } => {
140                for &a in &args {
141                    count_dag_nodes_rec(a, pool, visited);
142                }
143            }
144            ExprData::Piecewise { branches, default } => {
145                for (cond, val) in &branches {
146                    count_dag_nodes_rec(*cond, pool, visited);
147                    count_dag_nodes_rec(*val, pool, visited);
148                }
149                count_dag_nodes_rec(default, pool, visited);
150            }
151            ExprData::Predicate { args, .. } => {
152                for a in args {
153                    count_dag_nodes_rec(a, pool, visited);
154                }
155            }
156            ExprData::Forall { var, body } | ExprData::Exists { var, body } => {
157                count_dag_nodes_rec(var, pool, visited);
158                count_dag_nodes_rec(body, pool, visited);
159            }
160            ExprData::BigO(arg) => {
161                count_dag_nodes_rec(arg, pool, visited);
162            }
163            // Leaf nodes
164            ExprData::Integer(_)
165            | ExprData::Rational(_)
166            | ExprData::Float(_)
167            | ExprData::Symbol { .. } => {}
168        }
169    }
170
171    fn egglog_program(expr_str: &str, config: &super::EgraphConfig) -> String {
172        // node_limit is enforced as a pre-saturation DAG-size check in
173        // simplify_egraph_impl; egglog 0.4 does not expose a per-run node cap.
174        let node_limit_line = String::new();
175        let iter_limit_line = config
176            .iter_limit
177            .map(|n| format!("(set-option iteration_limit {n})\n"))
178            .unwrap_or_default();
179
180        let si = config.shrink_iters;
181        let ei = config.explore_iters;
182        let ci = config.const_fold_iters;
183
184        // Conditionally include trig / log-exp rules based on config flags.
185        let trig_rules = if config.include_trig_rules {
186            // Both Mul form (sin(x)*sin(x)) and Pow form (sin(x)^2) are matched
187            // so the identity fires regardless of how the square is represented.
188            "(rewrite (Add (Mul (Sin ?x) (Sin ?x)) (Mul (Cos ?x) (Cos ?x))) (Num 1) :ruleset explore)\n\
189             (rewrite (Add (Mul (Cos ?x) (Cos ?x)) (Mul (Sin ?x) (Sin ?x))) (Num 1) :ruleset explore)\n\
190             (rewrite (Add (Pow (Sin ?x) (Num 2)) (Pow (Cos ?x) (Num 2))) (Num 1) :ruleset explore)\n\
191             (rewrite (Add (Pow (Cos ?x) (Num 2)) (Pow (Sin ?x) (Num 2))) (Num 1) :ruleset explore)"
192        } else {
193            ""
194        };
195
196        let log_exp_rules = if config.include_log_exp_rules {
197            "(rewrite (Exp (Log ?x)) ?x :ruleset explore)\n\
198             (rewrite (Log (Exp ?x)) ?x :ruleset explore)"
199        } else {
200            ""
201        };
202
203        format!(
204            r#"
205{node_limit_line}{iter_limit_line}(datatype Expr
206  (Num i64)
207  (Var String)
208  (Add Expr Expr)
209  (Mul Expr Expr)
210  (Pow Expr Expr)
211  (Sin Expr)
212  (Cos Expr)
213  (Exp Expr)
214  (Log Expr)
215  (Sqrt Expr))
216
217; ── shrink ruleset: identity / absorption / cancellation ─────────────────────
218(ruleset shrink)
219(rewrite (Add ?x (Num 0)) ?x :ruleset shrink)
220(rewrite (Add (Num 0) ?x) ?x :ruleset shrink)
221(rewrite (Mul ?x (Num 1)) ?x :ruleset shrink)
222(rewrite (Mul (Num 1) ?x) ?x :ruleset shrink)
223(rewrite (Mul ?x (Num 0)) (Num 0) :ruleset shrink)
224(rewrite (Mul (Num 0) ?x) (Num 0) :ruleset shrink)
225(rewrite (Pow ?x (Num 1)) ?x :ruleset shrink)
226(rewrite (Pow ?x (Num 0)) (Num 1) :ruleset shrink)
227(rewrite (Add ?x (Mul (Num -1) ?x)) (Num 0) :ruleset shrink)
228(rewrite (Add (Mul (Num -1) ?x) ?x) (Num 0) :ruleset shrink)
229(rewrite (Mul ?x (Pow ?x (Num -1))) (Num 1) :ruleset shrink)
230(rewrite (Mul (Pow ?x (Num -1)) ?x) (Num 1) :ruleset shrink)
231
232; ── explore ruleset: trig and log/exp identities (default: both enabled) ──────
233(ruleset explore)
234{trig_rules}
235{log_exp_rules}
236(rewrite (Mul (Num -1) (Mul (Num -1) ?x)) ?x :ruleset explore)
237
238; ── constant folding ──────────────────────────────────────────────────────────
239(ruleset const-fold)
240(rule ((= e (Add (Num ?a) (Num ?b))))
241      ((union e (Num (+ ?a ?b))))
242      :ruleset const-fold)
243(rule ((= e (Mul (Num ?a) (Num ?b))))
244      ((union e (Num (* ?a ?b))))
245      :ruleset const-fold)
246(rule ((= e (Pow (Num ?a) (Num ?b))) (>= ?b 0))
247      ((union e (Num (^ ?a ?b))))
248      :ruleset const-fold)
249
250; ── phased schedule: shrink → const-fold → explore → shrink → const-fold ─────
251(let __expr {expr})
252(run shrink {si})
253(run const-fold {ci})
254(run explore {ei})
255(run shrink {si})
256(run const-fold {ci})
257(extract __expr)
258"#,
259            node_limit_line = node_limit_line,
260            iter_limit_line = iter_limit_line,
261            trig_rules = trig_rules,
262            log_exp_rules = log_exp_rules,
263            expr = expr_str,
264            si = si,
265            ei = ei,
266            ci = ci,
267        )
268    }
269
270    // -----------------------------------------------------------------------
271    // 3. Parse egglog output back to ExprId  (RW-1: flatten binary → n-ary)
272    // -----------------------------------------------------------------------
273
274    /// Collect all top-level Add children, recursively flattening nested Adds.
275    fn flatten_add_args(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
276        match pool.get(expr) {
277            ExprData::Add(args) => args
278                .iter()
279                .flat_map(|&a| flatten_add_args(a, pool))
280                .collect(),
281            _ => vec![expr],
282        }
283    }
284
285    /// Collect all top-level Mul children, recursively flattening nested Muls.
286    fn flatten_mul_args(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
287        match pool.get(expr) {
288            ExprData::Mul(args) => args
289                .iter()
290                .flat_map(|&a| flatten_mul_args(a, pool))
291                .collect(),
292            _ => vec![expr],
293        }
294    }
295
296    fn parse_egglog_term(s: &str, pool: &ExprPool) -> Option<ExprId> {
297        let s = s.trim();
298        if s.starts_with('(') && s.ends_with(')') {
299            let inner = &s[1..s.len() - 1];
300            let (head, rest) = split_head(inner)?;
301            match head {
302                "Num" => {
303                    let n: i64 = rest.trim().parse().ok()?;
304                    Some(pool.integer(n))
305                }
306                "Var" => {
307                    let name = rest.trim().trim_matches('"');
308                    Some(pool.symbol(name, crate::kernel::Domain::Real))
309                }
310                "Add" => {
311                    let (a_str, b_str) = split_two_args(rest)?;
312                    let a = parse_egglog_term(&a_str, pool)?;
313                    let b = parse_egglog_term(&b_str, pool)?;
314                    // RW-1: flatten binary tree back to n-ary on the way out.
315                    let mut children = flatten_add_args(a, pool);
316                    children.extend(flatten_add_args(b, pool));
317                    Some(pool.add(children))
318                }
319                "Mul" => {
320                    let (a_str, b_str) = split_two_args(rest)?;
321                    let a = parse_egglog_term(&a_str, pool)?;
322                    let b = parse_egglog_term(&b_str, pool)?;
323                    let mut children = flatten_mul_args(a, pool);
324                    children.extend(flatten_mul_args(b, pool));
325                    Some(pool.mul(children))
326                }
327                "Pow" => {
328                    let (a_str, b_str) = split_two_args(rest)?;
329                    let a = parse_egglog_term(&a_str, pool)?;
330                    let b = parse_egglog_term(&b_str, pool)?;
331                    Some(pool.pow(a, b))
332                }
333                "Sin" => Some(pool.func("sin", vec![parse_egglog_term(rest.trim(), pool)?])),
334                "Cos" => Some(pool.func("cos", vec![parse_egglog_term(rest.trim(), pool)?])),
335                "Exp" => Some(pool.func("exp", vec![parse_egglog_term(rest.trim(), pool)?])),
336                "Log" => Some(pool.func("log", vec![parse_egglog_term(rest.trim(), pool)?])),
337                "Sqrt" => Some(pool.func("sqrt", vec![parse_egglog_term(rest.trim(), pool)?])),
338                _ => None,
339            }
340        } else {
341            let n: i64 = s.parse().ok()?;
342            Some(pool.integer(n))
343        }
344    }
345
346    fn split_head(s: &str) -> Option<(&str, &str)> {
347        let s = s.trim();
348        let pos = s.find(|c: char| c.is_whitespace())?;
349        Some((&s[..pos], &s[pos + 1..]))
350    }
351
352    fn split_two_args(s: &str) -> Option<(String, String)> {
353        let s = s.trim();
354        let (first, remainder) = consume_term(s)?;
355        let second = remainder.trim();
356        Some((first.to_string(), second.to_string()))
357    }
358
359    fn consume_term(s: &str) -> Option<(&str, &str)> {
360        let s = s.trim_start();
361        if s.starts_with('(') {
362            let mut depth = 0usize;
363            let mut in_string = false;
364            for (i, c) in s.char_indices() {
365                match c {
366                    '"' => in_string = !in_string,
367                    '(' if !in_string => depth += 1,
368                    ')' if !in_string => {
369                        depth -= 1;
370                        if depth == 0 {
371                            return Some((&s[..=i], &s[i + 1..]));
372                        }
373                    }
374                    _ => {}
375                }
376            }
377            None
378        } else {
379            let end = s
380                .find(|c: char| c.is_whitespace() || c == ')')
381                .unwrap_or(s.len());
382            Some((&s[..end], &s[end..]))
383        }
384    }
385
386    // -----------------------------------------------------------------------
387    // RW-3: Linear-expression canonizer (post-extraction pass)
388    // -----------------------------------------------------------------------
389
390    /// Try to extract a linear term as `(integer_coefficient, base_expr)`.
391    ///
392    /// Recognises: bare symbols (coeff = 1) and `Mul(Integer, Symbol)`.
393    fn extract_linear_term(expr: ExprId, pool: &ExprPool) -> Option<(i64, ExprId)> {
394        match pool.get(expr) {
395            ExprData::Symbol { .. } => Some((1, expr)),
396            ExprData::Mul(args) if args.len() == 2 => {
397                let (a, b) = (args[0], args[1]);
398                if let ExprData::Integer(n) = pool.get(a) {
399                    if matches!(pool.get(b), ExprData::Symbol { .. }) {
400                        return n.0.to_i64().map(|c| (c, b));
401                    }
402                }
403                if let ExprData::Integer(n) = pool.get(b) {
404                    if matches!(pool.get(a), ExprData::Symbol { .. }) {
405                        return n.0.to_i64().map(|c| (c, a));
406                    }
407                }
408                None
409            }
410            _ => None,
411        }
412    }
413
414    /// Canonicalize linear combinations in an expression.
415    ///
416    /// At each `Add` node, collects `(coefficient, symbol)` pairs and sums
417    /// coefficients for identical bases, eliminating zero terms.
418    ///
419    /// Example: `2*x + 3*x + y` → `5*x + y`.
420    pub(super) fn canonicalize_linear(expr: ExprId, pool: &ExprPool) -> ExprId {
421        match pool.get(expr) {
422            ExprData::Add(args) => {
423                let args: Vec<ExprId> =
424                    args.iter().map(|&a| canonicalize_linear(a, pool)).collect();
425
426                let mut coeff_map: HashMap<ExprId, i64> = HashMap::new();
427                let mut non_linear: Vec<ExprId> = Vec::new();
428                let mut found_linear = false;
429
430                for &arg in &args {
431                    if let Some((coeff, base)) = extract_linear_term(arg, pool) {
432                        *coeff_map.entry(base).or_insert(0) += coeff;
433                        found_linear = true;
434                    } else {
435                        non_linear.push(arg);
436                    }
437                }
438
439                if !found_linear {
440                    return pool.add(args);
441                }
442
443                let mut result: Vec<ExprId> = non_linear;
444                // Sort by key for determinism
445                let mut pairs: Vec<(ExprId, i64)> = coeff_map.into_iter().collect();
446                pairs.sort_by_key(|(id, _)| *id);
447                for (base, coeff) in pairs {
448                    match coeff {
449                        0 => {}
450                        1 => result.push(base),
451                        c => result.push(pool.mul(vec![pool.integer(c), base])),
452                    }
453                }
454
455                match result.len() {
456                    0 => pool.integer(0_i32),
457                    1 => result[0],
458                    _ => pool.add(result),
459                }
460            }
461            ExprData::Mul(args) => {
462                let args: Vec<ExprId> =
463                    args.iter().map(|&a| canonicalize_linear(a, pool)).collect();
464                pool.mul(args)
465            }
466            ExprData::Pow { base, exp } => {
467                let base = canonicalize_linear(base, pool);
468                let exp = canonicalize_linear(exp, pool);
469                pool.pow(base, exp)
470            }
471            ExprData::Func { name, args } => {
472                let args: Vec<ExprId> =
473                    args.iter().map(|&a| canonicalize_linear(a, pool)).collect();
474                pool.func(&name, args)
475            }
476            _ => expr,
477        }
478    }
479
480    // -----------------------------------------------------------------------
481    // 4. Public implementation
482    // -----------------------------------------------------------------------
483
484    pub fn simplify_egraph_impl(
485        expr: ExprId,
486        pool: &ExprPool,
487        config: &super::EgraphConfig,
488    ) -> crate::deriv::log::DerivedExpr<ExprId> {
489        use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
490        use crate::kernel::expr_props::expr_contains_noncommutative_symbol;
491
492        if expr_contains_noncommutative_symbol(pool, expr) {
493            return super::super::engine::simplify(expr, pool);
494        }
495
496        // Enforce the node limit before handing the expression to egglog.
497        // Saturation can materialise exponentially many equivalent forms, so a
498        // hard pre-check on input size prevents OOM on pathological inputs.
499        if let Some(limit) = config.node_limit {
500            let n = count_dag_nodes(expr, pool);
501            if n > limit {
502                let mut log = DerivationLog::new();
503                log.push(RewriteStep::simple(
504                    "egraph_node_limit_exceeded",
505                    expr,
506                    expr,
507                ));
508                return DerivedExpr::with_log(expr, log);
509            }
510        }
511
512        let expr_str = expr_to_egglog(expr, pool);
513        let program = egglog_program(&expr_str, config);
514
515        let result: Option<ExprId> = (|| {
516            let mut egraph = egglog::EGraph::default();
517            let outputs = egraph.parse_and_run_program(None, &program).ok()?;
518            let term_str = outputs.into_iter().last()?;
519            parse_egglog_term(&term_str, pool)
520        })();
521
522        let simplified = result.unwrap_or(expr);
523        // RW-3: apply linear canonizer as a post-extraction pass.
524        let simplified = canonicalize_linear(simplified, pool);
525
526        let mut log = DerivationLog::new();
527        if simplified != expr {
528            log.push(RewriteStep::simple("egraph_simplify", expr, simplified));
529        }
530        DerivedExpr::with_log(simplified, log)
531    }
532}
533
534// ---------------------------------------------------------------------------
535// PA-6 / RW-4 — Pluggable e-graph cost functions
536// ---------------------------------------------------------------------------
537
538use crate::deriv::log::DerivedExpr;
539use crate::kernel::{ExprId, ExprPool};
540
541/// Cost model used when extracting from the e-graph.
542///
543/// The extractor chooses the expression with the *lowest* total cost.
544/// Implement this trait to define custom extraction objectives.
545///
546/// # Built-in implementations
547///
548/// | Type | Description |
549/// |------|-------------|
550/// | [`SizeCost`] | Every node costs 1 (tree size). Default. |
551/// | [`OpCost`]   | Operators weighted by evaluation cost. |
552/// | [`DepthCost`]| Cost = max child depth + 1. |
553/// | [`StabilityCost`] | Penalises catastrophic cancellation. |
554/// | [`NoncommutativeCost`] | Tie-break for non-commutative `Mul` chains (V3-2). |
555pub trait EgraphCost: Send + Sync {
556    /// Compute the cost of a node given its operator name and its children's costs.
557    fn cost(&self, op: &str, child_costs: &[f64]) -> f64;
558}
559
560/// Every node costs 1 (tree-size cost). This is the egglog default.
561pub struct SizeCost;
562impl EgraphCost for SizeCost {
563    fn cost(&self, _op: &str, child_costs: &[f64]) -> f64 {
564        1.0 + child_costs.iter().sum::<f64>()
565    }
566}
567
568/// Operators weighted by their numerical evaluation cost.
569pub struct OpCost;
570impl EgraphCost for OpCost {
571    fn cost(&self, op: &str, child_costs: &[f64]) -> f64 {
572        let w = match op {
573            "Num" | "Var" => 0.1,
574            "Add" => 1.0,
575            "Mul" => 1.5,
576            "Pow" => 3.0,
577            "Sin" | "Cos" | "Exp" | "Log" | "Sqrt" => 5.0,
578            _ => 2.0,
579        };
580        w + child_costs.iter().sum::<f64>()
581    }
582}
583
584/// Cost = max child depth + 1.
585///
586/// Minimises the critical-path length; useful for GPU / parallel evaluation
587/// where depth determines the number of synchronisation barriers.
588pub struct DepthCost;
589impl EgraphCost for DepthCost {
590    fn cost(&self, _op: &str, child_costs: &[f64]) -> f64 {
591        1.0 + child_costs.iter().cloned().fold(0.0_f64, f64::max)
592    }
593}
594
595/// Penalises catastrophic cancellation.
596///
597/// Applies a `3×` multiplier to binary `Add`/`Sub` nodes whose both children
598/// have non-trivial cost (i.e. not a bare literal), discouraging expressions
599/// of the form `large_expr - large_expr` in favour of Horner form or
600/// log-sum-exp style rewrites.
601pub struct StabilityCost;
602impl EgraphCost for StabilityCost {
603    fn cost(&self, op: &str, child_costs: &[f64]) -> f64 {
604        let base = 1.0 + child_costs.iter().sum::<f64>();
605        match op {
606            // Penalise binary add/sub between two non-trivial children.
607            "Add" | "Sub"
608                if child_costs.len() == 2 && child_costs[0] > 1.0 && child_costs[1] > 1.0 =>
609            {
610                base * 3.0
611            }
612            "Pow" => base * 2.0,
613            _ => base,
614        }
615    }
616}
617
618/// Extraction cost biased toward **left-to-right** (`Mul`) products (V3-2).
619///
620/// When egglog gains a fully pluggable extractor, this can rank
621/// normal-ordered operator strings (Pauli / Clifford) lower than scrambled
622/// permutations. Today it adds a small tie-break on `Mul` so experiments
623/// with non-commuting `Var` encodings stay deterministic.
624pub struct NoncommutativeCost;
625impl EgraphCost for NoncommutativeCost {
626    fn cost(&self, op: &str, child_costs: &[f64]) -> f64 {
627        let base = SizeCost.cost(op, child_costs);
628        match op {
629            "Mul" => base + 1.0e-6 * child_costs.len() as f64,
630            _ => base,
631        }
632    }
633}
634
635// ---------------------------------------------------------------------------
636// PA-6 — Schedule configuration  (RW-2: node_limit / iter_limit)
637// ---------------------------------------------------------------------------
638
639/// Configuration for the e-graph schedule and extraction strategy.
640///
641/// Pass to [`simplify_egraph_with`] to customise iteration counts and
642/// resource limits.
643///
644/// # Rule flags
645///
646/// By default both `include_trig_rules` and `include_log_exp_rules` are `true`,
647/// so `simplify_egraph` reduces `sin²(x)+cos²(x)→1` and `exp(log(x))→x`
648/// without any extra configuration.  Set either flag to `false` to suppress
649/// the corresponding rule set (useful when you need to benchmark rule impact or
650/// avoid domain-sensitive rewrites).
651#[derive(Debug, Clone)]
652pub struct EgraphConfig {
653    /// Saturation iterations in the *shrinking* phase. Default 5.
654    pub shrink_iters: usize,
655    /// Saturation iterations in the *exploring* phase. Default 3.
656    pub explore_iters: usize,
657    /// Constant-folding iterations appended after each phase. Default 3.
658    pub const_fold_iters: usize,
659    /// Abort if the e-graph exceeds this many nodes. `None` = unlimited.
660    pub node_limit: Option<usize>,
661    /// Per-ruleset iteration cap passed to egglog's scheduler. `None` = unlimited.
662    pub iter_limit: Option<usize>,
663    /// Include the Pythagorean trig identity (`sin²+cos²→1`) in the explore phase.
664    /// Default `true`.
665    pub include_trig_rules: bool,
666    /// Include exp/log cancellation (`exp(log(x))→x`, `log(exp(x))→x`) in the
667    /// explore phase. Default `true`.
668    pub include_log_exp_rules: bool,
669}
670
671impl Default for EgraphConfig {
672    fn default() -> Self {
673        EgraphConfig {
674            shrink_iters: 5,
675            explore_iters: 3,
676            const_fold_iters: 3,
677            node_limit: None,
678            iter_limit: None,
679            include_trig_rules: true,
680            include_log_exp_rules: true,
681        }
682    }
683}
684
685// ---------------------------------------------------------------------------
686// Public entry points
687// ---------------------------------------------------------------------------
688
689/// Simplify `expr` using the e-graph backend with default settings.
690///
691/// Falls back to the rule-based simplifier when `egraph` feature is off.
692pub fn simplify_egraph(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
693    #[cfg(feature = "egraph")]
694    {
695        backend::simplify_egraph_impl(expr, pool, &EgraphConfig::default())
696    }
697    #[cfg(not(feature = "egraph"))]
698    {
699        super::engine::simplify(expr, pool)
700    }
701}
702
703/// Simplify `expr` using the e-graph backend with a custom configuration.
704///
705/// The `cost` parameter documents the intended extraction preference; full
706/// pluggable-extractor support requires a future egglog API.  The config
707/// schedule limits (`node_limit`, `iter_limit`, phase iters) are wired
708/// into the egglog program today.
709pub fn simplify_egraph_with(
710    expr: ExprId,
711    pool: &ExprPool,
712    config: &EgraphConfig,
713    _cost: &dyn EgraphCost,
714) -> DerivedExpr<ExprId> {
715    #[cfg(feature = "egraph")]
716    {
717        backend::simplify_egraph_impl(expr, pool, config)
718    }
719    #[cfg(not(feature = "egraph"))]
720    {
721        let _ = config;
722        super::engine::simplify(expr, pool)
723    }
724}
725
726// ---------------------------------------------------------------------------
727// Tests
728// ---------------------------------------------------------------------------
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733    use crate::kernel::{Domain, ExprPool};
734
735    #[test]
736    fn egraph_simplify_x_plus_y_minus_x() {
737        let pool = ExprPool::new();
738        let x = pool.symbol("x", Domain::Real);
739        let y = pool.symbol("y", Domain::Real);
740        let neg_x = pool.mul(vec![pool.integer(-1_i32), x]);
741        let expr = pool.add(vec![x, y, neg_x]);
742        let result = simplify_egraph(expr, &pool);
743        assert_ne!(result.value, pool.integer(0_i32), "should not be zero");
744    }
745
746    #[test]
747    fn egraph_simplify_const_fold() {
748        let pool = ExprPool::new();
749        let expr = pool.add(vec![pool.integer(3_i32), pool.integer(4_i32)]);
750        let result = simplify_egraph(expr, &pool);
751        assert_eq!(result.value, pool.integer(7_i32));
752    }
753
754    #[test]
755    fn egraph_simplify_add_zero() {
756        let pool = ExprPool::new();
757        let x = pool.symbol("x", Domain::Real);
758        let expr = pool.add(vec![x, pool.integer(0_i32)]);
759        let result = simplify_egraph(expr, &pool);
760        assert_eq!(result.value, x);
761    }
762
763    #[test]
764    fn egraph_simplify_mul_one() {
765        let pool = ExprPool::new();
766        let x = pool.symbol("x", Domain::Real);
767        let expr = pool.mul(vec![x, pool.integer(1_i32)]);
768        let result = simplify_egraph(expr, &pool);
769        assert_eq!(result.value, x);
770    }
771
772    #[test]
773    fn egraph_simplify_mul_zero() {
774        let pool = ExprPool::new();
775        let x = pool.symbol("x", Domain::Real);
776        let expr = pool.mul(vec![x, pool.integer(0_i32)]);
777        let result = simplify_egraph(expr, &pool);
778        assert_eq!(result.value, pool.integer(0_i32));
779    }
780
781    #[test]
782    fn egraph_fallback_no_panic_on_rational() {
783        let pool = ExprPool::new();
784        let r = pool.rational(1, 3);
785        let _ = simplify_egraph(r, &pool);
786    }
787
788    // RW-1: flattening round-trip
789    #[test]
790    fn egraph_round_trips_nary_add() {
791        let pool = ExprPool::new();
792        let x = pool.symbol("x", Domain::Real);
793        let y = pool.symbol("y", Domain::Real);
794        let z = pool.symbol("z", Domain::Real);
795        // x + y + z should survive the egglog round-trip as a 3-arg Add
796        let expr = pool.add(vec![x, y, z]);
797        let result = simplify_egraph(expr, &pool);
798        // Must still be an Add (not a nested binary tree)
799        if let crate::kernel::ExprData::Add(args) =
800            crate::kernel::ExprPool::get(&pool, result.value)
801        {
802            assert_eq!(args.len(), 3);
803        }
804    }
805
806    // RW-3: linear canonizer
807    #[test]
808    fn linear_canonizer_combines_like_terms() {
809        let pool = ExprPool::new();
810        let x = pool.symbol("x", Domain::Real);
811        // 2*x + 3*x = 5*x
812        let two_x = pool.mul(vec![pool.integer(2_i32), x]);
813        let three_x = pool.mul(vec![pool.integer(3_i32), x]);
814        let expr = pool.add(vec![two_x, three_x]);
815        #[cfg(feature = "egraph")]
816        {
817            let result = backend::canonicalize_linear(expr, &pool);
818            let five_x = pool.mul(vec![pool.integer(5_i32), x]);
819            assert_eq!(result, five_x);
820        }
821        #[cfg(not(feature = "egraph"))]
822        let _ = expr;
823    }
824
825    // RW-2: config wiring compiles and does not panic
826    #[test]
827    fn egraph_with_node_limit() {
828        let pool = ExprPool::new();
829        let x = pool.symbol("x", Domain::Real);
830        let expr = pool.add(vec![x, pool.integer(0_i32)]);
831        let config = EgraphConfig {
832            node_limit: Some(10_000),
833            ..EgraphConfig::default()
834        };
835        let result = simplify_egraph_with(expr, &pool, &config, &SizeCost);
836        assert_eq!(result.value, x);
837    }
838
839    #[test]
840    fn egraph_noncommutative_falls_back_to_rules() {
841        let pool = ExprPool::new();
842        let a = pool.symbol_commutative("A", Domain::Real, false);
843        let expr = pool.add(vec![a, pool.integer(0_i32)]);
844        let result = simplify_egraph(expr, &pool);
845        assert_eq!(result.value, a);
846    }
847
848    // V3-2: NoncommutativeCost is callable
849    #[test]
850    fn noncommutative_cost_is_callable() {
851        let nc = NoncommutativeCost;
852        let v = nc.cost("Mul", &[1.0, 1.0]);
853        assert!(v.is_finite());
854    }
855
856    // RW-4: StabilityCost is callable
857    #[test]
858    fn stability_cost_penalises_binary_add() {
859        let sc = StabilityCost;
860        let penalised = sc.cost("Add", &[2.0, 2.0]);
861        let normal = sc.cost("Add", &[0.1, 2.0]);
862        assert!(penalised > normal);
863    }
864
865    // V1-15: trig identity via Pow form (sin(x)^2 + cos(x)^2 → 1)
866    #[test]
867    fn egraph_trig_identity_pow_form() {
868        let pool = ExprPool::new();
869        let x = pool.symbol("x", Domain::Real);
870        let sin_x = pool.func("sin", vec![x]);
871        let cos_x = pool.func("cos", vec![x]);
872        let sin2 = pool.pow(sin_x, pool.integer(2_i32));
873        let cos2 = pool.pow(cos_x, pool.integer(2_i32));
874        let expr = pool.add(vec![sin2, cos2]);
875        #[cfg(feature = "egraph")]
876        {
877            let result = simplify_egraph(expr, &pool);
878            assert_eq!(result.value, pool.integer(1_i32));
879        }
880        #[cfg(not(feature = "egraph"))]
881        let _ = expr;
882    }
883
884    // V1-15: exp(log(x)) → x
885    #[test]
886    fn egraph_exp_of_log() {
887        let pool = ExprPool::new();
888        let x = pool.symbol("x", Domain::Real);
889        let expr = pool.func("exp", vec![pool.func("log", vec![x])]);
890        #[cfg(feature = "egraph")]
891        {
892            let result = simplify_egraph(expr, &pool);
893            assert_eq!(result.value, x);
894        }
895        #[cfg(not(feature = "egraph"))]
896        let _ = expr;
897    }
898
899    // V1-15: log(exp(x)) → x
900    #[test]
901    fn egraph_log_of_exp() {
902        let pool = ExprPool::new();
903        let x = pool.symbol("x", Domain::Real);
904        let expr = pool.func("log", vec![pool.func("exp", vec![x])]);
905        #[cfg(feature = "egraph")]
906        {
907            let result = simplify_egraph(expr, &pool);
908            assert_eq!(result.value, x);
909        }
910        #[cfg(not(feature = "egraph"))]
911        let _ = expr;
912    }
913
914    // V1-15: opt-out trig rules via config
915    #[test]
916    fn egraph_opt_out_trig_rules() {
917        let pool = ExprPool::new();
918        let x = pool.symbol("x", Domain::Real);
919        let sin_x = pool.func("sin", vec![x]);
920        let cos_x = pool.func("cos", vec![x]);
921        let sin2 = pool.pow(sin_x, pool.integer(2_i32));
922        let cos2 = pool.pow(cos_x, pool.integer(2_i32));
923        let expr = pool.add(vec![sin2, cos2]);
924        let config = EgraphConfig {
925            include_trig_rules: false,
926            ..EgraphConfig::default()
927        };
928        let result = simplify_egraph_with(expr, &pool, &config, &SizeCost);
929        assert_ne!(result.value, pool.integer(1_i32));
930    }
931
932    // V1-15: opt-out log/exp rules via config
933    #[test]
934    fn egraph_opt_out_log_exp_rules() {
935        let pool = ExprPool::new();
936        let x = pool.symbol("x", Domain::Real);
937        let expr = pool.func("exp", vec![pool.func("log", vec![x])]);
938        let config = EgraphConfig {
939            include_log_exp_rules: false,
940            ..EgraphConfig::default()
941        };
942        let result = simplify_egraph_with(expr, &pool, &config, &SizeCost);
943        assert_ne!(result.value, x);
944    }
945}