evalexpr_jit/
expr.rs

1//! Expression module for representing mathematical expressions.
2//!
3//! This module defines the core expression types used to represent mathematical expressions
4//! in a form that supports both JIT compilation and automatic differentiation. The main types are:
5//!
6//! - `Expr`: An enum representing different kinds of mathematical expressions
7//! - `VarRef`: A struct containing metadata about variables in expressions
8//!
9//! The expression tree is built recursively using `Box<Expr>` for nested expressions and can be:
10//! - JIT compiled into machine code using Cranelift
11//! - Symbolically differentiated to compute derivatives
12//! - Evaluated efficiently at runtime
13//! - Simplified using algebraic rules
14//! - Modified by inserting replacement expressions
15//!
16//! Supported operations include:
17//! - Basic arithmetic (+, -, *, /)
18//! - Variables and constants
19//! - Absolute value
20//! - Integer exponentiation
21//! - Transcendental functions (exp, ln, sqrt)
22//! - Expression caching for optimization
23//!
24//! # Expression Tree Structure
25//! The expression tree is built recursively with each node being one of:
26//! - Leaf nodes: Constants and Variables
27//! - Unary operations: Abs, Neg, Exp, Ln, Sqrt
28//! - Binary operations: Add, Sub, Mul, Div
29//! - Special nodes: Pow (with integer exponent), Cached expressions
30//!
31//! # Symbolic Differentiation
32//! The derivative method implements symbolic differentiation by recursively applying
33//! calculus rules like:
34//! - Product rule
35//! - Quotient rule
36//! - Chain rule
37//! - Power rule
38//! - Special function derivatives (exp, ln, sqrt)
39//!
40//! # Expression Simplification
41//! The simplify method performs algebraic simplifications including:
42//! - Constant folding (e.g. 2 + 3 → 5)
43//! - Identity rules (e.g. x + 0 → x, x * 1 → x)
44//! - Exponent rules (e.g. x^0 → 1, x^1 → x)
45//! - Expression caching for repeated subexpressions
46//! - Special function simplifications
47//!
48//! # Expression Modification
49//! The insert method allows replacing parts of an expression tree by:
50//! - Matching nodes using a predicate function
51//! - Replacing matched nodes with a new expression
52//! - Recursively traversing and rebuilding the tree
53
54use cranelift::prelude::*;
55use cranelift_codegen::ir::{immediates::Offset32, Value};
56use cranelift_module::Module;
57
58use crate::{errors::EquationError, operators};
59
60/// Represents a reference to a variable in an expression.
61///
62/// Contains metadata needed to generate code that loads the variable's value:
63/// - The variable's name as a string
64/// - A Cranelift Value representing the pointer to the input array
65/// - The variable's index in the input array
66#[derive(Debug, Clone, PartialEq)]
67pub struct VarRef {
68    pub name: String,
69    pub vec_ref: Value,
70    pub index: u32,
71}
72
73/// An expression tree node representing mathematical operations.
74///
75/// This enum represents different types of mathematical expressions that can be:
76/// - JIT compiled into machine code using Cranelift
77/// - Symbolically differentiated to compute derivatives
78/// - Simplified using algebraic rules
79/// - Modified by inserting replacement expressions
80///
81/// The expression tree is built recursively using Box<Expr> for nested expressions.
82#[derive(Debug, Clone, PartialEq)]
83pub enum Expr {
84    /// A constant floating point value
85    Const(f64),
86    /// A reference to a variable
87    Var(VarRef),
88    /// Addition of two expressions
89    Add(Box<Expr>, Box<Expr>),
90    /// Multiplication of two expressions
91    Mul(Box<Expr>, Box<Expr>),
92    /// Subtraction of two expressions
93    Sub(Box<Expr>, Box<Expr>),
94    /// Division of two expressions
95    Div(Box<Expr>, Box<Expr>),
96    /// Absolute value of an expression
97    Abs(Box<Expr>),
98    /// Exponentiation of an expression by an integer constant
99    Pow(Box<Expr>, i64),
100    /// Exponentiation of an expression by a floating point constant
101    PowFloat(Box<Expr>, f64),
102    /// Exponentiation of an expression by another expression
103    PowExpr(Box<Expr>, Box<Expr>),
104    /// Exponential function of an expression
105    Exp(Box<Expr>),
106    /// Natural logarithm of an expression
107    Ln(Box<Expr>),
108    /// Square root of an expression
109    Sqrt(Box<Expr>),
110    /// Sine of an expression (argument in radians)
111    Sin(Box<Expr>),
112    /// Cosine of an expression (argument in radians)
113    Cos(Box<Expr>),
114    /// Negation of an expression
115    Neg(Box<Expr>),
116    /// Cached expression with optional pre-computed value
117    Cached(Box<Expr>, Option<f64>),
118}
119
120/// Linear operation for flattened expression evaluation
121#[derive(Debug, Clone)]
122pub enum LinearOp {
123    /// Load constant value
124    LoadConst(f64),
125    /// Load variable by index
126    LoadVar(u32),
127    /// Add two values from stack positions
128    Add,
129    /// Subtract two values from stack positions  
130    Sub,
131    /// Multiply two values from stack positions
132    Mul,
133    /// Divide two values from stack positions
134    Div,
135    /// Absolute value of stack top
136    Abs,
137    /// Negate stack top
138    Neg,
139    /// Power operation with constant exponent
140    PowConst(i64),
141    /// Power operation with floating point constant exponent
142    PowFloat(f64),
143    /// Power operation with expression exponent
144    PowExpr,
145    /// Exponential of stack top
146    Exp,
147    /// Natural log of stack top
148    Ln,
149    /// Square root of stack top
150    Sqrt,
151    /// Sine of stack top (argument in radians)
152    Sin,
153    /// Cosine of stack top (argument in radians)
154    Cos,
155}
156
157/// Flattened expression representation for efficient evaluation
158#[derive(Debug, Clone)]
159pub struct FlattenedExpr {
160    /// Linear sequence of operations
161    pub ops: Vec<LinearOp>,
162    /// Maximum variable index accessed
163    pub max_var_index: Option<u32>,
164    /// Pre-computed constant result (if expression is constant)
165    pub constant_result: Option<f64>,
166}
167
168impl Expr {
169    /// Pre-evaluates constants and caches variable loads for improved performance
170    pub fn pre_evaluate(
171        &self,
172        var_cache: &mut std::collections::HashMap<String, f64>,
173    ) -> Box<Expr> {
174        match self {
175            Expr::Const(_) => Box::new(self.clone()),
176
177            Expr::Var(var_ref) => {
178                // Check if we can pre-evaluate this variable
179                if let Some(&value) = var_cache.get(&var_ref.name) {
180                    Box::new(Expr::Const(value))
181                } else {
182                    Box::new(self.clone())
183                }
184            }
185
186            Expr::Add(left, right) => {
187                let l = left.pre_evaluate(var_cache);
188                let r = right.pre_evaluate(var_cache);
189                match (&*l, &*r) {
190                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a + b)),
191                    _ => Box::new(Expr::Add(l, r)),
192                }
193            }
194
195            Expr::Sub(left, right) => {
196                let l = left.pre_evaluate(var_cache);
197                let r = right.pre_evaluate(var_cache);
198                match (&*l, &*r) {
199                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a - b)),
200                    _ => Box::new(Expr::Sub(l, r)),
201                }
202            }
203
204            Expr::Mul(left, right) => {
205                let l = left.pre_evaluate(var_cache);
206                let r = right.pre_evaluate(var_cache);
207                match (&*l, &*r) {
208                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a * b)),
209                    _ => Box::new(Expr::Mul(l, r)),
210                }
211            }
212
213            Expr::Div(left, right) => {
214                let l = left.pre_evaluate(var_cache);
215                let r = right.pre_evaluate(var_cache);
216                match (&*l, &*r) {
217                    (Expr::Const(a), Expr::Const(b)) if *b != 0.0 => Box::new(Expr::Const(a / b)),
218                    _ => Box::new(Expr::Div(l, r)),
219                }
220            }
221
222            Expr::Abs(expr) => {
223                let e = expr.pre_evaluate(var_cache);
224                match &*e {
225                    Expr::Const(a) => Box::new(Expr::Const(a.abs())),
226                    _ => Box::new(Expr::Abs(e)),
227                }
228            }
229
230            Expr::Neg(expr) => {
231                let e = expr.pre_evaluate(var_cache);
232                match &*e {
233                    Expr::Const(a) => Box::new(Expr::Const(-a)),
234                    _ => Box::new(Expr::Neg(e)),
235                }
236            }
237
238            Expr::Pow(base, exp) => {
239                let b = base.pre_evaluate(var_cache);
240                match &*b {
241                    Expr::Const(a) => Box::new(Expr::Const(a.powi(*exp as i32))),
242                    _ => Box::new(Expr::Pow(b, *exp)),
243                }
244            }
245
246            Expr::PowFloat(base, exp) => {
247                let b = base.pre_evaluate(var_cache);
248                match &*b {
249                    Expr::Const(a) => Box::new(Expr::Const(a.powf(*exp))),
250                    _ => Box::new(Expr::PowFloat(b, *exp)),
251                }
252            }
253
254            Expr::PowExpr(base, exponent) => {
255                let b = base.pre_evaluate(var_cache);
256                let e = exponent.pre_evaluate(var_cache);
257                match (&*b, &*e) {
258                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a.powf(*b))),
259                    _ => Box::new(Expr::PowExpr(b, e)),
260                }
261            }
262
263            Expr::Exp(expr) => {
264                let e = expr.pre_evaluate(var_cache);
265                match &*e {
266                    Expr::Const(a) => Box::new(Expr::Const(a.exp())),
267                    _ => Box::new(Expr::Exp(e)),
268                }
269            }
270
271            Expr::Ln(expr) => {
272                let e = expr.pre_evaluate(var_cache);
273                match &*e {
274                    Expr::Const(a) if *a > 0.0 => Box::new(Expr::Const(a.ln())),
275                    _ => Box::new(Expr::Ln(e)),
276                }
277            }
278
279            Expr::Sqrt(expr) => {
280                let e = expr.pre_evaluate(var_cache);
281                match &*e {
282                    Expr::Const(a) if *a >= 0.0 => Box::new(Expr::Const(a.sqrt())),
283                    _ => Box::new(Expr::Sqrt(e)),
284                }
285            }
286
287            Expr::Sin(expr) => {
288                let e = expr.pre_evaluate(var_cache);
289                match &*e {
290                    Expr::Const(a) => Box::new(Expr::Const(a.sin())),
291                    _ => Box::new(Expr::Sin(e)),
292                }
293            }
294
295            Expr::Cos(expr) => {
296                let e = expr.pre_evaluate(var_cache);
297                match &*e {
298                    Expr::Const(a) => Box::new(Expr::Const(a.cos())),
299                    _ => Box::new(Expr::Cos(e)),
300                }
301            }
302
303            Expr::Cached(expr, _) => expr.pre_evaluate(var_cache),
304        }
305    }
306
307    /// Computes the symbolic derivative of this expression with respect to a variable.
308    ///
309    /// Recursively applies the rules of differentiation to build a new expression tree
310    /// representing the derivative. The rules implemented are:
311    /// - d/dx(c) = 0 for constants
312    /// - d/dx(x) = 1 for the variable we're differentiating with respect to
313    /// - d/dx(y) = 0 for other variables
314    /// - Sum rule: d/dx(f + g) = df/dx + dg/dx
315    /// - Product rule: d/dx(f * g) = f * dg/dx + g * df/dx
316    /// - Quotient rule: d/dx(f/g) = (g * df/dx - f * dg/dx) / g^2
317    /// - Chain rule for abs: d/dx|f| = f/|f| * df/dx
318    /// - Power rule: d/dx(f^n) = n * f^(n-1) * df/dx
319    /// - Chain rule for exp: d/dx(e^f) = e^f * df/dx
320    /// - Chain rule for ln: d/dx(ln(f)) = 1/f * df/dx
321    /// - Chain rule for sqrt: d/dx(sqrt(f)) = 1/(2*sqrt(f)) * df/dx
322    /// - Negation: d/dx(-f) = -(df/dx)
323    ///
324    /// # Arguments
325    /// * `with_respect_to` - The name of the variable to differentiate with respect to
326    ///
327    /// # Returns
328    /// A new expression tree representing the derivative
329    pub fn derivative(&self, with_respect_to: &str) -> Box<Expr> {
330        match self {
331            Expr::Const(_) => Box::new(Expr::Const(0.0)),
332
333            Expr::Var(var_ref) => {
334                if var_ref.name == with_respect_to {
335                    Box::new(Expr::Const(1.0))
336                } else {
337                    Box::new(Expr::Const(0.0))
338                }
339            }
340
341            Expr::Add(left, right) => {
342                // d/dx(f + g) = df/dx + dg/dx
343                Box::new(Expr::Add(
344                    left.derivative(with_respect_to),
345                    right.derivative(with_respect_to),
346                ))
347            }
348
349            Expr::Sub(left, right) => {
350                // d/dx(f - g) = df/dx - dg/dx
351                Box::new(Expr::Sub(
352                    left.derivative(with_respect_to),
353                    right.derivative(with_respect_to),
354                ))
355            }
356
357            Expr::Mul(left, right) => {
358                // d/dx(f * g) = f * dg/dx + g * df/dx
359                Box::new(Expr::Add(
360                    Box::new(Expr::Mul(left.clone(), right.derivative(with_respect_to))),
361                    Box::new(Expr::Mul(right.clone(), left.derivative(with_respect_to))),
362                ))
363            }
364
365            Expr::Div(left, right) => {
366                // d/dx(f/g) = (g * df/dx - f * dg/dx) / g^2
367                Box::new(Expr::Div(
368                    Box::new(Expr::Sub(
369                        Box::new(Expr::Mul(right.clone(), left.derivative(with_respect_to))),
370                        Box::new(Expr::Mul(left.clone(), right.derivative(with_respect_to))),
371                    )),
372                    Box::new(Expr::Pow(right.clone(), 2)),
373                ))
374            }
375
376            Expr::Abs(expr) => {
377                // d/dx|f| = f/|f| * df/dx
378                Box::new(Expr::Mul(
379                    Box::new(Expr::Div(expr.clone(), Box::new(Expr::Abs(expr.clone())))),
380                    expr.derivative(with_respect_to),
381                ))
382            }
383
384            Expr::Pow(base, exp) => {
385                // d/dx(f^n) = n * f^(n-1) * df/dx
386                Box::new(Expr::Mul(
387                    Box::new(Expr::Mul(
388                        Box::new(Expr::Const(*exp as f64)),
389                        Box::new(Expr::Pow(base.clone(), exp - 1)),
390                    )),
391                    base.derivative(with_respect_to),
392                ))
393            }
394
395            Expr::PowFloat(base, exp) => {
396                // d/dx(f^c) = c * f^(c-1) * df/dx
397                Box::new(Expr::Mul(
398                    Box::new(Expr::Mul(
399                        Box::new(Expr::Const(*exp)),
400                        Box::new(Expr::PowFloat(base.clone(), exp - 1.0)),
401                    )),
402                    base.derivative(with_respect_to),
403                ))
404            }
405
406            Expr::PowExpr(base, exponent) => {
407                // d/dx(f^g) = f^g * (g' * ln(f) + g * f'/f)
408                // Using the general power rule
409                Box::new(Expr::Mul(
410                    Box::new(Expr::PowExpr(base.clone(), exponent.clone())),
411                    Box::new(Expr::Add(
412                        Box::new(Expr::Mul(
413                            exponent.derivative(with_respect_to),
414                            Box::new(Expr::Ln(base.clone())),
415                        )),
416                        Box::new(Expr::Mul(
417                            exponent.clone(),
418                            Box::new(Expr::Div(base.derivative(with_respect_to), base.clone())),
419                        )),
420                    )),
421                ))
422            }
423
424            Expr::Exp(expr) => {
425                // d/dx(e^f) = e^f * df/dx
426                Box::new(Expr::Mul(
427                    Box::new(Expr::Exp(expr.clone())),
428                    expr.derivative(with_respect_to),
429                ))
430            }
431
432            Expr::Ln(expr) => {
433                // d/dx(ln(f)) = 1/f * df/dx
434                Box::new(Expr::Mul(
435                    Box::new(Expr::Div(Box::new(Expr::Const(1.0)), expr.clone())),
436                    expr.derivative(with_respect_to),
437                ))
438            }
439
440            Expr::Sqrt(expr) => {
441                // d/dx(sqrt(f)) = 1/(2*sqrt(f)) * df/dx
442                Box::new(Expr::Mul(
443                    Box::new(Expr::Div(
444                        Box::new(Expr::Const(1.0)),
445                        Box::new(Expr::Mul(
446                            Box::new(Expr::Const(2.0)),
447                            Box::new(Expr::Sqrt(expr.clone())),
448                        )),
449                    )),
450                    expr.derivative(with_respect_to),
451                ))
452            }
453
454            Expr::Sin(expr) => {
455                // d/dx(sin(f)) = cos(f) * df/dx
456                Box::new(Expr::Mul(
457                    Box::new(Expr::Cos(expr.clone())),
458                    expr.derivative(with_respect_to),
459                ))
460            }
461
462            Expr::Cos(expr) => {
463                // d/dx(cos(f)) = -sin(f) * df/dx
464                Box::new(Expr::Mul(
465                    Box::new(Expr::Neg(Box::new(Expr::Sin(expr.clone())))),
466                    expr.derivative(with_respect_to),
467                ))
468            }
469
470            Expr::Neg(expr) => {
471                // d/dx(-f) = -(df/dx)
472                Box::new(Expr::Neg(expr.derivative(with_respect_to)))
473            }
474
475            Expr::Cached(expr, _) => expr.derivative(with_respect_to),
476        }
477    }
478
479    /// Simplifies the expression by folding constants and applying basic algebraic rules.
480    ///
481    /// This method performs several types of algebraic simplifications:
482    ///
483    /// # Constant Folding
484    /// - Evaluates constant expressions: 2 + 3 → 5
485    /// - Simplifies operations with special constants: x * 0 → 0
486    ///
487    /// # Identity Rules
488    /// - Additive identity: x + 0 → x
489    /// - Multiplicative identity: x * 1 → x
490    /// - Division identity: x / 1 → x
491    /// - Division by self: x / x → 1
492    ///
493    /// # Exponent Rules
494    /// - Zero exponent: x^0 → 1 (except when x = 0)
495    /// - First power: x^1 → x
496    /// - Nested exponents: (x^a)^b → x^(a*b)
497    ///
498    /// # Special Function Simplification
499    /// - Absolute value: |-3| → 3, ||x|| → |x|
500    /// - Double negation: -(-x) → x
501    /// - Evaluates constant special functions: ln(1) → 0
502    ///
503    /// # Expression Caching
504    /// - Caches repeated subexpressions to avoid redundant computation
505    /// - Preserves existing cached values
506    ///
507    /// # Returns
508    /// A new simplified expression tree
509    pub fn simplify(&self) -> Box<Expr> {
510        match self {
511            // Base cases - constants and variables remain unchanged
512            Expr::Const(_) | Expr::Var(_) => Box::new(self.clone()),
513
514            Expr::Add(left, right) => {
515                let l = left.simplify();
516                let r = right.simplify();
517                match (&*l, &*r) {
518                    // Fold constants: 1 + 2 -> 3
519                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a + b)),
520                    // Identity: x + 0 -> x
521                    (expr, Expr::Const(0.0)) | (Expr::Const(0.0), expr) => Box::new(expr.clone()),
522                    // Combine like terms: c1*x + c2*x -> (c1+c2)*x
523                    (Expr::Mul(a1, x1), Expr::Mul(a2, x2)) if x1 == x2 => {
524                        let combined_coeff = Expr::Add(a1.clone(), a2.clone()).simplify();
525                        Box::new(Expr::Mul(combined_coeff, x1.clone()))
526                    }
527                    // Associativity: (x + c1) + c2 -> x + (c1 + c2)
528                    (Expr::Add(x, c1), c2)
529                        if matches!(**c1, Expr::Const(_)) && matches!(*c2, Expr::Const(_)) =>
530                    {
531                        Box::new(Expr::Add(
532                            x.clone(),
533                            Expr::Add(c1.clone(), Box::new(c2.clone())).simplify(),
534                        ))
535                    }
536                    _ => Box::new(Expr::Add(l, r)),
537                }
538            }
539
540            Expr::Sub(left, right) => {
541                let l = left.simplify();
542                let r = right.simplify();
543                match (&*l, &*r) {
544                    // Fold constants: 3 - 2 -> 1
545                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a - b)),
546                    // Identity: x - 0 -> x
547                    (expr, Expr::Const(0.0)) => Box::new(expr.clone()),
548                    // Zero: x - x -> 0
549                    (a, b) if a == b => Box::new(Expr::Const(0.0)),
550                    // Combine like terms: c1*x - c2*x -> (c1-c2)*x
551                    (Expr::Mul(a1, x1), Expr::Mul(a2, x2)) if x1 == x2 => {
552                        let combined_coeff = Expr::Sub(a1.clone(), a2.clone()).simplify();
553                        Box::new(Expr::Mul(combined_coeff, x1.clone()))
554                    }
555                    // Convert subtraction to addition: x - c -> x + (-c)
556                    (x, Expr::Const(c)) => {
557                        Box::new(Expr::Add(Box::new(x.clone()), Box::new(Expr::Const(-c))))
558                    }
559                    _ => Box::new(Expr::Sub(l, r)),
560                }
561            }
562
563            Expr::Mul(left, right) => {
564                let l = left.simplify();
565                let r = right.simplify();
566
567                // Common subexpression elimination
568                if l == r {
569                    return Box::new(Expr::Pow(l, 2)); // x * x -> x^2
570                }
571
572                match (&*l, &*r) {
573                    // Fold constants: 2 * 3 -> 6
574                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a * b)),
575                    // Zero property: x * 0 -> 0
576                    (Expr::Const(0.0), _) | (_, Expr::Const(0.0)) => Box::new(Expr::Const(0.0)),
577                    // Identity: x * 1 -> x
578                    (expr, Expr::Const(1.0)) | (Expr::Const(1.0), expr) => Box::new(expr.clone()),
579                    // Negative one: x * (-1) -> -x
580                    (expr, Expr::Const(-1.0)) | (Expr::Const(-1.0), expr) => {
581                        Box::new(Expr::Neg(Box::new(expr.clone())))
582                    }
583                    // Combine exponents: x^a * x^b -> x^(a+b)
584                    (Expr::Pow(b1, e1), Expr::Pow(b2, e2)) if b1 == b2 => {
585                        Box::new(Expr::Pow(b1.clone(), e1 + e2))
586                    }
587                    // Distribute constants: c * (x + y) -> c*x + c*y (only if beneficial)
588                    (Expr::Const(c), Expr::Add(x, y)) | (Expr::Add(x, y), Expr::Const(c))
589                        if c.abs() < 10.0 =>
590                    {
591                        Expr::Add(
592                            Box::new(Expr::Mul(Box::new(Expr::Const(*c)), x.clone())),
593                            Box::new(Expr::Mul(Box::new(Expr::Const(*c)), y.clone())),
594                        )
595                        .simplify()
596                    }
597                    // Strength reduction: x * 2 -> x + x (only for small integers)
598                    (expr, Expr::Const(2.0)) | (Expr::Const(2.0), expr) => {
599                        Box::new(Expr::Add(Box::new(expr.clone()), Box::new(expr.clone())))
600                    }
601                    // Associativity: (c1 * x) * c2 -> (c1 * c2) * x
602                    (Expr::Mul(c1, x), c2)
603                        if matches!(**c1, Expr::Const(_)) && matches!(*c2, Expr::Const(_)) =>
604                    {
605                        Box::new(Expr::Mul(
606                            Expr::Mul(c1.clone(), Box::new(c2.clone())).simplify(),
607                            x.clone(),
608                        ))
609                    }
610                    _ => Box::new(Expr::Mul(l, r)),
611                }
612            }
613
614            Expr::Div(left, right) => {
615                let l = left.simplify();
616                let r = right.simplify();
617                match (&*l, &*r) {
618                    // Fold constants: 6 / 2 -> 3
619                    (Expr::Const(a), Expr::Const(b)) if *b != 0.0 => Box::new(Expr::Const(a / b)),
620                    // Zero numerator: 0 / x -> 0
621                    (Expr::Const(0.0), _) => Box::new(Expr::Const(0.0)),
622                    // Identity: x / 1 -> x
623                    (expr, Expr::Const(1.0)) => Box::new(expr.clone()),
624                    // Division by negative one: x / (-1) -> -x
625                    (expr, Expr::Const(-1.0)) => Box::new(Expr::Neg(Box::new(expr.clone()))),
626                    // Identity: x / x -> 1
627                    (a, b) if a == b => Box::new(Expr::Const(1.0)),
628                    // Simplify exponents: x^a / x^b -> x^(a-b)
629                    (Expr::Pow(b1, e1), Expr::Pow(b2, e2)) if b1 == b2 => {
630                        Box::new(Expr::Pow(b1.clone(), e1 - e2))
631                    }
632                    // Convert division by constant to multiplication: x / c -> x * (1/c)
633                    (x, Expr::Const(c)) if *c != 0.0 && c.abs() > 1e-10 => Box::new(Expr::Mul(
634                        Box::new(x.clone()),
635                        Box::new(Expr::Const(1.0 / c)),
636                    )),
637                    // Simplify nested divisions: (x/y)/z -> x/(y*z)
638                    (Expr::Div(x, y), z) => Box::new(Expr::Div(
639                        x.clone(),
640                        Box::new(Expr::Mul(y.clone(), Box::new(z.clone()))),
641                    )),
642                    _ => Box::new(Expr::Div(l, r)),
643                }
644            }
645
646            Expr::Abs(expr) => {
647                let e = expr.simplify();
648                match &*e {
649                    // Fold constants: abs(3) -> 3
650                    Expr::Const(a) => Box::new(Expr::Const(a.abs())),
651                    // Nested abs: abs(abs(x)) -> abs(x)
652                    Expr::Abs(inner) => Box::new(Expr::Abs(inner.clone())),
653                    // abs(-x) -> abs(x)
654                    Expr::Neg(inner) => Box::new(Expr::Abs(inner.clone())),
655                    // abs(x^2) -> x^2 (even powers are always positive)
656                    Expr::Pow(_, exp) if exp % 2 == 0 => e,
657                    _ => Box::new(Expr::Abs(e)),
658                }
659            }
660
661            Expr::Pow(base, exp) => {
662                let b = base.simplify();
663                match (&*b, exp) {
664                    // x^0 -> 1 (including 0^0 = 1 by convention)
665                    (_, 0) => Box::new(Expr::Const(1.0)),
666                    // Fold constants: 2^3 -> 8
667                    (Expr::Const(a), exp) => Box::new(Expr::Const(a.powi(*exp as i32))),
668                    // Identity: x^1 -> x
669                    (expr, 1) => Box::new(expr.clone()),
670                    // Simplify negative exponents: x^(-n) -> 1/(x^n)
671                    (expr, exp) if *exp < 0 => Box::new(Expr::Div(
672                        Box::new(Expr::Const(1.0)),
673                        Box::new(Expr::Pow(Box::new(expr.clone()), -exp)),
674                    )),
675                    // Nested exponents: (x^a)^b -> x^(a*b)
676                    (Expr::Pow(inner_base, inner_exp), outer_exp) => {
677                        Box::new(Expr::Pow(inner_base.clone(), inner_exp * outer_exp))
678                    }
679                    // Power of product: (x*y)^n -> x^n * y^n (only for small n)
680                    (Expr::Mul(x, y), n) if *n >= 2 && *n <= 4 => Box::new(Expr::Mul(
681                        Box::new(Expr::Pow(x.clone(), *n)),
682                        Box::new(Expr::Pow(y.clone(), *n)),
683                    )),
684                    _ => Box::new(Expr::Pow(b, *exp)),
685                }
686            }
687
688            Expr::PowFloat(base, exp) => {
689                let b = base.simplify();
690                match (&*b, exp) {
691                    // x^0.0 -> 1
692                    (_, exp) if exp.abs() < 1e-10 => Box::new(Expr::Const(1.0)),
693                    // Fold constants: 2.0^3.5 -> result
694                    (Expr::Const(a), exp) => Box::new(Expr::Const(a.powf(*exp))),
695                    // Identity: x^1.0 -> x
696                    (expr, exp) if (exp - 1.0).abs() < 1e-10 => Box::new(expr.clone()),
697                    // Convert to integer power if possible
698                    (expr, exp) if exp.fract().abs() < 1e-10 => {
699                        Box::new(Expr::Pow(Box::new(expr.clone()), *exp as i64))
700                    }
701                    _ => Box::new(Expr::PowFloat(b, *exp)),
702                }
703            }
704
705            Expr::PowExpr(base, exponent) => {
706                let b = base.simplify();
707                let e = exponent.simplify();
708                match (&*b, &*e) {
709                    // Fold constants: 2^3 -> 8
710                    (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a.powf(*b))),
711                    // x^0 -> 1
712                    (_, Expr::Const(0.0)) => Box::new(Expr::Const(1.0)),
713                    // x^1 -> x
714                    (expr, Expr::Const(1.0)) => Box::new(expr.clone()),
715                    // Convert to simpler forms if exponent is constant
716                    (expr, Expr::Const(exp)) if exp.fract().abs() < 1e-10 => {
717                        Box::new(Expr::Pow(Box::new(expr.clone()), *exp as i64))
718                    }
719                    (expr, Expr::Const(exp)) => {
720                        Box::new(Expr::PowFloat(Box::new(expr.clone()), *exp))
721                    }
722                    _ => Box::new(Expr::PowExpr(b, e)),
723                }
724            }
725
726            Expr::Exp(expr) => {
727                let e = expr.simplify();
728                match &*e {
729                    // exp(0) -> 1
730                    Expr::Const(0.0) => Box::new(Expr::Const(1.0)),
731                    // Fold constants: exp(c) -> e^c
732                    Expr::Const(a) => Box::new(Expr::Const(a.exp())),
733                    // exp(ln(x)) -> x
734                    Expr::Ln(inner) => inner.clone(),
735                    // exp(x + y) -> exp(x) * exp(y)
736                    Expr::Add(x, y) => Box::new(Expr::Mul(
737                        Box::new(Expr::Exp(x.clone())),
738                        Box::new(Expr::Exp(y.clone())),
739                    )),
740                    _ => Box::new(Expr::Exp(e)),
741                }
742            }
743
744            Expr::Ln(expr) => {
745                let e = expr.simplify();
746                match &*e {
747                    // Fold constants: ln(c) -> ln(c)
748                    Expr::Const(a) if *a > 0.0 => Box::new(Expr::Const(a.ln())),
749                    // ln(1) -> 0
750                    Expr::Const(1.0) => Box::new(Expr::Const(0.0)),
751                    // ln(exp(x)) -> x
752                    Expr::Exp(inner) => inner.clone(),
753                    // ln(x*y) -> ln(x) + ln(y)
754                    Expr::Mul(x, y) => Box::new(Expr::Add(
755                        Box::new(Expr::Ln(x.clone())),
756                        Box::new(Expr::Ln(y.clone())),
757                    )),
758                    // ln(x/y) -> ln(x) - ln(y)
759                    Expr::Div(x, y) => Box::new(Expr::Sub(
760                        Box::new(Expr::Ln(x.clone())),
761                        Box::new(Expr::Ln(y.clone())),
762                    )),
763                    // ln(x^n) -> n * ln(x)
764                    Expr::Pow(x, n) => Box::new(Expr::Mul(
765                        Box::new(Expr::Const(*n as f64)),
766                        Box::new(Expr::Ln(x.clone())),
767                    )),
768                    _ => Box::new(Expr::Ln(e)),
769                }
770            }
771
772            Expr::Sqrt(expr) => {
773                let e = expr.simplify();
774                match &*e {
775                    // Fold constants: sqrt(c) -> sqrt(c)
776                    Expr::Const(a) if *a >= 0.0 => Box::new(Expr::Const(a.sqrt())),
777                    // sqrt(0) -> 0
778                    Expr::Const(0.0) => Box::new(Expr::Const(0.0)),
779                    // sqrt(1) -> 1
780                    Expr::Const(1.0) => Box::new(Expr::Const(1.0)),
781                    // sqrt(x^2) -> abs(x)
782                    Expr::Pow(x, 2) => Box::new(Expr::Abs(x.clone())),
783                    // sqrt(x*y) -> sqrt(x) * sqrt(y)
784                    Expr::Mul(x, y) => Box::new(Expr::Mul(
785                        Box::new(Expr::Sqrt(x.clone())),
786                        Box::new(Expr::Sqrt(y.clone())),
787                    )),
788                    _ => Box::new(Expr::Sqrt(e)),
789                }
790            }
791
792            Expr::Sin(expr) => {
793                let e = expr.simplify();
794                match &*e {
795                    // sin(0) -> 0
796                    Expr::Const(0.0) => Box::new(Expr::Const(0.0)),
797                    // Fold constants: sin(c) -> sin(c)
798                    Expr::Const(a) => Box::new(Expr::Const(a.sin())),
799                    _ => Box::new(Expr::Sin(e)),
800                }
801            }
802
803            Expr::Cos(expr) => {
804                let e = expr.simplify();
805                match &*e {
806                    // cos(0) -> 1
807                    Expr::Const(0.0) => Box::new(Expr::Const(1.0)),
808                    // Fold constants: cos(c) -> cos(c)
809                    Expr::Const(a) => Box::new(Expr::Const(a.cos())),
810                    _ => Box::new(Expr::Cos(e)),
811                }
812            }
813
814            Expr::Neg(expr) => {
815                let e = expr.simplify();
816                match &*e {
817                    // Fold constants: -3 -> -3
818                    Expr::Const(a) => Box::new(Expr::Const(-a)),
819                    // Double negation: -(-x) -> x
820                    Expr::Neg(inner) => inner.clone(),
821                    // Distribute negation: -(x + y) -> -x - y
822                    Expr::Add(x, y) => {
823                        Expr::Sub(Box::new(Expr::Neg(x.clone())), y.clone()).simplify()
824                    }
825                    // Distribute negation: -(x - y) -> -x + y
826                    Expr::Sub(x, y) => {
827                        Expr::Add(Box::new(Expr::Neg(x.clone())), y.clone()).simplify()
828                    }
829                    // Factor out negation: -(c*x) -> (-c)*x
830                    Expr::Mul(c, x) if matches!(**c, Expr::Const(_)) => {
831                        Expr::Mul(Box::new(Expr::Neg(c.clone())), x.clone()).simplify()
832                    }
833                    _ => Box::new(Expr::Neg(e)),
834                }
835            }
836
837            Expr::Cached(expr, cached_value) => {
838                if cached_value.is_some() {
839                    Box::new(self.clone())
840                } else {
841                    // Simplify the inner expression directly
842                    expr.simplify()
843                }
844            }
845        }
846    }
847
848    /// Inserts an expression by replacing nodes that match a predicate.
849    ///
850    /// Recursively traverses the expression tree and replaces any nodes that match
851    /// the given predicate with the replacement expression. This allows for targeted
852    /// modifications of the expression tree.
853    ///
854    /// # Arguments
855    /// * `predicate` - A closure that determines which nodes to replace
856    /// * `replacement` - The expression to insert where the predicate matches
857    ///
858    /// # Returns
859    /// A new expression tree with the replacements applied
860    pub fn insert<F>(&self, predicate: F, replacement: &Expr) -> Box<Expr>
861    where
862        F: Fn(&Expr) -> bool + Clone,
863    {
864        if predicate(self) {
865            Box::new(replacement.clone())
866        } else {
867            match self {
868                Expr::Const(_) | Expr::Var(_) => Box::new(self.clone()),
869                Expr::Add(left, right) => Box::new(Expr::Add(
870                    left.insert(predicate.clone(), replacement),
871                    right.insert(predicate, replacement),
872                )),
873                Expr::Mul(left, right) => Box::new(Expr::Mul(
874                    left.insert(predicate.clone(), replacement),
875                    right.insert(predicate, replacement),
876                )),
877                Expr::Sub(left, right) => Box::new(Expr::Sub(
878                    left.insert(predicate.clone(), replacement),
879                    right.insert(predicate, replacement),
880                )),
881                Expr::Div(left, right) => Box::new(Expr::Div(
882                    left.insert(predicate.clone(), replacement),
883                    right.insert(predicate, replacement),
884                )),
885                Expr::Abs(expr) => Box::new(Expr::Abs(expr.insert(predicate, replacement))),
886                Expr::Pow(base, exp) => {
887                    Box::new(Expr::Pow(base.insert(predicate, replacement), *exp))
888                }
889                Expr::PowFloat(base, exp) => {
890                    Box::new(Expr::PowFloat(base.insert(predicate, replacement), *exp))
891                }
892                Expr::PowExpr(base, exponent) => Box::new(Expr::PowExpr(
893                    base.insert(predicate.clone(), replacement),
894                    exponent.insert(predicate, replacement),
895                )),
896                Expr::Exp(expr) => Box::new(Expr::Exp(expr.insert(predicate, replacement))),
897                Expr::Ln(expr) => Box::new(Expr::Ln(expr.insert(predicate, replacement))),
898                Expr::Sqrt(expr) => Box::new(Expr::Sqrt(expr.insert(predicate, replacement))),
899                Expr::Sin(expr) => Box::new(Expr::Sin(expr.insert(predicate, replacement))),
900                Expr::Cos(expr) => Box::new(Expr::Cos(expr.insert(predicate, replacement))),
901                Expr::Neg(expr) => Box::new(Expr::Neg(expr.insert(predicate, replacement))),
902                Expr::Cached(expr, _) => {
903                    Box::new(Expr::Cached(expr.insert(predicate, replacement), None))
904                }
905            }
906        }
907    }
908
909    /// Converts expression tree to flattened linear operations for efficient evaluation.
910    ///
911    /// This optimization eliminates:
912    /// - Tree traversal overhead
913    /// - Function call overhead  
914    /// - Memory allocation in hot path
915    /// - Variable lookup overhead
916    ///
917    /// The result is a linear sequence of stack-based operations that can be
918    /// executed with minimal overhead.
919    pub fn flatten(&self) -> FlattenedExpr {
920        let mut ops = Vec::new();
921        let mut max_var_index = None;
922
923        // Check if entire expression is constant
924        if let Some(constant) = self.try_evaluate_constant() {
925            return FlattenedExpr {
926                ops: vec![LinearOp::LoadConst(constant)],
927                max_var_index: None,
928                constant_result: Some(constant),
929            };
930        }
931
932        self.flatten_recursive(&mut ops, &mut max_var_index);
933
934        FlattenedExpr {
935            ops,
936            max_var_index,
937            constant_result: None,
938        }
939    }
940
941    /// Tries to evaluate expression as constant (aggressive constant folding)
942    fn try_evaluate_constant(&self) -> Option<f64> {
943        match self {
944            Expr::Const(val) => Some(*val),
945            Expr::Var(_) => None,
946            Expr::Add(left, right) => {
947                Some(left.try_evaluate_constant()? + right.try_evaluate_constant()?)
948            }
949            Expr::Sub(left, right) => {
950                Some(left.try_evaluate_constant()? - right.try_evaluate_constant()?)
951            }
952            Expr::Mul(left, right) => {
953                Some(left.try_evaluate_constant()? * right.try_evaluate_constant()?)
954            }
955            Expr::Div(left, right) => {
956                let r = right.try_evaluate_constant()?;
957                if r.abs() < 1e-10 {
958                    return None;
959                }
960                Some(left.try_evaluate_constant()? / r)
961            }
962            Expr::Abs(expr) => Some(expr.try_evaluate_constant()?.abs()),
963            Expr::Neg(expr) => Some(-expr.try_evaluate_constant()?),
964            Expr::Pow(base, exp) => Some(base.try_evaluate_constant()?.powi(*exp as i32)),
965            Expr::PowFloat(base, exp) => Some(base.try_evaluate_constant()?.powf(*exp)),
966            Expr::PowExpr(base, exponent) => Some(
967                base.try_evaluate_constant()?
968                    .powf(exponent.try_evaluate_constant()?),
969            ),
970            Expr::Exp(expr) => Some(expr.try_evaluate_constant()?.exp()),
971            Expr::Ln(expr) => {
972                let val = expr.try_evaluate_constant()?;
973                if val <= 0.0 {
974                    return None;
975                }
976                Some(val.ln())
977            }
978            Expr::Sqrt(expr) => {
979                let val = expr.try_evaluate_constant()?;
980                if val < 0.0 {
981                    return None;
982                }
983                Some(val.sqrt())
984            }
985            Expr::Sin(expr) => Some(expr.try_evaluate_constant()?.sin()),
986            Expr::Cos(expr) => Some(expr.try_evaluate_constant()?.cos()),
987            Expr::Cached(expr, cached_value) => {
988                cached_value.or_else(|| expr.try_evaluate_constant())
989            }
990        }
991    }
992
993    /// Recursively flattens expression into linear operations
994    fn flatten_recursive(&self, ops: &mut Vec<LinearOp>, max_var_index: &mut Option<u32>) {
995        match self {
996            Expr::Const(val) => {
997                ops.push(LinearOp::LoadConst(*val));
998            }
999
1000            Expr::Var(var_ref) => {
1001                let index = var_ref.index;
1002                *max_var_index = Some(max_var_index.unwrap_or(0).max(index));
1003                ops.push(LinearOp::LoadVar(index));
1004            }
1005
1006            Expr::Add(left, right) => {
1007                left.flatten_recursive(ops, max_var_index);
1008                right.flatten_recursive(ops, max_var_index);
1009                ops.push(LinearOp::Add);
1010            }
1011
1012            Expr::Sub(left, right) => {
1013                left.flatten_recursive(ops, max_var_index);
1014                right.flatten_recursive(ops, max_var_index);
1015                ops.push(LinearOp::Sub);
1016            }
1017
1018            Expr::Mul(left, right) => {
1019                left.flatten_recursive(ops, max_var_index);
1020                right.flatten_recursive(ops, max_var_index);
1021                ops.push(LinearOp::Mul);
1022            }
1023
1024            Expr::Div(left, right) => {
1025                left.flatten_recursive(ops, max_var_index);
1026                right.flatten_recursive(ops, max_var_index);
1027                ops.push(LinearOp::Div);
1028            }
1029
1030            Expr::Abs(expr) => {
1031                expr.flatten_recursive(ops, max_var_index);
1032                ops.push(LinearOp::Abs);
1033            }
1034
1035            Expr::Neg(expr) => {
1036                expr.flatten_recursive(ops, max_var_index);
1037                ops.push(LinearOp::Neg);
1038            }
1039
1040            Expr::Pow(base, exp) => {
1041                base.flatten_recursive(ops, max_var_index);
1042                ops.push(LinearOp::PowConst(*exp));
1043            }
1044
1045            Expr::PowFloat(base, exp) => {
1046                base.flatten_recursive(ops, max_var_index);
1047                ops.push(LinearOp::PowFloat(*exp));
1048            }
1049
1050            Expr::PowExpr(base, exponent) => {
1051                base.flatten_recursive(ops, max_var_index);
1052                exponent.flatten_recursive(ops, max_var_index);
1053                ops.push(LinearOp::PowExpr);
1054            }
1055
1056            Expr::Exp(expr) => {
1057                expr.flatten_recursive(ops, max_var_index);
1058                ops.push(LinearOp::Exp);
1059            }
1060
1061            Expr::Ln(expr) => {
1062                expr.flatten_recursive(ops, max_var_index);
1063                ops.push(LinearOp::Ln);
1064            }
1065
1066            Expr::Sqrt(expr) => {
1067                expr.flatten_recursive(ops, max_var_index);
1068                ops.push(LinearOp::Sqrt);
1069            }
1070
1071            Expr::Sin(expr) => {
1072                expr.flatten_recursive(ops, max_var_index);
1073                ops.push(LinearOp::Sin);
1074            }
1075
1076            Expr::Cos(expr) => {
1077                expr.flatten_recursive(ops, max_var_index);
1078                ops.push(LinearOp::Cos);
1079            }
1080
1081            Expr::Cached(expr, cached_value) => {
1082                if let Some(val) = cached_value {
1083                    ops.push(LinearOp::LoadConst(*val));
1084                } else {
1085                    expr.flatten_recursive(ops, max_var_index);
1086                }
1087            }
1088        }
1089    }
1090
1091    /// Generates ultra-optimized linear code from flattened operations.
1092    ///
1093    /// This eliminates all function call overhead by generating a single
1094    /// linear sequence of optimal instructions with direct register allocation.
1095    pub fn codegen_flattened(
1096        &self,
1097        builder: &mut FunctionBuilder,
1098        module: &mut dyn Module,
1099    ) -> Result<Value, EquationError> {
1100        let flattened = self.flatten();
1101
1102        // Fast path for constant expressions
1103        if let Some(constant) = flattened.constant_result {
1104            return Ok(builder.ins().f64const(constant));
1105        }
1106
1107        // Pre-allocate stack for operations (eliminates allocations)
1108        let mut value_stack = Vec::with_capacity(flattened.ops.len());
1109
1110        // Get input pointer once
1111        let input_ptr = builder
1112            .func
1113            .dfg
1114            .block_params(builder.current_block().unwrap())[0];
1115
1116        // Execute linear operations with optimal code generation
1117        for op in &flattened.ops {
1118            match op {
1119                LinearOp::LoadConst(val) => {
1120                    value_stack.push(builder.ins().f64const(*val));
1121                }
1122
1123                LinearOp::LoadVar(index) => {
1124                    let offset = (*index as i32) * 8;
1125                    let memflags = MemFlags::new().with_aligned().with_readonly().with_notrap();
1126                    let val =
1127                        builder
1128                            .ins()
1129                            .load(types::F64, memflags, input_ptr, Offset32::new(offset));
1130                    value_stack.push(val);
1131                }
1132
1133                LinearOp::Add => {
1134                    let rhs = value_stack.pop().unwrap();
1135                    let lhs = value_stack.pop().unwrap();
1136                    value_stack.push(builder.ins().fadd(lhs, rhs));
1137                }
1138
1139                LinearOp::Sub => {
1140                    let rhs = value_stack.pop().unwrap();
1141                    let lhs = value_stack.pop().unwrap();
1142                    value_stack.push(builder.ins().fsub(lhs, rhs));
1143                }
1144
1145                LinearOp::Mul => {
1146                    let rhs = value_stack.pop().unwrap();
1147                    let lhs = value_stack.pop().unwrap();
1148                    value_stack.push(builder.ins().fmul(lhs, rhs));
1149                }
1150
1151                LinearOp::Div => {
1152                    let rhs = value_stack.pop().unwrap();
1153                    let lhs = value_stack.pop().unwrap();
1154                    value_stack.push(builder.ins().fdiv(lhs, rhs));
1155                }
1156
1157                LinearOp::Abs => {
1158                    let val = value_stack.pop().unwrap();
1159                    value_stack.push(builder.ins().fabs(val));
1160                }
1161
1162                LinearOp::Neg => {
1163                    let val = value_stack.pop().unwrap();
1164                    value_stack.push(builder.ins().fneg(val));
1165                }
1166
1167                LinearOp::PowConst(exp) => {
1168                    let base = value_stack.pop().unwrap();
1169                    let result = match *exp {
1170                        0 => builder.ins().f64const(1.0),
1171                        1 => base,
1172                        2 => builder.ins().fmul(base, base),
1173                        3 => {
1174                            let square = builder.ins().fmul(base, base);
1175                            builder.ins().fmul(square, base)
1176                        }
1177                        4 => {
1178                            let square = builder.ins().fmul(base, base);
1179                            builder.ins().fmul(square, square)
1180                        }
1181                        -1 => {
1182                            let one = builder.ins().f64const(1.0);
1183                            builder.ins().fdiv(one, base)
1184                        }
1185                        -2 => {
1186                            let square = builder.ins().fmul(base, base);
1187                            let one = builder.ins().f64const(1.0);
1188                            builder.ins().fdiv(one, square)
1189                        }
1190                        _ => {
1191                            // For other exponents, use optimized binary exponentiation
1192                            generate_optimized_power(builder, base, *exp)
1193                        }
1194                    };
1195                    value_stack.push(result);
1196                }
1197
1198                LinearOp::PowFloat(exp) => {
1199                    let base = value_stack.pop().unwrap();
1200                    let func_id = crate::operators::pow::link_powf(module).unwrap();
1201                    let exp_val = builder.ins().f64const(*exp);
1202                    let result =
1203                        crate::operators::pow::call_powf(builder, module, func_id, base, exp_val);
1204                    value_stack.push(result);
1205                }
1206
1207                LinearOp::PowExpr => {
1208                    let exponent = value_stack.pop().unwrap();
1209                    let base = value_stack.pop().unwrap();
1210                    let func_id = crate::operators::pow::link_powf(module).unwrap();
1211                    let result =
1212                        crate::operators::pow::call_powf(builder, module, func_id, base, exponent);
1213                    value_stack.push(result);
1214                }
1215
1216                LinearOp::Exp => {
1217                    let arg = value_stack.pop().unwrap();
1218                    let func_id = operators::exp::link_exp(module).unwrap();
1219                    let result = operators::exp::call_exp(builder, module, func_id, arg);
1220                    value_stack.push(result);
1221                }
1222
1223                LinearOp::Ln => {
1224                    let arg = value_stack.pop().unwrap();
1225                    let func_id = operators::ln::link_ln(module).unwrap();
1226                    let result = operators::ln::call_ln(builder, module, func_id, arg);
1227                    value_stack.push(result);
1228                }
1229
1230                LinearOp::Sqrt => {
1231                    let arg = value_stack.pop().unwrap();
1232                    let func_id = operators::sqrt::link_sqrt(module).unwrap();
1233                    let result = operators::sqrt::call_sqrt(builder, module, func_id, arg);
1234                    value_stack.push(result);
1235                }
1236
1237                LinearOp::Sin => {
1238                    let arg = value_stack.pop().unwrap();
1239                    let func_id = crate::operators::trigonometric::link_sin(module).unwrap();
1240                    let result =
1241                        crate::operators::trigonometric::call_sin(builder, module, func_id, arg);
1242                    value_stack.push(result);
1243                }
1244
1245                LinearOp::Cos => {
1246                    let arg = value_stack.pop().unwrap();
1247                    let func_id = crate::operators::trigonometric::link_cos(module).unwrap();
1248                    let result =
1249                        crate::operators::trigonometric::call_cos(builder, module, func_id, arg);
1250                    value_stack.push(result);
1251                }
1252            }
1253        }
1254
1255        // Return final result
1256        Ok(value_stack.pop().unwrap())
1257    }
1258}
1259
1260/// Generates optimized power operation using binary exponentiation
1261fn generate_optimized_power(builder: &mut FunctionBuilder, base: Value, exp: i64) -> Value {
1262    if exp == 0 {
1263        return builder.ins().f64const(1.0);
1264    }
1265
1266    if exp == 1 {
1267        return base;
1268    }
1269
1270    let abs_exp = exp.abs();
1271    let mut result = builder.ins().f64const(1.0);
1272    let mut current_base = base;
1273    let mut remaining = abs_exp;
1274
1275    // Binary exponentiation - optimal for any exponent
1276    while remaining > 0 {
1277        if remaining & 1 == 1 {
1278            result = builder.ins().fmul(result, current_base);
1279        }
1280        if remaining > 1 {
1281            current_base = builder.ins().fmul(current_base, current_base);
1282        }
1283        remaining >>= 1;
1284    }
1285
1286    if exp < 0 {
1287        let one = builder.ins().f64const(1.0);
1288        builder.ins().fdiv(one, result)
1289    } else {
1290        result
1291    }
1292}
1293
1294/// Implements string formatting for expressions.
1295///
1296/// This implementation converts expressions to their standard mathematical notation:
1297/// - Constants are formatted as numbers
1298/// - Variables are formatted as their names
1299/// - Binary operations (+,-,*,/) are wrapped in parentheses
1300/// - Functions (exp, ln, sqrt) use function call notation
1301/// - Absolute value uses |x| notation
1302/// - Exponents use ^
1303/// - Negation uses - prefix
1304/// - Cached expressions display their underlying expression
1305impl std::fmt::Display for Expr {
1306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1307        match self {
1308            Expr::Const(val) => write!(f, "{val}"),
1309            Expr::Var(var_ref) => write!(f, "{0}", var_ref.name),
1310            Expr::Add(left, right) => write!(f, "({left} + {right})"),
1311            Expr::Mul(left, right) => write!(f, "({left} * {right})"),
1312            Expr::Sub(left, right) => write!(f, "({left} - {right})"),
1313            Expr::Div(left, right) => write!(f, "({left} / {right})"),
1314            Expr::Abs(expr) => write!(f, "|{expr}|"),
1315            Expr::Pow(base, exp) => write!(f, "({base}^{exp})"),
1316            Expr::PowFloat(base, exp) => write!(f, "({base}^{exp})"),
1317            Expr::PowExpr(base, exponent) => write!(f, "({base}^{exponent})"),
1318            Expr::Exp(expr) => write!(f, "exp({expr})"),
1319            Expr::Ln(expr) => write!(f, "ln({expr})"),
1320            Expr::Sqrt(expr) => write!(f, "sqrt({expr})"),
1321            Expr::Sin(expr) => write!(f, "sin({expr})"),
1322            Expr::Cos(expr) => write!(f, "cos({expr})"),
1323            Expr::Neg(expr) => write!(f, "-({expr})"),
1324            Expr::Cached(expr, _) => write!(f, "{expr}"),
1325        }
1326    }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331    use super::*;
1332
1333    // Helper function to create a variable
1334    fn var(name: &str) -> Box<Expr> {
1335        Box::new(Expr::Var(VarRef {
1336            name: name.to_string(),
1337            vec_ref: Value::from_u32(0),
1338            index: 0,
1339        }))
1340    }
1341
1342    #[test]
1343    fn test_simplify() {
1344        // Helper function to create a variable
1345        fn var(name: &str) -> Box<Expr> {
1346            Box::new(Expr::Var(VarRef {
1347                name: name.to_string(),
1348                vec_ref: Value::from_u32(0), // Dummy value for testing
1349                index: 0,
1350            }))
1351        }
1352
1353        // Test constant folding
1354        // 2 + 3 → 5
1355        assert_eq!(
1356            *Expr::Add(Box::new(Expr::Const(2.0)), Box::new(Expr::Const(3.0))).simplify(),
1357            Expr::Const(5.0)
1358        );
1359
1360        // Test additive identity
1361        // x + 0 → x
1362        assert_eq!(
1363            *Expr::Add(var("x"), Box::new(Expr::Const(0.0))).simplify(),
1364            *var("x")
1365        );
1366
1367        // Test multiplicative identity
1368        // x * 1 → x
1369        assert_eq!(
1370            *Expr::Mul(var("x"), Box::new(Expr::Const(1.0))).simplify(),
1371            *var("x")
1372        );
1373
1374        // Test multiplication by zero
1375        // x * 0 → 0
1376        assert_eq!(
1377            *Expr::Mul(var("x"), Box::new(Expr::Const(0.0))).simplify(),
1378            Expr::Const(0.0)
1379        );
1380
1381        // Test division identity
1382        // x / 1 → x
1383        assert_eq!(
1384            *Expr::Div(var("x"), Box::new(Expr::Const(1.0))).simplify(),
1385            *var("x")
1386        );
1387
1388        // Test division by self
1389        // x / x → 1
1390        assert_eq!(*Expr::Div(var("x"), var("x")).simplify(), Expr::Const(1.0));
1391
1392        // Test exponent simplification
1393        // x^0 → 1
1394        assert_eq!(*Expr::Pow(var("x"), 0).simplify(), Expr::Const(1.0));
1395        // x^1 → x
1396        assert_eq!(*Expr::Pow(var("x"), 1).simplify(), *var("x"));
1397
1398        // Test absolute value of constant
1399        // |-3| → 3
1400        assert_eq!(
1401            *Expr::Abs(Box::new(Expr::Const(-3.0))).simplify(),
1402            Expr::Const(3.0)
1403        );
1404
1405        // Test nested absolute value
1406        // ||x|| → |x|
1407        assert_eq!(
1408            *Expr::Abs(Box::new(Expr::Abs(var("x")))).simplify(),
1409            Expr::Abs(var("x"))
1410        );
1411    }
1412
1413    #[test]
1414    fn test_insert() {
1415        // Helper function to create a variable
1416        fn var(name: &str) -> Box<Expr> {
1417            Box::new(Expr::Var(VarRef {
1418                name: name.to_string(),
1419                vec_ref: Value::from_u32(0),
1420                index: 0,
1421            }))
1422        }
1423
1424        // Create expression: x + y
1425        let expr = Box::new(Expr::Add(var("x"), var("y")));
1426
1427        // Replace all occurrences of 'x' with '2*z'
1428        let replacement = Box::new(Expr::Mul(Box::new(Expr::Const(2.0)), var("z")));
1429
1430        let result = expr.insert(|e| matches!(e, Expr::Var(v) if v.name == "x"), &replacement);
1431
1432        // Expected: (2*z) + y
1433        assert_eq!(
1434            *result,
1435            Expr::Add(
1436                Box::new(Expr::Mul(Box::new(Expr::Const(2.0)), var("z"),)),
1437                var("y"),
1438            )
1439        );
1440    }
1441
1442    #[test]
1443    fn test_derivative() {
1444        // Test constant derivative
1445        assert_eq!(*Expr::Const(5.0).derivative("x"), Expr::Const(0.0));
1446
1447        // Test variable derivatives (x)' = 1, (y)' = 0
1448        assert_eq!(*var("x").derivative("x"), Expr::Const(1.0));
1449        assert_eq!(*var("y").derivative("x"), Expr::Const(0.0));
1450
1451        // Test sum rule (u+v)' = u' + v'
1452        let sum = Box::new(Expr::Add(var("x"), var("y")));
1453        assert_eq!(
1454            *sum.derivative("x"),
1455            Expr::Add(Box::new(Expr::Const(1.0)), Box::new(Expr::Const(0.0)))
1456        );
1457
1458        // Test product rule (u*v)' = u'*v + u*v'
1459        let product = Box::new(Expr::Mul(var("x"), var("y")));
1460        assert_eq!(
1461            *product.derivative("x"),
1462            Expr::Add(
1463                Box::new(Expr::Mul(var("x"), Box::new(Expr::Const(0.0)))),
1464                Box::new(Expr::Mul(var("y"), Box::new(Expr::Const(1.0))))
1465            )
1466        );
1467
1468        // Test power rule (u^v)' = u'*v*u^(v-1) + ln(u)*u^v*v'
1469        let power = Box::new(Expr::Pow(var("x"), 3));
1470        assert_eq!(
1471            *power.derivative("x"),
1472            Expr::Mul(
1473                Box::new(Expr::Mul(
1474                    Box::new(Expr::Const(3.0)),
1475                    Box::new(Expr::Pow(var("x"), 2))
1476                )),
1477                Box::new(Expr::Const(1.0))
1478            )
1479        );
1480    }
1481
1482    #[test]
1483    fn test_complex_simplifications() {
1484        // Test nested operations: (x + 0) * (y + 0) → x * y
1485        let expr = Box::new(Expr::Mul(
1486            Box::new(Expr::Add(var("x"), Box::new(Expr::Const(0.0)))),
1487            Box::new(Expr::Add(var("y"), Box::new(Expr::Const(0.0)))),
1488        ));
1489        assert_eq!(*expr.simplify(), Expr::Mul(var("x"), var("y")));
1490
1491        // Test double negation: -(-x) → x
1492        let expr = Box::new(Expr::Neg(Box::new(Expr::Neg(var("x")))));
1493        assert_eq!(*expr.simplify(), *var("x"));
1494
1495        // Test multiplication by 1: (1 * x) * (y * 1) → x * y
1496        let expr = Box::new(Expr::Mul(
1497            Box::new(Expr::Mul(Box::new(Expr::Const(1.0)), var("x"))),
1498            Box::new(Expr::Mul(var("y"), Box::new(Expr::Const(1.0)))),
1499        ));
1500        assert_eq!(*expr.simplify(), Expr::Mul(var("x"), var("y")));
1501
1502        // Test division simplification: (x/y)/(x/y) → 1
1503        let div = Box::new(Expr::Div(var("x"), var("y")));
1504        let expr = Box::new(Expr::Div(div.clone(), div));
1505        assert_eq!(*expr.simplify(), Expr::Const(1.0));
1506    }
1507
1508    #[test]
1509    fn test_special_functions() {
1510        // Test abs(abs(x)) simplification to abs(x)
1511        let expr = Box::new(Expr::Abs(Box::new(Expr::Abs(var("x")))));
1512        assert_eq!(*expr.simplify(), Expr::Abs(var("x")));
1513
1514        // Test sqrt(x^2) - should simplify to abs(x)
1515        let expr = Box::new(Expr::Sqrt(Box::new(Expr::Pow(var("x"), 2))));
1516        assert_eq!(*expr.simplify(), Expr::Abs(var("x")));
1517
1518        // Test constant special functions
1519        // exp(0) = 1
1520        assert_eq!(
1521            *Expr::Exp(Box::new(Expr::Const(0.0))).simplify(),
1522            Expr::Const(1.0)
1523        );
1524        // ln(1) = 0
1525        assert_eq!(
1526            *Expr::Ln(Box::new(Expr::Const(1.0))).simplify(),
1527            Expr::Const(0.0)
1528        );
1529    }
1530
1531    #[test]
1532    fn test_display() {
1533        // Test basic expressions
1534        assert_eq!(format!("{}", Expr::Const(5.0)), "5");
1535        assert_eq!(format!("{}", *var("x")), "x");
1536
1537        // Test binary operations
1538        let sum = Expr::Add(var("x"), var("y"));
1539        assert_eq!(format!("{sum}"), "(x + y)");
1540
1541        let product = Expr::Mul(var("x"), var("y"));
1542        assert_eq!(format!("{product}"), "(x * y)");
1543
1544        // Test special functions
1545        let exp = Expr::Exp(var("x"));
1546        assert_eq!(format!("{exp}"), "exp(x)");
1547
1548        let abs = Expr::Abs(var("x"));
1549        assert_eq!(format!("{abs}"), "|x|");
1550
1551        // Test complex expression
1552        let complex = Expr::Div(
1553            Box::new(Expr::Add(Box::new(Expr::Pow(var("x"), 2)), var("y"))),
1554            var("z"),
1555        );
1556        assert_eq!(format!("{complex}"), "(((x^2) + y) / z)");
1557    }
1558
1559    #[test]
1560    fn test_cached_expressions() {
1561        // Test cached constant
1562        let cached = Box::new(Expr::Cached(Box::new(Expr::Const(5.0)), Some(5.0)));
1563        assert_eq!(*cached.simplify(), *cached);
1564
1565        // Test uncached expression simplification
1566        let uncached = Box::new(Expr::Cached(
1567            Box::new(Expr::Add(
1568                Box::new(Expr::Const(2.0)),
1569                Box::new(Expr::Const(3.0)),
1570            )),
1571            None,
1572        ));
1573        assert_eq!(*uncached.simplify(), Expr::Const(5.0));
1574    }
1575}