Skip to main content

ries_rs/
eval.rs

1//! Expression evaluation with automatic differentiation
2//!
3//! Evaluates postfix expressions and computes derivatives using forward-mode AD.
4//!
5//! # Performance
6//!
7//! For hot loops (generation, Newton-Raphson), use `evaluate_with_workspace()` with
8//! a reusable `EvalWorkspace` to avoid heap allocations on every call.
9
10use crate::expr::Expression;
11use crate::profile::UserConstant;
12use crate::symbol::{NumType, Seft, Symbol};
13use crate::udf::{UdfOp, UserFunction};
14
15/// Result of evaluating an expression
16#[derive(Debug, Clone, Copy)]
17pub struct EvalResult {
18    /// The computed value
19    pub value: f64,
20    /// Derivative with respect to x
21    pub derivative: f64,
22    /// Numeric type of the result
23    pub num_type: NumType,
24}
25
26/// Evaluation error types
27///
28/// These errors indicate what went wrong during expression evaluation.
29/// For more detailed context, use the error message methods.
30#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
31pub enum EvalError {
32    /// Stack underflow during evaluation
33    #[error("Stack underflow: not enough operands on stack")]
34    StackUnderflow,
35    /// User constant slot referenced by the expression is not configured
36    #[error("Missing user constant: slot u{0} is not configured")]
37    MissingUserConstant(usize),
38    /// Division by zero
39    #[error("Division by zero: divisor was zero or near-zero")]
40    DivisionByZero,
41    /// Logarithm of non-positive number
42    #[error("Logarithm domain error: argument was non-positive")]
43    LogDomain,
44    /// Square root of negative number
45    #[error("Square root domain error: argument was negative")]
46    SqrtDomain,
47    /// Overflow or NaN result
48    #[error("Overflow: result is infinite or NaN")]
49    Overflow,
50    /// Invalid expression
51    #[error("Invalid expression: malformed or incomplete")]
52    Invalid,
53    /// Error with position context
54    #[error("{err} at position {pos}")]
55    WithPosition {
56        #[source]
57        err: Box<EvalError>,
58        pos: usize,
59    },
60    /// Error with value context
61    #[error("{err} (value: {val})")]
62    WithValue {
63        #[source]
64        err: Box<EvalError>,
65        val: ordered_float::OrderedFloat<f64>,
66    },
67    /// Error with expression context
68    #[error("{err} in expression '{expr}'")]
69    WithExpression {
70        #[source]
71        err: Box<EvalError>,
72        expr: String,
73    },
74}
75
76impl EvalError {
77    /// Create a detailed error message with context (backward compatibility)
78    pub fn with_context(self, position: Option<usize>, value: Option<f64>) -> Self {
79        let mut err = self;
80        if let Some(pos) = position {
81            err = EvalError::WithPosition {
82                err: Box::new(err),
83                pos,
84            };
85        }
86        if let Some(val) = value {
87            err = EvalError::WithValue {
88                err: Box::new(err),
89                val: ordered_float::OrderedFloat(val),
90            };
91        }
92        err
93    }
94
95    /// Add expression context
96    pub fn with_expression(self, expr: String) -> Self {
97        EvalError::WithExpression {
98            err: Box::new(self),
99            expr,
100        }
101    }
102}
103
104/// Mathematical constants
105pub mod constants {
106    pub const PI: f64 = std::f64::consts::PI;
107    pub const E: f64 = std::f64::consts::E;
108    pub const PHI: f64 = 1.618_033_988_749_895; // Golden ratio
109    /// Euler-Mascheroni constant γ
110    pub const GAMMA: f64 = 0.577_215_664_901_532_9;
111    /// Plastic constant ρ (root of x³ = x + 1)
112    pub const PLASTIC: f64 = 1.324_717_957_244_746;
113    /// Apéry's constant ζ(3)
114    pub const APERY: f64 = 1.202_056_903_159_594_2;
115    /// Catalan's constant G
116    pub const CATALAN: f64 = 0.915_965_594_177_219;
117}
118
119/// Default trig argument scale used by `sinpi/cospi/tanpi`.
120///
121/// This matches original `sinpi(x) = sin(πx)` semantics.
122pub const DEFAULT_TRIG_ARGUMENT_SCALE: f64 = std::f64::consts::PI;
123
124/// Explicit evaluation context for a single run.
125///
126/// This keeps trig scaling and user-defined symbols inside the function
127/// signature instead of relying on process-global evaluator state.
128#[derive(Clone, Copy, Debug)]
129pub struct EvalContext<'a> {
130    /// Argument scale for `sinpi/cospi/tanpi`.
131    pub trig_argument_scale: f64,
132    /// User-defined constants available during evaluation.
133    pub user_constants: &'a [UserConstant],
134    /// User-defined functions available during evaluation.
135    pub user_functions: &'a [UserFunction],
136}
137
138impl Default for EvalContext<'static> {
139    fn default() -> Self {
140        Self {
141            trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
142            user_constants: &[],
143            user_functions: &[],
144        }
145    }
146}
147
148impl EvalContext<'static> {
149    /// Create a default context with built-in trig semantics and no user symbols.
150    pub fn new() -> Self {
151        Self::default()
152    }
153}
154
155impl<'a> EvalContext<'a> {
156    /// Create a context from user-defined constants and functions.
157    pub fn from_slices(
158        user_constants: &'a [UserConstant],
159        user_functions: &'a [UserFunction],
160    ) -> Self {
161        Self {
162            trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
163            user_constants,
164            user_functions,
165        }
166    }
167
168    /// Override the trig argument scale for this evaluation context.
169    pub fn with_trig_argument_scale(mut self, scale: f64) -> Self {
170        if scale.is_finite() && scale != 0.0 {
171            self.trig_argument_scale = scale;
172        }
173        self
174    }
175}
176
177/// Stack entry for evaluation with derivative tracking
178#[derive(Debug, Clone, Copy)]
179struct StackEntry {
180    val: f64,
181    deriv: f64,
182    num_type: NumType,
183}
184
185impl StackEntry {
186    fn new(val: f64, deriv: f64, num_type: NumType) -> Self {
187        Self {
188            val,
189            deriv,
190            num_type,
191        }
192    }
193
194    fn constant(val: f64, num_type: NumType) -> Self {
195        Self {
196            val,
197            deriv: 0.0,
198            num_type,
199        }
200    }
201}
202
203/// Reusable workspace for expression evaluation.
204///
205/// Using a workspace avoids heap allocations on every `evaluate()` call,
206/// which is critical for performance in hot loops (generation, Newton-Raphson).
207///
208/// # Example
209///
210/// ```no_run
211/// use ries_rs::eval::{EvalWorkspace, evaluate_with_workspace};
212/// use ries_rs::expr::Expression;
213/// let mut workspace = EvalWorkspace::new();
214/// let expressions: Vec<Expression> = vec![];
215/// let x = 1.0_f64;
216/// for expr in &expressions {
217///     let result = evaluate_with_workspace(expr, x, &mut workspace)?;
218///     // workspace is reused, no new allocations
219/// }
220/// # Ok::<(), ries_rs::eval::EvalError>(())
221/// ```
222pub struct EvalWorkspace {
223    stack: Vec<StackEntry>,
224}
225
226impl EvalWorkspace {
227    /// Create a new workspace with pre-allocated capacity.
228    ///
229    /// Capacity of 32 handles most expressions; grows automatically if needed.
230    pub fn new() -> Self {
231        Self {
232            stack: Vec::with_capacity(32),
233        }
234    }
235
236    /// Clear the workspace for reuse (keeps allocated capacity).
237    #[inline]
238    fn clear(&mut self) {
239        self.stack.clear();
240    }
241}
242
243impl Default for EvalWorkspace {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249/// Evaluate an expression at a given value of x, using a reusable workspace.
250///
251/// This is the hot-path version that avoids heap allocations.
252/// Use this in loops where `evaluate()` is called many times.
253///
254/// Note: This is a convenience wrapper for the full `evaluate_with_workspace_and_constants_and_functions`
255/// when you don't need user constants or functions. It's provided as a simpler API for common cases.
256#[inline]
257pub fn evaluate_with_workspace(
258    expr: &Expression,
259    x: f64,
260    workspace: &mut EvalWorkspace,
261) -> Result<EvalResult, EvalError> {
262    evaluate_with_workspace_and_context(expr, x, workspace, &EvalContext::new())
263}
264
265/// Evaluate an expression with user constants, using a reusable workspace.
266///
267/// This is the hot-path version that avoids heap allocations.
268/// The `user_constants` slice provides values for `UserConstant0..15` symbols.
269///
270/// Note: This is a convenience wrapper for the full `evaluate_with_workspace_and_constants_and_functions`
271/// when you don't need user functions. It's provided as a simpler API for common cases.
272#[inline]
273pub fn evaluate_with_workspace_and_constants(
274    expr: &Expression,
275    x: f64,
276    workspace: &mut EvalWorkspace,
277    user_constants: &[UserConstant],
278) -> Result<EvalResult, EvalError> {
279    let context = EvalContext::from_slices(user_constants, &[]);
280    evaluate_with_workspace_and_context(expr, x, workspace, &context)
281}
282
283/// Evaluate an expression with user constants and user functions, using a reusable workspace.
284///
285/// This is the full hot-path version that avoids heap allocations.
286/// The `user_constants` slice provides values for `UserConstant0..15` symbols.
287/// The `user_functions` slice provides bodies for `UserFunction0..15` symbols.
288#[inline]
289pub fn evaluate_with_workspace_and_constants_and_functions(
290    expr: &Expression,
291    x: f64,
292    workspace: &mut EvalWorkspace,
293    user_constants: &[UserConstant],
294    user_functions: &[UserFunction],
295) -> Result<EvalResult, EvalError> {
296    let context = EvalContext::from_slices(user_constants, user_functions);
297    evaluate_with_workspace_and_context(expr, x, workspace, &context)
298}
299
300/// Evaluate an expression using an explicit evaluation context and reusable workspace.
301///
302/// This is the preferred hot-path API for library consumers that need explicit
303/// control over trig semantics or user-defined symbols.
304#[inline]
305pub fn evaluate_with_workspace_and_context(
306    expr: &Expression,
307    x: f64,
308    workspace: &mut EvalWorkspace,
309    context: &EvalContext<'_>,
310) -> Result<EvalResult, EvalError> {
311    workspace.clear();
312    let stack = &mut workspace.stack;
313
314    for &sym in expr.symbols() {
315        match sym.seft() {
316            Seft::A => {
317                let entry = eval_constant_with_user(sym, x, context.user_constants)?;
318                stack.push(entry);
319            }
320            Seft::B => {
321                // Check if this is a user function
322                if matches!(
323                    sym,
324                    Symbol::UserFunction0
325                        | Symbol::UserFunction1
326                        | Symbol::UserFunction2
327                        | Symbol::UserFunction3
328                        | Symbol::UserFunction4
329                        | Symbol::UserFunction5
330                        | Symbol::UserFunction6
331                        | Symbol::UserFunction7
332                        | Symbol::UserFunction8
333                        | Symbol::UserFunction9
334                        | Symbol::UserFunction10
335                        | Symbol::UserFunction11
336                        | Symbol::UserFunction12
337                        | Symbol::UserFunction13
338                        | Symbol::UserFunction14
339                        | Symbol::UserFunction15
340                ) {
341                    // Evaluate user function
342                    let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
343                    let result = eval_user_function(sym, a, context, x)?;
344                    stack.push(result);
345                } else {
346                    let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
347                    let result = eval_unary(sym, a, context.trig_argument_scale)?;
348                    stack.push(result);
349                }
350            }
351            Seft::C => {
352                let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
353                let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
354                let result = eval_binary(sym, a, b)?;
355                stack.push(result);
356            }
357        }
358    }
359
360    if stack.len() != 1 {
361        return Err(EvalError::Invalid);
362    }
363
364    let result = stack.pop().unwrap();
365
366    // Check for invalid results
367    if result.val.is_nan() || result.val.is_infinite() {
368        return Err(EvalError::Overflow);
369    }
370
371    Ok(EvalResult {
372        value: result.val,
373        derivative: result.deriv,
374        num_type: result.num_type,
375    })
376}
377
378/// Evaluate an expression at a given value of x.
379///
380/// Convenience wrapper that allocates a new workspace. For hot loops,
381/// prefer `evaluate_with_workspace()` with a reusable `EvalWorkspace`.
382///
383/// Note: This is a convenience API for library users. Internal code uses
384/// `evaluate_fast_with_constants_and_functions` for performance.
385pub fn evaluate(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
386    evaluate_with_context(expr, x, &EvalContext::new())
387}
388
389/// Evaluate an expression at a given value of x with user constants.
390///
391/// Convenience wrapper that allocates a new workspace.
392pub fn evaluate_with_constants(
393    expr: &Expression,
394    x: f64,
395    user_constants: &[UserConstant],
396) -> Result<EvalResult, EvalError> {
397    let context = EvalContext::from_slices(user_constants, &[]);
398    evaluate_with_context(expr, x, &context)
399}
400
401/// Evaluate an expression at a given value of x with user constants and user functions.
402///
403/// Convenience wrapper that allocates a new workspace.
404pub fn evaluate_with_constants_and_functions(
405    expr: &Expression,
406    x: f64,
407    user_constants: &[UserConstant],
408    user_functions: &[UserFunction],
409) -> Result<EvalResult, EvalError> {
410    let context = EvalContext::from_slices(user_constants, user_functions);
411    evaluate_with_context(expr, x, &context)
412}
413
414/// Evaluate an expression at a given value of x with an explicit evaluation context.
415///
416/// Convenience wrapper that allocates a new workspace.
417pub fn evaluate_with_context(
418    expr: &Expression,
419    x: f64,
420    context: &EvalContext<'_>,
421) -> Result<EvalResult, EvalError> {
422    let mut workspace = EvalWorkspace::new();
423    evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
424}
425
426/// Evaluate an expression using a thread-local workspace (zero allocations after warmup).
427///
428/// This is ideal for parallel code where each thread needs its own workspace.
429/// Note: This version does NOT support user constants. For user constants,
430/// use `evaluate_with_constants()` or `evaluate_with_workspace_and_constants()`.
431///
432/// Note: This is a convenience wrapper for the full `evaluate_fast_with_constants_and_functions`
433/// when you don't need user constants or functions. It's provided as a simpler API for common cases.
434#[inline]
435pub fn evaluate_fast(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
436    evaluate_fast_with_context(expr, x, &EvalContext::new())
437}
438
439/// Evaluate an expression using a thread-local workspace with user constants.
440///
441/// Note: This uses a global thread-local storage, so it's not safe to call recursively
442/// with different user_constants. For recursive calls, use `evaluate_with_workspace_and_constants`.
443///
444/// Note: This is a convenience wrapper for the full `evaluate_fast_with_constants_and_functions`
445/// when you don't need user functions. It's provided as a simpler API for common cases.
446#[inline]
447pub fn evaluate_fast_with_constants(
448    expr: &Expression,
449    x: f64,
450    user_constants: &[UserConstant],
451) -> Result<EvalResult, EvalError> {
452    let context = EvalContext::from_slices(user_constants, &[]);
453    evaluate_fast_with_context(expr, x, &context)
454}
455
456/// Evaluate an expression using a thread-local workspace with user constants and user functions.
457///
458/// # Thread-Local Workspace
459///
460/// This function uses a `thread_local!` static to cache an `EvalWorkspace` for each thread.
461/// The workspace is created on first use and reused for all subsequent calls from the same thread.
462/// This provides zero-allocation evaluation after the initial warmup, making it ideal for:
463///
464/// - Parallel code where each thread needs its own workspace
465/// - Hot loops where allocation overhead matters
466/// - High-throughput evaluation scenarios
467///
468/// # Limitations
469///
470/// - This uses a global thread-local storage, so it's not safe to call recursively
471///   with different `user_constants` or `user_functions`. The same workspace is shared.
472/// - For recursive calls or when user constants/functions vary per-call,
473///   use [`evaluate_with_workspace_and_constants_and_functions`] instead.
474///
475/// # Example
476///
477/// ```no_run
478/// use ries_rs::eval::evaluate_fast_with_constants_and_functions;
479/// use ries_rs::expr::Expression;
480/// let expr = Expression::new();
481/// let x = 1.0_f64;
482/// // First call allocates workspace (warmup)
483/// let result = evaluate_fast_with_constants_and_functions(&expr, x, &[], &[]);
484///
485/// // Subsequent calls reuse the same workspace (no allocations)
486/// for _ in 0..1000 {
487///     let _ = evaluate_fast_with_constants_and_functions(&expr, x, &[], &[]);
488/// }
489/// ```
490#[inline]
491pub fn evaluate_fast_with_constants_and_functions(
492    expr: &Expression,
493    x: f64,
494    user_constants: &[UserConstant],
495    user_functions: &[UserFunction],
496) -> Result<EvalResult, EvalError> {
497    let context = EvalContext::from_slices(user_constants, user_functions);
498    evaluate_fast_with_context(expr, x, &context)
499}
500
501/// Evaluate an expression using a thread-local workspace and explicit context.
502#[inline]
503pub fn evaluate_fast_with_context(
504    expr: &Expression,
505    x: f64,
506    context: &EvalContext<'_>,
507) -> Result<EvalResult, EvalError> {
508    thread_local! {
509        /// Thread-local evaluation workspace.
510        ///
511        /// Each thread gets its own workspace instance that's lazily allocated
512        /// on first use. The workspace maintains internal Vec storage that grows
513        /// as needed but is never deallocated, providing zero-allocation hot paths.
514        static WORKSPACE: std::cell::RefCell<EvalWorkspace> = std::cell::RefCell::new(EvalWorkspace::new());
515    }
516
517    WORKSPACE.with(|ws| {
518        let mut workspace = ws.borrow_mut();
519        evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
520    })
521}
522
523/// Evaluate a constant or variable symbol with user constant lookup.
524fn eval_constant_with_user(
525    sym: Symbol,
526    x: f64,
527    user_constants: &[UserConstant],
528) -> Result<StackEntry, EvalError> {
529    use Symbol::*;
530    match sym {
531        One => Ok(StackEntry::constant(1.0, NumType::Integer)),
532        Two => Ok(StackEntry::constant(2.0, NumType::Integer)),
533        Three => Ok(StackEntry::constant(3.0, NumType::Integer)),
534        Four => Ok(StackEntry::constant(4.0, NumType::Integer)),
535        Five => Ok(StackEntry::constant(5.0, NumType::Integer)),
536        Six => Ok(StackEntry::constant(6.0, NumType::Integer)),
537        Seven => Ok(StackEntry::constant(7.0, NumType::Integer)),
538        Eight => Ok(StackEntry::constant(8.0, NumType::Integer)),
539        Nine => Ok(StackEntry::constant(9.0, NumType::Integer)),
540        Pi => Ok(StackEntry::constant(constants::PI, NumType::Transcendental)),
541        E => Ok(StackEntry::constant(constants::E, NumType::Transcendental)),
542        Phi => Ok(StackEntry::constant(constants::PHI, NumType::Algebraic)),
543        // New constants
544        Gamma => Ok(StackEntry::constant(
545            constants::GAMMA,
546            NumType::Transcendental,
547        )),
548        Plastic => Ok(StackEntry::constant(constants::PLASTIC, NumType::Algebraic)),
549        Apery => Ok(StackEntry::constant(
550            constants::APERY,
551            NumType::Transcendental,
552        )),
553        Catalan => Ok(StackEntry::constant(
554            constants::CATALAN,
555            NumType::Transcendental,
556        )),
557        X => Ok(StackEntry::new(x, 1.0, NumType::Integer)), // x can be any value, including integer
558        // User constants - look up value from the user_constants slice
559        UserConstant0 | UserConstant1 | UserConstant2 | UserConstant3 | UserConstant4
560        | UserConstant5 | UserConstant6 | UserConstant7 | UserConstant8 | UserConstant9
561        | UserConstant10 | UserConstant11 | UserConstant12 | UserConstant13 | UserConstant14
562        | UserConstant15 => {
563            // Get the index from the symbol
564            let idx = sym.user_constant_index().unwrap() as usize;
565            user_constants
566                .get(idx)
567                .map(|uc| StackEntry::constant(uc.value, uc.num_type))
568                .ok_or(EvalError::MissingUserConstant(idx))
569        }
570        _ => Err(EvalError::Invalid),
571    }
572}
573
574/// Evaluate a user-defined function
575///
576/// Takes the input argument and the user_functions slice, looks up the function
577/// definition, executes the body, and returns the result.
578fn eval_user_function(
579    sym: Symbol,
580    input: StackEntry,
581    context: &EvalContext<'_>,
582    x: f64,
583) -> Result<StackEntry, EvalError> {
584    // Get the function index
585    let idx = sym.user_function_index().ok_or(EvalError::Invalid)? as usize;
586
587    // Look up the function definition
588    let udf = context.user_functions.get(idx).ok_or(EvalError::Invalid)?;
589
590    // Reuse a thread-local scratch buffer rather than allocating a fresh Vec on every
591    // call. eval_user_function is invoked in the inner generation loop (potentially
592    // millions of times at high complexity), so avoiding the heap allocation matters.
593    // UDFs do not call other UDFs, so the borrow is never re-entered.
594    thread_local! {
595        static UDF_STACK: std::cell::RefCell<Vec<StackEntry>> =
596            std::cell::RefCell::new(Vec::with_capacity(16));
597    }
598
599    UDF_STACK.with(|cell| -> Result<StackEntry, EvalError> {
600        let mut stack = cell.borrow_mut();
601        stack.clear();
602        stack.push(input);
603
604        // Execute each operation in the function body
605        for op in &udf.body {
606            match op {
607                UdfOp::Symbol(sym) => {
608                    match sym.seft() {
609                        Seft::A => {
610                            // Constant - push onto stack
611                            let entry = eval_constant_with_user(*sym, x, context.user_constants)?;
612                            stack.push(entry);
613                        }
614                        Seft::B => {
615                            // Unary operator - pop one, push result
616                            let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
617                            let result = eval_unary(*sym, a, context.trig_argument_scale)?;
618                            stack.push(result);
619                        }
620                        Seft::C => {
621                            // Binary operator - pop two, push result
622                            let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
623                            let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
624                            let result = eval_binary(*sym, a, b)?;
625                            stack.push(result);
626                        }
627                    }
628                }
629                UdfOp::Dup => {
630                    // Duplicate top of stack. Dereference immediately so the
631                    // immutable borrow ends before the mutable push.
632                    let top = *stack.last().ok_or(EvalError::StackUnderflow)?;
633                    stack.push(top);
634                }
635                UdfOp::Swap => {
636                    // Swap top two elements
637                    let len = stack.len();
638                    if len < 2 {
639                        return Err(EvalError::StackUnderflow);
640                    }
641                    stack.swap(len - 1, len - 2);
642                }
643            }
644        }
645
646        // Function should leave exactly one value on the stack
647        if stack.len() != 1 {
648            return Err(EvalError::Invalid);
649        }
650
651        let result = stack.pop().unwrap();
652
653        // Check for invalid results
654        if result.val.is_nan() || result.val.is_infinite() {
655            return Err(EvalError::Overflow);
656        }
657
658        Ok(result)
659    })
660}
661
662/// Evaluate a unary operator with derivative
663fn eval_unary(
664    sym: Symbol,
665    a: StackEntry,
666    trig_argument_scale: f64,
667) -> Result<StackEntry, EvalError> {
668    use Symbol::*;
669
670    let (val, deriv, num_type) = match sym {
671        // Negation: -a, d(-a)/dx = -da/dx
672        Neg => (-a.val, -a.deriv, a.num_type),
673
674        // Reciprocal: 1/a, d(1/a)/dx = -da/dx / a²
675        Recip => {
676            if a.val.abs() < f64::MIN_POSITIVE {
677                return Err(EvalError::DivisionByZero);
678            }
679            let val = 1.0 / a.val;
680            let deriv = -a.deriv / (a.val * a.val);
681            let num_type = if a.num_type == NumType::Integer {
682                NumType::Rational
683            } else {
684                a.num_type
685            };
686            (val, deriv, num_type)
687        }
688
689        // Square root: sqrt(a), d(sqrt(a))/dx = da/dx / (2*sqrt(a))
690        Sqrt => {
691            if a.val < 0.0 {
692                return Err(EvalError::SqrtDomain);
693            }
694            let val = a.val.sqrt();
695            let deriv = if val.abs() > f64::MIN_POSITIVE {
696                a.deriv / (2.0 * val)
697            } else {
698                0.0
699            };
700            let num_type = if a.num_type >= NumType::Constructible {
701                NumType::Constructible
702            } else {
703                a.num_type
704            };
705            (val, deriv, num_type)
706        }
707
708        // Square: a², d(a²)/dx = 2*a*da/dx
709        Square => {
710            let val = a.val * a.val;
711            let deriv = 2.0 * a.val * a.deriv;
712            (val, deriv, a.num_type)
713        }
714
715        // Natural log: ln(a), d(ln(a))/dx = da/dx / a
716        Ln => {
717            if a.val <= 0.0 {
718                return Err(EvalError::LogDomain);
719            }
720            let val = a.val.ln();
721            let deriv = a.deriv / a.val;
722            (val, deriv, NumType::Transcendental)
723        }
724
725        // Exponential: e^a, d(e^a)/dx = e^a * da/dx
726        Exp => {
727            let val = a.val.exp();
728            if val.is_infinite() {
729                return Err(EvalError::Overflow);
730            }
731            let deriv = val * a.deriv;
732            (val, deriv, NumType::Transcendental)
733        }
734
735        // sin(π*a), d(sin(πa))/dx = π*cos(πa)*da/dx
736        SinPi => {
737            let val = (trig_argument_scale * a.val).sin();
738            let deriv = trig_argument_scale * (trig_argument_scale * a.val).cos() * a.deriv;
739            (val, deriv, NumType::Transcendental)
740        }
741
742        // cos(π*a), d(cos(πa))/dx = -π*sin(πa)*da/dx
743        CosPi => {
744            let val = (trig_argument_scale * a.val).cos();
745            let deriv = -trig_argument_scale * (trig_argument_scale * a.val).sin() * a.deriv;
746            (val, deriv, NumType::Transcendental)
747        }
748
749        // tan(π*a), d(tan(πa))/dx = π*sec²(πa)*da/dx
750        TanPi => {
751            let cos_val = (trig_argument_scale * a.val).cos();
752            if cos_val.abs() < 1e-10 {
753                return Err(EvalError::Overflow);
754            }
755            let val = (trig_argument_scale * a.val).tan();
756            let deriv = trig_argument_scale * a.deriv / (cos_val * cos_val);
757            (val, deriv, NumType::Transcendental)
758        }
759
760        // Lambert W function (principal branch)
761        LambertW => {
762            let val = lambert_w(a.val)?;
763            // d(W(a))/dx = W(a) / (a * (1 + W(a))) * da/dx
764            // Special case: W'(0) = 1 (by L'Hôpital's rule, since W(x) ≈ x near 0)
765            let deriv = if a.val.abs() < 1e-10 {
766                a.deriv // W'(0) = 1
767            } else {
768                let denom = a.val * (1.0 + val);
769                if denom.abs() > f64::MIN_POSITIVE {
770                    val / denom * a.deriv
771                } else {
772                    0.0
773                }
774            };
775            (val, deriv, NumType::Transcendental)
776        }
777
778        // User functions are handled at the main evaluation loop level, not here
779        // If we reach this point, return an error
780        UserFunction0 | UserFunction1 | UserFunction2 | UserFunction3 | UserFunction4
781        | UserFunction5 | UserFunction6 | UserFunction7 | UserFunction8 | UserFunction9
782        | UserFunction10 | UserFunction11 | UserFunction12 | UserFunction13 | UserFunction14
783        | UserFunction15 => {
784            // This indicates a bug in the evaluation loop - user functions should be
785            // handled before calling eval_unary
786            return Err(EvalError::Invalid);
787        }
788
789        // Non-unary symbols should never be passed to this function
790        _ => return Err(EvalError::Invalid),
791    };
792
793    Ok(StackEntry::new(val, deriv, num_type))
794}
795
796/// Evaluate a binary operator with derivative
797fn eval_binary(sym: Symbol, a: StackEntry, b: StackEntry) -> Result<StackEntry, EvalError> {
798    use Symbol::*;
799
800    let (val, deriv, num_type) = match sym {
801        // Addition: a + b
802        Add => {
803            let val = a.val + b.val;
804            let deriv = a.deriv + b.deriv;
805            let num_type = a.num_type.combine(b.num_type);
806            (val, deriv, num_type)
807        }
808
809        // Subtraction: a - b
810        Sub => {
811            let val = a.val - b.val;
812            let deriv = a.deriv - b.deriv;
813            let num_type = a.num_type.combine(b.num_type);
814            (val, deriv, num_type)
815        }
816
817        // Multiplication: a * b, d(ab)/dx = a*db/dx + b*da/dx
818        Mul => {
819            let val = a.val * b.val;
820            let deriv = a.val * b.deriv + b.val * a.deriv;
821            let num_type = a.num_type.combine(b.num_type);
822            (val, deriv, num_type)
823        }
824
825        // Division: a / b, d(a/b)/dx = (b*da/dx - a*db/dx) / b²
826        Div => {
827            if b.val.abs() < f64::MIN_POSITIVE {
828                return Err(EvalError::DivisionByZero);
829            }
830            let val = a.val / b.val;
831            let deriv = (b.val * a.deriv - a.val * b.deriv) / (b.val * b.val);
832            let mut num_type = a.num_type.combine(b.num_type);
833            if num_type == NumType::Integer {
834                num_type = NumType::Rational;
835            }
836            (val, deriv, num_type)
837        }
838
839        // Power: a^b, d(a^b)/dx = a^b * (b*da/dx/a + ln(a)*db/dx)
840        Pow => {
841            if a.val <= 0.0 && b.val.fract() != 0.0 {
842                return Err(EvalError::SqrtDomain);
843            }
844            let val = a.val.powf(b.val);
845            if val.is_infinite() || val.is_nan() {
846                return Err(EvalError::Overflow);
847            }
848            // Guard for near-zero base to avoid numerical issues
849            let deriv = if a.val > f64::MIN_POSITIVE {
850                val * (b.val * a.deriv / a.val + a.val.ln() * b.deriv)
851            } else if a.val.abs() < f64::MIN_POSITIVE && b.val > 0.0 {
852                0.0
853            } else {
854                // Negative base, integer exponent (or near-zero base treated as 0).
855                // Full formula: val * (b * a.deriv/a + ln(a) * b.deriv).
856                // The ln(a) * b.deriv term is intentionally dropped here: ln(negative) is
857                // undefined in the reals (NaN), so it cannot contribute to Newton-Raphson.
858                // Dropping it gives 0 for the derivative w.r.t. x-in-the-exponent path,
859                // which is the correct safe fallback when x appears in the exponent of a
860                // negative base (e.g., (-2)^x is only real-valued at integer x).
861                if a.val.abs() < f64::MIN_POSITIVE {
862                    0.0
863                } else {
864                    val * b.val * a.deriv / a.val
865                }
866            };
867            let num_type = if b.num_type == NumType::Integer {
868                a.num_type
869            } else {
870                NumType::Transcendental
871            };
872            (val, deriv, num_type)
873        }
874
875        // a-th root of b: b^(1/a)
876        Root => {
877            if a.val.abs() < f64::MIN_POSITIVE {
878                return Err(EvalError::DivisionByZero);
879            }
880            let exp = 1.0 / a.val;
881
882            // For negative radicands, we need to check if the index is an odd integer
883            // Non-integer indices of negative numbers have no real value
884            if b.val < 0.0 {
885                // Check if the index is close to an integer
886                let rounded = a.val.round();
887                let is_integer = (a.val - rounded).abs() < 1e-10;
888
889                if !is_integer {
890                    // Non-integer index of negative number - no real value
891                    return Err(EvalError::SqrtDomain);
892                }
893
894                // Check if the integer is odd (odd roots of negatives are real)
895                let int_val = rounded as i64;
896                if int_val % 2 == 0 {
897                    // Even integer root of negative - no real value
898                    return Err(EvalError::SqrtDomain);
899                }
900                // Odd integer root of negative is OK
901            }
902
903            let val = if b.val < 0.0 {
904                // Odd root of negative number
905                -((-b.val).powf(exp))
906            } else {
907                b.val.powf(exp)
908            };
909            if val.is_infinite() || val.is_nan() {
910                return Err(EvalError::Overflow);
911            }
912            // d(b^(1/a))/dx = b^(1/a) * (db/dx/(a*b) - ln(b)*da/dx/a²)
913            let deriv = if b.val.abs() > f64::MIN_POSITIVE {
914                val * (b.deriv / (a.val * b.val) - b.val.abs().ln() * a.deriv / (a.val * a.val))
915            } else {
916                0.0
917            };
918            (val, deriv, NumType::Algebraic)
919        }
920
921        // Logarithm base a of b: ln(b) / ln(a)
922        Log => {
923            if a.val <= 0.0 || a.val == 1.0 || b.val <= 0.0 {
924                return Err(EvalError::LogDomain);
925            }
926            let ln_a = a.val.ln();
927            let ln_b = b.val.ln();
928            let val = ln_b / ln_a;
929            // d(log_a(b))/dx = (db/dx/(b*ln(a)) - ln(b)*da/dx/(a*ln(a)²))
930            let deriv = b.deriv / (b.val * ln_a) - ln_b * a.deriv / (a.val * ln_a * ln_a);
931            (val, deriv, NumType::Transcendental)
932        }
933
934        // atan2(a, b) = angle of point (b, a) from origin
935        Atan2 => {
936            let val = a.val.atan2(b.val);
937            // d(atan2(a,b))/dx = (b*da/dx - a*db/dx) / (a² + b²)
938            let denom = a.val * a.val + b.val * b.val;
939            let deriv = if denom.abs() > f64::MIN_POSITIVE {
940                (b.val * a.deriv - a.val * b.deriv) / denom
941            } else {
942                0.0
943            };
944            (val, deriv, NumType::Transcendental)
945        }
946
947        // Non-binary symbols should never be passed to this function
948        _ => return Err(EvalError::Invalid),
949    };
950
951    Ok(StackEntry::new(val, deriv, num_type))
952}
953
954/// Compute the Lambert W function (principal branch) using Halley's method
955///
956/// The Lambert W function satisfies W(x) * exp(W(x)) = x.
957/// This implementation handles the principal branch (W₀) for x ≥ -1/e.
958fn lambert_w(x: f64) -> Result<f64, EvalError> {
959    // Branch point: x = -1/e gives W = -1
960    const INV_E: f64 = 1.0 / std::f64::consts::E;
961    const NEG_INV_E: f64 = -INV_E; // -0.36787944117144233...
962
963    // Domain check
964    if x < NEG_INV_E {
965        return Err(EvalError::LogDomain);
966    }
967
968    // Special cases
969    if x == 0.0 {
970        return Ok(0.0); // W(0) = 0
971    }
972    if (x - NEG_INV_E).abs() < 1e-15 {
973        return Ok(-1.0); // W(-1/e) = -1
974    }
975    if x == constants::E {
976        return Ok(1.0); // W(e) = 1
977    }
978
979    // Initial guess - different approximations for different regimes
980    let mut w = if x < -0.3 {
981        // Near the branch point, use a series expansion around -1/e
982        // W(x) ≈ -1 + p - p²/3 + 11p³/72 where p = sqrt(2(ex + 1))
983        let p = (2.0 * (constants::E * x + 1.0)).sqrt();
984        -1.0 + p * (1.0 - p / 3.0 * (1.0 - 11.0 * p / 72.0))
985    } else if x < 0.25 {
986        // Near zero, use a polynomial approximation
987        // W(x) ≈ x - x² + 3x³/2 - 8x⁴/3 + ...
988        // For numerical stability, use a rational approximation
989        let x2 = x * x;
990        x * (1.0 - x + x2 * (1.5 - 2.6667 * x))
991    } else if x < 4.0 {
992        // Moderate range: use log-based approximation
993        // W(x) ≈ ln(x) - ln(ln(x)) + ln(ln(x))/ln(x)
994        let lnx = x.ln();
995        if lnx > 0.0 {
996            let lnlnx = lnx.ln().max(0.0);
997            lnx - lnlnx + lnlnx / lnx.max(1.0)
998        } else {
999            x // fallback for x near 1
1000        }
1001    } else {
1002        // Large x: W(x) ≈ ln(x) - ln(ln(x)) + ln(ln(x))/ln(x)
1003        let l1 = x.ln();
1004        let l2 = l1.ln();
1005        l1 - l2 + l2 / l1
1006    };
1007
1008    // Halley's method iteration
1009    // For well-chosen initial guesses, 10-15 iterations are usually enough
1010    for _ in 0..25 {
1011        let ew = w.exp();
1012
1013        // Handle potential overflow
1014        if !ew.is_finite() {
1015            // Back off to a more stable approach
1016            w = x.ln() - w.ln().max(1e-10);
1017            continue;
1018        }
1019
1020        let wew = w * ew;
1021        let diff = wew - x;
1022
1023        // Convergence check with relative tolerance
1024        let tol = 1e-15 * (1.0 + w.abs().max(x.abs()));
1025        if diff.abs() < tol {
1026            break;
1027        }
1028
1029        let w1 = w + 1.0;
1030        // Halley's correction
1031        let denom = ew * w1 - 0.5 * (w + 2.0) * diff / w1;
1032        if denom.abs() < f64::MIN_POSITIVE {
1033            break;
1034        }
1035
1036        let delta = diff / denom;
1037
1038        // Damping for stability near branch point
1039        let correction = if w < -0.5 && delta.abs() > 0.5 {
1040            delta * 0.5 // Damped update near branch point
1041        } else {
1042            delta
1043        };
1044
1045        w -= correction;
1046    }
1047
1048    // Final validation
1049    if !w.is_finite() {
1050        return Err(EvalError::Overflow);
1051    }
1052
1053    Ok(w)
1054}
1055
1056#[cfg(test)]
1057mod tests {
1058    use super::*;
1059
1060    fn approx_eq(a: f64, b: f64) -> bool {
1061        (a - b).abs() < 1e-10
1062    }
1063
1064    #[test]
1065    fn test_basic_eval() {
1066        let expr = Expression::parse("32+").unwrap();
1067        let result = evaluate(&expr, 0.0).unwrap();
1068        assert!(approx_eq(result.value, 5.0));
1069        assert!(approx_eq(result.derivative, 0.0));
1070    }
1071
1072    #[test]
1073    fn test_variable() {
1074        let expr = Expression::parse("x").unwrap();
1075        let result = evaluate(&expr, 3.5).unwrap();
1076        assert!(approx_eq(result.value, 3.5));
1077        assert!(approx_eq(result.derivative, 1.0));
1078    }
1079
1080    #[test]
1081    fn test_x_squared() {
1082        let expr = Expression::parse("xs").unwrap(); // x^2
1083        let result = evaluate(&expr, 3.0).unwrap();
1084        assert!(approx_eq(result.value, 9.0));
1085        assert!(approx_eq(result.derivative, 6.0)); // 2x
1086    }
1087
1088    #[test]
1089    fn test_sqrt_pi() {
1090        let expr = Expression::parse("pq").unwrap(); // sqrt(pi)
1091        let result = evaluate(&expr, 0.0).unwrap();
1092        assert!(approx_eq(result.value, constants::PI.sqrt()));
1093    }
1094
1095    #[test]
1096    fn test_e_to_x() {
1097        let expr = Expression::parse("xE").unwrap(); // e^x
1098        let result = evaluate(&expr, 1.0).unwrap();
1099        assert!(approx_eq(result.value, constants::E));
1100        assert!(approx_eq(result.derivative, constants::E)); // d(e^x)/dx = e^x
1101    }
1102
1103    #[test]
1104    fn test_complex_expr() {
1105        // x^2 + 2*x + 1 = (x+1)^2
1106        let expr = Expression::parse("xs2x*+1+").unwrap();
1107        let result = evaluate(&expr, 3.0).unwrap();
1108        assert!(approx_eq(result.value, 16.0)); // (3+1)^2
1109        assert!(approx_eq(result.derivative, 8.0)); // 2x + 2 = 8
1110    }
1111
1112    #[test]
1113    fn test_lambert_w() {
1114        // W(1) ≈ 0.5671432904
1115        let w = lambert_w(1.0).unwrap();
1116        assert!((w - 0.5671432904).abs() < 1e-9);
1117
1118        // W(e) = 1
1119        let w = lambert_w(constants::E).unwrap();
1120        assert!((w - 1.0).abs() < 1e-10);
1121    }
1122
1123    #[test]
1124    fn test_user_constant_evaluation() {
1125        use crate::profile::UserConstant;
1126
1127        // Create a user constant (Euler-Mascheroni gamma ≈ 0.57721)
1128        let user_constants = vec![UserConstant {
1129            weight: 8,
1130            name: "g".to_string(),
1131            description: "gamma".to_string(),
1132            value: 0.5772156649,
1133            num_type: NumType::Transcendental,
1134        }];
1135
1136        // Create expression with UserConstant0 (byte 128)
1137        let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
1138
1139        // Evaluate with user constants
1140        let result = evaluate_with_constants(&expr, 0.0, &user_constants).unwrap();
1141
1142        // Should match the user constant value
1143        assert!(approx_eq(result.value, 0.5772156649));
1144        // Derivative should be 0 (it's a constant)
1145        assert!(approx_eq(result.derivative, 0.0));
1146    }
1147
1148    #[test]
1149    fn test_user_constant_in_expression() {
1150        use crate::profile::UserConstant;
1151
1152        // Create two user constants
1153        let user_constants = vec![
1154            UserConstant {
1155                weight: 8,
1156                name: "a".to_string(),
1157                description: "constant a".to_string(),
1158                value: 2.0,
1159                num_type: NumType::Integer,
1160            },
1161            UserConstant {
1162                weight: 8,
1163                name: "b".to_string(),
1164                description: "constant b".to_string(),
1165                value: 3.0,
1166                num_type: NumType::Integer,
1167            },
1168        ];
1169
1170        // Create expression: u0 * x + u1 (in postfix: u0 x * u1 +)
1171        let expr = Expression::from_symbols(&[
1172            Symbol::UserConstant0,
1173            Symbol::X,
1174            Symbol::Mul,
1175            Symbol::UserConstant1,
1176            Symbol::Add,
1177        ]);
1178
1179        // At x=4, should be 2*4 + 3 = 11
1180        let result = evaluate_with_constants(&expr, 4.0, &user_constants).unwrap();
1181        assert!(approx_eq(result.value, 11.0));
1182        // Derivative should be 2 (from u0 * x)
1183        assert!(approx_eq(result.derivative, 2.0));
1184    }
1185
1186    #[test]
1187    fn test_user_constant_missing_returns_error() {
1188        // Missing user constant slots must fail explicitly instead of silently
1189        // changing the expression's meaning.
1190        let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
1191
1192        let result = evaluate_with_constants(&expr, 0.0, &[]);
1193        assert!(matches!(result, Err(EvalError::MissingUserConstant(0))));
1194    }
1195
1196    #[test]
1197    fn test_user_function_sinh() {
1198        use crate::udf::UserFunction;
1199
1200        // sinh(x) = (e^x - e^-x) / 2
1201        // In postfix: E|r-2/ (exp, dup, recip, subtract, 2, divide)
1202        let user_functions = vec![UserFunction::parse("4:sinh:hyperbolic sine:E|r-2/").unwrap()];
1203
1204        // Create expression: sinh(x) (in postfix: xF0 where F0 = UserFunction0)
1205        let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1206
1207        // sinh(1) = (e - e^-1) / 2 ≈ 1.1752
1208        let result =
1209            evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
1210        let expected = (constants::E - 1.0 / constants::E) / 2.0;
1211        assert!(approx_eq(result.value, expected));
1212
1213        // Derivative: d(sinh(x))/dx = cosh(x) = (e^x + e^-x) / 2
1214        let expected_deriv = (constants::E + 1.0 / constants::E) / 2.0;
1215        assert!((result.derivative - expected_deriv).abs() < 1e-10);
1216    }
1217
1218    #[test]
1219    fn test_user_function_xex() {
1220        use crate::udf::UserFunction;
1221
1222        // XeX(x) = x * e^x
1223        // In postfix: |E* (dup, exp, multiply)
1224        let user_functions = vec![UserFunction::parse("4:XeX:x*exp(x):|E*").unwrap()];
1225
1226        // Create expression: XeX(x) (in postfix: xF0 where F0 = UserFunction0)
1227        let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1228
1229        // XeX(1) = 1 * e^1 = e
1230        let result =
1231            evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
1232        assert!(approx_eq(result.value, constants::E));
1233
1234        // Derivative: d(x*e^x)/dx = e^x + x*e^x = e^x * (1 + x) = e * 2
1235        let expected_deriv = constants::E * 2.0;
1236        assert!((result.derivative - expected_deriv).abs() < 1e-10);
1237    }
1238
1239    #[test]
1240    fn test_user_function_missing_returns_error() {
1241        // When no user functions are provided, user function evaluation should fail
1242        let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1243
1244        let result = evaluate_with_constants_and_functions(&expr, 1.0, &[], &[]);
1245        assert!(result.is_err());
1246    }
1247}