dslcompile_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{Expr, Ident, Token, parse_macro_input};
5
6// Direct egglog integration for compile-time optimization
7#[cfg(feature = "optimization")]
8use egglog::EGraph;
9
10/// Procedural macro for compile-time egglog optimization with direct code generation
11///
12/// This macro:
13/// 1. Parses the expression at compile time
14/// 2. Converts to AST representation  
15/// 3. Runs egglog equality saturation during macro expansion
16/// 4. Generates direct Rust expressions (no runtime dispatch, no enums)
17///
18/// Usage: `optimize_compile_time!(expr, [var1, var2, ...])`
19/// Returns: Direct Rust expression that compiles to optimal assembly
20///
21/// Example:
22/// ```rust
23/// use dslcompile_macros::optimize_compile_time;
24///
25/// // This is a compile-time optimization example
26/// // The macro would optimize mathematical expressions at compile time
27/// // For now, this is a placeholder that demonstrates the syntax
28/// # fn main() {
29/// #     // Placeholder test - the actual macro requires more complex setup
30/// #     let x = 1.0;
31/// #     let y = 2.0;
32/// #     let result = x + y; // This would be: optimize_compile_time!(x + y, [x, y]);
33/// #     assert_eq!(result, 3.0);
34/// # }
35/// ```
36#[proc_macro]
37pub fn optimize_compile_time(input: TokenStream) -> TokenStream {
38    let input = parse_macro_input!(input as OptimizeInput);
39
40    // Convert the expression to our internal AST representation
41    let ast = match expr_to_ast(&input.expr) {
42        Ok(ast) => ast,
43        Err(e) => {
44            return syn::Error::new_spanned(
45                &input.expr,
46                format!("Failed to parse expression: {e}"),
47            )
48            .to_compile_error()
49            .into();
50        }
51    };
52
53    // Run egglog optimization at compile time
54    let optimized_ast = run_compile_time_optimization(&ast);
55
56    // Generate direct Rust code
57    let generated_code = ast_to_rust_expr(&optimized_ast, &input.vars);
58
59    // Return the optimized expression
60    quote! {
61        {
62            #generated_code
63        }
64    }
65    .into()
66}
67
68/// Input structure for the macro
69struct OptimizeInput {
70    expr: Expr,
71    vars: Vec<Ident>,
72}
73
74impl syn::parse::Parse for OptimizeInput {
75    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
76        let expr = input.parse()?;
77        input.parse::<Token![,]>()?;
78
79        let content;
80        syn::bracketed!(content in input);
81
82        let vars = content
83            .parse_terminated(Ident::parse, Token![,])?
84            .into_iter()
85            .collect();
86
87        Ok(OptimizeInput { expr, vars })
88    }
89}
90
91/// Compile-time AST representation for procedural macro parsing
92#[derive(Debug, Clone, PartialEq)]
93enum CompileTimeAST {
94    Variable(usize),
95    Constant(f64),
96    Add(Box<CompileTimeAST>, Box<CompileTimeAST>),
97    Mul(Box<CompileTimeAST>, Box<CompileTimeAST>),
98    Sub(Box<CompileTimeAST>, Box<CompileTimeAST>),
99    Sin(Box<CompileTimeAST>),
100    Cos(Box<CompileTimeAST>),
101    Exp(Box<CompileTimeAST>),
102    Ln(Box<CompileTimeAST>),
103    Pow(Box<CompileTimeAST>, Box<CompileTimeAST>),
104}
105
106impl CompileTimeAST {
107    /// Convert to egglog s-expression format
108    #[cfg(feature = "optimization")]
109    fn to_egglog(&self) -> String {
110        match self {
111            CompileTimeAST::Variable(id) => format!("(Var \"x{id}\")"),
112            CompileTimeAST::Constant(val) => {
113                if val.fract() == 0.0 {
114                    format!("(Num {val:.1})")
115                } else {
116                    format!("(Num {val})")
117                }
118            }
119            CompileTimeAST::Add(left, right) => {
120                format!("(Add {} {})", left.to_egglog(), right.to_egglog())
121            }
122            CompileTimeAST::Mul(left, right) => {
123                format!("(Mul {} {})", left.to_egglog(), right.to_egglog())
124            }
125            CompileTimeAST::Sub(left, right) => {
126                // Convert Sub to Add + Neg for canonical form
127                format!("(Add {} (Neg {}))", left.to_egglog(), right.to_egglog())
128            }
129            CompileTimeAST::Sin(inner) => {
130                format!("(Sin {})", inner.to_egglog())
131            }
132            CompileTimeAST::Cos(inner) => {
133                format!("(Cos {})", inner.to_egglog())
134            }
135            CompileTimeAST::Exp(inner) => {
136                format!("(Exp {})", inner.to_egglog())
137            }
138            CompileTimeAST::Ln(inner) => {
139                format!("(Ln {})", inner.to_egglog())
140            }
141            CompileTimeAST::Pow(base, exp) => {
142                format!("(Pow {} {})", base.to_egglog(), exp.to_egglog())
143            }
144        }
145    }
146}
147
148/// Convert Rust expression to our internal AST representation
149fn expr_to_ast(expr: &Expr) -> Result<CompileTimeAST, String> {
150    match expr {
151        // Method calls like var::<0>().sin().add(...)
152        Expr::MethodCall(method_call) => {
153            let receiver_ast = expr_to_ast(&method_call.receiver)?;
154
155            match method_call.method.to_string().as_str() {
156                "add" => {
157                    if method_call.args.len() != 1 {
158                        return Err("add() requires exactly one argument".to_string());
159                    }
160                    let arg_ast = expr_to_ast(&method_call.args[0])?;
161                    Ok(CompileTimeAST::Add(
162                        Box::new(receiver_ast),
163                        Box::new(arg_ast),
164                    ))
165                }
166                "sub" => {
167                    if method_call.args.len() != 1 {
168                        return Err("sub() requires exactly one argument".to_string());
169                    }
170                    let arg_ast = expr_to_ast(&method_call.args[0])?;
171                    Ok(CompileTimeAST::Sub(
172                        Box::new(receiver_ast),
173                        Box::new(arg_ast),
174                    ))
175                }
176                "mul" => {
177                    if method_call.args.len() != 1 {
178                        return Err("mul() requires exactly one argument".to_string());
179                    }
180                    let arg_ast = expr_to_ast(&method_call.args[0])?;
181                    Ok(CompileTimeAST::Mul(
182                        Box::new(receiver_ast),
183                        Box::new(arg_ast),
184                    ))
185                }
186                "div" => {
187                    if method_call.args.len() != 1 {
188                        return Err("div() requires exactly one argument".to_string());
189                    }
190                    let arg_ast = expr_to_ast(&method_call.args[0])?;
191                    // Convert division to multiplication by reciprocal: a / b = a * b^(-1)
192                    Ok(CompileTimeAST::Mul(
193                        Box::new(receiver_ast),
194                        Box::new(CompileTimeAST::Pow(
195                            Box::new(arg_ast),
196                            Box::new(CompileTimeAST::Constant(-1.0)),
197                        )),
198                    ))
199                }
200                "pow" => {
201                    if method_call.args.len() != 1 {
202                        return Err("pow() requires exactly one argument".to_string());
203                    }
204                    let arg_ast = expr_to_ast(&method_call.args[0])?;
205                    Ok(CompileTimeAST::Pow(
206                        Box::new(receiver_ast),
207                        Box::new(arg_ast),
208                    ))
209                }
210                "sin" => {
211                    if !method_call.args.is_empty() {
212                        return Err("sin() takes no arguments".to_string());
213                    }
214                    Ok(CompileTimeAST::Sin(Box::new(receiver_ast)))
215                }
216                "cos" => {
217                    if !method_call.args.is_empty() {
218                        return Err("cos() takes no arguments".to_string());
219                    }
220                    Ok(CompileTimeAST::Cos(Box::new(receiver_ast)))
221                }
222                "exp" => {
223                    if !method_call.args.is_empty() {
224                        return Err("exp() takes no arguments".to_string());
225                    }
226                    Ok(CompileTimeAST::Exp(Box::new(receiver_ast)))
227                }
228                "ln" => {
229                    if !method_call.args.is_empty() {
230                        return Err("ln() takes no arguments".to_string());
231                    }
232                    Ok(CompileTimeAST::Ln(Box::new(receiver_ast)))
233                }
234                "sqrt" => {
235                    if !method_call.args.is_empty() {
236                        return Err("sqrt() takes no arguments".to_string());
237                    }
238                    // Convert sqrt to power of 0.5: sqrt(x) = x^0.5
239                    Ok(CompileTimeAST::Pow(
240                        Box::new(receiver_ast),
241                        Box::new(CompileTimeAST::Constant(0.5)),
242                    ))
243                }
244                "neg" => {
245                    if !method_call.args.is_empty() {
246                        return Err("neg() takes no arguments".to_string());
247                    }
248                    // Convert negation to multiplication by -1: -x = (-1) * x
249                    Ok(CompileTimeAST::Mul(
250                        Box::new(CompileTimeAST::Constant(-1.0)),
251                        Box::new(receiver_ast),
252                    ))
253                }
254                _ => Err(format!("Unknown method: {}", method_call.method)),
255            }
256        }
257
258        // Function calls like var::<0>() or constant(1.0)
259        Expr::Call(call) => {
260            if let Expr::Path(path) = &*call.func {
261                if let Some(segment) = path.path.segments.last() {
262                    match segment.ident.to_string().as_str() {
263                        "var" => {
264                            // Extract the const generic parameter
265                            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
266                                if let Some(syn::GenericArgument::Const(const_expr)) =
267                                    args.args.first()
268                                {
269                                    if let Expr::Lit(syn::ExprLit {
270                                        lit: syn::Lit::Int(lit_int),
271                                        ..
272                                    }) = const_expr
273                                    {
274                                        let var_id: usize = lit_int
275                                            .base10_parse()
276                                            .map_err(|_| "Invalid variable ID".to_string())?;
277                                        return Ok(CompileTimeAST::Variable(var_id));
278                                    }
279                                }
280                            }
281                            Err("Invalid var::<ID>() syntax".to_string())
282                        }
283                        "constant" => {
284                            if call.args.len() != 1 {
285                                return Err("constant() requires exactly one argument".to_string());
286                            }
287
288                            // Handle different types of constant arguments
289                            match &call.args[0] {
290                                // Direct float literal: constant(1.0)
291                                Expr::Lit(syn::ExprLit { lit: syn::Lit::Float(lit_float), .. }) => {
292                                    let value: f64 = lit_float.base10_parse()
293                                        .map_err(|_| "Invalid float literal".to_string())?;
294                                    Ok(CompileTimeAST::Constant(value))
295                                }
296                                // Direct int literal: constant(1)
297                                Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(lit_int), .. }) => {
298                                    let value: f64 = lit_int.base10_parse::<i64>()
299                                        .map_err(|_| "Invalid int literal".to_string())? as f64;
300                                    Ok(CompileTimeAST::Constant(value))
301                                }
302                                // Unary expression: constant(*c) or constant(-1.0)
303                                Expr::Unary(unary) => {
304                                    match unary.op {
305                                        syn::UnOp::Deref(_) => {
306                                            // Handle *c - this is a dereference of a variable
307                                            // For macro purposes, we'll treat this as a runtime constant
308                                            // that gets evaluated when the macro is expanded
309                                            Err("constant() with variable dereference not supported in compile-time optimization".to_string())
310                                        }
311                                        syn::UnOp::Neg(_) => {
312                                            // Handle -1.0 or -1
313                                            match &*unary.expr {
314                                                Expr::Lit(syn::ExprLit { lit: syn::Lit::Float(lit_float), .. }) => {
315                                                    let value: f64 = lit_float.base10_parse()
316                                                        .map_err(|_| "Invalid float literal".to_string())?;
317                                                    Ok(CompileTimeAST::Constant(-value))
318                                                }
319                                                Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(lit_int), .. }) => {
320                                                    let value: f64 = lit_int.base10_parse::<i64>()
321                                                        .map_err(|_| "Invalid int literal".to_string())? as f64;
322                                                    Ok(CompileTimeAST::Constant(-value))
323                                                }
324                                                _ => Err("constant() with complex negative expression not supported".to_string())
325                                            }
326                                        }
327                                        _ => Err("constant() with unsupported unary operator".to_string())
328                                    }
329                                }
330                                // Variable or other complex expression
331                                _ => Err("constant() argument must be a numeric literal (variables not supported in compile-time optimization)".to_string())
332                            }
333                        }
334                        _ => Err(format!("Unknown function: {}", segment.ident)),
335                    }
336                } else {
337                    Err("Invalid function call".to_string())
338                }
339            } else {
340                Err("Complex function calls not supported".to_string())
341            }
342        }
343
344        _ => Err("Unsupported expression type".to_string()),
345    }
346}
347
348/// Run compile-time egglog optimization using the real egglog engine
349#[cfg(feature = "optimization")]
350fn run_compile_time_optimization(ast: &CompileTimeAST) -> CompileTimeAST {
351    // Create egglog instance with mathematical rules
352    let mut egraph = match create_egglog_with_math_rules() {
353        Ok(egraph) => egraph,
354        Err(_) => return ast.clone(), // Fallback to original if egglog fails
355    };
356
357    // Convert to egglog s-expression format
358    let egglog_expr = ast.to_egglog();
359    let expr_id = "expr_0";
360
361    // Add expression to egglog
362    let add_command = format!("(let {expr_id} {egglog_expr})");
363    if egraph.parse_and_run_program(None, &add_command).is_err() {
364        return ast.clone(); // Fallback if adding expression fails
365    }
366
367    // Run optimization rules with STRICT LIMIT to prevent infinite expansion
368    if egraph.parse_and_run_program(None, "(run 3)").is_err() {
369        return ast.clone(); // Fallback if optimization fails
370    }
371
372    // Extract the best expression
373    let extract_command = format!("(extract {expr_id})");
374    match egraph.parse_and_run_program(None, &extract_command) {
375        Ok(result) => {
376            // Parse the result back to CompileTimeAST
377            let output_string = result.join("\n");
378            parse_egglog_result(&output_string).unwrap_or_else(|_| ast.clone())
379        }
380        Err(_) => ast.clone(), // Fallback if extraction fails
381    }
382}
383
384/// Create egglog instance with mathematical optimization rules
385#[cfg(feature = "optimization")]
386fn create_egglog_with_math_rules() -> Result<EGraph, String> {
387    let mut egraph = EGraph::default();
388
389    // Load a SAFE mathematical optimization program (no infinite expansion)
390    let program = r"
391; Mathematical expression datatype
392(datatype Math
393  (Num f64)
394  (Var String)
395  (Add Math Math)
396  (Mul Math Math)
397  (Neg Math)
398  (Pow Math Math)
399  (Ln Math)
400  (Exp Math)
401  (Sin Math)
402  (Cos Math))
403
404; SAFE SIMPLIFICATION RULES (no expansion)
405; Identity rules
406(rewrite (Add a (Num 0.0)) a)
407(rewrite (Add (Num 0.0) a) a)
408(rewrite (Mul a (Num 1.0)) a)
409(rewrite (Mul (Num 1.0) a) a)
410(rewrite (Mul a (Num 0.0)) (Num 0.0))
411(rewrite (Mul (Num 0.0) a) (Num 0.0))
412(rewrite (Pow a (Num 0.0)) (Num 1.0))
413(rewrite (Pow a (Num 1.0)) a)
414
415; SAFE transcendental identities (only simplifying)
416(rewrite (Ln (Exp x)) x)
417; Remove the problematic expansion rule: (rewrite (Exp (Add a b)) (Mul (Exp a) (Exp b)))
418
419; SAFE specific patterns (no general commutativity/associativity)
420(rewrite (Ln (Mul (Exp a) (Exp b))) (Add a b))
421
422; Double negation
423(rewrite (Neg (Neg x)) x)
424
425; Power simplifications
426(rewrite (Pow (Exp x) y) (Exp (Mul x y)))
427(rewrite (Pow x (Num 0.5)) (Sqrt x))
428";
429
430    egraph
431        .parse_and_run_program(None, program)
432        .map_err(|e| format!("Failed to initialize egglog: {e}"))?;
433
434    Ok(egraph)
435}
436
437/// Parse egglog extraction result back to `CompileTimeAST`
438#[cfg(feature = "optimization")]
439fn parse_egglog_result(output: &str) -> Result<CompileTimeAST, String> {
440    let cleaned = output.trim();
441    parse_sexpr(cleaned)
442}
443
444/// Parse a single s-expression to `CompileTimeAST`
445#[cfg(feature = "optimization")]
446fn parse_sexpr(s: &str) -> Result<CompileTimeAST, String> {
447    let s = s.trim();
448
449    if !s.starts_with('(') || !s.ends_with(')') {
450        return Err("Invalid s-expression format".to_string());
451    }
452
453    let inner = &s[1..s.len() - 1];
454    let tokens = tokenize_sexpr(inner)?;
455
456    if tokens.is_empty() {
457        return Err("Empty s-expression".to_string());
458    }
459
460    match tokens[0].as_str() {
461        "Num" => {
462            if tokens.len() != 2 {
463                return Err("Num requires exactly one argument".to_string());
464            }
465            let value: f64 = tokens[1]
466                .parse()
467                .map_err(|_| "Invalid number format".to_string())?;
468            Ok(CompileTimeAST::Constant(value))
469        }
470        "Var" => {
471            if tokens.len() != 2 {
472                return Err("Var requires exactly one argument".to_string());
473            }
474            // Parse variable name like "x0" to get index
475            let var_name = tokens[1].trim_matches('"');
476            if !var_name.starts_with('x') {
477                return Err("Invalid variable name format".to_string());
478            }
479            let index: usize = var_name[1..]
480                .parse()
481                .map_err(|_| "Invalid variable index".to_string())?;
482            Ok(CompileTimeAST::Variable(index))
483        }
484        "Add" => {
485            if tokens.len() != 3 {
486                return Err("Add requires exactly two arguments".to_string());
487            }
488            let left = parse_sexpr(&tokens[1])?;
489            let right = parse_sexpr(&tokens[2])?;
490            Ok(CompileTimeAST::Add(Box::new(left), Box::new(right)))
491        }
492        "Mul" => {
493            if tokens.len() != 3 {
494                return Err("Mul requires exactly two arguments".to_string());
495            }
496            let left = parse_sexpr(&tokens[1])?;
497            let right = parse_sexpr(&tokens[2])?;
498            Ok(CompileTimeAST::Mul(Box::new(left), Box::new(right)))
499        }
500        "Neg" => {
501            if tokens.len() != 2 {
502                return Err("Neg requires exactly one argument".to_string());
503            }
504            let inner = parse_sexpr(&tokens[1])?;
505            // Convert Neg to Mul by -1
506            Ok(CompileTimeAST::Mul(
507                Box::new(CompileTimeAST::Constant(-1.0)),
508                Box::new(inner),
509            ))
510        }
511        "Pow" => {
512            if tokens.len() != 3 {
513                return Err("Pow requires exactly two arguments".to_string());
514            }
515            let base = parse_sexpr(&tokens[1])?;
516            let exp = parse_sexpr(&tokens[2])?;
517            Ok(CompileTimeAST::Pow(Box::new(base), Box::new(exp)))
518        }
519        "Ln" => {
520            if tokens.len() != 2 {
521                return Err("Ln requires exactly one argument".to_string());
522            }
523            let inner = parse_sexpr(&tokens[1])?;
524            Ok(CompileTimeAST::Ln(Box::new(inner)))
525        }
526        "Exp" => {
527            if tokens.len() != 2 {
528                return Err("Exp requires exactly one argument".to_string());
529            }
530            let inner = parse_sexpr(&tokens[1])?;
531            Ok(CompileTimeAST::Exp(Box::new(inner)))
532        }
533        "Sin" => {
534            if tokens.len() != 2 {
535                return Err("Sin requires exactly one argument".to_string());
536            }
537            let inner = parse_sexpr(&tokens[1])?;
538            Ok(CompileTimeAST::Sin(Box::new(inner)))
539        }
540        "Cos" => {
541            if tokens.len() != 2 {
542                return Err("Cos requires exactly one argument".to_string());
543            }
544            let inner = parse_sexpr(&tokens[1])?;
545            Ok(CompileTimeAST::Cos(Box::new(inner)))
546        }
547        _ => Err(format!("Unknown function: {}", tokens[0])),
548    }
549}
550
551/// Tokenize s-expression while respecting nested parentheses
552#[cfg(feature = "optimization")]
553fn tokenize_sexpr(s: &str) -> Result<Vec<String>, String> {
554    let mut tokens = Vec::new();
555    let mut current_token = String::new();
556    let mut paren_depth = 0;
557    let mut in_quotes = false;
558
559    for ch in s.chars() {
560        match ch {
561            '"' => {
562                in_quotes = !in_quotes;
563                current_token.push(ch);
564            }
565            '(' if !in_quotes => {
566                if paren_depth == 0 && !current_token.is_empty() {
567                    tokens.push(current_token.trim().to_string());
568                    current_token.clear();
569                }
570                paren_depth += 1;
571                current_token.push(ch);
572            }
573            ')' if !in_quotes => {
574                paren_depth -= 1;
575                current_token.push(ch);
576                if paren_depth == 0 {
577                    tokens.push(current_token.trim().to_string());
578                    current_token.clear();
579                }
580            }
581            ' ' | '\t' | '\n' if !in_quotes && paren_depth == 0 => {
582                if !current_token.is_empty() {
583                    tokens.push(current_token.trim().to_string());
584                    current_token.clear();
585                }
586            }
587            _ => {
588                current_token.push(ch);
589            }
590        }
591    }
592
593    if !current_token.is_empty() {
594        tokens.push(current_token.trim().to_string());
595    }
596
597    Ok(tokens)
598}
599
600/// Convert optimized AST to direct Rust expression
601fn ast_to_rust_expr(ast: &CompileTimeAST, vars: &[Ident]) -> TokenStream2 {
602    match ast {
603        CompileTimeAST::Constant(c) => {
604            quote! { #c }
605        }
606        CompileTimeAST::Variable(idx) => {
607            if *idx < vars.len() {
608                let var = &vars[*idx];
609                quote! { #var }
610            } else {
611                quote! { 0.0 /* undefined variable */ }
612            }
613        }
614        CompileTimeAST::Add(left, right) => {
615            let left_expr = ast_to_rust_expr(left, vars);
616            let right_expr = ast_to_rust_expr(right, vars);
617            quote! { (#left_expr + #right_expr) }
618        }
619        CompileTimeAST::Sub(left, right) => {
620            let left_expr = ast_to_rust_expr(left, vars);
621            let right_expr = ast_to_rust_expr(right, vars);
622            quote! { (#left_expr - #right_expr) }
623        }
624        CompileTimeAST::Mul(left, right) => {
625            let left_expr = ast_to_rust_expr(left, vars);
626            let right_expr = ast_to_rust_expr(right, vars);
627            quote! { (#left_expr * #right_expr) }
628        }
629        CompileTimeAST::Pow(base, exp) => {
630            let base_expr = ast_to_rust_expr_with_parens(base, vars);
631            let exp_expr = ast_to_rust_expr(exp, vars);
632            quote! { #base_expr.powf(#exp_expr) }
633        }
634        CompileTimeAST::Sin(inner) => {
635            let inner_expr = ast_to_rust_expr_with_parens(inner, vars);
636            quote! { #inner_expr.sin() }
637        }
638        CompileTimeAST::Cos(inner) => {
639            let inner_expr = ast_to_rust_expr_with_parens(inner, vars);
640            quote! { #inner_expr.cos() }
641        }
642        CompileTimeAST::Exp(inner) => {
643            let inner_expr = ast_to_rust_expr_with_parens(inner, vars);
644            quote! { #inner_expr.exp() }
645        }
646        CompileTimeAST::Ln(inner) => {
647            let inner_expr = ast_to_rust_expr_with_parens(inner, vars);
648            quote! { #inner_expr.ln() }
649        }
650    }
651}
652
653/// Generate expression with parentheses when needed for method calls
654fn ast_to_rust_expr_with_parens(ast: &CompileTimeAST, vars: &[Ident]) -> TokenStream2 {
655    match ast {
656        // Simple expressions don't need parentheses for method calls
657        CompileTimeAST::Constant(_) | CompileTimeAST::Variable(_) => ast_to_rust_expr(ast, vars),
658        // Complex expressions need parentheses
659        _ => {
660            let expr = ast_to_rust_expr(ast, vars);
661            quote! { (#expr) }
662        }
663    }
664}
665
666/// Fallback optimization when egglog is not available
667#[cfg(not(feature = "optimization"))]
668fn run_compile_time_optimization(ast: &CompileTimeAST) -> CompileTimeAST {
669    // Apply basic manual optimizations as fallback
670    apply_basic_optimizations(ast)
671}
672
673/// Apply basic optimization rules without egglog
674#[cfg(not(feature = "optimization"))]
675fn apply_basic_optimizations(ast: &CompileTimeAST) -> CompileTimeAST {
676    match ast {
677        // x + 0 -> x
678        CompileTimeAST::Add(left, right) => {
679            let left_opt = apply_basic_optimizations(left);
680            let right_opt = apply_basic_optimizations(right);
681
682            if let CompileTimeAST::Constant(0.0) = right_opt {
683                left_opt
684            } else if let CompileTimeAST::Constant(0.0) = left_opt {
685                right_opt
686            } else {
687                CompileTimeAST::Add(Box::new(left_opt), Box::new(right_opt))
688            }
689        }
690        // x * 1 -> x, x * 0 -> 0
691        CompileTimeAST::Mul(left, right) => {
692            let left_opt = apply_basic_optimizations(left);
693            let right_opt = apply_basic_optimizations(right);
694
695            if let CompileTimeAST::Constant(1.0) = right_opt {
696                left_opt
697            } else if let CompileTimeAST::Constant(1.0) = left_opt {
698                right_opt
699            } else if let CompileTimeAST::Constant(0.0) = right_opt {
700                CompileTimeAST::Constant(0.0)
701            } else if let CompileTimeAST::Constant(0.0) = left_opt {
702                CompileTimeAST::Constant(0.0)
703            } else {
704                CompileTimeAST::Mul(Box::new(left_opt), Box::new(right_opt))
705            }
706        }
707        // ln(exp(x)) -> x
708        CompileTimeAST::Ln(inner) => {
709            let inner_opt = apply_basic_optimizations(inner);
710            if let CompileTimeAST::Exp(exp_inner) = &inner_opt {
711                (**exp_inner).clone()
712            } else {
713                CompileTimeAST::Ln(Box::new(inner_opt))
714            }
715        }
716        // Recursively optimize other expressions
717        CompileTimeAST::Sub(left, right) => CompileTimeAST::Sub(
718            Box::new(apply_basic_optimizations(left)),
719            Box::new(apply_basic_optimizations(right)),
720        ),
721        CompileTimeAST::Pow(base, exp) => CompileTimeAST::Pow(
722            Box::new(apply_basic_optimizations(base)),
723            Box::new(apply_basic_optimizations(exp)),
724        ),
725        CompileTimeAST::Sin(inner) => {
726            CompileTimeAST::Sin(Box::new(apply_basic_optimizations(inner)))
727        }
728        CompileTimeAST::Cos(inner) => {
729            CompileTimeAST::Cos(Box::new(apply_basic_optimizations(inner)))
730        }
731        CompileTimeAST::Exp(inner) => {
732            CompileTimeAST::Exp(Box::new(apply_basic_optimizations(inner)))
733        }
734        // Leaf nodes
735        CompileTimeAST::Variable(_) | CompileTimeAST::Constant(_) => ast.clone(),
736    }
737}