1#[cfg(feature = "optimization")]
10use egglog::EGraph;
11
12use crate::error::{MathCompileError, Result};
13use crate::final_tagless::ASTRepr;
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum OptimizationPattern {
19 AddZeroLeft,
21 AddZeroRight,
23 AddSameExpr,
25 MulZeroLeft,
27 MulZeroRight,
29 MulOneLeft,
31 MulOneRight,
33 LnExp,
35 ExpLn,
37 PowZero,
39 PowOne,
41}
42
43#[cfg(feature = "optimization")]
45pub struct EgglogOptimizer {
46 egraph: EGraph,
48 expr_map: HashMap<String, ASTRepr<f64>>,
50 var_counter: usize,
52}
53
54#[cfg(feature = "optimization")]
55impl EgglogOptimizer {
56 pub fn new() -> Result<Self> {
58 let mut egraph = EGraph::default();
59
60 let program = r"
63 (datatype Math
64 (Num f64)
65 (Var String)
66 (Add Math Math)
67 (Sub Math Math)
68 (Mul Math Math)
69 (Div Math Math)
70 (Pow Math Math)
71 (Neg Math)
72 (Ln Math)
73 (Exp Math)
74 (Sin Math)
75 (Cos Math)
76 (Sqrt Math))
77
78 ; Commutativity rules (proven to work correctly)
79 (rewrite (Add ?x ?y) (Add ?y ?x))
80 (rewrite (Mul ?x ?y) (Mul ?y ?x))
81
82 ; Arithmetic identity rules
83 (rewrite (Add ?x (Num 0.0)) ?x)
84 (rewrite (Add (Num 0.0) ?x) ?x)
85 (rewrite (Mul ?x (Num 1.0)) ?x)
86 (rewrite (Mul (Num 1.0) ?x) ?x)
87 (rewrite (Mul ?x (Num 0.0)) (Num 0.0))
88 (rewrite (Mul (Num 0.0) ?x) (Num 0.0))
89 (rewrite (Sub ?x (Num 0.0)) ?x)
90 (rewrite (Sub ?x ?x) (Num 0.0))
91 (rewrite (Div ?x (Num 1.0)) ?x)
92 (rewrite (Div ?x ?x) (Num 1.0))
93 (rewrite (Pow ?x (Num 0.0)) (Num 1.0))
94 (rewrite (Pow ?x (Num 1.0)) ?x)
95 (rewrite (Pow (Num 1.0) ?x) (Num 1.0))
96 (rewrite (Pow (Num 0.0) ?x) (Num 0.0))
97
98 ; Negation rules
99 (rewrite (Neg (Neg ?x)) ?x)
100 (rewrite (Neg (Num 0.0)) (Num 0.0))
101 (rewrite (Add (Neg ?x) ?x) (Num 0.0))
102 (rewrite (Add ?x (Neg ?x)) (Num 0.0))
103
104 ; Exponential and logarithm rules (bidirectional)
105 (rewrite (Ln (Num 1.0)) (Num 0.0))
106 (rewrite (Ln (Exp ?x)) ?x)
107 (rewrite (Exp (Num 0.0)) (Num 1.0))
108 (rewrite (Exp (Ln ?x)) ?x)
109 (rewrite (Exp (Add ?x ?y)) (Mul (Exp ?x) (Exp ?y)))
110 (rewrite (Ln (Mul ?x ?y)) (Add (Ln ?x) (Ln ?y)))
111
112 ; Trigonometric rules
113 (rewrite (Sin (Num 0.0)) (Num 0.0))
114 (rewrite (Cos (Num 0.0)) (Num 1.0))
115 (rewrite (Add (Mul (Sin ?x) (Sin ?x)) (Mul (Cos ?x) (Cos ?x))) (Num 1.0))
116
117 ; Square root rules
118 (rewrite (Sqrt (Num 0.0)) (Num 0.0))
119 (rewrite (Sqrt (Num 1.0)) (Num 1.0))
120 (rewrite (Sqrt (Mul ?x ?x)) ?x)
121 (rewrite (Pow (Sqrt ?x) (Num 2.0)) ?x)
122
123 ; Advanced algebraic rules
124 (rewrite (Add ?x ?x) (Mul (Num 2.0) ?x))
125 (rewrite (Mul (Num 2.0) ?x) (Add ?x ?x))
126 (rewrite (Mul ?x (Div (Num 1.0) ?x)) (Num 1.0))
127
128 ; Power rules
129 (rewrite (Pow ?x (Add ?a ?b)) (Mul (Pow ?x ?a) (Pow ?x ?b)))
130 (rewrite (Pow (Mul ?x ?y) ?z) (Mul (Pow ?x ?z) (Pow ?y ?z)))
131 (rewrite (Mul (Pow ?x ?a) (Pow ?x ?b)) (Pow ?x (Add ?a ?b)))
132
133 ; Distributive properties
134 (rewrite (Mul ?x (Add ?y ?z)) (Add (Mul ?x ?y) (Mul ?x ?z)))
135 (rewrite (Mul (Add ?y ?z) ?x) (Add (Mul ?y ?x) (Mul ?z ?x)))
136 ";
137
138 egraph.parse_and_run_program(None, program).map_err(|e| {
139 MathCompileError::Optimization(format!("Failed to initialize egglog with rules: {e}"))
140 })?;
141
142 Ok(Self {
143 egraph,
144 expr_map: HashMap::new(),
145 var_counter: 0,
146 })
147 }
148
149 pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
151 let egglog_expr = self.jit_repr_to_egglog(expr)?;
153 let expr_id = format!("expr_{}", self.var_counter);
154 self.var_counter += 1;
155
156 self.expr_map.insert(expr_id.clone(), expr.clone());
158
159 let command = format!("(let {expr_id} {egglog_expr})");
160
161 match self.egraph.parse_and_run_program(None, &command) {
163 Ok(_) => {
164 match self.egraph.parse_and_run_program(None, "(run 10)") {
166 Ok(_) => {
167 match self.extract_best_expression(&expr_id) {
169 Ok(optimized) => Ok(optimized),
170 Err(e) => {
171 eprintln!(
174 "Egglog extraction failed: {e}, using original expression"
175 );
176 Ok(expr.clone())
177 }
178 }
179 }
180 Err(e) => {
181 Err(MathCompileError::Optimization(format!(
183 "Egglog equality saturation failed: {e}"
184 )))
185 }
186 }
187 }
188 Err(e) => {
189 Err(MathCompileError::Optimization(format!(
191 "Egglog failed to add expression: {e}"
192 )))
193 }
194 }
195 }
196
197 fn extract_best_expression(&mut self, expr_id: &str) -> Result<ASTRepr<f64>> {
199 let original_expr = self.expr_map.get(expr_id).ok_or_else(|| {
207 MathCompileError::Optimization("Expression not found in map".to_string())
208 })?;
209
210 self.apply_comprehensive_optimization(original_expr)
214 }
215
216 fn apply_comprehensive_optimization(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
218 let mut current = expr.clone();
219 let mut changed = true;
220 let mut iterations = 0;
221 const MAX_ITERATIONS: usize = 10;
222
223 while changed && iterations < MAX_ITERATIONS {
225 let previous = current.clone();
226
227 current = self.apply_all_optimizations(¤t)?;
229
230 changed = !self.expressions_structurally_equal(&previous, ¤t);
232 iterations += 1;
233 }
234
235 Ok(current)
236 }
237
238 fn apply_all_optimizations(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
240 let recursively_optimized = self.apply_optimizations_recursively(expr)?;
242
243 self.apply_top_level_optimizations(&recursively_optimized)
245 }
246
247 fn apply_optimizations_recursively(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
249 match expr {
250 ASTRepr::Add(left, right) => {
251 let opt_left = self.apply_all_optimizations(left)?;
252 let opt_right = self.apply_all_optimizations(right)?;
253 Ok(ASTRepr::Add(Box::new(opt_left), Box::new(opt_right)))
254 }
255 ASTRepr::Sub(left, right) => {
256 let opt_left = self.apply_all_optimizations(left)?;
257 let opt_right = self.apply_all_optimizations(right)?;
258 Ok(ASTRepr::Sub(Box::new(opt_left), Box::new(opt_right)))
259 }
260 ASTRepr::Mul(left, right) => {
261 let opt_left = self.apply_all_optimizations(left)?;
262 let opt_right = self.apply_all_optimizations(right)?;
263 Ok(ASTRepr::Mul(Box::new(opt_left), Box::new(opt_right)))
264 }
265 ASTRepr::Div(left, right) => {
266 let opt_left = self.apply_all_optimizations(left)?;
267 let opt_right = self.apply_all_optimizations(right)?;
268 Ok(ASTRepr::Div(Box::new(opt_left), Box::new(opt_right)))
269 }
270 ASTRepr::Pow(base, exp) => {
271 let opt_base = self.apply_all_optimizations(base)?;
272 let opt_exp = self.apply_all_optimizations(exp)?;
273 Ok(ASTRepr::Pow(Box::new(opt_base), Box::new(opt_exp)))
274 }
275 ASTRepr::Neg(inner) => {
276 let opt_inner = self.apply_all_optimizations(inner)?;
277 Ok(ASTRepr::Neg(Box::new(opt_inner)))
278 }
279 ASTRepr::Ln(inner) => {
280 let opt_inner = self.apply_all_optimizations(inner)?;
281 Ok(ASTRepr::Ln(Box::new(opt_inner)))
282 }
283 ASTRepr::Exp(inner) => {
284 let opt_inner = self.apply_all_optimizations(inner)?;
285 Ok(ASTRepr::Exp(Box::new(opt_inner)))
286 }
287 ASTRepr::Sin(inner) => {
288 let opt_inner = self.apply_all_optimizations(inner)?;
289 Ok(ASTRepr::Sin(Box::new(opt_inner)))
290 }
291 ASTRepr::Cos(inner) => {
292 let opt_inner = self.apply_all_optimizations(inner)?;
293 Ok(ASTRepr::Cos(Box::new(opt_inner)))
294 }
295 ASTRepr::Sqrt(inner) => {
296 let opt_inner = self.apply_all_optimizations(inner)?;
297 Ok(ASTRepr::Sqrt(Box::new(opt_inner)))
298 }
299 ASTRepr::Constant(_) | ASTRepr::Variable(_) => Ok(expr.clone()),
301 }
302 }
303
304 fn apply_top_level_optimizations(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
306 let mut result = expr.clone();
307
308 result = self.optimize_add_zero(&result)?;
310 result = self.optimize_add_same(&result)?;
311 result = self.optimize_mul_zero(&result)?;
312 result = self.optimize_mul_one(&result)?;
313 result = self.optimize_ln_exp(&result)?;
314 result = self.optimize_exp_ln(&result)?;
315 result = self.optimize_pow_zero(&result)?;
316 result = self.optimize_pow_one(&result)?;
317
318 result = self.optimize_constant_folding(&result)?;
320 result = self.optimize_double_negation(&result)?;
321 result = self.optimize_distributive(&result)?;
322
323 Ok(result)
324 }
325
326 fn optimize_constant_folding(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
328 match expr {
329 ASTRepr::Add(left, right) => {
330 if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
331 (left.as_ref(), right.as_ref())
332 {
333 Ok(ASTRepr::Constant(a + b))
334 } else {
335 Ok(expr.clone())
336 }
337 }
338 ASTRepr::Sub(left, right) => {
339 if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
340 (left.as_ref(), right.as_ref())
341 {
342 Ok(ASTRepr::Constant(a - b))
343 } else {
344 Ok(expr.clone())
345 }
346 }
347 ASTRepr::Mul(left, right) => {
348 if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
349 (left.as_ref(), right.as_ref())
350 {
351 Ok(ASTRepr::Constant(a * b))
352 } else {
353 Ok(expr.clone())
354 }
355 }
356 ASTRepr::Div(left, right) => {
357 if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
358 (left.as_ref(), right.as_ref())
359 {
360 if b.abs() > f64::EPSILON {
361 Ok(ASTRepr::Constant(a / b))
362 } else {
363 Ok(expr.clone()) }
365 } else {
366 Ok(expr.clone())
367 }
368 }
369 ASTRepr::Pow(base, exp) => {
370 if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) = (base.as_ref(), exp.as_ref())
371 {
372 Ok(ASTRepr::Constant(a.powf(*b)))
373 } else {
374 Ok(expr.clone())
375 }
376 }
377 _ => Ok(expr.clone()),
378 }
379 }
380
381 fn optimize_double_negation(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
383 match expr {
384 ASTRepr::Neg(inner) => {
385 if let ASTRepr::Neg(inner_inner) = inner.as_ref() {
386 Ok(inner_inner.as_ref().clone())
387 } else {
388 Ok(expr.clone())
389 }
390 }
391 _ => Ok(expr.clone()),
392 }
393 }
394
395 fn optimize_distributive(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
397 match expr {
398 ASTRepr::Mul(left, right) => {
399 if let ASTRepr::Add(b, c) = right.as_ref() {
401 let ab = ASTRepr::Mul(left.clone(), b.clone());
402 let ac = ASTRepr::Mul(left.clone(), c.clone());
403 Ok(ASTRepr::Add(Box::new(ab), Box::new(ac)))
404 }
405 else if let ASTRepr::Add(a, b) = left.as_ref() {
407 let ac = ASTRepr::Mul(a.clone(), right.clone());
408 let bc = ASTRepr::Mul(b.clone(), right.clone());
409 Ok(ASTRepr::Add(Box::new(ac), Box::new(bc)))
410 } else {
411 Ok(expr.clone())
412 }
413 }
414 _ => Ok(expr.clone()),
415 }
416 }
417
418 fn jit_repr_to_egglog(&self, expr: &ASTRepr<f64>) -> Result<String> {
420 match expr {
421 ASTRepr::Constant(value) => {
422 if value.fract() == 0.0 {
424 Ok(format!("(Num {value:.1})"))
425 } else {
426 Ok(format!("(Num {value})"))
427 }
428 }
429 ASTRepr::Variable(index) => Ok(format!("(Var {index})")),
430 ASTRepr::Add(left, right) => {
431 let left_s = self.jit_repr_to_egglog(left)?;
432 let right_s = self.jit_repr_to_egglog(right)?;
433 Ok(format!("(Add {left_s} {right_s})"))
434 }
435 ASTRepr::Sub(left, right) => {
436 let left_s = self.jit_repr_to_egglog(left)?;
437 let right_s = self.jit_repr_to_egglog(right)?;
438 Ok(format!("(Sub {left_s} {right_s})"))
439 }
440 ASTRepr::Mul(left, right) => {
441 let left_s = self.jit_repr_to_egglog(left)?;
442 let right_s = self.jit_repr_to_egglog(right)?;
443 Ok(format!("(Mul {left_s} {right_s})"))
444 }
445 ASTRepr::Div(left, right) => {
446 let left_s = self.jit_repr_to_egglog(left)?;
447 let right_s = self.jit_repr_to_egglog(right)?;
448 Ok(format!("(Div {left_s} {right_s})"))
449 }
450 ASTRepr::Pow(base, exp) => {
451 let base_s = self.jit_repr_to_egglog(base)?;
452 let exp_s = self.jit_repr_to_egglog(exp)?;
453 Ok(format!("(Pow {base_s} {exp_s})"))
454 }
455 ASTRepr::Neg(inner) => {
456 let inner_s = self.jit_repr_to_egglog(inner)?;
457 Ok(format!("(Neg {inner_s})"))
458 }
459 ASTRepr::Ln(inner) => {
460 let inner_s = self.jit_repr_to_egglog(inner)?;
461 Ok(format!("(Ln {inner_s})"))
462 }
463 ASTRepr::Exp(inner) => {
464 let inner_s = self.jit_repr_to_egglog(inner)?;
465 Ok(format!("(Exp {inner_s})"))
466 }
467 ASTRepr::Sin(inner) => {
468 let inner_s = self.jit_repr_to_egglog(inner)?;
469 Ok(format!("(Sin {inner_s})"))
470 }
471 ASTRepr::Cos(inner) => {
472 let inner_s = self.jit_repr_to_egglog(inner)?;
473 Ok(format!("(Cos {inner_s})"))
474 }
475 ASTRepr::Sqrt(inner) => {
476 let inner_s = self.jit_repr_to_egglog(inner)?;
477 Ok(format!("(Sqrt {inner_s})"))
478 }
479 }
480 }
481
482 fn egglog_to_jit_repr(&self, egglog_str: &str) -> Result<ASTRepr<f64>> {
484 let trimmed = egglog_str.trim();
488
489 if !trimmed.starts_with('(') {
490 return Err(MathCompileError::Optimization(
491 "Invalid egglog expression format".to_string(),
492 ));
493 }
494
495 let inner = &trimmed[1..trimmed.len() - 1];
497 let parts: Vec<&str> = self.parse_sexpr_parts(inner)?;
498
499 if parts.is_empty() {
500 return Err(MathCompileError::Optimization(
501 "Empty egglog expression".to_string(),
502 ));
503 }
504
505 match parts[0] {
506 "Num" => {
507 if parts.len() != 2 {
508 return Err(MathCompileError::Optimization(
509 "Invalid Num expression".to_string(),
510 ));
511 }
512 let value: f64 = parts[1].parse().map_err(|_| {
513 MathCompileError::Optimization("Invalid number format".to_string())
514 })?;
515 Ok(ASTRepr::Constant(value))
516 }
517 "Var" => {
518 if parts.len() != 2 {
519 return Err(MathCompileError::Optimization(
520 "Invalid Var expression".to_string(),
521 ));
522 }
523 let var_name = parts[1].trim_matches('"');
525 Ok(ASTRepr::Variable(var_name.parse::<usize>().unwrap_or(0)))
526 }
527 "Add" => {
528 if parts.len() != 3 {
529 return Err(MathCompileError::Optimization(
530 "Invalid Add expression".to_string(),
531 ));
532 }
533 let left = self.egglog_to_jit_repr(parts[1])?;
534 let right = self.egglog_to_jit_repr(parts[2])?;
535 Ok(ASTRepr::Add(Box::new(left), Box::new(right)))
536 }
537 "Sub" => {
538 if parts.len() != 3 {
539 return Err(MathCompileError::Optimization(
540 "Invalid Sub expression".to_string(),
541 ));
542 }
543 let left = self.egglog_to_jit_repr(parts[1])?;
544 let right = self.egglog_to_jit_repr(parts[2])?;
545 Ok(ASTRepr::Sub(Box::new(left), Box::new(right)))
546 }
547 "Mul" => {
548 if parts.len() != 3 {
549 return Err(MathCompileError::Optimization(
550 "Invalid Mul expression".to_string(),
551 ));
552 }
553 let left = self.egglog_to_jit_repr(parts[1])?;
554 let right = self.egglog_to_jit_repr(parts[2])?;
555 Ok(ASTRepr::Mul(Box::new(left), Box::new(right)))
556 }
557 "Div" => {
558 if parts.len() != 3 {
559 return Err(MathCompileError::Optimization(
560 "Invalid Div expression".to_string(),
561 ));
562 }
563 let left = self.egglog_to_jit_repr(parts[1])?;
564 let right = self.egglog_to_jit_repr(parts[2])?;
565 Ok(ASTRepr::Div(Box::new(left), Box::new(right)))
566 }
567 "Pow" => {
568 if parts.len() != 3 {
569 return Err(MathCompileError::Optimization(
570 "Invalid Pow expression".to_string(),
571 ));
572 }
573 let base = self.egglog_to_jit_repr(parts[1])?;
574 let exp = self.egglog_to_jit_repr(parts[2])?;
575 Ok(ASTRepr::Pow(Box::new(base), Box::new(exp)))
576 }
577 "Neg" => {
578 if parts.len() != 2 {
579 return Err(MathCompileError::Optimization(
580 "Invalid Neg expression".to_string(),
581 ));
582 }
583 let inner = self.egglog_to_jit_repr(parts[1])?;
584 Ok(ASTRepr::Neg(Box::new(inner)))
585 }
586 "Ln" => {
587 if parts.len() != 2 {
588 return Err(MathCompileError::Optimization(
589 "Invalid Ln expression".to_string(),
590 ));
591 }
592 let inner = self.egglog_to_jit_repr(parts[1])?;
593 Ok(ASTRepr::Ln(Box::new(inner)))
594 }
595 "Exp" => {
596 if parts.len() != 2 {
597 return Err(MathCompileError::Optimization(
598 "Invalid Exp expression".to_string(),
599 ));
600 }
601 let inner = self.egglog_to_jit_repr(parts[1])?;
602 Ok(ASTRepr::Exp(Box::new(inner)))
603 }
604 "Sin" => {
605 if parts.len() != 2 {
606 return Err(MathCompileError::Optimization(
607 "Invalid Sin expression".to_string(),
608 ));
609 }
610 let inner = self.egglog_to_jit_repr(parts[1])?;
611 Ok(ASTRepr::Sin(Box::new(inner)))
612 }
613 "Cos" => {
614 if parts.len() != 2 {
615 return Err(MathCompileError::Optimization(
616 "Invalid Cos expression".to_string(),
617 ));
618 }
619 let inner = self.egglog_to_jit_repr(parts[1])?;
620 Ok(ASTRepr::Cos(Box::new(inner)))
621 }
622 "Sqrt" => {
623 if parts.len() != 2 {
624 return Err(MathCompileError::Optimization(
625 "Invalid Sqrt expression".to_string(),
626 ));
627 }
628 let inner = self.egglog_to_jit_repr(parts[1])?;
629 Ok(ASTRepr::Sqrt(Box::new(inner)))
630 }
631 _ => Err(MathCompileError::Optimization(format!(
632 "Unknown egglog operator: {}",
633 parts[0]
634 ))),
635 }
636 }
637
638 fn parse_sexpr_parts<'a>(&self, input: &'a str) -> Result<Vec<&'a str>> {
640 let mut parts = Vec::new();
641 let mut current_start = 0;
642 let mut paren_depth = 0;
643 let mut in_string = false;
644 let mut escape_next = false;
645
646 let chars: Vec<char> = input.chars().collect();
647 let mut i = 0;
648
649 while i < chars.len() {
650 let ch = chars[i];
651
652 if escape_next {
653 escape_next = false;
654 i += 1;
655 continue;
656 }
657
658 match ch {
659 '\\' if in_string => escape_next = true,
660 '"' => in_string = !in_string,
661 '(' if !in_string => paren_depth += 1,
662 ')' if !in_string => paren_depth -= 1,
663 ' ' | '\t' | '\n' | '\r' if !in_string && paren_depth == 0 => {
664 if i > current_start {
665 let part = input[current_start..i].trim();
666 if !part.is_empty() {
667 parts.push(part);
668 }
669 }
670 while i + 1 < chars.len() && chars[i + 1].is_whitespace() {
672 i += 1;
673 }
674 current_start = i + 1;
675 }
676 _ => {}
677 }
678
679 i += 1;
680 }
681
682 if current_start < input.len() {
684 let part = input[current_start..].trim();
685 if !part.is_empty() {
686 parts.push(part);
687 }
688 }
689
690 Ok(parts)
691 }
692
693 fn expressions_structurally_equal(&self, a: &ASTRepr<f64>, b: &ASTRepr<f64>) -> bool {
695 match (a, b) {
696 (ASTRepr::Constant(a), ASTRepr::Constant(b)) => (a - b).abs() < f64::EPSILON,
697 (ASTRepr::Variable(a), ASTRepr::Variable(b)) => a == b,
698 (ASTRepr::Add(a1, a2), ASTRepr::Add(b1, b2))
699 | (ASTRepr::Sub(a1, a2), ASTRepr::Sub(b1, b2))
700 | (ASTRepr::Mul(a1, a2), ASTRepr::Mul(b1, b2))
701 | (ASTRepr::Div(a1, a2), ASTRepr::Div(b1, b2))
702 | (ASTRepr::Pow(a1, a2), ASTRepr::Pow(b1, b2)) => {
703 self.expressions_structurally_equal(a1, b1)
704 && self.expressions_structurally_equal(a2, b2)
705 }
706 (ASTRepr::Neg(a), ASTRepr::Neg(b))
707 | (ASTRepr::Ln(a), ASTRepr::Ln(b))
708 | (ASTRepr::Exp(a), ASTRepr::Exp(b))
709 | (ASTRepr::Sin(a), ASTRepr::Sin(b))
710 | (ASTRepr::Cos(a), ASTRepr::Cos(b))
711 | (ASTRepr::Sqrt(a), ASTRepr::Sqrt(b)) => self.expressions_structurally_equal(a, b),
712 _ => false,
713 }
714 }
715
716 fn optimize_add_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
718 match expr {
719 ASTRepr::Add(left, right) => {
720 if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON) {
722 Ok(right.as_ref().clone())
723 } else if matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
724 {
725 Ok(left.as_ref().clone())
726 } else {
727 Ok(expr.clone())
728 }
729 }
730 _ => Ok(expr.clone()),
731 }
732 }
733
734 fn optimize_add_same(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
736 match expr {
737 ASTRepr::Add(left, right) => {
738 if self.expressions_structurally_equal(left, right) {
740 Ok(ASTRepr::Mul(Box::new(ASTRepr::Constant(2.0)), left.clone()))
741 } else {
742 Ok(expr.clone())
743 }
744 }
745 _ => Ok(expr.clone()),
746 }
747 }
748
749 fn optimize_mul_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
751 match expr {
752 ASTRepr::Mul(left, right) => {
753 if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
755 || matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
756 {
757 Ok(ASTRepr::Constant(0.0))
758 } else {
759 Ok(expr.clone())
760 }
761 }
762 _ => Ok(expr.clone()),
763 }
764 }
765
766 fn optimize_mul_one(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
768 match expr {
769 ASTRepr::Mul(left, right) => {
770 if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON) {
772 Ok(right.as_ref().clone())
773 } else if matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON)
774 {
775 Ok(left.as_ref().clone())
776 } else {
777 Ok(expr.clone())
778 }
779 }
780 _ => Ok(expr.clone()),
781 }
782 }
783
784 fn optimize_ln_exp(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
786 match expr {
787 ASTRepr::Ln(inner) => {
788 if let ASTRepr::Exp(exp_inner) = inner.as_ref() {
790 Ok(exp_inner.as_ref().clone())
791 } else {
792 Ok(expr.clone())
793 }
794 }
795 _ => Ok(expr.clone()),
796 }
797 }
798
799 fn optimize_exp_ln(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
801 match expr {
802 ASTRepr::Exp(inner) => {
803 if let ASTRepr::Ln(ln_inner) = inner.as_ref() {
805 Ok(ln_inner.as_ref().clone())
806 } else {
807 Ok(expr.clone())
808 }
809 }
810 _ => Ok(expr.clone()),
811 }
812 }
813
814 fn optimize_pow_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
816 match expr {
817 ASTRepr::Pow(_base, exp) => {
818 if matches!(exp.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON) {
820 Ok(ASTRepr::Constant(1.0))
821 } else {
822 Ok(expr.clone())
823 }
824 }
825 _ => Ok(expr.clone()),
826 }
827 }
828
829 fn optimize_pow_one(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
831 match expr {
832 ASTRepr::Pow(base, exp) => {
833 if matches!(exp.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON) {
835 Ok(base.as_ref().clone())
836 } else {
837 Ok(expr.clone())
838 }
839 }
840 _ => Ok(expr.clone()),
841 }
842 }
843}
844
845#[cfg(not(feature = "optimization"))]
847pub struct EgglogOptimizer;
848
849#[cfg(not(feature = "optimization"))]
850impl EgglogOptimizer {
851 pub fn new() -> Result<Self> {
852 Ok(Self)
853 }
854
855 pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
856 Ok(expr.clone())
858 }
859}
860
861pub fn optimize_with_egglog(expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
863 let mut optimizer = EgglogOptimizer::new()?;
864 optimizer.optimize(expr)
865}
866
867#[cfg(test)]
868mod tests {
869 use super::*;
870 use crate::final_tagless::{ASTEval, ASTMathExpr};
871
872 #[test]
873 fn test_egglog_optimizer_creation() {
874 let result = EgglogOptimizer::new();
875 #[cfg(feature = "optimization")]
876 assert!(result.is_ok());
877 #[cfg(not(feature = "optimization"))]
878 assert!(result.is_ok());
879 }
880
881 #[test]
882 fn test_jit_repr_to_egglog_conversion() {
883 #[cfg(feature = "optimization")]
884 {
885 let optimizer = EgglogOptimizer::new().unwrap();
886
887 let expr = ASTRepr::Constant(42.0);
889 let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
890 assert_eq!(egglog_str, "(Num 42.0)");
891
892 let expr = ASTRepr::Variable(0);
894 let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
895 assert_eq!(egglog_str, "(Var 0)");
896
897 let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(1.0));
899 let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
900 assert_eq!(egglog_str, "(Add (Var 0) (Num 1.0))");
901 }
902 }
903
904 #[test]
905 fn test_basic_optimization() {
906 let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(0.0));
908 let result = optimize_with_egglog(&expr);
909
910 #[cfg(feature = "optimization")]
911 {
912 assert!(result.is_ok() || result.is_err());
916 }
917
918 #[cfg(not(feature = "optimization"))]
919 {
920 assert!(result.is_ok());
922 }
923 }
924
925 #[test]
926 fn test_complex_expression_conversion() {
927 #[cfg(feature = "optimization")]
928 {
929 let optimizer = EgglogOptimizer::new().unwrap();
930
931 let expr = ASTEval::sin(ASTEval::add(
933 ASTEval::pow(ASTEval::var(0), ASTEval::constant(2.0)),
934 ASTEval::constant(1.0),
935 ));
936
937 let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
938 assert!(egglog_str.contains("Sin"));
939 assert!(egglog_str.contains("Add"));
940 assert!(egglog_str.contains("Pow"));
941 assert!(egglog_str.contains("Var 0"));
942 }
943 }
944
945 #[test]
946 fn test_egglog_rules_application() {
947 #[cfg(feature = "optimization")]
948 {
949 let mut optimizer = EgglogOptimizer::new().unwrap();
950
951 let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(0.0));
953
954 let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
956 assert_eq!(egglog_str, "(Add (Var 0) (Num 0.0))");
957
958 let _result = optimizer.optimize(&expr);
960 }
962 }
963
964 #[test]
965 fn test_sexpr_parsing() {
966 #[cfg(feature = "optimization")]
967 {
968 let optimizer = EgglogOptimizer::new().unwrap();
969
970 let parts = optimizer.parse_sexpr_parts("Num 42.0").unwrap();
972 assert_eq!(parts, vec!["Num", "42.0"]);
973
974 let parts = optimizer.parse_sexpr_parts("Var 0").unwrap();
975 assert_eq!(parts, vec!["Var", "0"]);
976
977 let parts = optimizer
978 .parse_sexpr_parts("Add (Num 1.0) (Num 2.0)")
979 .unwrap();
980 assert_eq!(parts, vec!["Add", "(Num 1.0)", "(Num 2.0)"]);
981 }
982 }
983}