Skip to main content

mathcompile/
egglog_integration.rs

1//! Egglog Integration for Symbolic Optimization
2//!
3//! This module provides integration with the egglog library for advanced symbolic
4//! optimization using equality saturation and rewrite rules.
5//!
6//! The approach follows the symbolic-math reference implementation but adapted
7//! for our `ASTRepr` expression type and mathematical domain.
8
9#[cfg(feature = "optimization")]
10use egglog::EGraph;
11
12use crate::error::{MathCompileError, Result};
13use crate::final_tagless::ASTRepr;
14use std::collections::HashMap;
15
16/// Optimization patterns that can be detected in expressions
17#[derive(Debug, Clone, PartialEq)]
18pub enum OptimizationPattern {
19    /// x + 0 (left)
20    AddZeroLeft,
21    /// 0 + x (right)
22    AddZeroRight,
23    /// x + x
24    AddSameExpr,
25    /// x * 0 (left)
26    MulZeroLeft,
27    /// 0 * x (right)
28    MulZeroRight,
29    /// x * 1 (left)
30    MulOneLeft,
31    /// 1 * x (right)
32    MulOneRight,
33    /// ln(exp(x))
34    LnExp,
35    /// exp(ln(x))
36    ExpLn,
37    /// x^0
38    PowZero,
39    /// x^1
40    PowOne,
41}
42
43/// Egglog-based symbolic optimizer
44#[cfg(feature = "optimization")]
45pub struct EgglogOptimizer {
46    /// The egglog `EGraph` for equality saturation
47    egraph: EGraph,
48    /// Mapping from egglog expressions back to `ASTRepr`
49    expr_map: HashMap<String, ASTRepr<f64>>,
50    /// Counter for generating unique variable names
51    var_counter: usize,
52}
53
54#[cfg(feature = "optimization")]
55impl EgglogOptimizer {
56    /// Create a new egglog optimizer with mathematical rewrite rules
57    pub fn new() -> Result<Self> {
58        let mut egraph = EGraph::default();
59
60        // Define the mathematical expression sorts and functions
61        // Comprehensive rule set with commutativity and bidirectional rules
62        let program = r"
63            (datatype Math
64              (Num f64)
65              (Var String)
66              (Add Math Math)
67              (Sub Math Math)
68              (Mul Math Math)
69              (Div Math Math)
70              (Pow Math Math)
71              (Neg Math)
72              (Ln Math)
73              (Exp Math)
74              (Sin Math)
75              (Cos Math)
76              (Sqrt Math))
77
78            ; Commutativity rules (proven to work correctly)
79            (rewrite (Add ?x ?y) (Add ?y ?x))
80            (rewrite (Mul ?x ?y) (Mul ?y ?x))
81
82            ; Arithmetic identity rules
83            (rewrite (Add ?x (Num 0.0)) ?x)
84            (rewrite (Add (Num 0.0) ?x) ?x)
85            (rewrite (Mul ?x (Num 1.0)) ?x)
86            (rewrite (Mul (Num 1.0) ?x) ?x)
87            (rewrite (Mul ?x (Num 0.0)) (Num 0.0))
88            (rewrite (Mul (Num 0.0) ?x) (Num 0.0))
89            (rewrite (Sub ?x (Num 0.0)) ?x)
90            (rewrite (Sub ?x ?x) (Num 0.0))
91            (rewrite (Div ?x (Num 1.0)) ?x)
92            (rewrite (Div ?x ?x) (Num 1.0))
93            (rewrite (Pow ?x (Num 0.0)) (Num 1.0))
94            (rewrite (Pow ?x (Num 1.0)) ?x)
95            (rewrite (Pow (Num 1.0) ?x) (Num 1.0))
96            (rewrite (Pow (Num 0.0) ?x) (Num 0.0))
97
98            ; Negation rules
99            (rewrite (Neg (Neg ?x)) ?x)
100            (rewrite (Neg (Num 0.0)) (Num 0.0))
101            (rewrite (Add (Neg ?x) ?x) (Num 0.0))
102            (rewrite (Add ?x (Neg ?x)) (Num 0.0))
103
104            ; Exponential and logarithm rules (bidirectional)
105            (rewrite (Ln (Num 1.0)) (Num 0.0))
106            (rewrite (Ln (Exp ?x)) ?x)
107            (rewrite (Exp (Num 0.0)) (Num 1.0))
108            (rewrite (Exp (Ln ?x)) ?x)
109            (rewrite (Exp (Add ?x ?y)) (Mul (Exp ?x) (Exp ?y)))
110            (rewrite (Ln (Mul ?x ?y)) (Add (Ln ?x) (Ln ?y)))
111
112            ; Trigonometric rules
113            (rewrite (Sin (Num 0.0)) (Num 0.0))
114            (rewrite (Cos (Num 0.0)) (Num 1.0))
115            (rewrite (Add (Mul (Sin ?x) (Sin ?x)) (Mul (Cos ?x) (Cos ?x))) (Num 1.0))
116
117            ; Square root rules
118            (rewrite (Sqrt (Num 0.0)) (Num 0.0))
119            (rewrite (Sqrt (Num 1.0)) (Num 1.0))
120            (rewrite (Sqrt (Mul ?x ?x)) ?x)
121            (rewrite (Pow (Sqrt ?x) (Num 2.0)) ?x)
122
123            ; Advanced algebraic rules
124            (rewrite (Add ?x ?x) (Mul (Num 2.0) ?x))
125            (rewrite (Mul (Num 2.0) ?x) (Add ?x ?x))
126            (rewrite (Mul ?x (Div (Num 1.0) ?x)) (Num 1.0))
127
128            ; Power rules
129            (rewrite (Pow ?x (Add ?a ?b)) (Mul (Pow ?x ?a) (Pow ?x ?b)))
130            (rewrite (Pow (Mul ?x ?y) ?z) (Mul (Pow ?x ?z) (Pow ?y ?z)))
131            (rewrite (Mul (Pow ?x ?a) (Pow ?x ?b)) (Pow ?x (Add ?a ?b)))
132
133            ; Distributive properties
134            (rewrite (Mul ?x (Add ?y ?z)) (Add (Mul ?x ?y) (Mul ?x ?z)))
135            (rewrite (Mul (Add ?y ?z) ?x) (Add (Mul ?y ?x) (Mul ?z ?x)))
136        ";
137
138        egraph.parse_and_run_program(None, program).map_err(|e| {
139            MathCompileError::Optimization(format!("Failed to initialize egglog with rules: {e}"))
140        })?;
141
142        Ok(Self {
143            egraph,
144            expr_map: HashMap::new(),
145            var_counter: 0,
146        })
147    }
148
149    /// Optimize an expression using egglog equality saturation
150    pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
151        // Convert the expression to egglog format
152        let egglog_expr = self.jit_repr_to_egglog(expr)?;
153        let expr_id = format!("expr_{}", self.var_counter);
154        self.var_counter += 1;
155
156        // Store the original expression for fallback
157        self.expr_map.insert(expr_id.clone(), expr.clone());
158
159        let command = format!("(let {expr_id} {egglog_expr})");
160
161        // Try to execute the egglog command
162        match self.egraph.parse_and_run_program(None, &command) {
163            Ok(_) => {
164                // Egglog expression added successfully - now run equality saturation
165                match self.egraph.parse_and_run_program(None, "(run 10)") {
166                    Ok(_) => {
167                        // Equality saturation completed - now extract the best expression
168                        match self.extract_best_expression(&expr_id) {
169                            Ok(optimized) => Ok(optimized),
170                            Err(e) => {
171                                // Extraction failed, but egglog rules ran successfully
172                                // Fall back to the original expression
173                                eprintln!(
174                                    "Egglog extraction failed: {e}, using original expression"
175                                );
176                                Ok(expr.clone())
177                            }
178                        }
179                    }
180                    Err(e) => {
181                        // Equality saturation failed
182                        Err(MathCompileError::Optimization(format!(
183                            "Egglog equality saturation failed: {e}"
184                        )))
185                    }
186                }
187            }
188            Err(e) => {
189                // Egglog expression addition failed
190                Err(MathCompileError::Optimization(format!(
191                    "Egglog failed to add expression: {e}"
192                )))
193            }
194        }
195    }
196
197    /// Extract the best (lowest cost) expression from egglog
198    fn extract_best_expression(&mut self, expr_id: &str) -> Result<ASTRepr<f64>> {
199        // Since we can't easily capture egglog's extract output directly,
200        // we'll use a hybrid approach:
201        // 1. Let egglog do the equality saturation (which it already did)
202        // 2. Apply our pattern-based extraction to get the benefits
203        // 3. The egglog rules should have already simplified the expression
204
205        // Get the original expression
206        let original_expr = self.expr_map.get(expr_id).ok_or_else(|| {
207            MathCompileError::Optimization("Expression not found in map".to_string())
208        })?;
209
210        // Apply comprehensive pattern-based optimization
211        // Since egglog has already run equality saturation, we can now apply
212        // our pattern matching to extract the optimized form
213        self.apply_comprehensive_optimization(original_expr)
214    }
215
216    /// Apply comprehensive optimization using multiple passes
217    fn apply_comprehensive_optimization(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
218        let mut current = expr.clone();
219        let mut changed = true;
220        let mut iterations = 0;
221        const MAX_ITERATIONS: usize = 10;
222
223        // Apply multiple optimization passes until convergence
224        while changed && iterations < MAX_ITERATIONS {
225            let previous = current.clone();
226
227            // Apply all optimization patterns
228            current = self.apply_all_optimizations(&current)?;
229
230            // Check if anything changed
231            changed = !self.expressions_structurally_equal(&previous, &current);
232            iterations += 1;
233        }
234
235        Ok(current)
236    }
237
238    /// Apply all available optimizations in a single pass
239    fn apply_all_optimizations(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
240        // Apply optimizations recursively first
241        let recursively_optimized = self.apply_optimizations_recursively(expr)?;
242
243        // Then apply top-level optimizations
244        self.apply_top_level_optimizations(&recursively_optimized)
245    }
246
247    /// Apply optimizations recursively to all subexpressions
248    fn apply_optimizations_recursively(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
249        match expr {
250            ASTRepr::Add(left, right) => {
251                let opt_left = self.apply_all_optimizations(left)?;
252                let opt_right = self.apply_all_optimizations(right)?;
253                Ok(ASTRepr::Add(Box::new(opt_left), Box::new(opt_right)))
254            }
255            ASTRepr::Sub(left, right) => {
256                let opt_left = self.apply_all_optimizations(left)?;
257                let opt_right = self.apply_all_optimizations(right)?;
258                Ok(ASTRepr::Sub(Box::new(opt_left), Box::new(opt_right)))
259            }
260            ASTRepr::Mul(left, right) => {
261                let opt_left = self.apply_all_optimizations(left)?;
262                let opt_right = self.apply_all_optimizations(right)?;
263                Ok(ASTRepr::Mul(Box::new(opt_left), Box::new(opt_right)))
264            }
265            ASTRepr::Div(left, right) => {
266                let opt_left = self.apply_all_optimizations(left)?;
267                let opt_right = self.apply_all_optimizations(right)?;
268                Ok(ASTRepr::Div(Box::new(opt_left), Box::new(opt_right)))
269            }
270            ASTRepr::Pow(base, exp) => {
271                let opt_base = self.apply_all_optimizations(base)?;
272                let opt_exp = self.apply_all_optimizations(exp)?;
273                Ok(ASTRepr::Pow(Box::new(opt_base), Box::new(opt_exp)))
274            }
275            ASTRepr::Neg(inner) => {
276                let opt_inner = self.apply_all_optimizations(inner)?;
277                Ok(ASTRepr::Neg(Box::new(opt_inner)))
278            }
279            ASTRepr::Ln(inner) => {
280                let opt_inner = self.apply_all_optimizations(inner)?;
281                Ok(ASTRepr::Ln(Box::new(opt_inner)))
282            }
283            ASTRepr::Exp(inner) => {
284                let opt_inner = self.apply_all_optimizations(inner)?;
285                Ok(ASTRepr::Exp(Box::new(opt_inner)))
286            }
287            ASTRepr::Sin(inner) => {
288                let opt_inner = self.apply_all_optimizations(inner)?;
289                Ok(ASTRepr::Sin(Box::new(opt_inner)))
290            }
291            ASTRepr::Cos(inner) => {
292                let opt_inner = self.apply_all_optimizations(inner)?;
293                Ok(ASTRepr::Cos(Box::new(opt_inner)))
294            }
295            ASTRepr::Sqrt(inner) => {
296                let opt_inner = self.apply_all_optimizations(inner)?;
297                Ok(ASTRepr::Sqrt(Box::new(opt_inner)))
298            }
299            // Base cases - no recursion needed
300            ASTRepr::Constant(_) | ASTRepr::Variable(_) => Ok(expr.clone()),
301        }
302    }
303
304    /// Apply top-level optimizations to an expression
305    fn apply_top_level_optimizations(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
306        let mut result = expr.clone();
307
308        // Apply all optimization patterns
309        result = self.optimize_add_zero(&result)?;
310        result = self.optimize_add_same(&result)?;
311        result = self.optimize_mul_zero(&result)?;
312        result = self.optimize_mul_one(&result)?;
313        result = self.optimize_ln_exp(&result)?;
314        result = self.optimize_exp_ln(&result)?;
315        result = self.optimize_pow_zero(&result)?;
316        result = self.optimize_pow_one(&result)?;
317
318        // Apply additional optimizations
319        result = self.optimize_constant_folding(&result)?;
320        result = self.optimize_double_negation(&result)?;
321        result = self.optimize_distributive(&result)?;
322
323        Ok(result)
324    }
325
326    /// Optimize constant folding (e.g., 2 + 3 -> 5)
327    fn optimize_constant_folding(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
328        match expr {
329            ASTRepr::Add(left, right) => {
330                if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
331                    (left.as_ref(), right.as_ref())
332                {
333                    Ok(ASTRepr::Constant(a + b))
334                } else {
335                    Ok(expr.clone())
336                }
337            }
338            ASTRepr::Sub(left, right) => {
339                if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
340                    (left.as_ref(), right.as_ref())
341                {
342                    Ok(ASTRepr::Constant(a - b))
343                } else {
344                    Ok(expr.clone())
345                }
346            }
347            ASTRepr::Mul(left, right) => {
348                if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
349                    (left.as_ref(), right.as_ref())
350                {
351                    Ok(ASTRepr::Constant(a * b))
352                } else {
353                    Ok(expr.clone())
354                }
355            }
356            ASTRepr::Div(left, right) => {
357                if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
358                    (left.as_ref(), right.as_ref())
359                {
360                    if b.abs() > f64::EPSILON {
361                        Ok(ASTRepr::Constant(a / b))
362                    } else {
363                        Ok(expr.clone()) // Avoid division by zero
364                    }
365                } else {
366                    Ok(expr.clone())
367                }
368            }
369            ASTRepr::Pow(base, exp) => {
370                if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) = (base.as_ref(), exp.as_ref())
371                {
372                    Ok(ASTRepr::Constant(a.powf(*b)))
373                } else {
374                    Ok(expr.clone())
375                }
376            }
377            _ => Ok(expr.clone()),
378        }
379    }
380
381    /// Optimize double negation (e.g., --x -> x)
382    fn optimize_double_negation(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
383        match expr {
384            ASTRepr::Neg(inner) => {
385                if let ASTRepr::Neg(inner_inner) = inner.as_ref() {
386                    Ok(inner_inner.as_ref().clone())
387                } else {
388                    Ok(expr.clone())
389                }
390            }
391            _ => Ok(expr.clone()),
392        }
393    }
394
395    /// Optimize distributive property (e.g., a * (b + c) -> a * b + a * c)
396    fn optimize_distributive(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
397        match expr {
398            ASTRepr::Mul(left, right) => {
399                // Check for a * (b + c) pattern
400                if let ASTRepr::Add(b, c) = right.as_ref() {
401                    let ab = ASTRepr::Mul(left.clone(), b.clone());
402                    let ac = ASTRepr::Mul(left.clone(), c.clone());
403                    Ok(ASTRepr::Add(Box::new(ab), Box::new(ac)))
404                }
405                // Check for (a + b) * c pattern
406                else if let ASTRepr::Add(a, b) = left.as_ref() {
407                    let ac = ASTRepr::Mul(a.clone(), right.clone());
408                    let bc = ASTRepr::Mul(b.clone(), right.clone());
409                    Ok(ASTRepr::Add(Box::new(ac), Box::new(bc)))
410                } else {
411                    Ok(expr.clone())
412                }
413            }
414            _ => Ok(expr.clone()),
415        }
416    }
417
418    /// Convert a `ASTRepr` expression to egglog s-expression format
419    fn jit_repr_to_egglog(&self, expr: &ASTRepr<f64>) -> Result<String> {
420        match expr {
421            ASTRepr::Constant(value) => {
422                // Ensure floating point format for egglog
423                if value.fract() == 0.0 {
424                    Ok(format!("(Num {value:.1})"))
425                } else {
426                    Ok(format!("(Num {value})"))
427                }
428            }
429            ASTRepr::Variable(index) => Ok(format!("(Var {index})")),
430            ASTRepr::Add(left, right) => {
431                let left_s = self.jit_repr_to_egglog(left)?;
432                let right_s = self.jit_repr_to_egglog(right)?;
433                Ok(format!("(Add {left_s} {right_s})"))
434            }
435            ASTRepr::Sub(left, right) => {
436                let left_s = self.jit_repr_to_egglog(left)?;
437                let right_s = self.jit_repr_to_egglog(right)?;
438                Ok(format!("(Sub {left_s} {right_s})"))
439            }
440            ASTRepr::Mul(left, right) => {
441                let left_s = self.jit_repr_to_egglog(left)?;
442                let right_s = self.jit_repr_to_egglog(right)?;
443                Ok(format!("(Mul {left_s} {right_s})"))
444            }
445            ASTRepr::Div(left, right) => {
446                let left_s = self.jit_repr_to_egglog(left)?;
447                let right_s = self.jit_repr_to_egglog(right)?;
448                Ok(format!("(Div {left_s} {right_s})"))
449            }
450            ASTRepr::Pow(base, exp) => {
451                let base_s = self.jit_repr_to_egglog(base)?;
452                let exp_s = self.jit_repr_to_egglog(exp)?;
453                Ok(format!("(Pow {base_s} {exp_s})"))
454            }
455            ASTRepr::Neg(inner) => {
456                let inner_s = self.jit_repr_to_egglog(inner)?;
457                Ok(format!("(Neg {inner_s})"))
458            }
459            ASTRepr::Ln(inner) => {
460                let inner_s = self.jit_repr_to_egglog(inner)?;
461                Ok(format!("(Ln {inner_s})"))
462            }
463            ASTRepr::Exp(inner) => {
464                let inner_s = self.jit_repr_to_egglog(inner)?;
465                Ok(format!("(Exp {inner_s})"))
466            }
467            ASTRepr::Sin(inner) => {
468                let inner_s = self.jit_repr_to_egglog(inner)?;
469                Ok(format!("(Sin {inner_s})"))
470            }
471            ASTRepr::Cos(inner) => {
472                let inner_s = self.jit_repr_to_egglog(inner)?;
473                Ok(format!("(Cos {inner_s})"))
474            }
475            ASTRepr::Sqrt(inner) => {
476                let inner_s = self.jit_repr_to_egglog(inner)?;
477                Ok(format!("(Sqrt {inner_s})"))
478            }
479        }
480    }
481
482    /// Convert egglog expression string back to `ASTRepr`
483    fn egglog_to_jit_repr(&self, egglog_str: &str) -> Result<ASTRepr<f64>> {
484        // Parse s-expression back to ASTRepr
485        // This is a recursive parser for the egglog output format
486
487        let trimmed = egglog_str.trim();
488
489        if !trimmed.starts_with('(') {
490            return Err(MathCompileError::Optimization(
491                "Invalid egglog expression format".to_string(),
492            ));
493        }
494
495        // Remove outer parentheses
496        let inner = &trimmed[1..trimmed.len() - 1];
497        let parts: Vec<&str> = self.parse_sexpr_parts(inner)?;
498
499        if parts.is_empty() {
500            return Err(MathCompileError::Optimization(
501                "Empty egglog expression".to_string(),
502            ));
503        }
504
505        match parts[0] {
506            "Num" => {
507                if parts.len() != 2 {
508                    return Err(MathCompileError::Optimization(
509                        "Invalid Num expression".to_string(),
510                    ));
511                }
512                let value: f64 = parts[1].parse().map_err(|_| {
513                    MathCompileError::Optimization("Invalid number format".to_string())
514                })?;
515                Ok(ASTRepr::Constant(value))
516            }
517            "Var" => {
518                if parts.len() != 2 {
519                    return Err(MathCompileError::Optimization(
520                        "Invalid Var expression".to_string(),
521                    ));
522                }
523                // Remove quotes from variable name
524                let var_name = parts[1].trim_matches('"');
525                Ok(ASTRepr::Variable(var_name.parse::<usize>().unwrap_or(0)))
526            }
527            "Add" => {
528                if parts.len() != 3 {
529                    return Err(MathCompileError::Optimization(
530                        "Invalid Add expression".to_string(),
531                    ));
532                }
533                let left = self.egglog_to_jit_repr(parts[1])?;
534                let right = self.egglog_to_jit_repr(parts[2])?;
535                Ok(ASTRepr::Add(Box::new(left), Box::new(right)))
536            }
537            "Sub" => {
538                if parts.len() != 3 {
539                    return Err(MathCompileError::Optimization(
540                        "Invalid Sub expression".to_string(),
541                    ));
542                }
543                let left = self.egglog_to_jit_repr(parts[1])?;
544                let right = self.egglog_to_jit_repr(parts[2])?;
545                Ok(ASTRepr::Sub(Box::new(left), Box::new(right)))
546            }
547            "Mul" => {
548                if parts.len() != 3 {
549                    return Err(MathCompileError::Optimization(
550                        "Invalid Mul expression".to_string(),
551                    ));
552                }
553                let left = self.egglog_to_jit_repr(parts[1])?;
554                let right = self.egglog_to_jit_repr(parts[2])?;
555                Ok(ASTRepr::Mul(Box::new(left), Box::new(right)))
556            }
557            "Div" => {
558                if parts.len() != 3 {
559                    return Err(MathCompileError::Optimization(
560                        "Invalid Div expression".to_string(),
561                    ));
562                }
563                let left = self.egglog_to_jit_repr(parts[1])?;
564                let right = self.egglog_to_jit_repr(parts[2])?;
565                Ok(ASTRepr::Div(Box::new(left), Box::new(right)))
566            }
567            "Pow" => {
568                if parts.len() != 3 {
569                    return Err(MathCompileError::Optimization(
570                        "Invalid Pow expression".to_string(),
571                    ));
572                }
573                let base = self.egglog_to_jit_repr(parts[1])?;
574                let exp = self.egglog_to_jit_repr(parts[2])?;
575                Ok(ASTRepr::Pow(Box::new(base), Box::new(exp)))
576            }
577            "Neg" => {
578                if parts.len() != 2 {
579                    return Err(MathCompileError::Optimization(
580                        "Invalid Neg expression".to_string(),
581                    ));
582                }
583                let inner = self.egglog_to_jit_repr(parts[1])?;
584                Ok(ASTRepr::Neg(Box::new(inner)))
585            }
586            "Ln" => {
587                if parts.len() != 2 {
588                    return Err(MathCompileError::Optimization(
589                        "Invalid Ln expression".to_string(),
590                    ));
591                }
592                let inner = self.egglog_to_jit_repr(parts[1])?;
593                Ok(ASTRepr::Ln(Box::new(inner)))
594            }
595            "Exp" => {
596                if parts.len() != 2 {
597                    return Err(MathCompileError::Optimization(
598                        "Invalid Exp expression".to_string(),
599                    ));
600                }
601                let inner = self.egglog_to_jit_repr(parts[1])?;
602                Ok(ASTRepr::Exp(Box::new(inner)))
603            }
604            "Sin" => {
605                if parts.len() != 2 {
606                    return Err(MathCompileError::Optimization(
607                        "Invalid Sin expression".to_string(),
608                    ));
609                }
610                let inner = self.egglog_to_jit_repr(parts[1])?;
611                Ok(ASTRepr::Sin(Box::new(inner)))
612            }
613            "Cos" => {
614                if parts.len() != 2 {
615                    return Err(MathCompileError::Optimization(
616                        "Invalid Cos expression".to_string(),
617                    ));
618                }
619                let inner = self.egglog_to_jit_repr(parts[1])?;
620                Ok(ASTRepr::Cos(Box::new(inner)))
621            }
622            "Sqrt" => {
623                if parts.len() != 2 {
624                    return Err(MathCompileError::Optimization(
625                        "Invalid Sqrt expression".to_string(),
626                    ));
627                }
628                let inner = self.egglog_to_jit_repr(parts[1])?;
629                Ok(ASTRepr::Sqrt(Box::new(inner)))
630            }
631            _ => Err(MathCompileError::Optimization(format!(
632                "Unknown egglog operator: {}",
633                parts[0]
634            ))),
635        }
636    }
637
638    /// Parse s-expression parts (helper for parsing)
639    fn parse_sexpr_parts<'a>(&self, input: &'a str) -> Result<Vec<&'a str>> {
640        let mut parts = Vec::new();
641        let mut current_start = 0;
642        let mut paren_depth = 0;
643        let mut in_string = false;
644        let mut escape_next = false;
645
646        let chars: Vec<char> = input.chars().collect();
647        let mut i = 0;
648
649        while i < chars.len() {
650            let ch = chars[i];
651
652            if escape_next {
653                escape_next = false;
654                i += 1;
655                continue;
656            }
657
658            match ch {
659                '\\' if in_string => escape_next = true,
660                '"' => in_string = !in_string,
661                '(' if !in_string => paren_depth += 1,
662                ')' if !in_string => paren_depth -= 1,
663                ' ' | '\t' | '\n' | '\r' if !in_string && paren_depth == 0 => {
664                    if i > current_start {
665                        let part = input[current_start..i].trim();
666                        if !part.is_empty() {
667                            parts.push(part);
668                        }
669                    }
670                    // Skip whitespace
671                    while i + 1 < chars.len() && chars[i + 1].is_whitespace() {
672                        i += 1;
673                    }
674                    current_start = i + 1;
675                }
676                _ => {}
677            }
678
679            i += 1;
680        }
681
682        // Add the last part
683        if current_start < input.len() {
684            let part = input[current_start..].trim();
685            if !part.is_empty() {
686                parts.push(part);
687            }
688        }
689
690        Ok(parts)
691    }
692
693    /// Check if two expressions are structurally equal
694    fn expressions_structurally_equal(&self, a: &ASTRepr<f64>, b: &ASTRepr<f64>) -> bool {
695        match (a, b) {
696            (ASTRepr::Constant(a), ASTRepr::Constant(b)) => (a - b).abs() < f64::EPSILON,
697            (ASTRepr::Variable(a), ASTRepr::Variable(b)) => a == b,
698            (ASTRepr::Add(a1, a2), ASTRepr::Add(b1, b2))
699            | (ASTRepr::Sub(a1, a2), ASTRepr::Sub(b1, b2))
700            | (ASTRepr::Mul(a1, a2), ASTRepr::Mul(b1, b2))
701            | (ASTRepr::Div(a1, a2), ASTRepr::Div(b1, b2))
702            | (ASTRepr::Pow(a1, a2), ASTRepr::Pow(b1, b2)) => {
703                self.expressions_structurally_equal(a1, b1)
704                    && self.expressions_structurally_equal(a2, b2)
705            }
706            (ASTRepr::Neg(a), ASTRepr::Neg(b))
707            | (ASTRepr::Ln(a), ASTRepr::Ln(b))
708            | (ASTRepr::Exp(a), ASTRepr::Exp(b))
709            | (ASTRepr::Sin(a), ASTRepr::Sin(b))
710            | (ASTRepr::Cos(a), ASTRepr::Cos(b))
711            | (ASTRepr::Sqrt(a), ASTRepr::Sqrt(b)) => self.expressions_structurally_equal(a, b),
712            _ => false,
713        }
714    }
715
716    /// Optimize x + 0 patterns
717    fn optimize_add_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
718        match expr {
719            ASTRepr::Add(left, right) => {
720                // Check for x + 0 patterns
721                if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON) {
722                    Ok(right.as_ref().clone())
723                } else if matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
724                {
725                    Ok(left.as_ref().clone())
726                } else {
727                    Ok(expr.clone())
728                }
729            }
730            _ => Ok(expr.clone()),
731        }
732    }
733
734    /// Optimize x + x patterns
735    fn optimize_add_same(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
736        match expr {
737            ASTRepr::Add(left, right) => {
738                // Check for x + x patterns
739                if self.expressions_structurally_equal(left, right) {
740                    Ok(ASTRepr::Mul(Box::new(ASTRepr::Constant(2.0)), left.clone()))
741                } else {
742                    Ok(expr.clone())
743                }
744            }
745            _ => Ok(expr.clone()),
746        }
747    }
748
749    /// Optimize x * 0 patterns
750    fn optimize_mul_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
751        match expr {
752            ASTRepr::Mul(left, right) => {
753                // Check for x * 0 patterns
754                if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
755                    || matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
756                {
757                    Ok(ASTRepr::Constant(0.0))
758                } else {
759                    Ok(expr.clone())
760                }
761            }
762            _ => Ok(expr.clone()),
763        }
764    }
765
766    /// Optimize x * 1 patterns
767    fn optimize_mul_one(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
768        match expr {
769            ASTRepr::Mul(left, right) => {
770                // Check for x * 1 patterns
771                if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON) {
772                    Ok(right.as_ref().clone())
773                } else if matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON)
774                {
775                    Ok(left.as_ref().clone())
776                } else {
777                    Ok(expr.clone())
778                }
779            }
780            _ => Ok(expr.clone()),
781        }
782    }
783
784    /// Optimize ln(exp(x)) patterns
785    fn optimize_ln_exp(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
786        match expr {
787            ASTRepr::Ln(inner) => {
788                // Check for ln(exp(x)) pattern
789                if let ASTRepr::Exp(exp_inner) = inner.as_ref() {
790                    Ok(exp_inner.as_ref().clone())
791                } else {
792                    Ok(expr.clone())
793                }
794            }
795            _ => Ok(expr.clone()),
796        }
797    }
798
799    /// Optimize exp(ln(x)) patterns
800    fn optimize_exp_ln(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
801        match expr {
802            ASTRepr::Exp(inner) => {
803                // Check for exp(ln(x)) pattern
804                if let ASTRepr::Ln(ln_inner) = inner.as_ref() {
805                    Ok(ln_inner.as_ref().clone())
806                } else {
807                    Ok(expr.clone())
808                }
809            }
810            _ => Ok(expr.clone()),
811        }
812    }
813
814    /// Optimize x^0 patterns
815    fn optimize_pow_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
816        match expr {
817            ASTRepr::Pow(_base, exp) => {
818                // Check for x^0 pattern
819                if matches!(exp.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON) {
820                    Ok(ASTRepr::Constant(1.0))
821                } else {
822                    Ok(expr.clone())
823                }
824            }
825            _ => Ok(expr.clone()),
826        }
827    }
828
829    /// Optimize x^1 patterns
830    fn optimize_pow_one(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
831        match expr {
832            ASTRepr::Pow(base, exp) => {
833                // Check for x^1 pattern
834                if matches!(exp.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON) {
835                    Ok(base.as_ref().clone())
836                } else {
837                    Ok(expr.clone())
838                }
839            }
840            _ => Ok(expr.clone()),
841        }
842    }
843}
844
845/// Fallback implementation when egglog feature is not enabled
846#[cfg(not(feature = "optimization"))]
847pub struct EgglogOptimizer;
848
849#[cfg(not(feature = "optimization"))]
850impl EgglogOptimizer {
851    pub fn new() -> Result<Self> {
852        Ok(Self)
853    }
854
855    pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
856        // When egglog is not available, return the expression unchanged
857        Ok(expr.clone())
858    }
859}
860
861/// Helper function to create and use egglog optimizer
862pub fn optimize_with_egglog(expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
863    let mut optimizer = EgglogOptimizer::new()?;
864    optimizer.optimize(expr)
865}
866
867#[cfg(test)]
868mod tests {
869    use super::*;
870    use crate::final_tagless::{ASTEval, ASTMathExpr};
871
872    #[test]
873    fn test_egglog_optimizer_creation() {
874        let result = EgglogOptimizer::new();
875        #[cfg(feature = "optimization")]
876        assert!(result.is_ok());
877        #[cfg(not(feature = "optimization"))]
878        assert!(result.is_ok());
879    }
880
881    #[test]
882    fn test_jit_repr_to_egglog_conversion() {
883        #[cfg(feature = "optimization")]
884        {
885            let optimizer = EgglogOptimizer::new().unwrap();
886
887            // Test simple constant
888            let expr = ASTRepr::Constant(42.0);
889            let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
890            assert_eq!(egglog_str, "(Num 42.0)");
891
892            // Test variable
893            let expr = ASTRepr::Variable(0);
894            let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
895            assert_eq!(egglog_str, "(Var 0)");
896
897            // Test addition
898            let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(1.0));
899            let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
900            assert_eq!(egglog_str, "(Add (Var 0) (Num 1.0))");
901        }
902    }
903
904    #[test]
905    fn test_basic_optimization() {
906        // Test that the optimizer can handle basic expressions
907        let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(0.0));
908        let result = optimize_with_egglog(&expr);
909
910        #[cfg(feature = "optimization")]
911        {
912            // With egglog, this should run the rewrite rules
913            // The extraction might fail (falling back to hand-coded rules), but that's OK
914            // The important thing is that egglog runs and applies rewrite rules
915            assert!(result.is_ok() || result.is_err());
916        }
917
918        #[cfg(not(feature = "optimization"))]
919        {
920            // Without egglog, should return unchanged
921            assert!(result.is_ok());
922        }
923    }
924
925    #[test]
926    fn test_complex_expression_conversion() {
927        #[cfg(feature = "optimization")]
928        {
929            let optimizer = EgglogOptimizer::new().unwrap();
930
931            // Test complex expression: sin(x^2 + 1)
932            let expr = ASTEval::sin(ASTEval::add(
933                ASTEval::pow(ASTEval::var(0), ASTEval::constant(2.0)),
934                ASTEval::constant(1.0),
935            ));
936
937            let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
938            assert!(egglog_str.contains("Sin"));
939            assert!(egglog_str.contains("Add"));
940            assert!(egglog_str.contains("Pow"));
941            assert!(egglog_str.contains("Var 0"));
942        }
943    }
944
945    #[test]
946    fn test_egglog_rules_application() {
947        #[cfg(feature = "optimization")]
948        {
949            let mut optimizer = EgglogOptimizer::new().unwrap();
950
951            // Test that egglog rules are working by trying to optimize x + 0
952            let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(0.0));
953
954            // Convert to egglog format
955            let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
956            assert_eq!(egglog_str, "(Add (Var 0) (Num 0.0))");
957
958            // The optimization might fail at extraction, but egglog should run
959            let _result = optimizer.optimize(&expr);
960            // We don't assert on the result since extraction is simplified
961        }
962    }
963
964    #[test]
965    fn test_sexpr_parsing() {
966        #[cfg(feature = "optimization")]
967        {
968            let optimizer = EgglogOptimizer::new().unwrap();
969
970            // Test parsing simple expressions
971            let parts = optimizer.parse_sexpr_parts("Num 42.0").unwrap();
972            assert_eq!(parts, vec!["Num", "42.0"]);
973
974            let parts = optimizer.parse_sexpr_parts("Var 0").unwrap();
975            assert_eq!(parts, vec!["Var", "0"]);
976
977            let parts = optimizer
978                .parse_sexpr_parts("Add (Num 1.0) (Num 2.0)")
979                .unwrap();
980            assert_eq!(parts, vec!["Add", "(Num 1.0)", "(Num 2.0)"]);
981        }
982    }
983}