1#[cfg(feature = "egraph")]
23mod backend {
24 use crate::kernel::{ExprData, ExprId, ExprPool};
25 use std::collections::HashMap;
26
27 pub(super) fn expr_to_egglog(expr: ExprId, pool: &ExprPool) -> String {
32 enum Node {
33 Num(i64),
34 Var(String),
35 Add(Vec<ExprId>),
36 Mul(Vec<ExprId>),
37 Pow(ExprId, ExprId),
38 Func(String, ExprId),
39 Unsupported,
40 }
41
42 let node = pool.with(expr, |data| match data {
43 ExprData::Integer(n) => {
44 let v =
45 n.0.to_i64()
46 .unwrap_or(if n.0 > 0 { i64::MAX } else { i64::MIN });
47 Node::Num(v)
48 }
49 ExprData::Rational(_) | ExprData::Float(_) => Node::Unsupported,
50 ExprData::Symbol { name, .. } => Node::Var(name.clone()),
51 ExprData::Add(args) => Node::Add(args.clone()),
52 ExprData::Mul(args) => Node::Mul(args.clone()),
53 ExprData::Pow { base, exp } => Node::Pow(*base, *exp),
54 ExprData::Func { name, args } if args.len() == 1 => Node::Func(name.clone(), args[0]),
55 ExprData::Func { .. } => Node::Unsupported,
56 ExprData::Piecewise { .. }
57 | ExprData::Predicate { .. }
58 | ExprData::Forall { .. }
59 | ExprData::Exists { .. }
60 | ExprData::BigO(_) => Node::Unsupported,
61 });
62
63 match node {
64 Node::Num(n) => format!("(Num {n})"),
65 Node::Var(name) => format!("(Var \"{name}\")"),
66 Node::Add(args) => {
67 let mut it = args.into_iter();
69 let first = it.next().expect(
70 "Add node must have at least one argument — ExprPool invariant violated",
71 );
72 let init = expr_to_egglog(first, pool);
73 it.fold(init, |acc, id| {
74 format!("(Add {acc} {})", expr_to_egglog(id, pool))
75 })
76 }
77 Node::Mul(args) => {
78 let mut it = args.into_iter();
79 let first = it.next().expect(
80 "Mul node must have at least one argument — ExprPool invariant violated",
81 );
82 let init = expr_to_egglog(first, pool);
83 it.fold(init, |acc, id| {
84 format!("(Mul {acc} {})", expr_to_egglog(id, pool))
85 })
86 }
87 Node::Pow(base, exp) => format!(
88 "(Pow {} {})",
89 expr_to_egglog(base, pool),
90 expr_to_egglog(exp, pool)
91 ),
92 Node::Func(name, arg) => {
93 let inner = expr_to_egglog(arg, pool);
94 match name.as_str() {
95 "sin" => format!("(Sin {inner})"),
96 "cos" => format!("(Cos {inner})"),
97 "exp" => format!("(Exp {inner})"),
98 "log" => format!("(Log {inner})"),
99 "sqrt" => format!("(Sqrt {inner})"),
100 _ => format!("(Var \"{name}_{inner}\")"),
101 }
102 }
103 Node::Unsupported => "(Num 0)".to_string(),
104 }
105 }
106
107 fn count_dag_nodes(expr: ExprId, pool: &ExprPool) -> usize {
116 let mut visited = std::collections::HashSet::new();
117 count_dag_nodes_rec(expr, pool, &mut visited);
118 visited.len()
119 }
120
121 fn count_dag_nodes_rec(
122 expr: ExprId,
123 pool: &ExprPool,
124 visited: &mut std::collections::HashSet<ExprId>,
125 ) {
126 if !visited.insert(expr) {
127 return;
128 }
129 match pool.get(expr) {
130 ExprData::Add(args) | ExprData::Mul(args) => {
131 for &a in &args {
132 count_dag_nodes_rec(a, pool, visited);
133 }
134 }
135 ExprData::Pow { base, exp } => {
136 count_dag_nodes_rec(base, pool, visited);
137 count_dag_nodes_rec(exp, pool, visited);
138 }
139 ExprData::Func { args, .. } => {
140 for &a in &args {
141 count_dag_nodes_rec(a, pool, visited);
142 }
143 }
144 ExprData::Piecewise { branches, default } => {
145 for (cond, val) in &branches {
146 count_dag_nodes_rec(*cond, pool, visited);
147 count_dag_nodes_rec(*val, pool, visited);
148 }
149 count_dag_nodes_rec(default, pool, visited);
150 }
151 ExprData::Predicate { args, .. } => {
152 for a in args {
153 count_dag_nodes_rec(a, pool, visited);
154 }
155 }
156 ExprData::Forall { var, body } | ExprData::Exists { var, body } => {
157 count_dag_nodes_rec(var, pool, visited);
158 count_dag_nodes_rec(body, pool, visited);
159 }
160 ExprData::BigO(arg) => {
161 count_dag_nodes_rec(arg, pool, visited);
162 }
163 ExprData::Integer(_)
165 | ExprData::Rational(_)
166 | ExprData::Float(_)
167 | ExprData::Symbol { .. } => {}
168 }
169 }
170
171 fn egglog_program(expr_str: &str, config: &super::EgraphConfig) -> String {
172 let node_limit_line = String::new();
175 let iter_limit_line = config
176 .iter_limit
177 .map(|n| format!("(set-option iteration_limit {n})\n"))
178 .unwrap_or_default();
179
180 let si = config.shrink_iters;
181 let ei = config.explore_iters;
182 let ci = config.const_fold_iters;
183
184 let trig_rules = if config.include_trig_rules {
186 "(rewrite (Add (Mul (Sin ?x) (Sin ?x)) (Mul (Cos ?x) (Cos ?x))) (Num 1) :ruleset explore)\n\
189 (rewrite (Add (Mul (Cos ?x) (Cos ?x)) (Mul (Sin ?x) (Sin ?x))) (Num 1) :ruleset explore)\n\
190 (rewrite (Add (Pow (Sin ?x) (Num 2)) (Pow (Cos ?x) (Num 2))) (Num 1) :ruleset explore)\n\
191 (rewrite (Add (Pow (Cos ?x) (Num 2)) (Pow (Sin ?x) (Num 2))) (Num 1) :ruleset explore)"
192 } else {
193 ""
194 };
195
196 let log_exp_rules = if config.include_log_exp_rules {
197 "(rewrite (Exp (Log ?x)) ?x :ruleset explore)\n\
198 (rewrite (Log (Exp ?x)) ?x :ruleset explore)"
199 } else {
200 ""
201 };
202
203 format!(
204 r#"
205{node_limit_line}{iter_limit_line}(datatype Expr
206 (Num i64)
207 (Var String)
208 (Add Expr Expr)
209 (Mul Expr Expr)
210 (Pow Expr Expr)
211 (Sin Expr)
212 (Cos Expr)
213 (Exp Expr)
214 (Log Expr)
215 (Sqrt Expr))
216
217; ── shrink ruleset: identity / absorption / cancellation ─────────────────────
218(ruleset shrink)
219(rewrite (Add ?x (Num 0)) ?x :ruleset shrink)
220(rewrite (Add (Num 0) ?x) ?x :ruleset shrink)
221(rewrite (Mul ?x (Num 1)) ?x :ruleset shrink)
222(rewrite (Mul (Num 1) ?x) ?x :ruleset shrink)
223(rewrite (Mul ?x (Num 0)) (Num 0) :ruleset shrink)
224(rewrite (Mul (Num 0) ?x) (Num 0) :ruleset shrink)
225(rewrite (Pow ?x (Num 1)) ?x :ruleset shrink)
226(rewrite (Pow ?x (Num 0)) (Num 1) :ruleset shrink)
227(rewrite (Add ?x (Mul (Num -1) ?x)) (Num 0) :ruleset shrink)
228(rewrite (Add (Mul (Num -1) ?x) ?x) (Num 0) :ruleset shrink)
229(rewrite (Mul ?x (Pow ?x (Num -1))) (Num 1) :ruleset shrink)
230(rewrite (Mul (Pow ?x (Num -1)) ?x) (Num 1) :ruleset shrink)
231
232; ── explore ruleset: trig and log/exp identities (default: both enabled) ──────
233(ruleset explore)
234{trig_rules}
235{log_exp_rules}
236(rewrite (Mul (Num -1) (Mul (Num -1) ?x)) ?x :ruleset explore)
237
238; ── constant folding ──────────────────────────────────────────────────────────
239(ruleset const-fold)
240(rule ((= e (Add (Num ?a) (Num ?b))))
241 ((union e (Num (+ ?a ?b))))
242 :ruleset const-fold)
243(rule ((= e (Mul (Num ?a) (Num ?b))))
244 ((union e (Num (* ?a ?b))))
245 :ruleset const-fold)
246(rule ((= e (Pow (Num ?a) (Num ?b))) (>= ?b 0))
247 ((union e (Num (^ ?a ?b))))
248 :ruleset const-fold)
249
250; ── phased schedule: shrink → const-fold → explore → shrink → const-fold ─────
251(let __expr {expr})
252(run shrink {si})
253(run const-fold {ci})
254(run explore {ei})
255(run shrink {si})
256(run const-fold {ci})
257(extract __expr)
258"#,
259 node_limit_line = node_limit_line,
260 iter_limit_line = iter_limit_line,
261 trig_rules = trig_rules,
262 log_exp_rules = log_exp_rules,
263 expr = expr_str,
264 si = si,
265 ei = ei,
266 ci = ci,
267 )
268 }
269
270 fn flatten_add_args(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
276 match pool.get(expr) {
277 ExprData::Add(args) => args
278 .iter()
279 .flat_map(|&a| flatten_add_args(a, pool))
280 .collect(),
281 _ => vec![expr],
282 }
283 }
284
285 fn flatten_mul_args(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
287 match pool.get(expr) {
288 ExprData::Mul(args) => args
289 .iter()
290 .flat_map(|&a| flatten_mul_args(a, pool))
291 .collect(),
292 _ => vec![expr],
293 }
294 }
295
296 fn parse_egglog_term(s: &str, pool: &ExprPool) -> Option<ExprId> {
297 let s = s.trim();
298 if s.starts_with('(') && s.ends_with(')') {
299 let inner = &s[1..s.len() - 1];
300 let (head, rest) = split_head(inner)?;
301 match head {
302 "Num" => {
303 let n: i64 = rest.trim().parse().ok()?;
304 Some(pool.integer(n))
305 }
306 "Var" => {
307 let name = rest.trim().trim_matches('"');
308 Some(pool.symbol(name, crate::kernel::Domain::Real))
309 }
310 "Add" => {
311 let (a_str, b_str) = split_two_args(rest)?;
312 let a = parse_egglog_term(&a_str, pool)?;
313 let b = parse_egglog_term(&b_str, pool)?;
314 let mut children = flatten_add_args(a, pool);
316 children.extend(flatten_add_args(b, pool));
317 Some(pool.add(children))
318 }
319 "Mul" => {
320 let (a_str, b_str) = split_two_args(rest)?;
321 let a = parse_egglog_term(&a_str, pool)?;
322 let b = parse_egglog_term(&b_str, pool)?;
323 let mut children = flatten_mul_args(a, pool);
324 children.extend(flatten_mul_args(b, pool));
325 Some(pool.mul(children))
326 }
327 "Pow" => {
328 let (a_str, b_str) = split_two_args(rest)?;
329 let a = parse_egglog_term(&a_str, pool)?;
330 let b = parse_egglog_term(&b_str, pool)?;
331 Some(pool.pow(a, b))
332 }
333 "Sin" => Some(pool.func("sin", vec![parse_egglog_term(rest.trim(), pool)?])),
334 "Cos" => Some(pool.func("cos", vec![parse_egglog_term(rest.trim(), pool)?])),
335 "Exp" => Some(pool.func("exp", vec![parse_egglog_term(rest.trim(), pool)?])),
336 "Log" => Some(pool.func("log", vec![parse_egglog_term(rest.trim(), pool)?])),
337 "Sqrt" => Some(pool.func("sqrt", vec![parse_egglog_term(rest.trim(), pool)?])),
338 _ => None,
339 }
340 } else {
341 let n: i64 = s.parse().ok()?;
342 Some(pool.integer(n))
343 }
344 }
345
346 fn split_head(s: &str) -> Option<(&str, &str)> {
347 let s = s.trim();
348 let pos = s.find(|c: char| c.is_whitespace())?;
349 Some((&s[..pos], &s[pos + 1..]))
350 }
351
352 fn split_two_args(s: &str) -> Option<(String, String)> {
353 let s = s.trim();
354 let (first, remainder) = consume_term(s)?;
355 let second = remainder.trim();
356 Some((first.to_string(), second.to_string()))
357 }
358
359 fn consume_term(s: &str) -> Option<(&str, &str)> {
360 let s = s.trim_start();
361 if s.starts_with('(') {
362 let mut depth = 0usize;
363 let mut in_string = false;
364 for (i, c) in s.char_indices() {
365 match c {
366 '"' => in_string = !in_string,
367 '(' if !in_string => depth += 1,
368 ')' if !in_string => {
369 depth -= 1;
370 if depth == 0 {
371 return Some((&s[..=i], &s[i + 1..]));
372 }
373 }
374 _ => {}
375 }
376 }
377 None
378 } else {
379 let end = s
380 .find(|c: char| c.is_whitespace() || c == ')')
381 .unwrap_or(s.len());
382 Some((&s[..end], &s[end..]))
383 }
384 }
385
386 fn extract_linear_term(expr: ExprId, pool: &ExprPool) -> Option<(i64, ExprId)> {
394 match pool.get(expr) {
395 ExprData::Symbol { .. } => Some((1, expr)),
396 ExprData::Mul(args) if args.len() == 2 => {
397 let (a, b) = (args[0], args[1]);
398 if let ExprData::Integer(n) = pool.get(a) {
399 if matches!(pool.get(b), ExprData::Symbol { .. }) {
400 return n.0.to_i64().map(|c| (c, b));
401 }
402 }
403 if let ExprData::Integer(n) = pool.get(b) {
404 if matches!(pool.get(a), ExprData::Symbol { .. }) {
405 return n.0.to_i64().map(|c| (c, a));
406 }
407 }
408 None
409 }
410 _ => None,
411 }
412 }
413
414 pub(super) fn canonicalize_linear(expr: ExprId, pool: &ExprPool) -> ExprId {
421 match pool.get(expr) {
422 ExprData::Add(args) => {
423 let args: Vec<ExprId> =
424 args.iter().map(|&a| canonicalize_linear(a, pool)).collect();
425
426 let mut coeff_map: HashMap<ExprId, i64> = HashMap::new();
427 let mut non_linear: Vec<ExprId> = Vec::new();
428 let mut found_linear = false;
429
430 for &arg in &args {
431 if let Some((coeff, base)) = extract_linear_term(arg, pool) {
432 *coeff_map.entry(base).or_insert(0) += coeff;
433 found_linear = true;
434 } else {
435 non_linear.push(arg);
436 }
437 }
438
439 if !found_linear {
440 return pool.add(args);
441 }
442
443 let mut result: Vec<ExprId> = non_linear;
444 let mut pairs: Vec<(ExprId, i64)> = coeff_map.into_iter().collect();
446 pairs.sort_by_key(|(id, _)| *id);
447 for (base, coeff) in pairs {
448 match coeff {
449 0 => {}
450 1 => result.push(base),
451 c => result.push(pool.mul(vec![pool.integer(c), base])),
452 }
453 }
454
455 match result.len() {
456 0 => pool.integer(0_i32),
457 1 => result[0],
458 _ => pool.add(result),
459 }
460 }
461 ExprData::Mul(args) => {
462 let args: Vec<ExprId> =
463 args.iter().map(|&a| canonicalize_linear(a, pool)).collect();
464 pool.mul(args)
465 }
466 ExprData::Pow { base, exp } => {
467 let base = canonicalize_linear(base, pool);
468 let exp = canonicalize_linear(exp, pool);
469 pool.pow(base, exp)
470 }
471 ExprData::Func { name, args } => {
472 let args: Vec<ExprId> =
473 args.iter().map(|&a| canonicalize_linear(a, pool)).collect();
474 pool.func(&name, args)
475 }
476 _ => expr,
477 }
478 }
479
480 pub fn simplify_egraph_impl(
485 expr: ExprId,
486 pool: &ExprPool,
487 config: &super::EgraphConfig,
488 ) -> crate::deriv::log::DerivedExpr<ExprId> {
489 use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
490 use crate::kernel::expr_props::expr_contains_noncommutative_symbol;
491
492 if expr_contains_noncommutative_symbol(pool, expr) {
493 return super::super::engine::simplify(expr, pool);
494 }
495
496 if let Some(limit) = config.node_limit {
500 let n = count_dag_nodes(expr, pool);
501 if n > limit {
502 let mut log = DerivationLog::new();
503 log.push(RewriteStep::simple(
504 "egraph_node_limit_exceeded",
505 expr,
506 expr,
507 ));
508 return DerivedExpr::with_log(expr, log);
509 }
510 }
511
512 let expr_str = expr_to_egglog(expr, pool);
513 let program = egglog_program(&expr_str, config);
514
515 let result: Option<ExprId> = (|| {
516 let mut egraph = egglog::EGraph::default();
517 let outputs = egraph.parse_and_run_program(None, &program).ok()?;
518 let term_str = outputs.into_iter().last()?;
519 parse_egglog_term(&term_str, pool)
520 })();
521
522 let simplified = result.unwrap_or(expr);
523 let simplified = canonicalize_linear(simplified, pool);
525
526 let mut log = DerivationLog::new();
527 if simplified != expr {
528 log.push(RewriteStep::simple("egraph_simplify", expr, simplified));
529 }
530 DerivedExpr::with_log(simplified, log)
531 }
532}
533
534use crate::deriv::log::DerivedExpr;
539use crate::kernel::{ExprId, ExprPool};
540
541pub trait EgraphCost: Send + Sync {
556 fn cost(&self, op: &str, child_costs: &[f64]) -> f64;
558}
559
560pub struct SizeCost;
562impl EgraphCost for SizeCost {
563 fn cost(&self, _op: &str, child_costs: &[f64]) -> f64 {
564 1.0 + child_costs.iter().sum::<f64>()
565 }
566}
567
568pub struct OpCost;
570impl EgraphCost for OpCost {
571 fn cost(&self, op: &str, child_costs: &[f64]) -> f64 {
572 let w = match op {
573 "Num" | "Var" => 0.1,
574 "Add" => 1.0,
575 "Mul" => 1.5,
576 "Pow" => 3.0,
577 "Sin" | "Cos" | "Exp" | "Log" | "Sqrt" => 5.0,
578 _ => 2.0,
579 };
580 w + child_costs.iter().sum::<f64>()
581 }
582}
583
584pub struct DepthCost;
589impl EgraphCost for DepthCost {
590 fn cost(&self, _op: &str, child_costs: &[f64]) -> f64 {
591 1.0 + child_costs.iter().cloned().fold(0.0_f64, f64::max)
592 }
593}
594
595pub struct StabilityCost;
602impl EgraphCost for StabilityCost {
603 fn cost(&self, op: &str, child_costs: &[f64]) -> f64 {
604 let base = 1.0 + child_costs.iter().sum::<f64>();
605 match op {
606 "Add" | "Sub"
608 if child_costs.len() == 2 && child_costs[0] > 1.0 && child_costs[1] > 1.0 =>
609 {
610 base * 3.0
611 }
612 "Pow" => base * 2.0,
613 _ => base,
614 }
615 }
616}
617
618pub struct NoncommutativeCost;
625impl EgraphCost for NoncommutativeCost {
626 fn cost(&self, op: &str, child_costs: &[f64]) -> f64 {
627 let base = SizeCost.cost(op, child_costs);
628 match op {
629 "Mul" => base + 1.0e-6 * child_costs.len() as f64,
630 _ => base,
631 }
632 }
633}
634
635#[derive(Debug, Clone)]
652pub struct EgraphConfig {
653 pub shrink_iters: usize,
655 pub explore_iters: usize,
657 pub const_fold_iters: usize,
659 pub node_limit: Option<usize>,
661 pub iter_limit: Option<usize>,
663 pub include_trig_rules: bool,
666 pub include_log_exp_rules: bool,
669}
670
671impl Default for EgraphConfig {
672 fn default() -> Self {
673 EgraphConfig {
674 shrink_iters: 5,
675 explore_iters: 3,
676 const_fold_iters: 3,
677 node_limit: None,
678 iter_limit: None,
679 include_trig_rules: true,
680 include_log_exp_rules: true,
681 }
682 }
683}
684
685pub fn simplify_egraph(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
693 #[cfg(feature = "egraph")]
694 {
695 backend::simplify_egraph_impl(expr, pool, &EgraphConfig::default())
696 }
697 #[cfg(not(feature = "egraph"))]
698 {
699 super::engine::simplify(expr, pool)
700 }
701}
702
703pub fn simplify_egraph_with(
710 expr: ExprId,
711 pool: &ExprPool,
712 config: &EgraphConfig,
713 _cost: &dyn EgraphCost,
714) -> DerivedExpr<ExprId> {
715 #[cfg(feature = "egraph")]
716 {
717 backend::simplify_egraph_impl(expr, pool, config)
718 }
719 #[cfg(not(feature = "egraph"))]
720 {
721 let _ = config;
722 super::engine::simplify(expr, pool)
723 }
724}
725
726#[cfg(test)]
731mod tests {
732 use super::*;
733 use crate::kernel::{Domain, ExprPool};
734
735 #[test]
736 fn egraph_simplify_x_plus_y_minus_x() {
737 let pool = ExprPool::new();
738 let x = pool.symbol("x", Domain::Real);
739 let y = pool.symbol("y", Domain::Real);
740 let neg_x = pool.mul(vec![pool.integer(-1_i32), x]);
741 let expr = pool.add(vec![x, y, neg_x]);
742 let result = simplify_egraph(expr, &pool);
743 assert_ne!(result.value, pool.integer(0_i32), "should not be zero");
744 }
745
746 #[test]
747 fn egraph_simplify_const_fold() {
748 let pool = ExprPool::new();
749 let expr = pool.add(vec![pool.integer(3_i32), pool.integer(4_i32)]);
750 let result = simplify_egraph(expr, &pool);
751 assert_eq!(result.value, pool.integer(7_i32));
752 }
753
754 #[test]
755 fn egraph_simplify_add_zero() {
756 let pool = ExprPool::new();
757 let x = pool.symbol("x", Domain::Real);
758 let expr = pool.add(vec![x, pool.integer(0_i32)]);
759 let result = simplify_egraph(expr, &pool);
760 assert_eq!(result.value, x);
761 }
762
763 #[test]
764 fn egraph_simplify_mul_one() {
765 let pool = ExprPool::new();
766 let x = pool.symbol("x", Domain::Real);
767 let expr = pool.mul(vec![x, pool.integer(1_i32)]);
768 let result = simplify_egraph(expr, &pool);
769 assert_eq!(result.value, x);
770 }
771
772 #[test]
773 fn egraph_simplify_mul_zero() {
774 let pool = ExprPool::new();
775 let x = pool.symbol("x", Domain::Real);
776 let expr = pool.mul(vec![x, pool.integer(0_i32)]);
777 let result = simplify_egraph(expr, &pool);
778 assert_eq!(result.value, pool.integer(0_i32));
779 }
780
781 #[test]
782 fn egraph_fallback_no_panic_on_rational() {
783 let pool = ExprPool::new();
784 let r = pool.rational(1, 3);
785 let _ = simplify_egraph(r, &pool);
786 }
787
788 #[test]
790 fn egraph_round_trips_nary_add() {
791 let pool = ExprPool::new();
792 let x = pool.symbol("x", Domain::Real);
793 let y = pool.symbol("y", Domain::Real);
794 let z = pool.symbol("z", Domain::Real);
795 let expr = pool.add(vec![x, y, z]);
797 let result = simplify_egraph(expr, &pool);
798 if let crate::kernel::ExprData::Add(args) =
800 crate::kernel::ExprPool::get(&pool, result.value)
801 {
802 assert_eq!(args.len(), 3);
803 }
804 }
805
806 #[test]
808 fn linear_canonizer_combines_like_terms() {
809 let pool = ExprPool::new();
810 let x = pool.symbol("x", Domain::Real);
811 let two_x = pool.mul(vec![pool.integer(2_i32), x]);
813 let three_x = pool.mul(vec![pool.integer(3_i32), x]);
814 let expr = pool.add(vec![two_x, three_x]);
815 #[cfg(feature = "egraph")]
816 {
817 let result = backend::canonicalize_linear(expr, &pool);
818 let five_x = pool.mul(vec![pool.integer(5_i32), x]);
819 assert_eq!(result, five_x);
820 }
821 #[cfg(not(feature = "egraph"))]
822 let _ = expr;
823 }
824
825 #[test]
827 fn egraph_with_node_limit() {
828 let pool = ExprPool::new();
829 let x = pool.symbol("x", Domain::Real);
830 let expr = pool.add(vec![x, pool.integer(0_i32)]);
831 let config = EgraphConfig {
832 node_limit: Some(10_000),
833 ..EgraphConfig::default()
834 };
835 let result = simplify_egraph_with(expr, &pool, &config, &SizeCost);
836 assert_eq!(result.value, x);
837 }
838
839 #[test]
840 fn egraph_noncommutative_falls_back_to_rules() {
841 let pool = ExprPool::new();
842 let a = pool.symbol_commutative("A", Domain::Real, false);
843 let expr = pool.add(vec![a, pool.integer(0_i32)]);
844 let result = simplify_egraph(expr, &pool);
845 assert_eq!(result.value, a);
846 }
847
848 #[test]
850 fn noncommutative_cost_is_callable() {
851 let nc = NoncommutativeCost;
852 let v = nc.cost("Mul", &[1.0, 1.0]);
853 assert!(v.is_finite());
854 }
855
856 #[test]
858 fn stability_cost_penalises_binary_add() {
859 let sc = StabilityCost;
860 let penalised = sc.cost("Add", &[2.0, 2.0]);
861 let normal = sc.cost("Add", &[0.1, 2.0]);
862 assert!(penalised > normal);
863 }
864
865 #[test]
867 fn egraph_trig_identity_pow_form() {
868 let pool = ExprPool::new();
869 let x = pool.symbol("x", Domain::Real);
870 let sin_x = pool.func("sin", vec![x]);
871 let cos_x = pool.func("cos", vec![x]);
872 let sin2 = pool.pow(sin_x, pool.integer(2_i32));
873 let cos2 = pool.pow(cos_x, pool.integer(2_i32));
874 let expr = pool.add(vec![sin2, cos2]);
875 #[cfg(feature = "egraph")]
876 {
877 let result = simplify_egraph(expr, &pool);
878 assert_eq!(result.value, pool.integer(1_i32));
879 }
880 #[cfg(not(feature = "egraph"))]
881 let _ = expr;
882 }
883
884 #[test]
886 fn egraph_exp_of_log() {
887 let pool = ExprPool::new();
888 let x = pool.symbol("x", Domain::Real);
889 let expr = pool.func("exp", vec![pool.func("log", vec![x])]);
890 #[cfg(feature = "egraph")]
891 {
892 let result = simplify_egraph(expr, &pool);
893 assert_eq!(result.value, x);
894 }
895 #[cfg(not(feature = "egraph"))]
896 let _ = expr;
897 }
898
899 #[test]
901 fn egraph_log_of_exp() {
902 let pool = ExprPool::new();
903 let x = pool.symbol("x", Domain::Real);
904 let expr = pool.func("log", vec![pool.func("exp", vec![x])]);
905 #[cfg(feature = "egraph")]
906 {
907 let result = simplify_egraph(expr, &pool);
908 assert_eq!(result.value, x);
909 }
910 #[cfg(not(feature = "egraph"))]
911 let _ = expr;
912 }
913
914 #[test]
916 fn egraph_opt_out_trig_rules() {
917 let pool = ExprPool::new();
918 let x = pool.symbol("x", Domain::Real);
919 let sin_x = pool.func("sin", vec![x]);
920 let cos_x = pool.func("cos", vec![x]);
921 let sin2 = pool.pow(sin_x, pool.integer(2_i32));
922 let cos2 = pool.pow(cos_x, pool.integer(2_i32));
923 let expr = pool.add(vec![sin2, cos2]);
924 let config = EgraphConfig {
925 include_trig_rules: false,
926 ..EgraphConfig::default()
927 };
928 let result = simplify_egraph_with(expr, &pool, &config, &SizeCost);
929 assert_ne!(result.value, pool.integer(1_i32));
930 }
931
932 #[test]
934 fn egraph_opt_out_log_exp_rules() {
935 let pool = ExprPool::new();
936 let x = pool.symbol("x", Domain::Real);
937 let expr = pool.func("exp", vec![pool.func("log", vec![x])]);
938 let config = EgraphConfig {
939 include_log_exp_rules: false,
940 ..EgraphConfig::default()
941 };
942 let result = simplify_egraph_with(expr, &pool, &config, &SizeCost);
943 assert_ne!(result.value, x);
944 }
945}