mathcompile/
final_tagless.rs

1//! Final Tagless Approach for Symbolic Mathematical Expressions
2//!
3//! This module implements the final tagless approach to solve the expression problem in symbolic
4//! mathematics. The final tagless approach uses traits with Generic Associated Types (GATs) to
5//! represent mathematical operations, enabling both easy extension of operations and interpreters
6//! without modifying existing code.
7//!
8//! # Technical Motivation
9//!
10//! Traditional approaches to symbolic mathematics face the expression problem: adding new operations
11//! requires modifying existing interpreter code, while adding new interpreters requires modifying
12//! existing operation definitions. The final tagless approach solves this by:
13//!
14//! 1. **Parameterizing representation types**: Operations are defined over abstract representation
15//!    types `Repr<T>`, allowing different interpreters to use different concrete representations
16//! 2. **Trait-based extensibility**: New operations can be added via trait extension without
17//!    modifying existing code
18//! 3. **Zero intermediate representation**: Expressions compile directly to target representations
19//!    without building intermediate ASTs
20//!
21//! # Architecture
22//!
23//! ## Core Traits
24//!
25//! - **`MathExpr`**: Defines basic mathematical operations (arithmetic, transcendental functions)
26//! - **`StatisticalExpr`**: Extends `MathExpr` with statistical functions (logistic, softplus)
27//! - **`NumericType`**: Helper trait bundling common numeric type requirements
28//!
29//! ## Interpreters
30//!
31//! - **`DirectEval`**: Immediate evaluation using native Rust operations (`type Repr<T> = T`)
32//! - **`PrettyPrint`**: String representation generation (`type Repr<T> = String`)
33//!
34//! # Usage Patterns
35//!
36//! ## Polymorphic Expression Definition
37//!
38//! Define mathematical expressions that work with any interpreter:
39//!
40//! ```rust
41//! use mathcompile::final_tagless::*;
42//!
43//! // Define a quadratic function: 2x² + 3x + 1
44//! fn quadratic<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
45//! where
46//!     E::Repr<f64>: Clone,
47//! {
48//!     let a = E::constant(2.0);
49//!     let b = E::constant(3.0);
50//!     let c = E::constant(1.0);
51//!     
52//!     E::add(
53//!         E::add(
54//!             E::mul(a, E::pow(x.clone(), E::constant(2.0))),
55//!             E::mul(b, x)
56//!         ),
57//!         c
58//!     )
59//! }
60//! ```
61//!
62//! ## Direct Evaluation
63//!
64//! Evaluate expressions immediately using native Rust operations:
65//!
66//! ```rust
67//! # use mathcompile::final_tagless::*;
68//! # fn quadratic<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
69//! # where E::Repr<f64>: Clone,
70//! # { E::add(E::add(E::mul(E::constant(2.0), E::pow(x.clone(), E::constant(2.0))), E::mul(E::constant(3.0), x)), E::constant(1.0)) }
71//! let result = quadratic::<DirectEval>(DirectEval::var("x", 2.0));
72//! assert_eq!(result, 15.0); // 2(4) + 3(2) + 1 = 15
73//! ```
74//!
75//! ## Pretty Printing
76//!
77//! Generate human-readable mathematical notation:
78//!
79//! ```rust
80//! # use mathcompile::final_tagless::*;
81//! # fn quadratic<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
82//! # where E::Repr<f64>: Clone,
83//! # { E::add(E::add(E::mul(E::constant(2.0), E::pow(x.clone(), E::constant(2.0))), E::mul(E::constant(3.0), x)), E::constant(1.0)) }
84//! let pretty = quadratic::<PrettyPrint>(PrettyPrint::var("x"));
85//! println!("Expression: {}", pretty);
86//! // Output: "((2 * (x ^ 2)) + (3 * x)) + 1"
87//! ```
88//!
89//! # Extension Example
90//!
91//! Adding new operations requires only trait extension:
92//!
93//! ```rust
94//! use mathcompile::final_tagless::*;
95//! use num_traits::Float;
96//!
97//! // Extend with hyperbolic functions
98//! trait HyperbolicExpr: MathExpr {
99//!     fn tanh<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T>
100//!     where
101//!         Self::Repr<T>: Clone,
102//!     {
103//!         let exp_x = Self::exp(x.clone());
104//!         let exp_neg_x = Self::exp(Self::neg(x));
105//!         let numerator = Self::sub(exp_x.clone(), exp_neg_x.clone());
106//!         let denominator = Self::add(exp_x, exp_neg_x);
107//!         Self::div(numerator, denominator)
108//!     }
109//! }
110//!
111//! // Automatically works with all existing interpreters
112//! impl HyperbolicExpr for DirectEval {}
113//! impl HyperbolicExpr for PrettyPrint {}
114//! ```
115
116use num_traits::Float;
117use std::collections::HashMap;
118use std::fmt::Debug;
119use std::ops::{Add, Div, Mul, Neg, Sub};
120use std::sync::{Arc, RwLock};
121
122/// Helper trait that bundles all the common trait bounds for numeric types
123/// This makes the main `MathExpr` trait much cleaner and easier to read
124pub trait NumericType:
125    Clone + Default + Send + Sync + 'static + std::fmt::Display + std::fmt::Debug
126{
127}
128
129/// Blanket implementation for all types that satisfy the bounds
130impl<T> NumericType for T where
131    T: Clone + Default + Send + Sync + 'static + std::fmt::Display + std::fmt::Debug
132{
133}
134
135/// Core trait for mathematical expressions using Generic Associated Types (GATs)
136/// This follows the final tagless approach where the representation type is parameterized
137/// and works with generic numeric types including AD types
138pub trait MathExpr {
139    /// The representation type parameterized by the value type
140    type Repr<T>;
141
142    /// Create a constant value
143    fn constant<T: NumericType>(value: T) -> Self::Repr<T>;
144
145    /// Create a variable reference by name (registers variable automatically)
146    fn var<T: NumericType>(name: &str) -> Self::Repr<T>;
147
148    /// Create a variable reference by index (for performance-critical code)
149    fn var_by_index<T: NumericType>(index: usize) -> Self::Repr<T>;
150
151    // Arithmetic operations with flexible type parameters
152    /// Addition operation
153    fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
154    where
155        L: NumericType + Add<R, Output = Output>,
156        R: NumericType,
157        Output: NumericType;
158
159    /// Subtraction operation
160    fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
161    where
162        L: NumericType + Sub<R, Output = Output>,
163        R: NumericType,
164        Output: NumericType;
165
166    /// Multiplication operation
167    fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
168    where
169        L: NumericType + Mul<R, Output = Output>,
170        R: NumericType,
171        Output: NumericType;
172
173    /// Division operation
174    fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
175    where
176        L: NumericType + Div<R, Output = Output>,
177        R: NumericType,
178        Output: NumericType;
179
180    /// Power operation
181    fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T>;
182
183    /// Negation operation
184    fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T>;
185
186    /// Natural logarithm
187    fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
188
189    /// Exponential function
190    fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
191
192    /// Square root
193    fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
194
195    /// Sine function
196    fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
197
198    /// Cosine function
199    fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T>;
200}
201
202/// Polynomial evaluation utilities using Horner's method
203///
204/// This module provides efficient polynomial evaluation using the final tagless approach.
205/// Horner's method reduces the number of multiplications and provides better numerical
206/// stability compared to naive polynomial evaluation.
207pub mod polynomial {
208    use super::{MathExpr, NumericType};
209    use std::ops::{Add, Mul, Sub};
210
211    /// Evaluate a polynomial using Horner's method
212    ///
213    /// Given coefficients [a₀, a₁, a₂, ..., aₙ] representing the polynomial:
214    /// a₀ + a₁x + a₂x² + ... + aₙxⁿ
215    ///
216    /// Horner's method evaluates this as:
217    /// a₀ + x(a₁ + x(a₂ + x(...)))
218    ///
219    /// This reduces the number of multiplications from O(n²) to O(n) and
220    /// provides better numerical stability.
221    ///
222    /// # Examples
223    ///
224    /// ```rust
225    /// use mathcompile::final_tagless::{DirectEval, polynomial::horner};
226    ///
227    /// // Evaluate 1 + 3x + 2x² at x = 2
228    /// let coeffs = [1.0, 3.0, 2.0]; // [constant, x, x²]
229    /// let x = DirectEval::var("x", 2.0);
230    /// let result = horner::<DirectEval, f64>(&coeffs, x);
231    /// assert_eq!(result, 15.0); // 1 + 3(2) + 2(4) = 15
232    /// ```
233    ///
234    /// # Type Parameters
235    ///
236    /// - `E`: The expression interpreter (`DirectEval`, `PrettyPrint`, etc.)
237    /// - `T`: The numeric type (f64, f32, etc.)
238    pub fn horner<E: MathExpr, T>(coeffs: &[T], x: E::Repr<T>) -> E::Repr<T>
239    where
240        T: NumericType + Clone + Add<Output = T> + Mul<Output = T>,
241        E::Repr<T>: Clone,
242    {
243        if coeffs.is_empty() {
244            return E::constant(T::default());
245        }
246
247        if coeffs.len() == 1 {
248            return E::constant(coeffs[0].clone());
249        }
250
251        // Start with the highest degree coefficient (last in ascending order)
252        let mut result = E::constant(coeffs[coeffs.len() - 1].clone());
253
254        // Work backwards through the coefficients (from highest to lowest degree)
255        for coeff in coeffs.iter().rev().skip(1) {
256            result = E::add(E::mul(result, x.clone()), E::constant(coeff.clone()));
257        }
258
259        result
260    }
261
262    /// Evaluate a polynomial with explicit coefficients using Horner's method
263    ///
264    /// This is a convenience function for when you want to specify coefficients
265    /// as expression representations rather than raw values.
266    ///
267    /// # Examples
268    ///
269    /// ```rust
270    /// use mathcompile::final_tagless::{DirectEval, MathExpr, polynomial::horner_expr};
271    ///
272    /// // Evaluate 1 + 3x + 2x² at x = 2
273    /// let coeffs = [
274    ///     DirectEval::constant(1.0), // constant term
275    ///     DirectEval::constant(3.0), // x coefficient  
276    ///     DirectEval::constant(2.0), // x² coefficient
277    /// ];
278    /// let x = DirectEval::var("x", 2.0);
279    /// let result = horner_expr::<DirectEval, f64>(&coeffs, x);
280    /// assert_eq!(result, 15.0);
281    /// ```
282    pub fn horner_expr<E: MathExpr, T>(coeffs: &[E::Repr<T>], x: E::Repr<T>) -> E::Repr<T>
283    where
284        T: NumericType + Add<Output = T> + Mul<Output = T>,
285        E::Repr<T>: Clone,
286    {
287        if coeffs.is_empty() {
288            return E::constant(T::default());
289        }
290
291        if coeffs.len() == 1 {
292            return coeffs[0].clone();
293        }
294
295        // Start with the highest degree coefficient
296        let mut result = coeffs[coeffs.len() - 1].clone();
297
298        // Work backwards through the coefficients
299        for coeff in coeffs.iter().rev().skip(1) {
300            result = E::add(E::mul(result, x.clone()), coeff.clone());
301        }
302
303        result
304    }
305
306    /// Create a polynomial from its roots using the final tagless approach
307    ///
308    /// Given roots [r₁, r₂, ..., rₙ], constructs the polynomial:
309    /// (x - r₁)(x - r₂)...(x - rₙ)
310    ///
311    /// # Examples
312    ///
313    /// ```rust
314    /// use mathcompile::final_tagless::{DirectEval, polynomial::from_roots};
315    ///
316    /// // Create polynomial with roots at 1 and 2: (x-1)(x-2) = x² - 3x + 2
317    /// let roots = [1.0, 2.0];
318    /// let x = DirectEval::var("x", 0.0);
319    /// let poly = from_roots::<DirectEval, f64>(&roots, x);
320    /// // At x=0: (0-1)(0-2) = 2
321    /// assert_eq!(poly, 2.0);
322    /// ```
323    pub fn from_roots<E: MathExpr, T>(roots: &[T], x: E::Repr<T>) -> E::Repr<T>
324    where
325        T: NumericType + Clone + Sub<Output = T> + num_traits::One,
326        E::Repr<T>: Clone,
327    {
328        if roots.is_empty() {
329            return E::constant(num_traits::One::one());
330        }
331
332        let mut result = E::sub(x.clone(), E::constant(roots[0].clone()));
333
334        for root in roots.iter().skip(1) {
335            let factor = E::sub(x.clone(), E::constant(root.clone()));
336            result = E::mul(result, factor);
337        }
338
339        result
340    }
341
342    /// Evaluate the derivative of a polynomial using Horner's method
343    ///
344    /// Given coefficients [a₀, a₁, a₂, ..., aₙ] representing:
345    /// a₀ + a₁x + a₂x² + ... + aₙxⁿ
346    ///
347    /// The derivative is: a₁ + 2a₂x + 3a₃x² + ... + naₙx^(n-1)
348    ///
349    /// # Examples
350    ///
351    /// ```rust
352    /// use mathcompile::final_tagless::{DirectEval, polynomial::horner_derivative};
353    ///
354    /// // Derivative of 1 + 3x + 2x² is 3 + 4x
355    /// let coeffs = [1.0, 3.0, 2.0]; // [constant, x, x²]
356    /// let x = DirectEval::var("x", 2.0);
357    /// let result = horner_derivative::<DirectEval, f64>(&coeffs, x);
358    /// assert_eq!(result, 11.0); // 3 + 4(2) = 11
359    /// ```
360    pub fn horner_derivative<E: MathExpr, T>(coeffs: &[T], x: E::Repr<T>) -> E::Repr<T>
361    where
362        T: NumericType + Clone + Add<Output = T> + Mul<Output = T> + num_traits::FromPrimitive,
363        E::Repr<T>: Clone,
364    {
365        if coeffs.len() <= 1 {
366            return E::constant(T::default());
367        }
368
369        // Create derivative coefficients: [a₁, 2a₂, 3a₃, ...]
370        let mut deriv_coeffs = Vec::with_capacity(coeffs.len() - 1);
371        for (i, coeff) in coeffs.iter().enumerate().skip(1) {
372            // Multiply coefficient by its power
373            let power = num_traits::FromPrimitive::from_usize(i).unwrap_or_else(|| T::default());
374            deriv_coeffs.push(coeff.clone() * power);
375        }
376
377        horner::<E, T>(&deriv_coeffs, x)
378    }
379}
380
381/// Direct evaluation interpreter for immediate computation
382///
383/// This interpreter provides immediate evaluation of mathematical expressions using native Rust
384/// operations. It represents expressions directly as their computed values (`type Repr<T> = T`),
385/// making it the simplest and most straightforward interpreter implementation.
386///
387/// # Characteristics
388///
389/// - **Zero overhead**: Direct mapping to native Rust operations
390/// - **Immediate evaluation**: No intermediate representation or compilation step
391/// - **Type preservation**: Works with any numeric type that implements required traits
392/// - **Reference implementation**: Serves as the canonical behavior for other interpreters
393///
394/// # Usage Patterns
395///
396/// ## Simple Expression Evaluation
397///
398/// ```rust
399/// use mathcompile::final_tagless::{DirectEval, MathExpr};
400///
401/// // Define a mathematical function
402/// fn polynomial<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
403/// where
404///     E::Repr<f64>: Clone,
405/// {
406///     // 3x² + 2x + 1
407///     let x_squared = E::pow(x.clone(), E::constant(2.0));
408///     let three_x_squared = E::mul(E::constant(3.0), x_squared);
409///     let two_x = E::mul(E::constant(2.0), x);
410///     E::add(E::add(three_x_squared, two_x), E::constant(1.0))
411/// }
412///
413/// // Evaluate directly with a specific value
414/// let result = polynomial::<DirectEval>(DirectEval::var("x", 2.0));
415/// assert_eq!(result, 17.0); // 3(4) + 2(2) + 1 = 17
416/// ```
417///
418/// ## Working with Different Numeric Types
419///
420/// ```rust
421/// # use mathcompile::final_tagless::{DirectEval, MathExpr, NumericType};
422/// // Function that works with any numeric type
423/// fn linear<E: MathExpr, T>(x: E::Repr<T>, slope: T, intercept: T) -> E::Repr<T>
424/// where
425///     T: Clone + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + NumericType,
426/// {
427///     E::add(E::mul(E::constant(slope), x), E::constant(intercept))
428/// }
429///
430/// // Works with f32
431/// let result_f32 = linear::<DirectEval, f32>(
432///     DirectEval::var("x", 3.0_f32),
433///     2.0_f32,
434///     1.0_f32
435/// );
436/// assert_eq!(result_f32, 7.0_f32);
437///
438/// // Works with f64
439/// let result_f64 = linear::<DirectEval, f64>(
440///     DirectEval::var("x", 3.0_f64),
441///     2.0_f64,
442///     1.0_f64
443/// );
444/// assert_eq!(result_f64, 7.0_f64);
445/// ```
446///
447/// ## Testing and Validation
448///
449/// `DirectEval` is particularly useful for testing the correctness of expressions
450/// before using them with other interpreters:
451///
452/// ```rust
453/// # use mathcompile::final_tagless::{DirectEval, MathExpr, StatisticalExpr};
454/// // Test a statistical function
455/// fn test_logistic<E: StatisticalExpr>(x: E::Repr<f64>) -> E::Repr<f64> {
456///     E::logistic(x)
457/// }
458///
459/// // Verify known values
460/// let result_zero = test_logistic::<DirectEval>(DirectEval::var("x", 0.0));
461/// assert!((result_zero - 0.5).abs() < 1e-10); // logistic(0) = 0.5
462///
463/// let result_large = test_logistic::<DirectEval>(DirectEval::var("x", 10.0));
464/// assert!(result_large > 0.99); // logistic(10) ≈ 1.0
465/// ```
466pub struct DirectEval;
467
468impl DirectEval {
469    /// Create a variable with a specific value for direct evaluation
470    /// Note: This no longer registers variables globally - use `ExpressionBuilder` for that
471    #[must_use]
472    pub fn var<T: NumericType>(name: &str, value: T) -> T {
473        value
474    }
475
476    /// Create a variable by index with a specific value (for performance)
477    #[must_use]
478    pub fn var_by_index<T: NumericType>(_index: usize, value: T) -> T {
479        value
480    }
481
482    /// Evaluate an expression with variables provided as a vector (efficient)
483    #[must_use]
484    pub fn eval_with_vars<T: NumericType + Float + Copy>(expr: &ASTRepr<T>, variables: &[T]) -> T {
485        Self::eval_vars_optimized(expr, variables)
486    }
487
488    /// Optimized variable evaluation without additional allocations
489    #[must_use]
490    pub fn eval_vars_optimized<T: NumericType + Float + Copy>(
491        expr: &ASTRepr<T>,
492        variables: &[T],
493    ) -> T {
494        match expr {
495            ASTRepr::Constant(value) => *value,
496            ASTRepr::Variable(index) => variables.get(*index).copied().unwrap_or_else(|| T::zero()),
497            ASTRepr::Add(left, right) => {
498                Self::eval_vars_optimized(left, variables)
499                    + Self::eval_vars_optimized(right, variables)
500            }
501            ASTRepr::Sub(left, right) => {
502                Self::eval_vars_optimized(left, variables)
503                    - Self::eval_vars_optimized(right, variables)
504            }
505            ASTRepr::Mul(left, right) => {
506                Self::eval_vars_optimized(left, variables)
507                    * Self::eval_vars_optimized(right, variables)
508            }
509            ASTRepr::Div(left, right) => {
510                Self::eval_vars_optimized(left, variables)
511                    / Self::eval_vars_optimized(right, variables)
512            }
513            ASTRepr::Pow(base, exp) => Self::eval_vars_optimized(base, variables)
514                .powf(Self::eval_vars_optimized(exp, variables)),
515            ASTRepr::Neg(inner) => -Self::eval_vars_optimized(inner, variables),
516            ASTRepr::Ln(inner) => Self::eval_vars_optimized(inner, variables).ln(),
517            ASTRepr::Exp(inner) => Self::eval_vars_optimized(inner, variables).exp(),
518            ASTRepr::Sin(inner) => Self::eval_vars_optimized(inner, variables).sin(),
519            ASTRepr::Cos(inner) => Self::eval_vars_optimized(inner, variables).cos(),
520            ASTRepr::Sqrt(inner) => Self::eval_vars_optimized(inner, variables).sqrt(),
521        }
522    }
523
524    /// Evaluate a two-variable expression with specific values (optimized version)
525    #[must_use]
526    pub fn eval_two_vars(expr: &ASTRepr<f64>, x: f64, y: f64) -> f64 {
527        Self::eval_two_vars_fast(expr, x, y)
528    }
529
530    /// Fast evaluation without heap allocation for two variables
531    #[must_use]
532    pub fn eval_two_vars_fast(expr: &ASTRepr<f64>, x: f64, y: f64) -> f64 {
533        match expr {
534            ASTRepr::Constant(value) => *value,
535            ASTRepr::Variable(index) => match *index {
536                0 => x,
537                1 => y,
538                _ => 0.0, // Default for out-of-bounds
539            },
540            ASTRepr::Add(left, right) => {
541                Self::eval_two_vars_fast(left, x, y) + Self::eval_two_vars_fast(right, x, y)
542            }
543            ASTRepr::Sub(left, right) => {
544                Self::eval_two_vars_fast(left, x, y) - Self::eval_two_vars_fast(right, x, y)
545            }
546            ASTRepr::Mul(left, right) => {
547                Self::eval_two_vars_fast(left, x, y) * Self::eval_two_vars_fast(right, x, y)
548            }
549            ASTRepr::Div(left, right) => {
550                Self::eval_two_vars_fast(left, x, y) / Self::eval_two_vars_fast(right, x, y)
551            }
552            ASTRepr::Pow(base, exp) => {
553                Self::eval_two_vars_fast(base, x, y).powf(Self::eval_two_vars_fast(exp, x, y))
554            }
555            ASTRepr::Neg(inner) => -Self::eval_two_vars_fast(inner, x, y),
556            ASTRepr::Ln(inner) => Self::eval_two_vars_fast(inner, x, y).ln(),
557            ASTRepr::Exp(inner) => Self::eval_two_vars_fast(inner, x, y).exp(),
558            ASTRepr::Sin(inner) => Self::eval_two_vars_fast(inner, x, y).sin(),
559            ASTRepr::Cos(inner) => Self::eval_two_vars_fast(inner, x, y).cos(),
560            ASTRepr::Sqrt(inner) => Self::eval_two_vars_fast(inner, x, y).sqrt(),
561        }
562    }
563}
564
565impl MathExpr for DirectEval {
566    type Repr<T> = T;
567
568    fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
569        value
570    }
571
572    fn var<T: NumericType>(name: &str) -> Self::Repr<T> {
573        // No longer register variables globally - use ExpressionBuilder for that
574        T::default()
575    }
576
577    fn var_by_index<T: NumericType>(_index: usize) -> Self::Repr<T> {
578        T::default()
579    }
580
581    fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
582    where
583        L: NumericType + Add<R, Output = Output>,
584        R: NumericType,
585        Output: NumericType,
586    {
587        left + right
588    }
589
590    fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
591    where
592        L: NumericType + Sub<R, Output = Output>,
593        R: NumericType,
594        Output: NumericType,
595    {
596        left - right
597    }
598
599    fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
600    where
601        L: NumericType + Mul<R, Output = Output>,
602        R: NumericType,
603        Output: NumericType,
604    {
605        left * right
606    }
607
608    fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
609    where
610        L: NumericType + Div<R, Output = Output>,
611        R: NumericType,
612        Output: NumericType,
613    {
614        left / right
615    }
616
617    fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
618        base.powf(exp)
619    }
620
621    fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
622        -expr
623    }
624
625    fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
626        expr.ln()
627    }
628
629    fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
630        expr.exp()
631    }
632
633    fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
634        expr.sqrt()
635    }
636
637    fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
638        expr.sin()
639    }
640
641    fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
642        expr.cos()
643    }
644}
645
646/// Extension trait for statistical operations
647pub trait StatisticalExpr: MathExpr {
648    /// Logistic function: 1 / (1 + exp(-x))
649    fn logistic<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
650        let one = Self::constant(T::one());
651        let neg_x = Self::neg(x);
652        let exp_neg_x = Self::exp(neg_x);
653        let denominator = Self::add(one, exp_neg_x);
654        Self::div(Self::constant(T::one()), denominator)
655    }
656
657    /// Softplus function: ln(1 + exp(x))
658    fn softplus<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
659        let one = Self::constant(T::one());
660        let exp_x = Self::exp(x);
661        let one_plus_exp_x = Self::add(one, exp_x);
662        Self::ln(one_plus_exp_x)
663    }
664
665    /// Sigmoid function (alias for logistic)
666    fn sigmoid<T: NumericType + Float>(x: Self::Repr<T>) -> Self::Repr<T> {
667        Self::logistic(x)
668    }
669}
670
671// Implement StatisticalExpr for DirectEval
672impl StatisticalExpr for DirectEval {}
673
674/// String representation interpreter for mathematical expressions
675///
676/// This interpreter converts final tagless expressions into human-readable mathematical notation.
677/// It generates parenthesized infix expressions that clearly show the structure and precedence
678/// of operations. This is useful for debugging, documentation, and displaying expressions to users.
679///
680/// # Output Format
681///
682/// - **Arithmetic operations**: Infix notation with parentheses `(a + b)`, `(a * b)`
683/// - **Functions**: Function call notation `ln(x)`, `exp(x)`, `sqrt(x)`
684/// - **Variables**: Variable names as provided `x`, `theta`, `data`
685/// - **Constants**: Numeric literals `2`, `3.14159`, `-1.5`
686///
687/// # Usage Examples
688///
689/// ## Basic Expression Formatting
690///
691/// ```rust
692/// use mathcompile::final_tagless::{PrettyPrint, MathExpr};
693///
694/// // Simple quadratic: x² + 2x + 1
695/// fn quadratic<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
696/// where
697///     E::Repr<f64>: Clone,
698/// {
699///     let x_squared = E::pow(x.clone(), E::constant(2.0));
700///     let two_x = E::mul(E::constant(2.0), x);
701///     E::add(E::add(x_squared, two_x), E::constant(1.0))
702/// }
703///
704/// let pretty = quadratic::<PrettyPrint>(PrettyPrint::var("x"));
705/// println!("Quadratic: {}", pretty);
706/// // Output: "((x ^ 2) + (2 * x)) + 1"
707/// ```
708///
709/// ## Complex Mathematical Expressions
710///
711/// ```rust
712/// # use mathcompile::final_tagless::{PrettyPrint, MathExpr, StatisticalExpr};
713/// // Logistic regression: 1 / (1 + exp(-θx))
714/// fn logistic_regression<E: StatisticalExpr>(x: E::Repr<f64>, theta: E::Repr<f64>) -> E::Repr<f64> {
715///     E::logistic(E::mul(theta, x))
716/// }
717///
718/// let pretty = logistic_regression::<PrettyPrint>(
719///     PrettyPrint::var("x"),
720///     PrettyPrint::var("theta")
721/// );
722/// println!("Logistic: {}", pretty);
723/// // Output shows the expanded logistic function structure
724/// ```
725///
726/// ## Transcendental Functions
727///
728/// ```rust
729/// # use mathcompile::final_tagless::{PrettyPrint, MathExpr};
730/// // Gaussian: exp(-x²/2) / sqrt(2π)
731/// fn gaussian_kernel<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
732/// where
733///     E::Repr<f64>: Clone,
734/// {
735///     let x_squared = E::pow(x, E::constant(2.0));
736///     let neg_half_x_squared = E::div(E::neg(x_squared), E::constant(2.0));
737///     let numerator = E::exp(neg_half_x_squared);
738///     let denominator = E::sqrt(E::mul(E::constant(2.0), E::constant(3.14159)));
739///     E::div(numerator, denominator)
740/// }
741///
742/// let pretty = gaussian_kernel::<PrettyPrint>(PrettyPrint::var("x"));
743/// println!("Gaussian: {}", pretty);
744/// // Output: "(exp((-(x ^ 2)) / 2) / sqrt((2 * 3.14159)))"
745/// ```
746pub struct PrettyPrint;
747
748impl PrettyPrint {
749    /// Create a variable for pretty printing
750    #[must_use]
751    pub fn var(name: &str) -> String {
752        name.to_string()
753    }
754}
755
756impl MathExpr for PrettyPrint {
757    type Repr<T> = String;
758
759    fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
760        format!("{value}")
761    }
762
763    fn var<T: NumericType>(name: &str) -> Self::Repr<T> {
764        name.to_string()
765    }
766
767    fn var_by_index<T: NumericType>(_index: usize) -> Self::Repr<T> {
768        T::default().to_string()
769    }
770
771    fn add<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
772    where
773        L: NumericType + Add<R, Output = Output>,
774        R: NumericType,
775        Output: NumericType,
776    {
777        format!("({left} + {right})")
778    }
779
780    fn sub<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
781    where
782        L: NumericType + Sub<R, Output = Output>,
783        R: NumericType,
784        Output: NumericType,
785    {
786        format!("({left} - {right})")
787    }
788
789    fn mul<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
790    where
791        L: NumericType + Mul<R, Output = Output>,
792        R: NumericType,
793        Output: NumericType,
794    {
795        format!("({left} * {right})")
796    }
797
798    fn div<L, R, Output>(left: Self::Repr<L>, right: Self::Repr<R>) -> Self::Repr<Output>
799    where
800        L: NumericType + Div<R, Output = Output>,
801        R: NumericType,
802        Output: NumericType,
803    {
804        format!("({left} / {right})")
805    }
806
807    fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
808        format!("({base} ^ {exp})")
809    }
810
811    fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
812        format!("(-{expr})")
813    }
814
815    fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
816        format!("ln({expr})")
817    }
818
819    fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
820        format!("exp({expr})")
821    }
822
823    fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
824        format!("sqrt({expr})")
825    }
826
827    fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
828        format!("sin({expr})")
829    }
830
831    fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
832        format!("cos({expr})")
833    }
834}
835
836// Implement StatisticalExpr for PrettyPrint
837impl StatisticalExpr for PrettyPrint {}
838
839/// JIT compilation representation for mathematical expressions
840///
841/// This enum represents mathematical expressions in a form suitable for JIT compilation
842/// using Cranelift. Each variant corresponds to a mathematical operation that can be
843/// compiled to native machine code.
844///
845/// # Performance Note
846///
847/// Variables are referenced by index for optimal performance with `DirectEval`,
848/// using vector indexing instead of string lookups:
849///
850/// ```rust
851/// use mathcompile::final_tagless::{ASTRepr, DirectEval};
852///
853/// // Efficient: uses vector indexing
854/// let expr = ASTRepr::Add(
855///     Box::new(ASTRepr::Variable(0)), // x
856///     Box::new(ASTRepr::Variable(1)), // y
857/// );
858/// let result = DirectEval::eval_with_vars(&expr, &[2.0, 3.0]);
859/// assert_eq!(result, 5.0);
860/// ```
861#[derive(Debug, Clone, PartialEq)]
862pub enum ASTRepr<T> {
863    /// Constant value
864    Constant(T),
865    /// Variable reference by index (efficient for evaluation)
866    Variable(usize),
867    /// Addition of two expressions
868    Add(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
869    /// Subtraction of two expressions
870    Sub(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
871    /// Multiplication of two expressions
872    Mul(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
873    /// Division of two expressions
874    Div(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
875    /// Power operation
876    Pow(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
877    /// Negation
878    Neg(Box<ASTRepr<T>>),
879    /// Natural logarithm
880    Ln(Box<ASTRepr<T>>),
881    /// Exponential function
882    Exp(Box<ASTRepr<T>>),
883    /// Square root
884    Sqrt(Box<ASTRepr<T>>),
885    /// Sine function
886    Sin(Box<ASTRepr<T>>),
887    /// Cosine function
888    Cos(Box<ASTRepr<T>>),
889}
890
891impl<T> ASTRepr<T> {
892    /// Count the total number of operations in the expression tree
893    pub fn count_operations(&self) -> usize {
894        match self {
895            ASTRepr::Constant(_) | ASTRepr::Variable(_) => 0,
896            ASTRepr::Add(left, right)
897            | ASTRepr::Sub(left, right)
898            | ASTRepr::Mul(left, right)
899            | ASTRepr::Div(left, right)
900            | ASTRepr::Pow(left, right) => 1 + left.count_operations() + right.count_operations(),
901            ASTRepr::Neg(inner)
902            | ASTRepr::Ln(inner)
903            | ASTRepr::Exp(inner)
904            | ASTRepr::Sin(inner)
905            | ASTRepr::Cos(inner)
906            | ASTRepr::Sqrt(inner) => 1 + inner.count_operations(),
907        }
908    }
909
910    /// Get the variable index if this is a variable, otherwise None
911    pub fn variable_index(&self) -> Option<usize> {
912        match self {
913            ASTRepr::Variable(index) => Some(*index),
914            _ => None,
915        }
916    }
917}
918
919/// JIT evaluation interpreter that builds an intermediate representation
920/// suitable for compilation with Cranelift or Rust codegen
921///
922/// This interpreter constructs a `ASTRepr` tree that can later be compiled
923/// to native machine code for high-performance evaluation.
924pub struct ASTEval;
925
926impl ASTEval {
927    /// Create a variable reference for JIT compilation using an index (efficient)
928    #[must_use]
929    pub fn var<T: NumericType>(index: usize) -> ASTRepr<T> {
930        ASTRepr::Variable(index)
931    }
932
933    /// Convenience method for creating variables by name (for backward compatibility)
934    /// Note: This no longer registers variables - use `ExpressionBuilder` for proper variable management
935    #[must_use]
936    pub fn var_by_name(_name: &str) -> ASTRepr<f64> {
937        // Default to variable index 0 for backward compatibility
938        ASTRepr::Variable(0)
939    }
940}
941
942/// Simplified trait for JIT compilation that works with homogeneous f64 types
943/// This is a practical compromise for JIT compilation while maintaining the final tagless approach
944pub trait ASTMathExpr {
945    /// The representation type for JIT compilation (always f64 for practical reasons)
946    type Repr;
947
948    /// Create a constant value
949    fn constant(value: f64) -> Self::Repr;
950
951    /// Create a variable reference by index
952    fn var(index: usize) -> Self::Repr;
953
954    /// Addition operation
955    fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr;
956
957    /// Subtraction operation
958    fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr;
959
960    /// Multiplication operation
961    fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr;
962
963    /// Division operation
964    fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr;
965
966    /// Power operation
967    fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr;
968
969    /// Negation operation
970    fn neg(expr: Self::Repr) -> Self::Repr;
971
972    /// Natural logarithm
973    fn ln(expr: Self::Repr) -> Self::Repr;
974
975    /// Exponential function
976    fn exp(expr: Self::Repr) -> Self::Repr;
977
978    /// Square root
979    fn sqrt(expr: Self::Repr) -> Self::Repr;
980
981    /// Sine function
982    fn sin(expr: Self::Repr) -> Self::Repr;
983
984    /// Cosine function
985    fn cos(expr: Self::Repr) -> Self::Repr;
986}
987
988impl ASTMathExpr for ASTEval {
989    type Repr = ASTRepr<f64>;
990
991    fn constant(value: f64) -> Self::Repr {
992        ASTRepr::Constant(value)
993    }
994
995    fn var(index: usize) -> Self::Repr {
996        ASTRepr::Variable(index)
997    }
998
999    fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1000        ASTRepr::Add(Box::new(left), Box::new(right))
1001    }
1002
1003    fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1004        ASTRepr::Sub(Box::new(left), Box::new(right))
1005    }
1006
1007    fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1008        ASTRepr::Mul(Box::new(left), Box::new(right))
1009    }
1010
1011    fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1012        ASTRepr::Div(Box::new(left), Box::new(right))
1013    }
1014
1015    fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr {
1016        ASTRepr::Pow(Box::new(base), Box::new(exp))
1017    }
1018
1019    fn neg(expr: Self::Repr) -> Self::Repr {
1020        ASTRepr::Neg(Box::new(expr))
1021    }
1022
1023    fn ln(expr: Self::Repr) -> Self::Repr {
1024        ASTRepr::Ln(Box::new(expr))
1025    }
1026
1027    fn exp(expr: Self::Repr) -> Self::Repr {
1028        ASTRepr::Exp(Box::new(expr))
1029    }
1030
1031    fn sqrt(expr: Self::Repr) -> Self::Repr {
1032        ASTRepr::Sqrt(Box::new(expr))
1033    }
1034
1035    fn sin(expr: Self::Repr) -> Self::Repr {
1036        ASTRepr::Sin(Box::new(expr))
1037    }
1038
1039    fn cos(expr: Self::Repr) -> Self::Repr {
1040        ASTRepr::Cos(Box::new(expr))
1041    }
1042}
1043
1044/// For compatibility with the main `MathExpr` trait, we provide a limited implementation
1045/// that works only with f64 types
1046impl MathExpr for ASTEval {
1047    type Repr<T> = ASTRepr<T>;
1048
1049    fn constant<T: NumericType>(value: T) -> Self::Repr<T> {
1050        ASTRepr::Constant(value)
1051    }
1052
1053    fn var<T: NumericType>(_name: &str) -> Self::Repr<T> {
1054        // Default to variable index 0 for compatibility
1055        ASTRepr::Variable(0)
1056    }
1057
1058    fn var_by_index<T: NumericType>(index: usize) -> Self::Repr<T> {
1059        ASTRepr::Variable(index)
1060    }
1061
1062    fn add<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
1063    where
1064        L: NumericType + Add<R, Output = Output>,
1065        R: NumericType,
1066        Output: NumericType,
1067    {
1068        // This is a placeholder implementation for the generic trait
1069        // In practice, you would use the specific f64 version
1070        unimplemented!("Use ASTMathExpr or ASTMathExprf64 for concrete implementations")
1071    }
1072
1073    fn sub<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
1074    where
1075        L: NumericType + Sub<R, Output = Output>,
1076        R: NumericType,
1077        Output: NumericType,
1078    {
1079        unimplemented!("Use ASTMathExpr or ASTMathExprf64 for concrete implementations")
1080    }
1081
1082    fn mul<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
1083    where
1084        L: NumericType + Mul<R, Output = Output>,
1085        R: NumericType,
1086        Output: NumericType,
1087    {
1088        unimplemented!("Use ASTMathExpr or ASTMathExprf64 for concrete implementations")
1089    }
1090
1091    fn div<L, R, Output>(_left: Self::Repr<L>, _right: Self::Repr<R>) -> Self::Repr<Output>
1092    where
1093        L: NumericType + Div<R, Output = Output>,
1094        R: NumericType,
1095        Output: NumericType,
1096    {
1097        unimplemented!("Use ASTMathExpr or ASTMathExprf64 for concrete implementations")
1098    }
1099
1100    fn pow<T: NumericType + Float>(base: Self::Repr<T>, exp: Self::Repr<T>) -> Self::Repr<T> {
1101        ASTRepr::Pow(Box::new(base), Box::new(exp))
1102    }
1103
1104    fn neg<T: NumericType + Neg<Output = T>>(expr: Self::Repr<T>) -> Self::Repr<T> {
1105        ASTRepr::Neg(Box::new(expr))
1106    }
1107
1108    fn ln<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1109        ASTRepr::Ln(Box::new(expr))
1110    }
1111
1112    fn exp<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1113        ASTRepr::Exp(Box::new(expr))
1114    }
1115
1116    fn sqrt<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1117        ASTRepr::Sqrt(Box::new(expr))
1118    }
1119
1120    fn sin<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1121        ASTRepr::Sin(Box::new(expr))
1122    }
1123
1124    fn cos<T: NumericType + Float>(expr: Self::Repr<T>) -> Self::Repr<T> {
1125        ASTRepr::Cos(Box::new(expr))
1126    }
1127}
1128
1129impl StatisticalExpr for ASTEval {}
1130
1131/// Simplified trait for f64 JIT compilation
1132pub trait ASTMathExprf64 {
1133    /// The representation type for f64 compilation
1134    type Repr;
1135
1136    /// Create a constant value
1137    fn constant(value: f64) -> Self::Repr;
1138
1139    /// Create a variable reference by index
1140    fn var(index: usize) -> Self::Repr;
1141
1142    /// Addition operation
1143    fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr;
1144
1145    /// Subtraction operation
1146    fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr;
1147
1148    /// Multiplication operation
1149    fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr;
1150
1151    /// Division operation
1152    fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr;
1153
1154    /// Power operation
1155    fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr;
1156
1157    /// Negation operation
1158    fn neg(expr: Self::Repr) -> Self::Repr;
1159
1160    /// Natural logarithm
1161    fn ln(expr: Self::Repr) -> Self::Repr;
1162
1163    /// Exponential function
1164    fn exp(expr: Self::Repr) -> Self::Repr;
1165
1166    /// Square root
1167    fn sqrt(expr: Self::Repr) -> Self::Repr;
1168
1169    /// Sine function
1170    fn sin(expr: Self::Repr) -> Self::Repr;
1171
1172    /// Cosine function
1173    fn cos(expr: Self::Repr) -> Self::Repr;
1174}
1175
1176impl ASTMathExprf64 for ASTEval {
1177    type Repr = ASTRepr<f64>;
1178
1179    fn constant(value: f64) -> Self::Repr {
1180        ASTRepr::Constant(value)
1181    }
1182
1183    fn var(index: usize) -> Self::Repr {
1184        ASTRepr::Variable(index)
1185    }
1186
1187    fn add(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1188        ASTRepr::Add(Box::new(left), Box::new(right))
1189    }
1190
1191    fn sub(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1192        ASTRepr::Sub(Box::new(left), Box::new(right))
1193    }
1194
1195    fn mul(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1196        ASTRepr::Mul(Box::new(left), Box::new(right))
1197    }
1198
1199    fn div(left: Self::Repr, right: Self::Repr) -> Self::Repr {
1200        ASTRepr::Div(Box::new(left), Box::new(right))
1201    }
1202
1203    fn pow(base: Self::Repr, exp: Self::Repr) -> Self::Repr {
1204        ASTRepr::Pow(Box::new(base), Box::new(exp))
1205    }
1206
1207    fn neg(expr: Self::Repr) -> Self::Repr {
1208        ASTRepr::Neg(Box::new(expr))
1209    }
1210
1211    fn ln(expr: Self::Repr) -> Self::Repr {
1212        ASTRepr::Ln(Box::new(expr))
1213    }
1214
1215    fn exp(expr: Self::Repr) -> Self::Repr {
1216        ASTRepr::Exp(Box::new(expr))
1217    }
1218
1219    fn sqrt(expr: Self::Repr) -> Self::Repr {
1220        ASTRepr::Sqrt(Box::new(expr))
1221    }
1222
1223    fn sin(expr: Self::Repr) -> Self::Repr {
1224        ASTRepr::Sin(Box::new(expr))
1225    }
1226
1227    fn cos(expr: Self::Repr) -> Self::Repr {
1228        ASTRepr::Cos(Box::new(expr))
1229    }
1230}
1231
1232// ============================================================================
1233// Summation Infrastructure Implementation
1234// ============================================================================
1235
1236/// Trait for range-like types in summations
1237///
1238/// This trait defines the interface for different types of ranges that can be used
1239/// in summations, from simple integer ranges to symbolic ranges with expression bounds.
1240pub trait RangeType: Clone + Send + Sync + 'static + std::fmt::Debug {
1241    /// The type of values in this range
1242    type IndexType: NumericType;
1243
1244    /// Start of the range (inclusive)
1245    fn start(&self) -> Self::IndexType;
1246
1247    /// End of the range (inclusive)  
1248    fn end(&self) -> Self::IndexType;
1249
1250    /// Check if the range contains a value
1251    fn contains(&self, value: &Self::IndexType) -> bool;
1252
1253    /// Get the length of the range (end - start + 1)
1254    fn len(&self) -> Self::IndexType;
1255
1256    /// Check if the range is empty
1257    fn is_empty(&self) -> bool;
1258}
1259
1260/// Simple integer range for summations
1261///
1262/// Represents ranges like 1..=n, 0..=100, etc. This is the most common
1263/// type of range used in mathematical summations.
1264///
1265/// # Examples
1266///
1267/// ```rust
1268/// use mathcompile::final_tagless::{IntRange, RangeType};
1269///
1270/// let range = IntRange::new(1, 10);  // Range from 1 to 10 inclusive
1271/// assert_eq!(range.len(), 10);
1272/// assert!(range.contains(&5));
1273/// assert!(!range.contains(&15));
1274/// ```
1275#[derive(Debug, Clone, PartialEq, Eq)]
1276pub struct IntRange {
1277    pub start: i64,
1278    pub end: i64, // inclusive
1279}
1280
1281impl IntRange {
1282    /// Create a new integer range
1283    #[must_use]
1284    pub fn new(start: i64, end: i64) -> Self {
1285        Self { start, end }
1286    }
1287
1288    /// Create a range from 1 to n (common mathematical convention)
1289    #[must_use]
1290    pub fn one_to_n(n: i64) -> Self {
1291        Self::new(1, n)
1292    }
1293
1294    /// Create a range from 0 to n-1 (common programming convention)
1295    #[must_use]
1296    pub fn zero_to_n_minus_one(n: i64) -> Self {
1297        Self::new(0, n - 1)
1298    }
1299
1300    /// Iterate over the range values
1301    pub fn iter(&self) -> impl Iterator<Item = i64> {
1302        self.start..=self.end
1303    }
1304}
1305
1306impl RangeType for IntRange {
1307    type IndexType = i64;
1308
1309    fn start(&self) -> Self::IndexType {
1310        self.start
1311    }
1312
1313    fn end(&self) -> Self::IndexType {
1314        self.end
1315    }
1316
1317    fn contains(&self, value: &Self::IndexType) -> bool {
1318        *value >= self.start && *value <= self.end
1319    }
1320
1321    fn len(&self) -> Self::IndexType {
1322        if self.end >= self.start {
1323            self.end - self.start + 1
1324        } else {
1325            0
1326        }
1327    }
1328
1329    fn is_empty(&self) -> bool {
1330        self.end < self.start
1331    }
1332}
1333
1334/// Floating-point range for summations
1335///
1336/// Represents ranges with floating-point bounds. Less common than integer ranges
1337/// but useful for continuous approximations or when bounds are computed values.
1338#[derive(Debug, Clone, PartialEq)]
1339pub struct FloatRange {
1340    pub start: f64,
1341    pub end: f64,
1342    pub step: f64,
1343}
1344
1345impl FloatRange {
1346    /// Create a new floating-point range with step size
1347    #[must_use]
1348    pub fn new(start: f64, end: f64, step: f64) -> Self {
1349        Self { start, end, step }
1350    }
1351
1352    /// Create a range with step size 1.0
1353    #[must_use]
1354    pub fn unit_step(start: f64, end: f64) -> Self {
1355        Self::new(start, end, 1.0)
1356    }
1357}
1358
1359impl RangeType for FloatRange {
1360    type IndexType = f64;
1361
1362    fn start(&self) -> Self::IndexType {
1363        self.start
1364    }
1365
1366    fn end(&self) -> Self::IndexType {
1367        self.end
1368    }
1369
1370    fn contains(&self, value: &Self::IndexType) -> bool {
1371        *value >= self.start && *value <= self.end
1372    }
1373
1374    fn len(&self) -> Self::IndexType {
1375        if self.end >= self.start && self.step > 0.0 {
1376            ((self.end - self.start) / self.step).floor() + 1.0
1377        } else {
1378            0.0
1379        }
1380    }
1381
1382    fn is_empty(&self) -> bool {
1383        self.end < self.start || self.step <= 0.0
1384    }
1385}
1386
1387/// Symbolic range with expression bounds
1388///
1389/// Represents ranges where the start and/or end are expressions rather than
1390/// concrete values. This enables symbolic manipulation of summation bounds.
1391///
1392/// # Examples
1393///
1394/// ```rust
1395/// use mathcompile::final_tagless::{SymbolicRange, ASTRepr};
1396///
1397/// // Range from 1 to n (where n is a variable at index 0)
1398/// let range = SymbolicRange::new(
1399///     ASTRepr::Constant(1.0),
1400///     ASTRepr::Variable(0)
1401/// );
1402/// ```
1403#[derive(Debug, Clone)]
1404pub struct SymbolicRange<T> {
1405    pub start: Box<ASTRepr<T>>,
1406    pub end: Box<ASTRepr<T>>,
1407}
1408
1409impl<T: NumericType> SymbolicRange<T> {
1410    /// Create a new symbolic range
1411    pub fn new(start: ASTRepr<T>, end: ASTRepr<T>) -> Self {
1412        Self {
1413            start: Box::new(start),
1414            end: Box::new(end),
1415        }
1416    }
1417
1418    /// Create a range from 1 to a symbolic expression
1419    pub fn one_to_expr(end: ASTRepr<T>) -> Self
1420    where
1421        T: num_traits::One,
1422    {
1423        Self::new(ASTRepr::Constant(T::one()), end)
1424    }
1425
1426    /// Evaluate the range bounds with given variable values
1427    pub fn evaluate_bounds(&self, variables: &[T]) -> Option<(T, T)>
1428    where
1429        T: Float + Copy,
1430    {
1431        let start_val = DirectEval::eval_with_vars(&self.start, variables);
1432        let end_val = DirectEval::eval_with_vars(&self.end, variables);
1433        Some((start_val, end_val))
1434    }
1435}
1436
1437// Note: SymbolicRange doesn't implement RangeType because it requires evaluation
1438// to determine concrete bounds. It's used in a different way in the summation system.
1439
1440/// Trait for function-like expressions in summations
1441///
1442/// The function must not be opaque to enable factor extraction and algebraic
1443/// manipulation. This trait provides access to the function's internal structure.
1444pub trait SummandFunction<T>: Clone + std::fmt::Debug {
1445    /// The expression representing the function body
1446    type Body: Clone;
1447
1448    /// The variable name for the summation index
1449    fn index_var(&self) -> &str;
1450
1451    /// Get the function body expression
1452    fn body(&self) -> &Self::Body;
1453
1454    /// Apply the function to a specific index value (for evaluation)
1455    fn apply(&self, index: T) -> Self::Body;
1456
1457    /// Check if the function depends on the index variable
1458    fn depends_on_index(&self) -> bool;
1459
1460    /// Extract factors that don't depend on the index variable
1461    /// Returns (`independent_factors`, `remaining_expression`)
1462    fn extract_independent_factors(&self) -> (Vec<Self::Body>, Self::Body);
1463}
1464
1465/// Concrete implementation for AST-based functions
1466///
1467/// This represents a function as an AST expression with a designated index variable.
1468/// It provides the foundation for algebraic manipulation of summands.
1469///
1470/// # Examples
1471///
1472/// ```rust
1473/// use mathcompile::final_tagless::{ASTFunction, ASTRepr};
1474///
1475/// // Function f(i) = 2*i + 3 (where i is at index 0)
1476/// let func = ASTFunction::new(
1477///     "i",
1478///     ASTRepr::Add(
1479///         Box::new(ASTRepr::Mul(
1480///             Box::new(ASTRepr::Constant(2.0)),
1481///             Box::new(ASTRepr::Variable(0))
1482///         )),
1483///         Box::new(ASTRepr::Constant(3.0))
1484///     )
1485/// );
1486/// ```
1487#[derive(Debug, Clone)]
1488pub struct ASTFunction<T> {
1489    pub index_var: String,
1490    pub body: ASTRepr<T>,
1491}
1492
1493impl<T: NumericType> ASTFunction<T> {
1494    /// Create a new AST-based function
1495    pub fn new(index_var: &str, body: ASTRepr<T>) -> Self {
1496        Self {
1497            index_var: index_var.to_string(),
1498            body,
1499        }
1500    }
1501
1502    /// Create a linear function: a*i + b
1503    pub fn linear(index_var: &str, coefficient: T, constant: T) -> Self {
1504        let body = ASTRepr::Add(
1505            Box::new(ASTRepr::Mul(
1506                Box::new(ASTRepr::Constant(coefficient)),
1507                Box::new(ASTRepr::Variable(0)), // Use index 0 for the index variable
1508            )),
1509            Box::new(ASTRepr::Constant(constant)),
1510        );
1511        Self::new(index_var, body)
1512    }
1513
1514    /// Create a power function: i^k
1515    pub fn power(index_var: &str, exponent: T) -> Self {
1516        let body = ASTRepr::Pow(
1517            Box::new(ASTRepr::Variable(0)), // Use index 0 for the index variable
1518            Box::new(ASTRepr::Constant(exponent)),
1519        );
1520        Self::new(index_var, body)
1521    }
1522
1523    /// Create a constant function (doesn't depend on index)
1524    pub fn constant_func(index_var: &str, value: T) -> Self {
1525        let body = ASTRepr::Constant(value);
1526        Self::new(index_var, body)
1527    }
1528}
1529
1530impl<T: NumericType + Float + Copy> SummandFunction<T> for ASTFunction<T> {
1531    type Body = ASTRepr<T>;
1532
1533    fn index_var(&self) -> &str {
1534        &self.index_var
1535    }
1536
1537    fn body(&self) -> &Self::Body {
1538        &self.body
1539    }
1540
1541    fn apply(&self, index: T) -> Self::Body {
1542        // Create a simple substitution - in a full implementation,
1543        // this would do proper variable substitution in the AST
1544        self.substitute_variable(&self.index_var, index)
1545    }
1546
1547    fn depends_on_index(&self) -> bool {
1548        self.contains_variable(&self.body, &self.index_var)
1549    }
1550
1551    fn extract_independent_factors(&self) -> (Vec<Self::Body>, Self::Body) {
1552        // Basic implementation - in practice, this would do sophisticated
1553        // algebraic analysis to extract factors
1554        self.extract_factors_recursive(&self.body)
1555    }
1556}
1557
1558impl<T: NumericType + Copy> ASTFunction<T> {
1559    /// Substitute a variable with a concrete value (simplified implementation)
1560    fn substitute_variable(&self, var_name: &str, value: T) -> ASTRepr<T> {
1561        self.substitute_in_expr(&self.body, var_name, value)
1562    }
1563
1564    /// Recursive variable substitution
1565    fn substitute_in_expr(&self, expr: &ASTRepr<T>, var_name: &str, value: T) -> ASTRepr<T> {
1566        match expr {
1567            ASTRepr::Constant(c) => ASTRepr::Constant(*c),
1568            ASTRepr::Variable(index) => {
1569                // Check if this variable index corresponds to our variable name
1570                let expected_index = match var_name {
1571                    "i" => 0,
1572                    "j" => 1,
1573                    "k" => 2,
1574                    "x" => 0,
1575                    "y" => 1,
1576                    "z" => 2,
1577                    _ => {
1578                        // Try to get from global registry
1579                        if let Some(idx) = get_variable_index(var_name) {
1580                            idx
1581                        } else {
1582                            // If not found, don't substitute
1583                            return expr.clone();
1584                        }
1585                    }
1586                };
1587
1588                if *index == expected_index {
1589                    ASTRepr::Constant(value)
1590                } else {
1591                    expr.clone()
1592                }
1593            }
1594            ASTRepr::Add(left, right) => ASTRepr::Add(
1595                Box::new(self.substitute_in_expr(left, var_name, value)),
1596                Box::new(self.substitute_in_expr(right, var_name, value)),
1597            ),
1598            ASTRepr::Sub(left, right) => ASTRepr::Sub(
1599                Box::new(self.substitute_in_expr(left, var_name, value)),
1600                Box::new(self.substitute_in_expr(right, var_name, value)),
1601            ),
1602            ASTRepr::Mul(left, right) => ASTRepr::Mul(
1603                Box::new(self.substitute_in_expr(left, var_name, value)),
1604                Box::new(self.substitute_in_expr(right, var_name, value)),
1605            ),
1606            ASTRepr::Div(left, right) => ASTRepr::Div(
1607                Box::new(self.substitute_in_expr(left, var_name, value)),
1608                Box::new(self.substitute_in_expr(right, var_name, value)),
1609            ),
1610            ASTRepr::Pow(base, exp) => ASTRepr::Pow(
1611                Box::new(self.substitute_in_expr(base, var_name, value)),
1612                Box::new(self.substitute_in_expr(exp, var_name, value)),
1613            ),
1614            ASTRepr::Neg(inner) => {
1615                ASTRepr::Neg(Box::new(self.substitute_in_expr(inner, var_name, value)))
1616            }
1617            ASTRepr::Ln(inner) => {
1618                ASTRepr::Ln(Box::new(self.substitute_in_expr(inner, var_name, value)))
1619            }
1620            ASTRepr::Exp(inner) => {
1621                ASTRepr::Exp(Box::new(self.substitute_in_expr(inner, var_name, value)))
1622            }
1623            ASTRepr::Sin(inner) => {
1624                ASTRepr::Sin(Box::new(self.substitute_in_expr(inner, var_name, value)))
1625            }
1626            ASTRepr::Cos(inner) => {
1627                ASTRepr::Cos(Box::new(self.substitute_in_expr(inner, var_name, value)))
1628            }
1629            ASTRepr::Sqrt(inner) => {
1630                ASTRepr::Sqrt(Box::new(self.substitute_in_expr(inner, var_name, value)))
1631            }
1632        }
1633    }
1634
1635    /// Check if an expression contains a variable by name
1636    /// Note: This only works for legacy named variables, not indexed variables
1637    fn contains_variable(&self, expr: &ASTRepr<T>, var_name: &str) -> bool {
1638        match expr {
1639            ASTRepr::Constant(_) => false,
1640            ASTRepr::Variable(index) => {
1641                // Check if this variable index corresponds to our variable name
1642                // For now, we use a simple mapping: "i" -> 0, "j" -> 1, etc.
1643                let expected_index = match var_name {
1644                    "i" => 0,
1645                    "j" => 1,
1646                    "k" => 2,
1647                    "x" => 0,
1648                    "y" => 1,
1649                    "z" => 2,
1650                    _ => {
1651                        // Try to get from global registry
1652                        if let Some(idx) = get_variable_index(var_name) {
1653                            idx
1654                        } else {
1655                            // If not found, assume it doesn't match
1656                            return false;
1657                        }
1658                    }
1659                };
1660                *index == expected_index
1661            }
1662            ASTRepr::Add(left, right)
1663            | ASTRepr::Sub(left, right)
1664            | ASTRepr::Mul(left, right)
1665            | ASTRepr::Div(left, right)
1666            | ASTRepr::Pow(left, right) => {
1667                self.contains_variable(left, var_name) || self.contains_variable(right, var_name)
1668            }
1669            ASTRepr::Neg(inner)
1670            | ASTRepr::Ln(inner)
1671            | ASTRepr::Exp(inner)
1672            | ASTRepr::Sin(inner)
1673            | ASTRepr::Cos(inner)
1674            | ASTRepr::Sqrt(inner) => self.contains_variable(inner, var_name),
1675        }
1676    }
1677
1678    /// Extract factors that don't depend on the index variable (simplified implementation)
1679    fn extract_factors_recursive(&self, expr: &ASTRepr<T>) -> (Vec<ASTRepr<T>>, ASTRepr<T>)
1680    where
1681        T: One,
1682    {
1683        match expr {
1684            // For multiplication, we can extract independent factors
1685            ASTRepr::Mul(left, right) => {
1686                let left_depends = self.contains_variable(left, &self.index_var);
1687                let right_depends = self.contains_variable(right, &self.index_var);
1688
1689                match (left_depends, right_depends) {
1690                    (false, false) => {
1691                        // Both factors are independent
1692                        (vec![expr.clone()], ASTRepr::Constant(T::one()))
1693                    }
1694                    (false, true) => {
1695                        // Left factor is independent
1696                        (vec![(**left).clone()], (**right).clone())
1697                    }
1698                    (true, false) => {
1699                        // Right factor is independent
1700                        (vec![(**right).clone()], (**left).clone())
1701                    }
1702                    (true, true) => {
1703                        // Both factors depend on index, can't extract
1704                        (vec![], expr.clone())
1705                    }
1706                }
1707            }
1708            // For other operations, basic handling
1709            _ => {
1710                if self.contains_variable(expr, &self.index_var) {
1711                    (vec![], expr.clone())
1712                } else {
1713                    (vec![expr.clone()], ASTRepr::Constant(T::one()))
1714                }
1715            }
1716        }
1717    }
1718}
1719
1720// Helper trait to provide one() method for numeric types
1721use num_traits::One;
1722
1723/// Extension trait for summation operations
1724///
1725/// This trait extends the final tagless approach to support summations with
1726/// algebraic manipulation capabilities. It provides methods for creating
1727/// various types of summations and will eventually support automatic simplification.
1728pub trait SummationExpr: MathExpr {
1729    /// Create a finite summation: Σ(i=start to end) f(i)
1730    ///
1731    /// This is the most general form of finite summation, where both the range
1732    /// and the function can be represented using any interpreter.
1733    fn sum_finite<T, R, F>(range: Self::Repr<R>, function: Self::Repr<F>) -> Self::Repr<T>
1734    where
1735        T: NumericType,
1736        R: RangeType,
1737        F: SummandFunction<T>,
1738        Self::Repr<T>: Clone;
1739
1740    /// Create an infinite summation: Σ(i=start to ∞) f(i)  
1741    ///
1742    /// For infinite summations, convergence analysis and special handling
1743    /// would be needed in a complete implementation.
1744    fn sum_infinite<T, F>(start: Self::Repr<T>, function: Self::Repr<F>) -> Self::Repr<T>
1745    where
1746        T: NumericType,
1747        F: SummandFunction<T>,
1748        Self::Repr<T>: Clone;
1749
1750    /// Create a telescoping sum for automatic simplification
1751    ///
1752    /// Telescoping sums have the special property that consecutive terms cancel,
1753    /// allowing for closed-form evaluation: Σ(f(i+1) - f(i)) = f(end+1) - f(start)
1754    fn sum_telescoping<T, F>(range: Self::Repr<IntRange>, function: Self::Repr<F>) -> Self::Repr<T>
1755    where
1756        T: NumericType,
1757        F: SummandFunction<T>;
1758
1759    /// Create a simple integer range for summations
1760    fn range_to<T: NumericType>(start: Self::Repr<T>, end: Self::Repr<T>) -> Self::Repr<IntRange>;
1761
1762    /// Create a function representation for summands
1763    fn function<T: NumericType>(index_var: &str, body: Self::Repr<T>)
1764    -> Self::Repr<ASTFunction<T>>;
1765}
1766
1767// Extension to ASTRepr to support summation operations
1768impl<T> ASTRepr<T> {
1769    /// Add summation support to the AST representation
1770    ///
1771    /// These variants would be added to the enum in a complete implementation:
1772    /// - SumFinite(Box<`ASTRepr`<IntRange>>, Box<`ASTRepr`<`ASTFunction`<T>>>)
1773    /// - SumInfinite(Box<`ASTRepr`<T>>, Box<`ASTRepr`<`ASTFunction`<T>>>)
1774    /// - SumTelescoping(Box<`ASTRepr`<IntRange>>, Box<`ASTRepr`<`ASTFunction`<T>>>)
1775    /// - Range(i64, i64)
1776    /// - Function(String, Box<`ASTRepr`<T>>)
1777
1778    /// Placeholder for future summation operation counting
1779    pub fn count_summation_operations(&self) -> usize {
1780        // This would count summation-specific operations in addition to
1781        // the basic operations already counted by count_operations()
1782        0
1783    }
1784}
1785
1786#[cfg(test)]
1787mod tests {
1788    use super::*;
1789
1790    #[test]
1791    fn test_direct_eval() {
1792        fn linear<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
1793        where
1794            E: MathExpr,
1795        {
1796            E::add(E::mul(E::constant(2.0), x), E::constant(1.0))
1797        }
1798
1799        let result = linear::<DirectEval>(DirectEval::var("x", 5.0));
1800        assert_eq!(result, 11.0); // 2*5 + 1 = 11
1801    }
1802
1803    #[test]
1804    fn test_statistical_extension() {
1805        fn logistic_expr<E: StatisticalExpr>(x: E::Repr<f64>) -> E::Repr<f64>
1806        where
1807            E: StatisticalExpr,
1808        {
1809            E::logistic(x)
1810        }
1811
1812        let result = logistic_expr::<DirectEval>(DirectEval::var("x", 0.0));
1813        assert!((result - 0.5).abs() < 1e-10); // logistic(0) = 0.5
1814    }
1815
1816    #[test]
1817    fn test_pretty_print() {
1818        fn quadratic<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
1819        where
1820            E: MathExpr,
1821            E::Repr<f64>: Clone,
1822        {
1823            let a = E::constant(2.0);
1824            let b = E::constant(3.0);
1825            let c = E::constant(1.0);
1826
1827            E::add(
1828                E::add(E::mul(a, E::pow(x.clone(), E::constant(2.0))), E::mul(b, x)),
1829                c,
1830            )
1831        }
1832
1833        let expr = quadratic::<PrettyPrint>(PrettyPrint::var("x"));
1834        assert!(expr.contains('x'));
1835        assert!(expr.contains('2'));
1836        assert!(expr.contains('3'));
1837        assert!(expr.contains('1'));
1838    }
1839
1840    #[test]
1841    fn test_horner_polynomial() {
1842        // Test polynomial: 1 + 2x + 3x^2 at x = 2
1843        // Expected: 1 + 2(2) + 3(4) = 17
1844        let coeffs = [1.0, 2.0, 3.0];
1845        let x = DirectEval::var("x", 2.0);
1846        let result = polynomial::horner::<DirectEval, f64>(&coeffs, x);
1847        assert_eq!(result, 17.0);
1848    }
1849
1850    #[test]
1851    fn test_horner_pretty_print() {
1852        let coeffs = [1.0, 2.0, 3.0];
1853        let x = PrettyPrint::var("x");
1854        let result = polynomial::horner::<PrettyPrint, f64>(&coeffs, x);
1855        assert!(result.contains('x'));
1856    }
1857
1858    #[test]
1859    fn test_polynomial_from_roots() {
1860        // Polynomial with roots at 1 and 2: (x-1)(x-2) = x^2 - 3x + 2
1861        // At x=0: (0-1)(0-2) = 2
1862        let roots = [1.0, 2.0];
1863        let x = DirectEval::var("x", 0.0);
1864        let result = polynomial::from_roots::<DirectEval, f64>(&roots, x);
1865        assert_eq!(result, 2.0);
1866
1867        // At x=3: (3-1)(3-2) = 2*1 = 2
1868        let x = DirectEval::var("x", 3.0);
1869        let result = polynomial::from_roots::<DirectEval, f64>(&roots, x);
1870        assert_eq!(result, 2.0);
1871    }
1872
1873    #[test]
1874    fn test_division_operations() {
1875        let div_1_3: f64 = DirectEval::div(DirectEval::constant(1.0), DirectEval::constant(3.0));
1876        assert!((div_1_3 - 1.0 / 3.0).abs() < 1e-10);
1877
1878        let div_10_2: f64 = DirectEval::div(DirectEval::constant(10.0), DirectEval::constant(2.0));
1879        assert!((div_10_2 - 5.0).abs() < 1e-10);
1880
1881        // Test division by one
1882        let div_by_one: f64 =
1883            DirectEval::div(DirectEval::constant(42.0), DirectEval::constant(1.0));
1884        assert!((div_by_one - 42.0).abs() < 1e-10);
1885    }
1886
1887    #[test]
1888    fn test_transcendental_functions() {
1889        // Test natural logarithm
1890        let ln_e: f64 = DirectEval::ln(DirectEval::constant(std::f64::consts::E));
1891        assert!((ln_e - 1.0).abs() < 1e-10);
1892
1893        // Test exponential
1894        let exp_1: f64 = DirectEval::exp(DirectEval::constant(1.0));
1895        assert!((exp_1 - std::f64::consts::E).abs() < 1e-10);
1896
1897        // Test square root
1898        let sqrt_4: f64 = DirectEval::sqrt(DirectEval::constant(4.0));
1899        assert!((sqrt_4 - 2.0).abs() < 1e-10);
1900
1901        // Test sine
1902        let sin_pi_2: f64 = DirectEval::sin(DirectEval::constant(std::f64::consts::PI / 2.0));
1903        assert!((sin_pi_2 - 1.0).abs() < 1e-10);
1904
1905        // Test cosine
1906        let cos_0: f64 = DirectEval::cos(DirectEval::constant(0.0));
1907        assert!((cos_0 - 1.0).abs() < 1e-10);
1908    }
1909
1910    #[test]
1911    fn test_pretty_print_basic() {
1912        // Test variable creation
1913        let var_x = PrettyPrint::var("x");
1914        assert_eq!(var_x, "x");
1915
1916        // Test constant creation
1917        let const_5 = PrettyPrint::constant::<f64>(5.0);
1918        assert_eq!(const_5, "5");
1919
1920        // Test addition
1921        let add_expr =
1922            PrettyPrint::add::<f64, f64, f64>(PrettyPrint::var("x"), PrettyPrint::constant(1.0));
1923        assert_eq!(add_expr, "(x + 1)");
1924    }
1925
1926    #[test]
1927    fn test_efficient_variable_indexing() {
1928        // Test efficient index-based variables
1929        let expr = ASTRepr::Add(
1930            Box::new(ASTRepr::Variable(0)), // x
1931            Box::new(ASTRepr::Variable(1)), // y
1932        );
1933        let result = DirectEval::eval_with_vars(&expr, &[2.0, 3.0]);
1934        assert_eq!(result, 5.0);
1935
1936        // Test multiplication with index-based variables
1937        let expr = ASTRepr::Mul(
1938            Box::new(ASTRepr::Variable(0)), // x
1939            Box::new(ASTRepr::Variable(1)), // y
1940        );
1941        let result = DirectEval::eval_with_vars(&expr, &[4.0, 5.0]);
1942        assert_eq!(result, 20.0);
1943    }
1944
1945    #[test]
1946    fn test_mixed_variable_types() {
1947        // Test using only index-based variables
1948        let expr = ASTRepr::Add(
1949            Box::new(ASTRepr::Variable(0)), // x
1950            Box::new(ASTRepr::Variable(1)), // y
1951        );
1952        let result = DirectEval::eval_with_vars(&expr, &[2.0, 3.0]);
1953        assert_eq!(result, 5.0);
1954    }
1955
1956    #[test]
1957    fn test_variable_index_access() {
1958        let expr: ASTRepr<f64> = ASTRepr::Variable(5);
1959        assert_eq!(expr.variable_index(), Some(5));
1960
1961        let expr: ASTRepr<f64> = ASTRepr::Constant(42.0);
1962        assert_eq!(expr.variable_index(), None);
1963    }
1964
1965    #[test]
1966    fn test_out_of_bounds_variable_index() {
1967        // Test behavior when variable index is out of bounds
1968        let expr = ASTRepr::Variable(10); // Index 10, but only 2 variables provided
1969        let result = DirectEval::eval_with_vars(&expr, &[1.0, 2.0]);
1970        assert_eq!(result, 0.0); // Should return zero for out-of-bounds index
1971    }
1972
1973    // ============================================================================
1974    // Summation Infrastructure Tests
1975    // ============================================================================
1976
1977    #[test]
1978    fn test_int_range() {
1979        let range = IntRange::new(1, 10);
1980        assert_eq!(range.start(), 1);
1981        assert_eq!(range.end(), 10);
1982        assert_eq!(range.len(), 10);
1983        assert!(range.contains(&5));
1984        assert!(!range.contains(&15));
1985        assert!(!range.is_empty());
1986
1987        let empty_range = IntRange::new(5, 3);
1988        assert!(empty_range.is_empty());
1989        assert_eq!(empty_range.len(), 0);
1990    }
1991
1992    #[test]
1993    fn test_float_range() {
1994        let range = FloatRange::new(1.0, 10.0, 1.0);
1995        assert_eq!(range.start(), 1.0);
1996        assert_eq!(range.end(), 10.0);
1997        assert_eq!(range.len(), 10.0);
1998        assert!(range.contains(&5.5));
1999        assert!(!range.contains(&15.0));
2000
2001        let empty_range = FloatRange::new(5.0, 3.0, 1.0);
2002        assert!(empty_range.is_empty());
2003    }
2004
2005    #[test]
2006    fn test_symbolic_range() {
2007        // Test with index-based variable (more reliable for evaluation)
2008        let range = SymbolicRange::new(
2009            ASTRepr::Constant(1.0),
2010            ASTRepr::Variable(0), // First variable in the array
2011        );
2012
2013        // Test evaluation with variable at index 0 = 10
2014        let bounds = range.evaluate_bounds(&[10.0]);
2015        assert_eq!(bounds, Some((1.0, 10.0)));
2016
2017        // Test with both bounds as variables
2018        let range2 = SymbolicRange::new(
2019            ASTRepr::Variable(0), // Start from first variable
2020            ASTRepr::Variable(1), // End at second variable
2021        );
2022
2023        let bounds2 = range2.evaluate_bounds(&[2.0, 8.0]);
2024        assert_eq!(bounds2, Some((2.0, 8.0)));
2025    }
2026
2027    #[test]
2028    fn test_ast_function_creation() {
2029        // Test linear function: 2*i + 3 (where i is at index 0)
2030        let func = ASTFunction::linear("i", 2.0, 3.0);
2031        assert_eq!(func.index_var(), "i");
2032        assert!(func.depends_on_index());
2033
2034        // Test constant function
2035        let const_func = ASTFunction::constant_func("i", 42.0);
2036        assert!(!const_func.depends_on_index());
2037    }
2038
2039    #[test]
2040    fn test_ast_function_substitution() {
2041        // Test function application: f(i) = 2*i + 3, evaluate at i = 5
2042        let func = ASTFunction::linear("i", 2.0, 3.0);
2043        let result = func.apply(5.0);
2044
2045        // The result should be a constant expression with value 13.0
2046        let evaluated = DirectEval::eval_with_vars(&result, &[]);
2047        assert_eq!(evaluated, 13.0); // 2*5 + 3 = 13
2048    }
2049
2050    #[test]
2051    fn test_ast_function_factor_extraction() {
2052        // Test factor extraction for: 3 * i (using indexed variable)
2053        let func = ASTFunction::new(
2054            "i",
2055            ASTRepr::Mul(
2056                Box::new(ASTRepr::Constant(3.0)),
2057                Box::new(ASTRepr::Variable(0)), // Use index 0 for variable i
2058            ),
2059        );
2060
2061        let (factors, remaining) = func.extract_independent_factors();
2062        assert_eq!(factors.len(), 1); // Should extract the constant factor 3
2063
2064        // Verify the extracted factor
2065        if let Some(ASTRepr::Constant(value)) = factors.first() {
2066            assert_eq!(*value, 3.0);
2067        } else {
2068            panic!("Expected constant factor");
2069        }
2070    }
2071
2072    #[test]
2073    fn test_range_convenience_methods() {
2074        let range_1_to_n = IntRange::one_to_n(10);
2075        assert_eq!(range_1_to_n.start(), 1);
2076        assert_eq!(range_1_to_n.end(), 10);
2077
2078        let range_0_to_n_minus_1 = IntRange::zero_to_n_minus_one(10);
2079        assert_eq!(range_0_to_n_minus_1.start(), 0);
2080        assert_eq!(range_0_to_n_minus_1.end(), 9);
2081    }
2082
2083    #[test]
2084    fn test_power_function() {
2085        // Test power function: i^2
2086        let func = ASTFunction::power("i", 2.0);
2087        assert!(func.depends_on_index());
2088
2089        // Test evaluation at i = 3 (should give 9)
2090        let result = func.apply(3.0);
2091        let evaluated = DirectEval::eval_with_vars(&result, &[]);
2092        assert_eq!(evaluated, 9.0);
2093    }
2094
2095    #[test]
2096    fn test_variable_registry() {
2097        // Use ExpressionBuilder instead of global registry
2098        let mut builder = ExpressionBuilder::new();
2099
2100        // Test variable registration
2101        let x_index = builder.register_variable("x");
2102        let y_index = builder.register_variable("y");
2103        let x_index_again = builder.register_variable("x"); // Should return same index
2104
2105        // Check that indices are different for different variables
2106        assert_ne!(x_index, y_index);
2107        // Check that same variable returns same index
2108        assert_eq!(x_index_again, x_index);
2109
2110        // Test lookups
2111        assert_eq!(builder.get_variable_index("x"), Some(x_index));
2112        assert_eq!(builder.get_variable_index("y"), Some(y_index));
2113        assert_eq!(builder.get_variable_index("z"), None);
2114
2115        assert_eq!(builder.get_variable_name(x_index), Some("x"));
2116        assert_eq!(builder.get_variable_name(y_index), Some("y"));
2117        // Test an index that shouldn't exist
2118        let max_index = std::cmp::max(x_index, y_index);
2119        assert_eq!(builder.get_variable_name(max_index + 10), None);
2120    }
2121
2122    #[test]
2123    fn test_named_variable_evaluation() {
2124        // Use ExpressionBuilder instead of global registry
2125        let mut builder = ExpressionBuilder::new();
2126
2127        // Create expression using string-based variables
2128        let expr = ASTRepr::Add(Box::new(builder.var("x")), Box::new(builder.var("y")));
2129
2130        // Evaluate with named variables using the builder
2131        let named_vars = vec![("x".to_string(), 3.0), ("y".to_string(), 4.0)];
2132        let result = builder.eval_with_named_vars(&expr, &named_vars);
2133        assert_eq!(result, 7.0);
2134    }
2135
2136    #[test]
2137    fn test_mixed_variable_access() {
2138        // Use ExpressionBuilder instead of global registry
2139        let mut builder = ExpressionBuilder::new();
2140
2141        // Register variables and get their indices
2142        let x_idx = builder.register_variable("x");
2143        let y_idx = builder.register_variable("y");
2144
2145        // Create expression using indices directly (performance path)
2146        let expr = ASTRepr::Mul(
2147            Box::new(ASTRepr::Variable(x_idx)),
2148            Box::new(ASTRepr::Variable(y_idx)),
2149        );
2150
2151        // Evaluate with indexed variables (fastest)
2152        let result1 = builder.eval_with_vars(&expr, &[2.0, 5.0]);
2153        assert_eq!(result1, 10.0);
2154
2155        // Evaluate with named variables (convenient)
2156        let named_vars = vec![("x".to_string(), 2.0), ("y".to_string(), 5.0)];
2157        let result2 = builder.eval_with_named_vars(&expr, &named_vars);
2158        assert_eq!(result2, 10.0);
2159    }
2160
2161    #[test]
2162    fn test_variable_registry_performance() {
2163        // Use ExpressionBuilder instead of global registry
2164        let mut builder = ExpressionBuilder::new();
2165
2166        // Get the starting state (should be empty for new builder)
2167        let start_count = builder.num_variables();
2168        assert_eq!(start_count, 0); // Should start empty
2169
2170        // Register many variables to test performance
2171        let mut indices = Vec::new();
2172        for i in 0..1000 {
2173            let var_name = format!("perf_test_var_{i}");
2174            let index = builder.register_variable(&var_name);
2175            indices.push(index);
2176            assert_eq!(index, i); // Should get sequential indices starting from 0
2177        }
2178
2179        // Test lookups are fast
2180        for i in 0..1000 {
2181            let var_name = format!("perf_test_var_{i}");
2182            let found_index = builder.get_variable_index(&var_name);
2183            assert_eq!(found_index, Some(i));
2184
2185            let found_name = builder.get_variable_name(i);
2186            assert_eq!(found_name, Some(var_name.as_str()));
2187        }
2188
2189        // Test that we have exactly 1000 variables registered
2190        let final_count = builder.num_variables();
2191        assert_eq!(final_count, 1000);
2192    }
2193
2194    #[test]
2195    fn test_generic_operator_overloading() {
2196        // Test with f64
2197        let x_f64 = ASTRepr::<f64>::Variable(0);
2198        let y_f64 = ASTRepr::<f64>::Variable(1);
2199        let const_f64 = ASTRepr::<f64>::Constant(2.5);
2200
2201        let expr_f64 = &x_f64 + &y_f64 * &const_f64;
2202        assert_eq!(expr_f64.count_operations(), 2); // one add, one mul
2203
2204        // Test with f32
2205        let x_f32 = ASTRepr::<f32>::Variable(0);
2206        let y_f32 = ASTRepr::<f32>::Variable(1);
2207        let const_f32 = ASTRepr::<f32>::Constant(2.5_f32);
2208
2209        let expr_f32 = &x_f32 + &y_f32 * &const_f32;
2210        assert_eq!(expr_f32.count_operations(), 2); // one add, one mul
2211
2212        // Test negation
2213        let neg_f64 = -&x_f64;
2214        let neg_f32 = -&x_f32;
2215
2216        match neg_f64 {
2217            ASTRepr::Neg(_) => {}
2218            _ => panic!("Expected negation"),
2219        }
2220
2221        match neg_f32 {
2222            ASTRepr::Neg(_) => {}
2223            _ => panic!("Expected negation"),
2224        }
2225
2226        // Test transcendental functions (require Float trait)
2227        let sin_f64 = x_f64.sin();
2228        let exp_f32 = x_f32.exp();
2229
2230        match sin_f64 {
2231            ASTRepr::Sin(_) => {}
2232            _ => panic!("Expected sine"),
2233        }
2234
2235        match exp_f32 {
2236            ASTRepr::Exp(_) => {}
2237            _ => panic!("Expected exponential"),
2238        }
2239    }
2240}
2241
2242/// Global variable registry for mapping between variable names and indices
2243/// This allows user-facing string-based variable access while using efficient
2244/// indices internally for performance-critical operations.
2245#[derive(Debug, Clone)]
2246pub struct VariableRegistry {
2247    /// Mapping from variable names to indices
2248    name_to_index: HashMap<String, usize>,
2249    /// Mapping from indices to variable names
2250    index_to_name: Vec<String>,
2251}
2252
2253impl VariableRegistry {
2254    /// Create a new empty variable registry
2255    #[must_use]
2256    pub fn new() -> Self {
2257        Self {
2258            name_to_index: HashMap::new(),
2259            index_to_name: Vec::new(),
2260        }
2261    }
2262
2263    /// Register a variable name and return its index
2264    /// If the variable already exists, returns its existing index
2265    pub fn register_variable(&mut self, name: &str) -> usize {
2266        if let Some(&index) = self.name_to_index.get(name) {
2267            index
2268        } else {
2269            let index = self.index_to_name.len();
2270            self.name_to_index.insert(name.to_string(), index);
2271            self.index_to_name.push(name.to_string());
2272            index
2273        }
2274    }
2275
2276    /// Get the index for a variable name
2277    #[must_use]
2278    pub fn get_index(&self, name: &str) -> Option<usize> {
2279        self.name_to_index.get(name).copied()
2280    }
2281
2282    /// Get the name for a variable index
2283    #[must_use]
2284    pub fn get_name(&self, index: usize) -> Option<&str> {
2285        self.index_to_name
2286            .get(index)
2287            .map(std::string::String::as_str)
2288    }
2289
2290    /// Get all registered variable names
2291    #[must_use]
2292    pub fn get_all_names(&self) -> &[String] {
2293        &self.index_to_name
2294    }
2295
2296    /// Get the number of registered variables
2297    #[must_use]
2298    pub fn len(&self) -> usize {
2299        self.index_to_name.len()
2300    }
2301
2302    /// Check if the registry is empty
2303    #[must_use]
2304    pub fn is_empty(&self) -> bool {
2305        self.index_to_name.is_empty()
2306    }
2307
2308    /// Clear all registered variables
2309    pub fn clear(&mut self) {
2310        self.name_to_index.clear();
2311        self.index_to_name.clear();
2312    }
2313
2314    /// Create a variable mapping for evaluation
2315    /// Maps variable names to their values for use with `eval_with_vars`
2316    #[must_use]
2317    pub fn create_variable_map(&self, values: &[(String, f64)]) -> Vec<f64> {
2318        let mut result = vec![0.0; self.len()];
2319        for (name, value) in values {
2320            if let Some(index) = self.get_index(name) {
2321                result[index] = *value;
2322            }
2323        }
2324        result
2325    }
2326
2327    /// Create a variable mapping from a slice of values in name order
2328    /// Assumes values are provided in the same order as variable registration
2329    #[must_use]
2330    pub fn create_ordered_variable_map(&self, values: &[f64]) -> Vec<f64> {
2331        let mut result = vec![0.0; self.len()];
2332        for (i, &value) in values.iter().enumerate() {
2333            if i < result.len() {
2334                result[i] = value;
2335            }
2336        }
2337        result
2338    }
2339}
2340
2341impl Default for VariableRegistry {
2342    fn default() -> Self {
2343        Self::new()
2344    }
2345}
2346
2347/// Thread-safe global variable registry
2348static GLOBAL_REGISTRY: std::sync::LazyLock<Arc<RwLock<VariableRegistry>>> =
2349    std::sync::LazyLock::new(|| Arc::new(RwLock::new(VariableRegistry::new())));
2350
2351/// Get a reference to the global variable registry
2352pub fn global_registry() -> Arc<RwLock<VariableRegistry>> {
2353    GLOBAL_REGISTRY.clone()
2354}
2355
2356/// Convenience function to register a variable globally and get its index
2357#[must_use]
2358pub fn register_variable(name: &str) -> usize {
2359    let registry = global_registry();
2360    let mut guard = registry.write().unwrap();
2361    guard.register_variable(name)
2362}
2363
2364/// Convenience function to get a variable index from the global registry
2365#[must_use]
2366pub fn get_variable_index(name: &str) -> Option<usize> {
2367    let registry = global_registry();
2368    let guard = registry.read().unwrap();
2369    guard.get_index(name)
2370}
2371
2372/// Convenience function to get a variable name from the global registry
2373#[must_use]
2374pub fn get_variable_name(index: usize) -> Option<String> {
2375    let registry = global_registry();
2376    let guard = registry.read().unwrap();
2377    guard.get_name(index).map(std::string::ToString::to_string)
2378}
2379
2380/// Convenience function to create a variable map for evaluation
2381#[must_use]
2382pub fn create_variable_map(values: &[(String, f64)]) -> Vec<f64> {
2383    let registry = global_registry();
2384    let guard = registry.read().unwrap();
2385    guard.create_variable_map(values)
2386}
2387
2388/// Clear the global variable registry (useful for testing)
2389pub fn clear_global_registry() {
2390    let registry = global_registry();
2391    let mut guard = registry.write().unwrap();
2392    guard.clear();
2393}
2394
2395/// Expression builder that maintains its own variable registry
2396/// This provides a clean API for building expressions with named variables
2397/// while using efficient indices internally.
2398#[derive(Debug, Clone)]
2399pub struct ExpressionBuilder {
2400    registry: VariableRegistry,
2401}
2402
2403impl ExpressionBuilder {
2404    /// Create a new expression builder with an empty variable registry
2405    #[must_use]
2406    pub fn new() -> Self {
2407        Self {
2408            registry: VariableRegistry::new(),
2409        }
2410    }
2411
2412    /// Register a variable and return its index
2413    pub fn register_variable(&mut self, name: &str) -> usize {
2414        self.registry.register_variable(name)
2415    }
2416
2417    /// Create a variable expression by name (registers automatically)
2418    pub fn var(&mut self, name: &str) -> ASTRepr<f64> {
2419        let index = self.register_variable(name);
2420        ASTRepr::Variable(index)
2421    }
2422
2423    /// Create a variable expression by index (for performance)
2424    #[must_use]
2425    pub fn var_by_index(&self, index: usize) -> ASTRepr<f64> {
2426        ASTRepr::Variable(index)
2427    }
2428
2429    /// Create a constant expression
2430    #[must_use]
2431    pub fn constant(&self, value: f64) -> ASTRepr<f64> {
2432        ASTRepr::Constant(value)
2433    }
2434
2435    /// Get the variable registry (for evaluation)
2436    #[must_use]
2437    pub fn registry(&self) -> &VariableRegistry {
2438        &self.registry
2439    }
2440
2441    /// Get a mutable reference to the variable registry
2442    pub fn registry_mut(&mut self) -> &mut VariableRegistry {
2443        &mut self.registry
2444    }
2445
2446    /// Evaluate an expression with named variables
2447    #[must_use]
2448    pub fn eval_with_named_vars(&self, expr: &ASTRepr<f64>, named_vars: &[(String, f64)]) -> f64 {
2449        let var_array = self.registry.create_variable_map(named_vars);
2450        DirectEval::eval_with_vars(expr, &var_array)
2451    }
2452
2453    /// Evaluate an expression with indexed variables (most efficient)
2454    #[must_use]
2455    pub fn eval_with_vars(&self, expr: &ASTRepr<f64>, variables: &[f64]) -> f64 {
2456        DirectEval::eval_with_vars(expr, variables)
2457    }
2458
2459    /// Get the number of registered variables
2460    #[must_use]
2461    pub fn num_variables(&self) -> usize {
2462        self.registry.len()
2463    }
2464
2465    /// Get all variable names in registration order
2466    #[must_use]
2467    pub fn variable_names(&self) -> &[String] {
2468        self.registry.get_all_names()
2469    }
2470
2471    /// Get the index of a variable by name
2472    #[must_use]
2473    pub fn get_variable_index(&self, name: &str) -> Option<usize> {
2474        self.registry.get_index(name)
2475    }
2476
2477    /// Get the name of a variable by index
2478    #[must_use]
2479    pub fn get_variable_name(&self, index: usize) -> Option<&str> {
2480        self.registry.get_name(index)
2481    }
2482}
2483
2484impl Default for ExpressionBuilder {
2485    fn default() -> Self {
2486        Self::new()
2487    }
2488}
2489
2490// ============================================================================
2491// Generic Operator Overloading for ASTRepr<T>
2492// ============================================================================
2493
2494/// Addition operator overloading for `ASTRepr<T>`
2495impl<T> Add for ASTRepr<T>
2496where
2497    T: NumericType + Add<Output = T>,
2498{
2499    type Output = ASTRepr<T>;
2500
2501    fn add(self, rhs: Self) -> Self::Output {
2502        ASTRepr::Add(Box::new(self), Box::new(rhs))
2503    }
2504}
2505
2506/// Addition with references
2507impl<T> Add<&ASTRepr<T>> for &ASTRepr<T>
2508where
2509    T: NumericType + Add<Output = T>,
2510{
2511    type Output = ASTRepr<T>;
2512
2513    fn add(self, rhs: &ASTRepr<T>) -> Self::Output {
2514        ASTRepr::Add(Box::new(self.clone()), Box::new(rhs.clone()))
2515    }
2516}
2517
2518/// Addition with mixed references
2519impl<T> Add<ASTRepr<T>> for &ASTRepr<T>
2520where
2521    T: NumericType + Add<Output = T>,
2522{
2523    type Output = ASTRepr<T>;
2524
2525    fn add(self, rhs: ASTRepr<T>) -> Self::Output {
2526        ASTRepr::Add(Box::new(self.clone()), Box::new(rhs))
2527    }
2528}
2529
2530impl<T> Add<&ASTRepr<T>> for ASTRepr<T>
2531where
2532    T: NumericType + Add<Output = T>,
2533{
2534    type Output = ASTRepr<T>;
2535
2536    fn add(self, rhs: &ASTRepr<T>) -> Self::Output {
2537        ASTRepr::Add(Box::new(self), Box::new(rhs.clone()))
2538    }
2539}
2540
2541/// Subtraction operator overloading for `ASTRepr<T>`
2542impl<T> Sub for ASTRepr<T>
2543where
2544    T: NumericType + Sub<Output = T>,
2545{
2546    type Output = ASTRepr<T>;
2547
2548    fn sub(self, rhs: Self) -> Self::Output {
2549        ASTRepr::Sub(Box::new(self), Box::new(rhs))
2550    }
2551}
2552
2553/// Subtraction with references
2554impl<T> Sub<&ASTRepr<T>> for &ASTRepr<T>
2555where
2556    T: NumericType + Sub<Output = T>,
2557{
2558    type Output = ASTRepr<T>;
2559
2560    fn sub(self, rhs: &ASTRepr<T>) -> Self::Output {
2561        ASTRepr::Sub(Box::new(self.clone()), Box::new(rhs.clone()))
2562    }
2563}
2564
2565/// Subtraction with mixed references
2566impl<T> Sub<ASTRepr<T>> for &ASTRepr<T>
2567where
2568    T: NumericType + Sub<Output = T>,
2569{
2570    type Output = ASTRepr<T>;
2571
2572    fn sub(self, rhs: ASTRepr<T>) -> Self::Output {
2573        ASTRepr::Sub(Box::new(self.clone()), Box::new(rhs))
2574    }
2575}
2576
2577impl<T> Sub<&ASTRepr<T>> for ASTRepr<T>
2578where
2579    T: NumericType + Sub<Output = T>,
2580{
2581    type Output = ASTRepr<T>;
2582
2583    fn sub(self, rhs: &ASTRepr<T>) -> Self::Output {
2584        ASTRepr::Sub(Box::new(self), Box::new(rhs.clone()))
2585    }
2586}
2587
2588/// Multiplication operator overloading for `ASTRepr<T>`
2589impl<T> Mul for ASTRepr<T>
2590where
2591    T: NumericType + Mul<Output = T>,
2592{
2593    type Output = ASTRepr<T>;
2594
2595    fn mul(self, rhs: Self) -> Self::Output {
2596        ASTRepr::Mul(Box::new(self), Box::new(rhs))
2597    }
2598}
2599
2600/// Multiplication with references
2601impl<T> Mul<&ASTRepr<T>> for &ASTRepr<T>
2602where
2603    T: NumericType + Mul<Output = T>,
2604{
2605    type Output = ASTRepr<T>;
2606
2607    fn mul(self, rhs: &ASTRepr<T>) -> Self::Output {
2608        ASTRepr::Mul(Box::new(self.clone()), Box::new(rhs.clone()))
2609    }
2610}
2611
2612/// Multiplication with mixed references
2613impl<T> Mul<ASTRepr<T>> for &ASTRepr<T>
2614where
2615    T: NumericType + Mul<Output = T>,
2616{
2617    type Output = ASTRepr<T>;
2618
2619    fn mul(self, rhs: ASTRepr<T>) -> Self::Output {
2620        ASTRepr::Mul(Box::new(self.clone()), Box::new(rhs))
2621    }
2622}
2623
2624impl<T> Mul<&ASTRepr<T>> for ASTRepr<T>
2625where
2626    T: NumericType + Mul<Output = T>,
2627{
2628    type Output = ASTRepr<T>;
2629
2630    fn mul(self, rhs: &ASTRepr<T>) -> Self::Output {
2631        ASTRepr::Mul(Box::new(self), Box::new(rhs.clone()))
2632    }
2633}
2634
2635/// Division operator overloading for `ASTRepr<T>`
2636impl<T> Div for ASTRepr<T>
2637where
2638    T: NumericType + Div<Output = T>,
2639{
2640    type Output = ASTRepr<T>;
2641
2642    fn div(self, rhs: Self) -> Self::Output {
2643        ASTRepr::Div(Box::new(self), Box::new(rhs))
2644    }
2645}
2646
2647/// Division with references
2648impl<T> Div<&ASTRepr<T>> for &ASTRepr<T>
2649where
2650    T: NumericType + Div<Output = T>,
2651{
2652    type Output = ASTRepr<T>;
2653
2654    fn div(self, rhs: &ASTRepr<T>) -> Self::Output {
2655        ASTRepr::Div(Box::new(self.clone()), Box::new(rhs.clone()))
2656    }
2657}
2658
2659/// Division with mixed references
2660impl<T> Div<ASTRepr<T>> for &ASTRepr<T>
2661where
2662    T: NumericType + Div<Output = T>,
2663{
2664    type Output = ASTRepr<T>;
2665
2666    fn div(self, rhs: ASTRepr<T>) -> Self::Output {
2667        ASTRepr::Div(Box::new(self.clone()), Box::new(rhs))
2668    }
2669}
2670
2671impl<T> Div<&ASTRepr<T>> for ASTRepr<T>
2672where
2673    T: NumericType + Div<Output = T>,
2674{
2675    type Output = ASTRepr<T>;
2676
2677    fn div(self, rhs: &ASTRepr<T>) -> Self::Output {
2678        ASTRepr::Div(Box::new(self), Box::new(rhs.clone()))
2679    }
2680}
2681
2682/// Negation operator overloading for `ASTRepr<T>`
2683impl<T> Neg for ASTRepr<T>
2684where
2685    T: NumericType + Neg<Output = T>,
2686{
2687    type Output = ASTRepr<T>;
2688
2689    fn neg(self) -> Self::Output {
2690        ASTRepr::Neg(Box::new(self))
2691    }
2692}
2693
2694/// Negation with references
2695impl<T> Neg for &ASTRepr<T>
2696where
2697    T: NumericType + Neg<Output = T>,
2698{
2699    type Output = ASTRepr<T>;
2700
2701    fn neg(self) -> Self::Output {
2702        ASTRepr::Neg(Box::new(self.clone()))
2703    }
2704}
2705
2706/// Additional convenience methods for `ASTRepr<T>` with generic types
2707impl<T> ASTRepr<T>
2708where
2709    T: NumericType,
2710{
2711    /// Power operation with natural syntax
2712    #[must_use]
2713    pub fn pow(self, exp: ASTRepr<T>) -> ASTRepr<T>
2714    where
2715        T: Float,
2716    {
2717        ASTRepr::Pow(Box::new(self), Box::new(exp))
2718    }
2719
2720    /// Power operation with reference
2721    #[must_use]
2722    pub fn pow_ref(&self, exp: &ASTRepr<T>) -> ASTRepr<T>
2723    where
2724        T: Float,
2725    {
2726        ASTRepr::Pow(Box::new(self.clone()), Box::new(exp.clone()))
2727    }
2728
2729    /// Natural logarithm
2730    #[must_use]
2731    pub fn ln(self) -> ASTRepr<T>
2732    where
2733        T: Float,
2734    {
2735        ASTRepr::Ln(Box::new(self))
2736    }
2737
2738    /// Natural logarithm with reference
2739    #[must_use]
2740    pub fn ln_ref(&self) -> ASTRepr<T>
2741    where
2742        T: Float,
2743    {
2744        ASTRepr::Ln(Box::new(self.clone()))
2745    }
2746
2747    /// Exponential function
2748    #[must_use]
2749    pub fn exp(self) -> ASTRepr<T>
2750    where
2751        T: Float,
2752    {
2753        ASTRepr::Exp(Box::new(self))
2754    }
2755
2756    /// Exponential function with reference
2757    #[must_use]
2758    pub fn exp_ref(&self) -> ASTRepr<T>
2759    where
2760        T: Float,
2761    {
2762        ASTRepr::Exp(Box::new(self.clone()))
2763    }
2764
2765    /// Square root
2766    #[must_use]
2767    pub fn sqrt(self) -> ASTRepr<T>
2768    where
2769        T: Float,
2770    {
2771        ASTRepr::Sqrt(Box::new(self))
2772    }
2773
2774    /// Square root with reference
2775    #[must_use]
2776    pub fn sqrt_ref(&self) -> ASTRepr<T>
2777    where
2778        T: Float,
2779    {
2780        ASTRepr::Sqrt(Box::new(self.clone()))
2781    }
2782
2783    /// Sine function
2784    #[must_use]
2785    pub fn sin(self) -> ASTRepr<T>
2786    where
2787        T: Float,
2788    {
2789        ASTRepr::Sin(Box::new(self))
2790    }
2791
2792    /// Sine function with reference
2793    #[must_use]
2794    pub fn sin_ref(&self) -> ASTRepr<T>
2795    where
2796        T: Float,
2797    {
2798        ASTRepr::Sin(Box::new(self.clone()))
2799    }
2800
2801    /// Cosine function
2802    #[must_use]
2803    pub fn cos(self) -> ASTRepr<T>
2804    where
2805        T: Float,
2806    {
2807        ASTRepr::Cos(Box::new(self))
2808    }
2809
2810    /// Cosine function with reference
2811    #[must_use]
2812    pub fn cos_ref(&self) -> ASTRepr<T>
2813    where
2814        T: Float,
2815    {
2816        ASTRepr::Cos(Box::new(self.clone()))
2817    }
2818}