evalexpr_jit/
expr.rs

1//! Expression module for representing mathematical expressions.
2//!
3//! This module defines the core expression types used to represent mathematical expressions
4//! in a form that supports both JIT compilation and automatic differentiation. The main types are:
5//!
6//! - `Expr`: An enum representing different kinds of mathematical expressions
7//! - `VarRef`: A struct containing metadata about variables in expressions
8//!
9//! The expression tree is built recursively using `Box<Expr>` for nested expressions and can be:
10//! - JIT compiled into machine code using Cranelift
11//! - Symbolically differentiated to compute derivatives
12//! - Evaluated efficiently at runtime
13//! - Simplified using algebraic rules
14//! - Modified by inserting replacement expressions
15//!
16//! Supported operations include:
17//! - Basic arithmetic (+, -, *, /)
18//! - Variables and constants
19//! - Absolute value
20//! - Integer exponentiation
21//! - Transcendental functions (exp, ln, sqrt)
22//! - Expression caching for optimization
23//!
24//! # Expression Tree Structure
25//! The expression tree is built recursively with each node being one of:
26//! - Leaf nodes: Constants and Variables
27//! - Unary operations: Abs, Neg, Exp, Ln, Sqrt
28//! - Binary operations: Add, Sub, Mul, Div
29//! - Special nodes: Pow (with integer exponent), Cached expressions
30//!
31//! # Symbolic Differentiation
32//! The derivative method implements symbolic differentiation by recursively applying
33//! calculus rules like:
34//! - Product rule
35//! - Quotient rule
36//! - Chain rule
37//! - Power rule
38//! - Special function derivatives (exp, ln, sqrt)
39//!
40//! # Expression Simplification
41//! The simplify method performs algebraic simplifications including:
42//! - Constant folding (e.g. 2 + 3 → 5)
43//! - Identity rules (e.g. x + 0 → x, x * 1 → x)
44//! - Exponent rules (e.g. x^0 → 1, x^1 → x)
45//! - Expression caching for repeated subexpressions
46//! - Special function simplifications
47//!
48//! # Expression Modification
49//! The insert method allows replacing parts of an expression tree by:
50//! - Matching nodes using a predicate function
51//! - Replacing matched nodes with a new expression
52//! - Recursively traversing and rebuilding the tree
53
54use cranelift::prelude::*;
55use cranelift_codegen::ir::{immediates::Offset32, Value};
56use cranelift_module::Module;
57
58use crate::{errors::EquationError, operators};
59
60/// Represents a reference to a variable in an expression.
61///
62/// Contains metadata needed to generate code that loads the variable's value:
63/// - The variable's name as a string
64/// - A Cranelift Value representing the pointer to the input array
65/// - The variable's index in the input array
66#[derive(Debug, Clone, PartialEq)]
67pub struct VarRef {
68    pub name: String,
69    pub vec_ref: Value,
70    pub index: u32,
71}
72
73/// An expression tree node representing mathematical operations.
74///
75/// This enum represents different types of mathematical expressions that can be:
76/// - JIT compiled into machine code using Cranelift
77/// - Symbolically differentiated to compute derivatives
78/// - Simplified using algebraic rules
79/// - Modified by inserting replacement expressions
80///
81/// The expression tree is built recursively using Box<Expr> for nested expressions.
82#[derive(Debug, Clone, PartialEq)]
83pub enum Expr {
84    /// A constant floating point value
85    Const(f64),
86    /// A reference to a variable
87    Var(VarRef),
88    /// Addition of two expressions
89    Add(Box<Expr>, Box<Expr>),
90    /// Multiplication of two expressions
91    Mul(Box<Expr>, Box<Expr>),
92    /// Subtraction of two expressions
93    Sub(Box<Expr>, Box<Expr>),
94    /// Division of two expressions
95    Div(Box<Expr>, Box<Expr>),
96    /// Absolute value of an expression
97    Abs(Box<Expr>),
98    /// Exponentiation of an expression by an integer constant
99    Pow(Box<Expr>, i64),
100    /// Exponentiation of an expression by a floating point constant
101    PowFloat(Box<Expr>, f64),
102    /// Exponentiation of an expression by another expression
103    PowExpr(Box<Expr>, Box<Expr>),
104    /// Exponential function of an expression
105    Exp(Box<Expr>),
106    /// Natural logarithm of an expression
107    Ln(Box<Expr>),
108    /// Square root of an expression
109    Sqrt(Box<Expr>),
110    /// Sine of an expression (argument in radians)
111    Sin(Box<Expr>),
112    /// Cosine of an expression (argument in radians)
113    Cos(Box<Expr>),
114    /// Negation of an expression
115    Neg(Box<Expr>),
116    /// Cached expression with optional pre-computed value
117    Cached(Box<Expr>, Option<f64>),
118}
119
120/// Linear operation for flattened expression evaluation
121#[derive(Debug, Clone)]
122pub enum LinearOp {
123    /// Load constant value
124    LoadConst(f64),
125    /// Load variable by index
126    LoadVar(u32),
127    /// Add two values from stack positions
128    Add,
129    /// Subtract two values from stack positions  
130    Sub,
131    /// Multiply two values from stack positions
132    Mul,
133    /// Divide two values from stack positions
134    Div,
135    /// Absolute value of stack top
136    Abs,
137    /// Negate stack top
138    Neg,
139    /// Power operation with constant exponent
140    PowConst(i64),
141    /// Power operation with floating point constant exponent
142    PowFloat(f64),
143    /// Power operation with expression exponent
144    PowExpr,
145    /// Exponential of stack top
146    Exp,
147    /// Natural log of stack top
148    Ln,
149    /// Square root of stack top
150    Sqrt,
151    /// Sine of stack top (argument in radians)
152    Sin,
153    /// Cosine of stack top (argument in radians)
154    Cos,
155}
156
157/// Flattened expression representation for efficient evaluation
158#[derive(Debug, Clone)]
159pub struct FlattenedExpr {
160    /// Linear sequence of operations
161    pub ops: Vec<LinearOp>,
162    /// Maximum variable index accessed
163    pub max_var_index: Option<u32>,
164    /// Pre-computed constant result (if expression is constant)
165    pub constant_result: Option<f64>,
166}
167
168impl Expr {
169    /// Pre-evaluates constants and caches variable loads for improved performance
170    pub fn pre_evaluate(
171        &self,
172        var_cache: &mut std::collections::HashMap<String, f64>,
173    ) -> Box<Expr> {
174        match self {
175            Expr::Const(_) => Box::new(self.clone()),
176
177            Expr::Var(var_ref) => {
178                // Check if we can pre-evaluate this variable
179                if let Some(&value) = var_cache.get(&var_ref.name) {
180                    Box::new(Expr::Const(value))
181                } else {
182                    Box::new(self.clone())
183                }
184            }
185
186            Expr::Add(left, right) => {
187                let l = left.pre_evaluate(var_cache);
188                let r = right.pre_evaluate(var_cache);
189                match (&*l, &*r) {
190                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a + b)),
191                    _ => Box::new(Expr::Add(l, r)),
192                }
193            }
194
195            Expr::Sub(left, right) => {
196                let l = left.pre_evaluate(var_cache);
197                let r = right.pre_evaluate(var_cache);
198                match (&*l, &*r) {
199                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a - b)),
200                    _ => Box::new(Expr::Sub(l, r)),
201                }
202            }
203
204            Expr::Mul(left, right) => {
205                let l = left.pre_evaluate(var_cache);
206                let r = right.pre_evaluate(var_cache);
207                match (&*l, &*r) {
208                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a * b)),
209                    _ => Box::new(Expr::Mul(l, r)),
210                }
211            }
212
213            Expr::Div(left, right) => {
214                let l = left.pre_evaluate(var_cache);
215                let r = right.pre_evaluate(var_cache);
216                match (&*l, &*r) {
217                    (Expr::Const(a), Expr::Const(b)) if *b != 0.0 => Box::new(Expr::Const(a / b)),
218                    _ => Box::new(Expr::Div(l, r)),
219                }
220            }
221
222            Expr::Abs(expr) => {
223                let e = expr.pre_evaluate(var_cache);
224                match &*e {
225                    Expr::Const(a) => Box::new(Expr::Const(a.abs())),
226                    _ => Box::new(Expr::Abs(e)),
227                }
228            }
229
230            Expr::Neg(expr) => {
231                let e = expr.pre_evaluate(var_cache);
232                match &*e {
233                    Expr::Const(a) => Box::new(Expr::Const(-a)),
234                    _ => Box::new(Expr::Neg(e)),
235                }
236            }
237
238            Expr::Pow(base, exp) => {
239                let b = base.pre_evaluate(var_cache);
240                match &*b {
241                    Expr::Const(a) => Box::new(Expr::Const(a.powi(*exp as i32))),
242                    _ => Box::new(Expr::Pow(b, *exp)),
243                }
244            }
245
246            Expr::PowFloat(base, exp) => {
247                let b = base.pre_evaluate(var_cache);
248                match &*b {
249                    Expr::Const(a) => Box::new(Expr::Const(a.powf(*exp))),
250                    _ => Box::new(Expr::PowFloat(b, *exp)),
251                }
252            }
253
254            Expr::PowExpr(base, exponent) => {
255                let b = base.pre_evaluate(var_cache);
256                let e = exponent.pre_evaluate(var_cache);
257                match (&*b, &*e) {
258                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a.powf(*b))),
259                    _ => Box::new(Expr::PowExpr(b, e)),
260                }
261            }
262
263            Expr::Exp(expr) => {
264                let e = expr.pre_evaluate(var_cache);
265                match &*e {
266                    Expr::Const(a) => Box::new(Expr::Const(a.exp())),
267                    _ => Box::new(Expr::Exp(e)),
268                }
269            }
270
271            Expr::Ln(expr) => {
272                let e = expr.pre_evaluate(var_cache);
273                match &*e {
274                    Expr::Const(a) if *a > 0.0 => Box::new(Expr::Const(a.ln())),
275                    _ => Box::new(Expr::Ln(e)),
276                }
277            }
278
279            Expr::Sqrt(expr) => {
280                let e = expr.pre_evaluate(var_cache);
281                match &*e {
282                    Expr::Const(a) if *a >= 0.0 => Box::new(Expr::Const(a.sqrt())),
283                    _ => Box::new(Expr::Sqrt(e)),
284                }
285            }
286
287            Expr::Sin(expr) => {
288                let e = expr.pre_evaluate(var_cache);
289                match &*e {
290                    Expr::Const(a) => Box::new(Expr::Const(a.sin())),
291                    _ => Box::new(Expr::Sin(e)),
292                }
293            }
294
295            Expr::Cos(expr) => {
296                let e = expr.pre_evaluate(var_cache);
297                match &*e {
298                    Expr::Const(a) => Box::new(Expr::Const(a.cos())),
299                    _ => Box::new(Expr::Cos(e)),
300                }
301            }
302
303            Expr::Cached(expr, _) => expr.pre_evaluate(var_cache),
304        }
305    }
306
307    /// Computes the symbolic derivative of this expression with respect to a variable.
308    ///
309    /// Recursively applies the rules of differentiation to build a new expression tree
310    /// representing the derivative. The rules implemented are:
311    /// - d/dx(c) = 0 for constants
312    /// - d/dx(x) = 1 for the variable we're differentiating with respect to
313    /// - d/dx(y) = 0 for other variables
314    /// - Sum rule: d/dx(f + g) = df/dx + dg/dx
315    /// - Product rule: d/dx(f * g) = f * dg/dx + g * df/dx
316    /// - Quotient rule: d/dx(f/g) = (g * df/dx - f * dg/dx) / g^2
317    /// - Chain rule for abs: d/dx|f| = f/|f| * df/dx
318    /// - Power rule: d/dx(f^n) = n * f^(n-1) * df/dx
319    /// - Chain rule for exp: d/dx(e^f) = e^f * df/dx
320    /// - Chain rule for ln: d/dx(ln(f)) = 1/f * df/dx
321    /// - Chain rule for sqrt: d/dx(sqrt(f)) = 1/(2*sqrt(f)) * df/dx
322    /// - Negation: d/dx(-f) = -(df/dx)
323    ///
324    /// # Arguments
325    /// * `with_respect_to` - The name of the variable to differentiate with respect to
326    ///
327    /// # Returns
328    /// A new expression tree representing the derivative
329    pub fn derivative(&self, with_respect_to: &str) -> Box<Expr> {
330        match self {
331            Expr::Const(_) => Box::new(Expr::Const(0.0)),
332
333            Expr::Var(var_ref) => {
334                if var_ref.name == with_respect_to {
335                    Box::new(Expr::Const(1.0))
336                } else {
337                    Box::new(Expr::Const(0.0))
338                }
339            }
340
341            Expr::Add(left, right) => {
342                // d/dx(f + g) = df/dx + dg/dx
343                Box::new(Expr::Add(
344                    left.derivative(with_respect_to),
345                    right.derivative(with_respect_to),
346                ))
347            }
348
349            Expr::Sub(left, right) => {
350                // d/dx(f - g) = df/dx - dg/dx
351                Box::new(Expr::Sub(
352                    left.derivative(with_respect_to),
353                    right.derivative(with_respect_to),
354                ))
355            }
356
357            Expr::Mul(left, right) => {
358                // d/dx(f * g) = f * dg/dx + g * df/dx
359                Box::new(Expr::Add(
360                    Box::new(Expr::Mul(left.clone(), right.derivative(with_respect_to))),
361                    Box::new(Expr::Mul(right.clone(), left.derivative(with_respect_to))),
362                ))
363            }
364
365            Expr::Div(left, right) => {
366                // d/dx(f/g) = (g * df/dx - f * dg/dx) / g^2
367                Box::new(Expr::Div(
368                    Box::new(Expr::Sub(
369                        Box::new(Expr::Mul(right.clone(), left.derivative(with_respect_to))),
370                        Box::new(Expr::Mul(left.clone(), right.derivative(with_respect_to))),
371                    )),
372                    Box::new(Expr::Pow(right.clone(), 2)),
373                ))
374            }
375
376            Expr::Abs(expr) => {
377                // d/dx|f| = f/|f| * df/dx
378                Box::new(Expr::Mul(
379                    Box::new(Expr::Div(expr.clone(), Box::new(Expr::Abs(expr.clone())))),
380                    expr.derivative(with_respect_to),
381                ))
382            }
383
384            Expr::Pow(base, exp) => {
385                // d/dx(f^n) = n * f^(n-1) * df/dx
386                Box::new(Expr::Mul(
387                    Box::new(Expr::Mul(
388                        Box::new(Expr::Const(*exp as f64)),
389                        Box::new(Expr::Pow(base.clone(), exp - 1)),
390                    )),
391                    base.derivative(with_respect_to),
392                ))
393            }
394
395            Expr::PowFloat(base, exp) => {
396                // d/dx(f^c) = c * f^(c-1) * df/dx
397                Box::new(Expr::Mul(
398                    Box::new(Expr::Mul(
399                        Box::new(Expr::Const(*exp)),
400                        Box::new(Expr::PowFloat(base.clone(), exp - 1.0)),
401                    )),
402                    base.derivative(with_respect_to),
403                ))
404            }
405
406            Expr::PowExpr(base, exponent) => {
407                // d/dx(f^g) = f^g * (g' * ln(f) + g * f'/f)
408                // Using the general power rule
409                Box::new(Expr::Mul(
410                    Box::new(Expr::PowExpr(base.clone(), exponent.clone())),
411                    Box::new(Expr::Add(
412                        Box::new(Expr::Mul(
413                            exponent.derivative(with_respect_to),
414                            Box::new(Expr::Ln(base.clone())),
415                        )),
416                        Box::new(Expr::Mul(
417                            exponent.clone(),
418                            Box::new(Expr::Div(base.derivative(with_respect_to), base.clone())),
419                        )),
420                    )),
421                ))
422            }
423
424            Expr::Exp(expr) => {
425                // d/dx(e^f) = e^f * df/dx
426                Box::new(Expr::Mul(
427                    Box::new(Expr::Exp(expr.clone())),
428                    expr.derivative(with_respect_to),
429                ))
430            }
431
432            Expr::Ln(expr) => {
433                // d/dx(ln(f)) = 1/f * df/dx
434                Box::new(Expr::Mul(
435                    Box::new(Expr::Div(Box::new(Expr::Const(1.0)), expr.clone())),
436                    expr.derivative(with_respect_to),
437                ))
438            }
439
440            Expr::Sqrt(expr) => {
441                // d/dx(sqrt(f)) = 1/(2*sqrt(f)) * df/dx
442                Box::new(Expr::Mul(
443                    Box::new(Expr::Div(
444                        Box::new(Expr::Const(1.0)),
445                        Box::new(Expr::Sqrt(expr.clone())),
446                    )),
447                    expr.derivative(with_respect_to),
448                ))
449            }
450
451            Expr::Sin(expr) => {
452                // d/dx(sin(f)) = cos(f) * df/dx
453                Box::new(Expr::Mul(
454                    Box::new(Expr::Cos(expr.clone())),
455                    expr.derivative(with_respect_to),
456                ))
457            }
458
459            Expr::Cos(expr) => {
460                // d/dx(cos(f)) = -sin(f) * df/dx
461                Box::new(Expr::Mul(
462                    Box::new(Expr::Neg(Box::new(Expr::Sin(expr.clone())))),
463                    expr.derivative(with_respect_to),
464                ))
465            }
466
467            Expr::Neg(expr) => {
468                // d/dx(-f) = -(df/dx)
469                Box::new(Expr::Neg(expr.derivative(with_respect_to)))
470            }
471
472            Expr::Cached(expr, _) => expr.derivative(with_respect_to),
473        }
474    }
475
476    /// Simplifies the expression by folding constants and applying basic algebraic rules.
477    ///
478    /// This method performs several types of algebraic simplifications:
479    ///
480    /// # Constant Folding
481    /// - Evaluates constant expressions: 2 + 3 → 5
482    /// - Simplifies operations with special constants: x * 0 → 0
483    ///
484    /// # Identity Rules
485    /// - Additive identity: x + 0 → x
486    /// - Multiplicative identity: x * 1 → x
487    /// - Division identity: x / 1 → x
488    /// - Division by self: x / x → 1
489    ///
490    /// # Exponent Rules
491    /// - Zero exponent: x^0 → 1 (except when x = 0)
492    /// - First power: x^1 → x
493    /// - Nested exponents: (x^a)^b → x^(a*b)
494    ///
495    /// # Special Function Simplification
496    /// - Absolute value: |-3| → 3, ||x|| → |x|
497    /// - Double negation: -(-x) → x
498    /// - Evaluates constant special functions: ln(1) → 0
499    ///
500    /// # Expression Caching
501    /// - Caches repeated subexpressions to avoid redundant computation
502    /// - Preserves existing cached values
503    ///
504    /// # Returns
505    /// A new simplified expression tree
506    pub fn simplify(&self) -> Box<Expr> {
507        match self {
508            // Base cases - constants and variables remain unchanged
509            Expr::Const(_) | Expr::Var(_) => Box::new(self.clone()),
510
511            Expr::Add(left, right) => {
512                let l = left.simplify();
513                let r = right.simplify();
514                match (&*l, &*r) {
515                    // Fold constants: 1 + 2 -> 3
516                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a + b)),
517                    // Identity: x + 0 -> x
518                    (expr, Expr::Const(0.0)) | (Expr::Const(0.0), expr) => Box::new(expr.clone()),
519                    // Combine like terms: c1*x + c2*x -> (c1+c2)*x
520                    (Expr::Mul(a1, x1), Expr::Mul(a2, x2)) if x1 == x2 => {
521                        let combined_coeff = Expr::Add(a1.clone(), a2.clone()).simplify();
522                        Box::new(Expr::Mul(combined_coeff, x1.clone()))
523                    }
524                    // Associativity: (x + c1) + c2 -> x + (c1 + c2)
525                    (Expr::Add(x, c1), c2)
526                        if matches!(**c1, Expr::Const(_)) && matches!(*c2, Expr::Const(_)) =>
527                    {
528                        Box::new(Expr::Add(
529                            x.clone(),
530                            Expr::Add(c1.clone(), Box::new(c2.clone())).simplify(),
531                        ))
532                    }
533                    _ => Box::new(Expr::Add(l, r)),
534                }
535            }
536
537            Expr::Sub(left, right) => {
538                let l = left.simplify();
539                let r = right.simplify();
540                match (&*l, &*r) {
541                    // Fold constants: 3 - 2 -> 1
542                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a - b)),
543                    // Identity: x - 0 -> x
544                    (expr, Expr::Const(0.0)) => Box::new(expr.clone()),
545                    // Zero: x - x -> 0
546                    (a, b) if a == b => Box::new(Expr::Const(0.0)),
547                    // Combine like terms: c1*x - c2*x -> (c1-c2)*x
548                    (Expr::Mul(a1, x1), Expr::Mul(a2, x2)) if x1 == x2 => {
549                        let combined_coeff = Expr::Sub(a1.clone(), a2.clone()).simplify();
550                        Box::new(Expr::Mul(combined_coeff, x1.clone()))
551                    }
552                    // Convert subtraction to addition: x - c -> x + (-c)
553                    (x, Expr::Const(c)) => {
554                        Box::new(Expr::Add(Box::new(x.clone()), Box::new(Expr::Const(-c))))
555                    }
556                    _ => Box::new(Expr::Sub(l, r)),
557                }
558            }
559
560            Expr::Mul(left, right) => {
561                let l = left.simplify();
562                let r = right.simplify();
563
564                // Common subexpression elimination
565                if l == r {
566                    return Box::new(Expr::Pow(l, 2)); // x * x -> x^2
567                }
568
569                match (&*l, &*r) {
570                    // Fold constants: 2 * 3 -> 6
571                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a * b)),
572                    // Zero property: x * 0 -> 0
573                    (Expr::Const(0.0), _) | (_, Expr::Const(0.0)) => Box::new(Expr::Const(0.0)),
574                    // Identity: x * 1 -> x
575                    (expr, Expr::Const(1.0)) | (Expr::Const(1.0), expr) => Box::new(expr.clone()),
576                    // Negative one: x * (-1) -> -x
577                    (expr, Expr::Const(-1.0)) | (Expr::Const(-1.0), expr) => {
578                        Box::new(Expr::Neg(Box::new(expr.clone())))
579                    }
580                    // Combine exponents: x^a * x^b -> x^(a+b)
581                    (Expr::Pow(b1, e1), Expr::Pow(b2, e2)) if b1 == b2 => {
582                        Box::new(Expr::Pow(b1.clone(), e1 + e2))
583                    }
584                    // Distribute constants: c * (x + y) -> c*x + c*y (only if beneficial)
585                    (Expr::Const(c), Expr::Add(x, y)) | (Expr::Add(x, y), Expr::Const(c))
586                        if c.abs() < 10.0 =>
587                    {
588                        Expr::Add(
589                            Box::new(Expr::Mul(Box::new(Expr::Const(*c)), x.clone())),
590                            Box::new(Expr::Mul(Box::new(Expr::Const(*c)), y.clone())),
591                        )
592                        .simplify()
593                    }
594                    // Strength reduction: x * 2 -> x + x (only for small integers)
595                    (expr, Expr::Const(2.0)) | (Expr::Const(2.0), expr) => {
596                        Box::new(Expr::Add(Box::new(expr.clone()), Box::new(expr.clone())))
597                    }
598                    // Associativity: (c1 * x) * c2 -> (c1 * c2) * x
599                    (Expr::Mul(c1, x), c2)
600                        if matches!(**c1, Expr::Const(_)) && matches!(*c2, Expr::Const(_)) =>
601                    {
602                        Box::new(Expr::Mul(
603                            Expr::Mul(c1.clone(), Box::new(c2.clone())).simplify(),
604                            x.clone(),
605                        ))
606                    }
607                    _ => Box::new(Expr::Mul(l, r)),
608                }
609            }
610
611            Expr::Div(left, right) => {
612                let l = left.simplify();
613                let r = right.simplify();
614                match (&*l, &*r) {
615                    // Fold constants: 6 / 2 -> 3
616                    (Expr::Const(a), Expr::Const(b)) if *b != 0.0 => Box::new(Expr::Const(a / b)),
617                    // Zero numerator: 0 / x -> 0
618                    (Expr::Const(0.0), _) => Box::new(Expr::Const(0.0)),
619                    // Identity: x / 1 -> x
620                    (expr, Expr::Const(1.0)) => Box::new(expr.clone()),
621                    // Division by negative one: x / (-1) -> -x
622                    (expr, Expr::Const(-1.0)) => Box::new(Expr::Neg(Box::new(expr.clone()))),
623                    // Identity: x / x -> 1
624                    (a, b) if a == b => Box::new(Expr::Const(1.0)),
625                    // Simplify exponents: x^a / x^b -> x^(a-b)
626                    (Expr::Pow(b1, e1), Expr::Pow(b2, e2)) if b1 == b2 => {
627                        Box::new(Expr::Pow(b1.clone(), e1 - e2))
628                    }
629                    // Convert division by constant to multiplication: x / c -> x * (1/c)
630                    (x, Expr::Const(c)) if *c != 0.0 && c.abs() > 1e-10 => Box::new(Expr::Mul(
631                        Box::new(x.clone()),
632                        Box::new(Expr::Const(1.0 / c)),
633                    )),
634                    // Simplify nested divisions: (x/y)/z -> x/(y*z)
635                    (Expr::Div(x, y), z) => Box::new(Expr::Div(
636                        x.clone(),
637                        Box::new(Expr::Mul(y.clone(), Box::new(z.clone()))),
638                    )),
639                    _ => Box::new(Expr::Div(l, r)),
640                }
641            }
642
643            Expr::Abs(expr) => {
644                let e = expr.simplify();
645                match &*e {
646                    // Fold constants: abs(3) -> 3
647                    Expr::Const(a) => Box::new(Expr::Const(a.abs())),
648                    // Nested abs: abs(abs(x)) -> abs(x)
649                    Expr::Abs(inner) => Box::new(Expr::Abs(inner.clone())),
650                    // abs(-x) -> abs(x)
651                    Expr::Neg(inner) => Box::new(Expr::Abs(inner.clone())),
652                    // abs(x^2) -> x^2 (even powers are always positive)
653                    Expr::Pow(_, exp) if exp % 2 == 0 => e,
654                    _ => Box::new(Expr::Abs(e)),
655                }
656            }
657
658            Expr::Pow(base, exp) => {
659                let b = base.simplify();
660                match (&*b, exp) {
661                    // x^0 -> 1 (including 0^0 = 1 by convention)
662                    (_, 0) => Box::new(Expr::Const(1.0)),
663                    // Fold constants: 2^3 -> 8
664                    (Expr::Const(a), exp) => Box::new(Expr::Const(a.powi(*exp as i32))),
665                    // Identity: x^1 -> x
666                    (expr, 1) => Box::new(expr.clone()),
667                    // Simplify negative exponents: x^(-n) -> 1/(x^n)
668                    (expr, exp) if *exp < 0 => Box::new(Expr::Div(
669                        Box::new(Expr::Const(1.0)),
670                        Box::new(Expr::Pow(Box::new(expr.clone()), -exp)),
671                    )),
672                    // Nested exponents: (x^a)^b -> x^(a*b)
673                    (Expr::Pow(inner_base, inner_exp), outer_exp) => {
674                        Box::new(Expr::Pow(inner_base.clone(), inner_exp * outer_exp))
675                    }
676                    // Power of product: (x*y)^n -> x^n * y^n (only for small n)
677                    (Expr::Mul(x, y), n) if *n >= 2 && *n <= 4 => Box::new(Expr::Mul(
678                        Box::new(Expr::Pow(x.clone(), *n)),
679                        Box::new(Expr::Pow(y.clone(), *n)),
680                    )),
681                    _ => Box::new(Expr::Pow(b, *exp)),
682                }
683            }
684
685            Expr::PowFloat(base, exp) => {
686                let b = base.simplify();
687                match (&*b, exp) {
688                    // x^0.0 -> 1
689                    (_, exp) if exp.abs() < 1e-10 => Box::new(Expr::Const(1.0)),
690                    // Fold constants: 2.0^3.5 -> result
691                    (Expr::Const(a), exp) => Box::new(Expr::Const(a.powf(*exp))),
692                    // Identity: x^1.0 -> x
693                    (expr, exp) if (exp - 1.0).abs() < 1e-10 => Box::new(expr.clone()),
694                    // Convert to integer power if possible
695                    (expr, exp) if exp.fract().abs() < 1e-10 => {
696                        Box::new(Expr::Pow(Box::new(expr.clone()), *exp as i64))
697                    }
698                    _ => Box::new(Expr::PowFloat(b, *exp)),
699                }
700            }
701
702            Expr::PowExpr(base, exponent) => {
703                let b = base.simplify();
704                let e = exponent.simplify();
705                match (&*b, &*e) {
706                    // Fold constants: 2^3 -> 8
707                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a.powf(*b))),
708                    // x^0 -> 1
709                    (_, Expr::Const(0.0)) => Box::new(Expr::Const(1.0)),
710                    // x^1 -> x
711                    (expr, Expr::Const(1.0)) => Box::new(expr.clone()),
712                    // Convert to simpler forms if exponent is constant
713                    (expr, Expr::Const(exp)) if exp.fract().abs() < 1e-10 => {
714                        Box::new(Expr::Pow(Box::new(expr.clone()), *exp as i64))
715                    }
716                    (expr, Expr::Const(exp)) => {
717                        Box::new(Expr::PowFloat(Box::new(expr.clone()), *exp))
718                    }
719                    _ => Box::new(Expr::PowExpr(b, e)),
720                }
721            }
722
723            Expr::Exp(expr) => {
724                let e = expr.simplify();
725                match &*e {
726                    // exp(0) -> 1
727                    Expr::Const(0.0) => Box::new(Expr::Const(1.0)),
728                    // Fold constants: exp(c) -> e^c
729                    Expr::Const(a) => Box::new(Expr::Const(a.exp())),
730                    // exp(ln(x)) -> x
731                    Expr::Ln(inner) => inner.clone(),
732                    // exp(x + y) -> exp(x) * exp(y)
733                    Expr::Add(x, y) => Box::new(Expr::Mul(
734                        Box::new(Expr::Exp(x.clone())),
735                        Box::new(Expr::Exp(y.clone())),
736                    )),
737                    _ => Box::new(Expr::Exp(e)),
738                }
739            }
740
741            Expr::Ln(expr) => {
742                let e = expr.simplify();
743                match &*e {
744                    // Fold constants: ln(c) -> ln(c)
745                    Expr::Const(a) if *a > 0.0 => Box::new(Expr::Const(a.ln())),
746                    // ln(1) -> 0
747                    Expr::Const(1.0) => Box::new(Expr::Const(0.0)),
748                    // ln(exp(x)) -> x
749                    Expr::Exp(inner) => inner.clone(),
750                    // ln(x*y) -> ln(x) + ln(y)
751                    Expr::Mul(x, y) => Box::new(Expr::Add(
752                        Box::new(Expr::Ln(x.clone())),
753                        Box::new(Expr::Ln(y.clone())),
754                    )),
755                    // ln(x/y) -> ln(x) - ln(y)
756                    Expr::Div(x, y) => Box::new(Expr::Sub(
757                        Box::new(Expr::Ln(x.clone())),
758                        Box::new(Expr::Ln(y.clone())),
759                    )),
760                    // ln(x^n) -> n * ln(x)
761                    Expr::Pow(x, n) => Box::new(Expr::Mul(
762                        Box::new(Expr::Const(*n as f64)),
763                        Box::new(Expr::Ln(x.clone())),
764                    )),
765                    _ => Box::new(Expr::Ln(e)),
766                }
767            }
768
769            Expr::Sqrt(expr) => {
770                let e = expr.simplify();
771                match &*e {
772                    // Fold constants: sqrt(c) -> sqrt(c)
773                    Expr::Const(a) if *a >= 0.0 => Box::new(Expr::Const(a.sqrt())),
774                    // sqrt(0) -> 0
775                    Expr::Const(0.0) => Box::new(Expr::Const(0.0)),
776                    // sqrt(1) -> 1
777                    Expr::Const(1.0) => Box::new(Expr::Const(1.0)),
778                    // sqrt(x^2) -> abs(x)
779                    Expr::Pow(x, 2) => Box::new(Expr::Abs(x.clone())),
780                    // sqrt(x*y) -> sqrt(x) * sqrt(y)
781                    Expr::Mul(x, y) => Box::new(Expr::Mul(
782                        Box::new(Expr::Sqrt(x.clone())),
783                        Box::new(Expr::Sqrt(y.clone())),
784                    )),
785                    _ => Box::new(Expr::Sqrt(e)),
786                }
787            }
788
789            Expr::Sin(expr) => {
790                let e = expr.simplify();
791                match &*e {
792                    // sin(0) -> 0
793                    Expr::Const(0.0) => Box::new(Expr::Const(0.0)),
794                    // Fold constants: sin(c) -> sin(c)
795                    Expr::Const(a) => Box::new(Expr::Const(a.sin())),
796                    _ => Box::new(Expr::Sin(e)),
797                }
798            }
799
800            Expr::Cos(expr) => {
801                let e = expr.simplify();
802                match &*e {
803                    // cos(0) -> 1
804                    Expr::Const(0.0) => Box::new(Expr::Const(1.0)),
805                    // Fold constants: cos(c) -> cos(c)
806                    Expr::Const(a) => Box::new(Expr::Const(a.cos())),
807                    _ => Box::new(Expr::Cos(e)),
808                }
809            }
810
811            Expr::Neg(expr) => {
812                let e = expr.simplify();
813                match &*e {
814                    // Fold constants: -3 -> -3
815                    Expr::Const(a) => Box::new(Expr::Const(-a)),
816                    // Double negation: -(-x) -> x
817                    Expr::Neg(inner) => inner.clone(),
818                    // Distribute negation: -(x + y) -> -x - y
819                    Expr::Add(x, y) => {
820                        Expr::Sub(Box::new(Expr::Neg(x.clone())), y.clone()).simplify()
821                    }
822                    // Distribute negation: -(x - y) -> -x + y
823                    Expr::Sub(x, y) => {
824                        Expr::Add(Box::new(Expr::Neg(x.clone())), y.clone()).simplify()
825                    }
826                    // Factor out negation: -(c*x) -> (-c)*x
827                    Expr::Mul(c, x) if matches!(**c, Expr::Const(_)) => {
828                        Expr::Mul(Box::new(Expr::Neg(c.clone())), x.clone()).simplify()
829                    }
830                    _ => Box::new(Expr::Neg(e)),
831                }
832            }
833
834            Expr::Cached(expr, cached_value) => {
835                if cached_value.is_some() {
836                    Box::new(self.clone())
837                } else {
838                    // Simplify the inner expression directly
839                    expr.simplify()
840                }
841            }
842        }
843    }
844
845    /// Inserts an expression by replacing nodes that match a predicate.
846    ///
847    /// Recursively traverses the expression tree and replaces any nodes that match
848    /// the given predicate with the replacement expression. This allows for targeted
849    /// modifications of the expression tree.
850    ///
851    /// # Arguments
852    /// * `predicate` - A closure that determines which nodes to replace
853    /// * `replacement` - The expression to insert where the predicate matches
854    ///
855    /// # Returns
856    /// A new expression tree with the replacements applied
857    pub fn insert<F>(&self, predicate: F, replacement: &Expr) -> Box<Expr>
858    where
859        F: Fn(&Expr) -> bool + Clone,
860    {
861        if predicate(self) {
862            Box::new(replacement.clone())
863        } else {
864            match self {
865                Expr::Const(_) | Expr::Var(_) => Box::new(self.clone()),
866                Expr::Add(left, right) => Box::new(Expr::Add(
867                    left.insert(predicate.clone(), replacement),
868                    right.insert(predicate, replacement),
869                )),
870                Expr::Mul(left, right) => Box::new(Expr::Mul(
871                    left.insert(predicate.clone(), replacement),
872                    right.insert(predicate, replacement),
873                )),
874                Expr::Sub(left, right) => Box::new(Expr::Sub(
875                    left.insert(predicate.clone(), replacement),
876                    right.insert(predicate, replacement),
877                )),
878                Expr::Div(left, right) => Box::new(Expr::Div(
879                    left.insert(predicate.clone(), replacement),
880                    right.insert(predicate, replacement),
881                )),
882                Expr::Abs(expr) => Box::new(Expr::Abs(expr.insert(predicate, replacement))),
883                Expr::Pow(base, exp) => {
884                    Box::new(Expr::Pow(base.insert(predicate, replacement), *exp))
885                }
886                Expr::PowFloat(base, exp) => {
887                    Box::new(Expr::PowFloat(base.insert(predicate, replacement), *exp))
888                }
889                Expr::PowExpr(base, exponent) => Box::new(Expr::PowExpr(
890                    base.insert(predicate.clone(), replacement),
891                    exponent.insert(predicate, replacement),
892                )),
893                Expr::Exp(expr) => Box::new(Expr::Exp(expr.insert(predicate, replacement))),
894                Expr::Ln(expr) => Box::new(Expr::Ln(expr.insert(predicate, replacement))),
895                Expr::Sqrt(expr) => Box::new(Expr::Sqrt(expr.insert(predicate, replacement))),
896                Expr::Sin(expr) => Box::new(Expr::Sin(expr.insert(predicate, replacement))),
897                Expr::Cos(expr) => Box::new(Expr::Cos(expr.insert(predicate, replacement))),
898                Expr::Neg(expr) => Box::new(Expr::Neg(expr.insert(predicate, replacement))),
899                Expr::Cached(expr, _) => {
900                    Box::new(Expr::Cached(expr.insert(predicate, replacement), None))
901                }
902            }
903        }
904    }
905
906    /// Converts expression tree to flattened linear operations for efficient evaluation.
907    ///
908    /// This optimization eliminates:
909    /// - Tree traversal overhead
910    /// - Function call overhead  
911    /// - Memory allocation in hot path
912    /// - Variable lookup overhead
913    ///
914    /// The result is a linear sequence of stack-based operations that can be
915    /// executed with minimal overhead.
916    pub fn flatten(&self) -> FlattenedExpr {
917        let mut ops = Vec::new();
918        let mut max_var_index = None;
919
920        // Check if entire expression is constant
921        if let Some(constant) = self.try_evaluate_constant() {
922            return FlattenedExpr {
923                ops: vec![LinearOp::LoadConst(constant)],
924                max_var_index: None,
925                constant_result: Some(constant),
926            };
927        }
928
929        self.flatten_recursive(&mut ops, &mut max_var_index);
930
931        FlattenedExpr {
932            ops,
933            max_var_index,
934            constant_result: None,
935        }
936    }
937
938    /// Tries to evaluate expression as constant (aggressive constant folding)
939    fn try_evaluate_constant(&self) -> Option<f64> {
940        match self {
941            Expr::Const(val) => Some(*val),
942            Expr::Var(_) => None,
943            Expr::Add(left, right) => {
944                Some(left.try_evaluate_constant()? + right.try_evaluate_constant()?)
945            }
946            Expr::Sub(left, right) => {
947                Some(left.try_evaluate_constant()? - right.try_evaluate_constant()?)
948            }
949            Expr::Mul(left, right) => {
950                Some(left.try_evaluate_constant()? * right.try_evaluate_constant()?)
951            }
952            Expr::Div(left, right) => {
953                let r = right.try_evaluate_constant()?;
954                if r.abs() < 1e-10 {
955                    return None;
956                }
957                Some(left.try_evaluate_constant()? / r)
958            }
959            Expr::Abs(expr) => Some(expr.try_evaluate_constant()?.abs()),
960            Expr::Neg(expr) => Some(-expr.try_evaluate_constant()?),
961            Expr::Pow(base, exp) => Some(base.try_evaluate_constant()?.powi(*exp as i32)),
962            Expr::PowFloat(base, exp) => Some(base.try_evaluate_constant()?.powf(*exp)),
963            Expr::PowExpr(base, exponent) => Some(
964                base.try_evaluate_constant()?
965                    .powf(exponent.try_evaluate_constant()?),
966            ),
967            Expr::Exp(expr) => Some(expr.try_evaluate_constant()?.exp()),
968            Expr::Ln(expr) => {
969                let val = expr.try_evaluate_constant()?;
970                if val <= 0.0 {
971                    return None;
972                }
973                Some(val.ln())
974            }
975            Expr::Sqrt(expr) => {
976                let val = expr.try_evaluate_constant()?;
977                if val < 0.0 {
978                    return None;
979                }
980                Some(val.sqrt())
981            }
982            Expr::Sin(expr) => Some(expr.try_evaluate_constant()?.sin()),
983            Expr::Cos(expr) => Some(expr.try_evaluate_constant()?.cos()),
984            Expr::Cached(expr, cached_value) => {
985                cached_value.or_else(|| expr.try_evaluate_constant())
986            }
987        }
988    }
989
990    /// Recursively flattens expression into linear operations
991    fn flatten_recursive(&self, ops: &mut Vec<LinearOp>, max_var_index: &mut Option<u32>) {
992        match self {
993            Expr::Const(val) => {
994                ops.push(LinearOp::LoadConst(*val));
995            }
996
997            Expr::Var(var_ref) => {
998                let index = var_ref.index;
999                *max_var_index = Some(max_var_index.unwrap_or(0).max(index));
1000                ops.push(LinearOp::LoadVar(index));
1001            }
1002
1003            Expr::Add(left, right) => {
1004                left.flatten_recursive(ops, max_var_index);
1005                right.flatten_recursive(ops, max_var_index);
1006                ops.push(LinearOp::Add);
1007            }
1008
1009            Expr::Sub(left, right) => {
1010                left.flatten_recursive(ops, max_var_index);
1011                right.flatten_recursive(ops, max_var_index);
1012                ops.push(LinearOp::Sub);
1013            }
1014
1015            Expr::Mul(left, right) => {
1016                left.flatten_recursive(ops, max_var_index);
1017                right.flatten_recursive(ops, max_var_index);
1018                ops.push(LinearOp::Mul);
1019            }
1020
1021            Expr::Div(left, right) => {
1022                left.flatten_recursive(ops, max_var_index);
1023                right.flatten_recursive(ops, max_var_index);
1024                ops.push(LinearOp::Div);
1025            }
1026
1027            Expr::Abs(expr) => {
1028                expr.flatten_recursive(ops, max_var_index);
1029                ops.push(LinearOp::Abs);
1030            }
1031
1032            Expr::Neg(expr) => {
1033                expr.flatten_recursive(ops, max_var_index);
1034                ops.push(LinearOp::Neg);
1035            }
1036
1037            Expr::Pow(base, exp) => {
1038                base.flatten_recursive(ops, max_var_index);
1039                ops.push(LinearOp::PowConst(*exp));
1040            }
1041
1042            Expr::PowFloat(base, exp) => {
1043                base.flatten_recursive(ops, max_var_index);
1044                ops.push(LinearOp::PowFloat(*exp));
1045            }
1046
1047            Expr::PowExpr(base, exponent) => {
1048                base.flatten_recursive(ops, max_var_index);
1049                exponent.flatten_recursive(ops, max_var_index);
1050                ops.push(LinearOp::PowExpr);
1051            }
1052
1053            Expr::Exp(expr) => {
1054                expr.flatten_recursive(ops, max_var_index);
1055                ops.push(LinearOp::Exp);
1056            }
1057
1058            Expr::Ln(expr) => {
1059                expr.flatten_recursive(ops, max_var_index);
1060                ops.push(LinearOp::Ln);
1061            }
1062
1063            Expr::Sqrt(expr) => {
1064                expr.flatten_recursive(ops, max_var_index);
1065                ops.push(LinearOp::Sqrt);
1066            }
1067
1068            Expr::Sin(expr) => {
1069                expr.flatten_recursive(ops, max_var_index);
1070                ops.push(LinearOp::Sin);
1071            }
1072
1073            Expr::Cos(expr) => {
1074                expr.flatten_recursive(ops, max_var_index);
1075                ops.push(LinearOp::Cos);
1076            }
1077
1078            Expr::Cached(expr, cached_value) => {
1079                if let Some(val) = cached_value {
1080                    ops.push(LinearOp::LoadConst(*val));
1081                } else {
1082                    expr.flatten_recursive(ops, max_var_index);
1083                }
1084            }
1085        }
1086    }
1087
1088    /// Generates ultra-optimized linear code from flattened operations.
1089    ///
1090    /// This eliminates all function call overhead by generating a single
1091    /// linear sequence of optimal instructions with direct register allocation.
1092    pub fn codegen_flattened(
1093        &self,
1094        builder: &mut FunctionBuilder,
1095        module: &mut dyn Module,
1096    ) -> Result<Value, EquationError> {
1097        let flattened = self.flatten();
1098
1099        // Fast path for constant expressions
1100        if let Some(constant) = flattened.constant_result {
1101            return Ok(builder.ins().f64const(constant));
1102        }
1103
1104        // Pre-allocate stack for operations (eliminates allocations)
1105        let mut value_stack = Vec::with_capacity(flattened.ops.len());
1106
1107        // Get input pointer once
1108        let input_ptr = builder
1109            .func
1110            .dfg
1111            .block_params(builder.current_block().unwrap())[0];
1112
1113        // Execute linear operations with optimal code generation
1114        for op in &flattened.ops {
1115            match op {
1116                LinearOp::LoadConst(val) => {
1117                    value_stack.push(builder.ins().f64const(*val));
1118                }
1119
1120                LinearOp::LoadVar(index) => {
1121                    let offset = (*index as i32) * 8;
1122                    let memflags = MemFlags::new().with_aligned().with_readonly().with_notrap();
1123                    let val =
1124                        builder
1125                            .ins()
1126                            .load(types::F64, memflags, input_ptr, Offset32::new(offset));
1127                    value_stack.push(val);
1128                }
1129
1130                LinearOp::Add => {
1131                    let rhs = value_stack.pop().unwrap();
1132                    let lhs = value_stack.pop().unwrap();
1133                    value_stack.push(builder.ins().fadd(lhs, rhs));
1134                }
1135
1136                LinearOp::Sub => {
1137                    let rhs = value_stack.pop().unwrap();
1138                    let lhs = value_stack.pop().unwrap();
1139                    value_stack.push(builder.ins().fsub(lhs, rhs));
1140                }
1141
1142                LinearOp::Mul => {
1143                    let rhs = value_stack.pop().unwrap();
1144                    let lhs = value_stack.pop().unwrap();
1145                    value_stack.push(builder.ins().fmul(lhs, rhs));
1146                }
1147
1148                LinearOp::Div => {
1149                    let rhs = value_stack.pop().unwrap();
1150                    let lhs = value_stack.pop().unwrap();
1151                    value_stack.push(builder.ins().fdiv(lhs, rhs));
1152                }
1153
1154                LinearOp::Abs => {
1155                    let val = value_stack.pop().unwrap();
1156                    value_stack.push(builder.ins().fabs(val));
1157                }
1158
1159                LinearOp::Neg => {
1160                    let val = value_stack.pop().unwrap();
1161                    value_stack.push(builder.ins().fneg(val));
1162                }
1163
1164                LinearOp::PowConst(exp) => {
1165                    let base = value_stack.pop().unwrap();
1166                    let result = match *exp {
1167                        0 => builder.ins().f64const(1.0),
1168                        1 => base,
1169                        2 => builder.ins().fmul(base, base),
1170                        3 => {
1171                            let square = builder.ins().fmul(base, base);
1172                            builder.ins().fmul(square, base)
1173                        }
1174                        4 => {
1175                            let square = builder.ins().fmul(base, base);
1176                            builder.ins().fmul(square, square)
1177                        }
1178                        -1 => {
1179                            let one = builder.ins().f64const(1.0);
1180                            builder.ins().fdiv(one, base)
1181                        }
1182                        -2 => {
1183                            let square = builder.ins().fmul(base, base);
1184                            let one = builder.ins().f64const(1.0);
1185                            builder.ins().fdiv(one, square)
1186                        }
1187                        _ => {
1188                            // For other exponents, use optimized binary exponentiation
1189                            generate_optimized_power(builder, base, *exp)
1190                        }
1191                    };
1192                    value_stack.push(result);
1193                }
1194
1195                LinearOp::PowFloat(exp) => {
1196                    let base = value_stack.pop().unwrap();
1197                    let func_id = crate::operators::pow::link_powf(module).unwrap();
1198                    let exp_val = builder.ins().f64const(*exp);
1199                    let result =
1200                        crate::operators::pow::call_powf(builder, module, func_id, base, exp_val);
1201                    value_stack.push(result);
1202                }
1203
1204                LinearOp::PowExpr => {
1205                    let exponent = value_stack.pop().unwrap();
1206                    let base = value_stack.pop().unwrap();
1207                    let func_id = crate::operators::pow::link_powf(module).unwrap();
1208                    let result =
1209                        crate::operators::pow::call_powf(builder, module, func_id, base, exponent);
1210                    value_stack.push(result);
1211                }
1212
1213                LinearOp::Exp => {
1214                    let arg = value_stack.pop().unwrap();
1215                    let func_id = operators::exp::link_exp(module).unwrap();
1216                    let result = operators::exp::call_exp(builder, module, func_id, arg);
1217                    value_stack.push(result);
1218                }
1219
1220                LinearOp::Ln => {
1221                    let arg = value_stack.pop().unwrap();
1222                    let func_id = operators::ln::link_ln(module).unwrap();
1223                    let result = operators::ln::call_ln(builder, module, func_id, arg);
1224                    value_stack.push(result);
1225                }
1226
1227                LinearOp::Sqrt => {
1228                    let arg = value_stack.pop().unwrap();
1229                    let func_id = operators::sqrt::link_sqrt(module).unwrap();
1230                    let result = operators::sqrt::call_sqrt(builder, module, func_id, arg);
1231                    value_stack.push(result);
1232                }
1233
1234                LinearOp::Sin => {
1235                    let arg = value_stack.pop().unwrap();
1236                    let func_id = crate::operators::trigonometric::link_sin(module).unwrap();
1237                    let result =
1238                        crate::operators::trigonometric::call_sin(builder, module, func_id, arg);
1239                    value_stack.push(result);
1240                }
1241
1242                LinearOp::Cos => {
1243                    let arg = value_stack.pop().unwrap();
1244                    let func_id = crate::operators::trigonometric::link_cos(module).unwrap();
1245                    let result =
1246                        crate::operators::trigonometric::call_cos(builder, module, func_id, arg);
1247                    value_stack.push(result);
1248                }
1249            }
1250        }
1251
1252        // Return final result
1253        Ok(value_stack.pop().unwrap())
1254    }
1255}
1256
1257/// Generates optimized power operation using binary exponentiation
1258fn generate_optimized_power(builder: &mut FunctionBuilder, base: Value, exp: i64) -> Value {
1259    if exp == 0 {
1260        return builder.ins().f64const(1.0);
1261    }
1262
1263    if exp == 1 {
1264        return base;
1265    }
1266
1267    let abs_exp = exp.abs();
1268    let mut result = builder.ins().f64const(1.0);
1269    let mut current_base = base;
1270    let mut remaining = abs_exp;
1271
1272    // Binary exponentiation - optimal for any exponent
1273    while remaining > 0 {
1274        if remaining & 1 == 1 {
1275            result = builder.ins().fmul(result, current_base);
1276        }
1277        if remaining > 1 {
1278            current_base = builder.ins().fmul(current_base, current_base);
1279        }
1280        remaining >>= 1;
1281    }
1282
1283    if exp < 0 {
1284        let one = builder.ins().f64const(1.0);
1285        builder.ins().fdiv(one, result)
1286    } else {
1287        result
1288    }
1289}
1290
1291/// Implements string formatting for expressions.
1292///
1293/// This implementation converts expressions to their standard mathematical notation:
1294/// - Constants are formatted as numbers
1295/// - Variables are formatted as their names
1296/// - Binary operations (+,-,*,/) are wrapped in parentheses
1297/// - Functions (exp, ln, sqrt) use function call notation
1298/// - Absolute value uses |x| notation
1299/// - Exponents use ^
1300/// - Negation uses - prefix
1301/// - Cached expressions display their underlying expression
1302impl std::fmt::Display for Expr {
1303    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1304        match self {
1305            Expr::Const(val) => write!(f, "{val}"),
1306            Expr::Var(var_ref) => write!(f, "{0}", var_ref.name),
1307            Expr::Add(left, right) => write!(f, "({left} + {right})"),
1308            Expr::Mul(left, right) => write!(f, "({left} * {right})"),
1309            Expr::Sub(left, right) => write!(f, "({left} - {right})"),
1310            Expr::Div(left, right) => write!(f, "({left} / {right})"),
1311            Expr::Abs(expr) => write!(f, "|{expr}|"),
1312            Expr::Pow(base, exp) => write!(f, "({base}^{exp})"),
1313            Expr::PowFloat(base, exp) => write!(f, "({base}^{exp})"),
1314            Expr::PowExpr(base, exponent) => write!(f, "({base}^{exponent})"),
1315            Expr::Exp(expr) => write!(f, "exp({expr})"),
1316            Expr::Ln(expr) => write!(f, "ln({expr})"),
1317            Expr::Sqrt(expr) => write!(f, "sqrt({expr})"),
1318            Expr::Sin(expr) => write!(f, "sin({expr})"),
1319            Expr::Cos(expr) => write!(f, "cos({expr})"),
1320            Expr::Neg(expr) => write!(f, "-({expr})"),
1321            Expr::Cached(expr, _) => write!(f, "{expr}"),
1322        }
1323    }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328    use super::*;
1329
1330    // Helper function to create a variable
1331    fn var(name: &str) -> Box<Expr> {
1332        Box::new(Expr::Var(VarRef {
1333            name: name.to_string(),
1334            vec_ref: Value::from_u32(0),
1335            index: 0,
1336        }))
1337    }
1338
1339    #[test]
1340    fn test_simplify() {
1341        // Helper function to create a variable
1342        fn var(name: &str) -> Box<Expr> {
1343            Box::new(Expr::Var(VarRef {
1344                name: name.to_string(),
1345                vec_ref: Value::from_u32(0), // Dummy value for testing
1346                index: 0,
1347            }))
1348        }
1349
1350        // Test constant folding
1351        // 2 + 3 → 5
1352        assert_eq!(
1353            *Expr::Add(Box::new(Expr::Const(2.0)), Box::new(Expr::Const(3.0))).simplify(),
1354            Expr::Const(5.0)
1355        );
1356
1357        // Test additive identity
1358        // x + 0 → x
1359        assert_eq!(
1360            *Expr::Add(var("x"), Box::new(Expr::Const(0.0))).simplify(),
1361            *var("x")
1362        );
1363
1364        // Test multiplicative identity
1365        // x * 1 → x
1366        assert_eq!(
1367            *Expr::Mul(var("x"), Box::new(Expr::Const(1.0))).simplify(),
1368            *var("x")
1369        );
1370
1371        // Test multiplication by zero
1372        // x * 0 → 0
1373        assert_eq!(
1374            *Expr::Mul(var("x"), Box::new(Expr::Const(0.0))).simplify(),
1375            Expr::Const(0.0)
1376        );
1377
1378        // Test division identity
1379        // x / 1 → x
1380        assert_eq!(
1381            *Expr::Div(var("x"), Box::new(Expr::Const(1.0))).simplify(),
1382            *var("x")
1383        );
1384
1385        // Test division by self
1386        // x / x → 1
1387        assert_eq!(*Expr::Div(var("x"), var("x")).simplify(), Expr::Const(1.0));
1388
1389        // Test exponent simplification
1390        // x^0 → 1
1391        assert_eq!(*Expr::Pow(var("x"), 0).simplify(), Expr::Const(1.0));
1392        // x^1 → x
1393        assert_eq!(*Expr::Pow(var("x"), 1).simplify(), *var("x"));
1394
1395        // Test absolute value of constant
1396        // |-3| → 3
1397        assert_eq!(
1398            *Expr::Abs(Box::new(Expr::Const(-3.0))).simplify(),
1399            Expr::Const(3.0)
1400        );
1401
1402        // Test nested absolute value
1403        // ||x|| → |x|
1404        assert_eq!(
1405            *Expr::Abs(Box::new(Expr::Abs(var("x")))).simplify(),
1406            Expr::Abs(var("x"))
1407        );
1408    }
1409
1410    #[test]
1411    fn test_insert() {
1412        // Helper function to create a variable
1413        fn var(name: &str) -> Box<Expr> {
1414            Box::new(Expr::Var(VarRef {
1415                name: name.to_string(),
1416                vec_ref: Value::from_u32(0),
1417                index: 0,
1418            }))
1419        }
1420
1421        // Create expression: x + y
1422        let expr = Box::new(Expr::Add(var("x"), var("y")));
1423
1424        // Replace all occurrences of 'x' with '2*z'
1425        let replacement = Box::new(Expr::Mul(Box::new(Expr::Const(2.0)), var("z")));
1426
1427        let result = expr.insert(|e| matches!(e, Expr::Var(v) if v.name == "x"), &replacement);
1428
1429        // Expected: (2*z) + y
1430        assert_eq!(
1431            *result,
1432            Expr::Add(
1433                Box::new(Expr::Mul(Box::new(Expr::Const(2.0)), var("z"),)),
1434                var("y"),
1435            )
1436        );
1437    }
1438
1439    #[test]
1440    fn test_derivative() {
1441        // Test constant derivative
1442        assert_eq!(*Expr::Const(5.0).derivative("x"), Expr::Const(0.0));
1443
1444        // Test variable derivatives (x)' = 1, (y)' = 0
1445        assert_eq!(*var("x").derivative("x"), Expr::Const(1.0));
1446        assert_eq!(*var("y").derivative("x"), Expr::Const(0.0));
1447
1448        // Test sum rule (u+v)' = u' + v'
1449        let sum = Box::new(Expr::Add(var("x"), var("y")));
1450        assert_eq!(
1451            *sum.derivative("x"),
1452            Expr::Add(Box::new(Expr::Const(1.0)), Box::new(Expr::Const(0.0)))
1453        );
1454
1455        // Test product rule (u*v)' = u'*v + u*v'
1456        let product = Box::new(Expr::Mul(var("x"), var("y")));
1457        assert_eq!(
1458            *product.derivative("x"),
1459            Expr::Add(
1460                Box::new(Expr::Mul(var("x"), Box::new(Expr::Const(0.0)))),
1461                Box::new(Expr::Mul(var("y"), Box::new(Expr::Const(1.0))))
1462            )
1463        );
1464
1465        // Test power rule (u^v)' = u'*v*u^(v-1) + ln(u)*u^v*v'
1466        let power = Box::new(Expr::Pow(var("x"), 3));
1467        assert_eq!(
1468            *power.derivative("x"),
1469            Expr::Mul(
1470                Box::new(Expr::Mul(
1471                    Box::new(Expr::Const(3.0)),
1472                    Box::new(Expr::Pow(var("x"), 2))
1473                )),
1474                Box::new(Expr::Const(1.0))
1475            )
1476        );
1477    }
1478
1479    #[test]
1480    fn test_complex_simplifications() {
1481        // Test nested operations: (x + 0) * (y + 0) → x * y
1482        let expr = Box::new(Expr::Mul(
1483            Box::new(Expr::Add(var("x"), Box::new(Expr::Const(0.0)))),
1484            Box::new(Expr::Add(var("y"), Box::new(Expr::Const(0.0)))),
1485        ));
1486        assert_eq!(*expr.simplify(), Expr::Mul(var("x"), var("y")));
1487
1488        // Test double negation: -(-x) → x
1489        let expr = Box::new(Expr::Neg(Box::new(Expr::Neg(var("x")))));
1490        assert_eq!(*expr.simplify(), *var("x"));
1491
1492        // Test multiplication by 1: (1 * x) * (y * 1) → x * y
1493        let expr = Box::new(Expr::Mul(
1494            Box::new(Expr::Mul(Box::new(Expr::Const(1.0)), var("x"))),
1495            Box::new(Expr::Mul(var("y"), Box::new(Expr::Const(1.0)))),
1496        ));
1497        assert_eq!(*expr.simplify(), Expr::Mul(var("x"), var("y")));
1498
1499        // Test division simplification: (x/y)/(x/y) → 1
1500        let div = Box::new(Expr::Div(var("x"), var("y")));
1501        let expr = Box::new(Expr::Div(div.clone(), div));
1502        assert_eq!(*expr.simplify(), Expr::Const(1.0));
1503    }
1504
1505    #[test]
1506    fn test_special_functions() {
1507        // Test abs(abs(x)) simplification to abs(x)
1508        let expr = Box::new(Expr::Abs(Box::new(Expr::Abs(var("x")))));
1509        assert_eq!(*expr.simplify(), Expr::Abs(var("x")));
1510
1511        // Test sqrt(x^2) - should simplify to abs(x)
1512        let expr = Box::new(Expr::Sqrt(Box::new(Expr::Pow(var("x"), 2))));
1513        assert_eq!(*expr.simplify(), Expr::Abs(var("x")));
1514
1515        // Test constant special functions
1516        // exp(0) = 1
1517        assert_eq!(
1518            *Expr::Exp(Box::new(Expr::Const(0.0))).simplify(),
1519            Expr::Const(1.0)
1520        );
1521        // ln(1) = 0
1522        assert_eq!(
1523            *Expr::Ln(Box::new(Expr::Const(1.0))).simplify(),
1524            Expr::Const(0.0)
1525        );
1526    }
1527
1528    #[test]
1529    fn test_display() {
1530        // Test basic expressions
1531        assert_eq!(format!("{}", Expr::Const(5.0)), "5");
1532        assert_eq!(format!("{}", *var("x")), "x");
1533
1534        // Test binary operations
1535        let sum = Expr::Add(var("x"), var("y"));
1536        assert_eq!(format!("{sum}"), "(x + y)");
1537
1538        let product = Expr::Mul(var("x"), var("y"));
1539        assert_eq!(format!("{product}"), "(x * y)");
1540
1541        // Test special functions
1542        let exp = Expr::Exp(var("x"));
1543        assert_eq!(format!("{exp}"), "exp(x)");
1544
1545        let abs = Expr::Abs(var("x"));
1546        assert_eq!(format!("{abs}"), "|x|");
1547
1548        // Test complex expression
1549        let complex = Expr::Div(
1550            Box::new(Expr::Add(Box::new(Expr::Pow(var("x"), 2)), var("y"))),
1551            var("z"),
1552        );
1553        assert_eq!(format!("{complex}"), "(((x^2) + y) / z)");
1554    }
1555
1556    #[test]
1557    fn test_cached_expressions() {
1558        // Test cached constant
1559        let cached = Box::new(Expr::Cached(Box::new(Expr::Const(5.0)), Some(5.0)));
1560        assert_eq!(*cached.simplify(), *cached);
1561
1562        // Test uncached expression simplification
1563        let uncached = Box::new(Expr::Cached(
1564            Box::new(Expr::Add(
1565                Box::new(Expr::Const(2.0)),
1566                Box::new(Expr::Const(3.0)),
1567            )),
1568            None,
1569        ));
1570        assert_eq!(*uncached.simplify(), Expr::Const(5.0));
1571    }
1572}