Skip to main content

ries_rs/
eval.rs

1//! Expression evaluation with automatic differentiation
2//!
3//! Evaluates postfix expressions and computes derivatives using forward-mode AD.
4//!
5//! # Performance
6//!
7//! For hot loops (generation, Newton-Raphson), use `evaluate_with_workspace()` with
8//! a reusable `EvalWorkspace` to avoid heap allocations on every call.
9
10use crate::expr::Expression;
11use crate::profile::UserConstant;
12use crate::symbol::{NumType, Seft, Symbol};
13use crate::udf::{UdfOp, UserFunction};
14
15/// Result of evaluating an expression
16#[derive(Debug, Clone, Copy)]
17pub struct EvalResult {
18    /// The computed value
19    pub value: f64,
20    /// Derivative with respect to x
21    pub derivative: f64,
22    /// Numeric type of the result
23    pub num_type: NumType,
24}
25
26/// Evaluation error types
27///
28/// These errors indicate what went wrong during expression evaluation.
29/// For more detailed context, use the error message methods.
30#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
31pub enum EvalError {
32    /// Stack underflow during evaluation
33    #[error("Stack underflow: not enough operands on stack")]
34    StackUnderflow,
35    /// User constant slot referenced by the expression is not configured
36    #[error("Missing user constant: slot u{0} is not configured")]
37    MissingUserConstant(usize),
38    /// Division by zero
39    #[error("Division by zero: divisor was zero or near-zero")]
40    DivisionByZero,
41    /// Logarithm of non-positive number
42    #[error("Logarithm domain error: argument was non-positive")]
43    LogDomain,
44    /// Square root of negative number
45    #[error("Square root domain error: argument was negative")]
46    SqrtDomain,
47    /// Overflow or NaN result
48    #[error("Overflow: result is infinite or NaN")]
49    Overflow,
50    /// Invalid expression
51    #[error("Invalid expression: malformed or incomplete")]
52    Invalid,
53    /// Error with position context
54    #[error("{err} at position {pos}")]
55    WithPosition {
56        #[source]
57        err: Box<EvalError>,
58        pos: usize,
59    },
60    /// Error with value context
61    #[error("{err} (value: {val})")]
62    WithValue {
63        #[source]
64        err: Box<EvalError>,
65        val: ordered_float::OrderedFloat<f64>,
66    },
67    /// Error with expression context
68    #[error("{err} in expression '{expr}'")]
69    WithExpression {
70        #[source]
71        err: Box<EvalError>,
72        expr: String,
73    },
74}
75
76impl EvalError {
77    /// Create a detailed error message with context (backward compatibility)
78    pub fn with_context(self, position: Option<usize>, value: Option<f64>) -> Self {
79        let mut err = self;
80        if let Some(pos) = position {
81            err = EvalError::WithPosition {
82                err: Box::new(err),
83                pos,
84            };
85        }
86        if let Some(val) = value {
87            err = EvalError::WithValue {
88                err: Box::new(err),
89                val: ordered_float::OrderedFloat(val),
90            };
91        }
92        err
93    }
94
95    /// Add expression context
96    pub fn with_expression(self, expr: String) -> Self {
97        EvalError::WithExpression {
98            err: Box::new(self),
99            expr,
100        }
101    }
102}
103
104/// Mathematical constants
105pub mod constants {
106    pub const PI: f64 = std::f64::consts::PI;
107    pub const E: f64 = std::f64::consts::E;
108    pub const PHI: f64 = 1.618_033_988_749_895; // Golden ratio
109    /// Euler-Mascheroni constant γ
110    pub const GAMMA: f64 = 0.577_215_664_901_532_9;
111    /// Plastic constant ρ (root of x³ = x + 1)
112    pub const PLASTIC: f64 = 1.324_717_957_244_746;
113    /// Apéry's constant ζ(3)
114    pub const APERY: f64 = 1.202_056_903_159_594_2;
115    /// Catalan's constant G
116    pub const CATALAN: f64 = 0.915_965_594_177_219;
117}
118
119/// Default trig argument scale used by `sinpi/cospi/tanpi`.
120///
121/// This matches original `sinpi(x) = sin(πx)` semantics.
122pub const DEFAULT_TRIG_ARGUMENT_SCALE: f64 = std::f64::consts::PI;
123
124/// Explicit evaluation context for a single run.
125///
126/// This keeps trig scaling and user-defined symbols inside the function
127/// signature instead of relying on process-global evaluator state.
128#[derive(Clone, Copy, Debug)]
129pub struct EvalContext<'a> {
130    /// Argument scale for `sinpi/cospi/tanpi`.
131    pub trig_argument_scale: f64,
132    /// User-defined constants available during evaluation.
133    pub user_constants: &'a [UserConstant],
134    /// User-defined functions available during evaluation.
135    pub user_functions: &'a [UserFunction],
136}
137
138impl Default for EvalContext<'static> {
139    fn default() -> Self {
140        Self {
141            trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
142            user_constants: &[],
143            user_functions: &[],
144        }
145    }
146}
147
148impl EvalContext<'static> {
149    /// Create a default context with built-in trig semantics and no user symbols.
150    pub fn new() -> Self {
151        Self::default()
152    }
153}
154
155impl<'a> EvalContext<'a> {
156    /// Create a context from user-defined constants and functions.
157    pub fn from_slices(
158        user_constants: &'a [UserConstant],
159        user_functions: &'a [UserFunction],
160    ) -> Self {
161        Self {
162            trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
163            user_constants,
164            user_functions,
165        }
166    }
167
168    /// Override the trig argument scale for this evaluation context.
169    pub fn with_trig_argument_scale(mut self, scale: f64) -> Self {
170        if scale.is_finite() && scale != 0.0 {
171            self.trig_argument_scale = scale;
172        }
173        self
174    }
175}
176
177/// Stack entry for evaluation with derivative tracking
178#[derive(Debug, Clone, Copy)]
179struct StackEntry {
180    val: f64,
181    deriv: f64,
182    num_type: NumType,
183}
184
185impl StackEntry {
186    fn new(val: f64, deriv: f64, num_type: NumType) -> Self {
187        Self {
188            val,
189            deriv,
190            num_type,
191        }
192    }
193
194    fn constant(val: f64, num_type: NumType) -> Self {
195        Self {
196            val,
197            deriv: 0.0,
198            num_type,
199        }
200    }
201}
202
203/// Reusable workspace for expression evaluation.
204///
205/// Using a workspace avoids heap allocations on every `evaluate()` call,
206/// which is critical for performance in hot loops (generation, Newton-Raphson).
207///
208/// # Example
209///
210/// ```no_run
211/// use ries_rs::eval::{EvalWorkspace, evaluate_with_workspace};
212/// use ries_rs::expr::Expression;
213/// let mut workspace = EvalWorkspace::new();
214/// let expressions: Vec<Expression> = vec![];
215/// let x = 1.0_f64;
216/// for expr in &expressions {
217///     let result = evaluate_with_workspace(expr, x, &mut workspace)?;
218///     // workspace is reused, no new allocations
219/// }
220/// # Ok::<(), ries_rs::eval::EvalError>(())
221/// ```
222pub struct EvalWorkspace {
223    stack: Vec<StackEntry>,
224}
225
226impl EvalWorkspace {
227    /// Create a new workspace with pre-allocated capacity.
228    ///
229    /// Capacity of 32 handles most expressions; grows automatically if needed.
230    pub fn new() -> Self {
231        Self {
232            stack: Vec::with_capacity(32),
233        }
234    }
235
236    /// Clear the workspace for reuse (keeps allocated capacity).
237    #[inline]
238    fn clear(&mut self) {
239        self.stack.clear();
240    }
241}
242
243impl Default for EvalWorkspace {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249/// Evaluate an expression at a given value of x, using a reusable workspace.
250///
251/// This is the hot-path version that avoids heap allocations.
252/// Use this in loops where `evaluate()` is called many times.
253///
254/// Note: This is a convenience wrapper for the full `evaluate_with_workspace_and_constants_and_functions`
255/// when you don't need user constants or functions. It's provided as a simpler API for common cases.
256#[inline]
257pub fn evaluate_with_workspace(
258    expr: &Expression,
259    x: f64,
260    workspace: &mut EvalWorkspace,
261) -> Result<EvalResult, EvalError> {
262    evaluate_with_workspace_and_context(expr, x, workspace, &EvalContext::new())
263}
264
265/// Evaluate an expression with user constants, using a reusable workspace.
266///
267/// This is the hot-path version that avoids heap allocations.
268/// The `user_constants` slice provides values for `UserConstant0..15` symbols.
269///
270/// Note: This is a convenience wrapper for the full `evaluate_with_workspace_and_constants_and_functions`
271/// when you don't need user functions. It's provided as a simpler API for common cases.
272#[inline]
273pub fn evaluate_with_workspace_and_constants(
274    expr: &Expression,
275    x: f64,
276    workspace: &mut EvalWorkspace,
277    user_constants: &[UserConstant],
278) -> Result<EvalResult, EvalError> {
279    let context = EvalContext::from_slices(user_constants, &[]);
280    evaluate_with_workspace_and_context(expr, x, workspace, &context)
281}
282
283/// Evaluate an expression with user constants and user functions, using a reusable workspace.
284///
285/// This is the full hot-path version that avoids heap allocations.
286/// The `user_constants` slice provides values for `UserConstant0..15` symbols.
287/// The `user_functions` slice provides bodies for `UserFunction0..15` symbols.
288#[inline]
289pub fn evaluate_with_workspace_and_constants_and_functions(
290    expr: &Expression,
291    x: f64,
292    workspace: &mut EvalWorkspace,
293    user_constants: &[UserConstant],
294    user_functions: &[UserFunction],
295) -> Result<EvalResult, EvalError> {
296    let context = EvalContext::from_slices(user_constants, user_functions);
297    evaluate_with_workspace_and_context(expr, x, workspace, &context)
298}
299
300/// Evaluate an expression using an explicit evaluation context and reusable workspace.
301///
302/// This is the preferred hot-path API for library consumers that need explicit
303/// control over trig semantics or user-defined symbols.
304#[inline]
305pub fn evaluate_with_workspace_and_context(
306    expr: &Expression,
307    x: f64,
308    workspace: &mut EvalWorkspace,
309    context: &EvalContext<'_>,
310) -> Result<EvalResult, EvalError> {
311    workspace.clear();
312    let stack = &mut workspace.stack;
313
314    for &sym in expr.symbols() {
315        match sym.seft() {
316            Seft::A => {
317                let entry = eval_constant_with_user(sym, x, context.user_constants)?;
318                stack.push(entry);
319            }
320            Seft::B => {
321                // Check if this is a user function
322                if matches!(
323                    sym,
324                    Symbol::UserFunction0
325                        | Symbol::UserFunction1
326                        | Symbol::UserFunction2
327                        | Symbol::UserFunction3
328                        | Symbol::UserFunction4
329                        | Symbol::UserFunction5
330                        | Symbol::UserFunction6
331                        | Symbol::UserFunction7
332                        | Symbol::UserFunction8
333                        | Symbol::UserFunction9
334                        | Symbol::UserFunction10
335                        | Symbol::UserFunction11
336                        | Symbol::UserFunction12
337                        | Symbol::UserFunction13
338                        | Symbol::UserFunction14
339                        | Symbol::UserFunction15
340                ) {
341                    // Evaluate user function
342                    let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
343                    let result = eval_user_function(sym, a, context, x)?;
344                    stack.push(result);
345                } else {
346                    let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
347                    let result = eval_unary(sym, a, context.trig_argument_scale)?;
348                    stack.push(result);
349                }
350            }
351            Seft::C => {
352                let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
353                let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
354                let result = eval_binary(sym, a, b)?;
355                stack.push(result);
356            }
357        }
358    }
359
360    if stack.len() != 1 {
361        return Err(EvalError::Invalid);
362    }
363
364    // SAFETY: len == 1 check above guarantees pop succeeds
365    let result = stack.pop().unwrap();
366
367    // Check for invalid results
368    if result.val.is_nan() || result.val.is_infinite() {
369        return Err(EvalError::Overflow);
370    }
371
372    Ok(EvalResult {
373        value: result.val,
374        derivative: result.deriv,
375        num_type: result.num_type,
376    })
377}
378
379/// Evaluate an expression at a given value of x.
380///
381/// Convenience wrapper that allocates a new workspace. For hot loops,
382/// prefer `evaluate_with_workspace()` with a reusable `EvalWorkspace`.
383///
384/// Note: This is a convenience API for library users. Internal code uses
385/// `evaluate_fast_with_constants_and_functions` for performance.
386pub fn evaluate(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
387    evaluate_with_context(expr, x, &EvalContext::new())
388}
389
390/// Evaluate an expression at a given value of x with user constants.
391///
392/// Convenience wrapper that allocates a new workspace.
393pub fn evaluate_with_constants(
394    expr: &Expression,
395    x: f64,
396    user_constants: &[UserConstant],
397) -> Result<EvalResult, EvalError> {
398    let context = EvalContext::from_slices(user_constants, &[]);
399    evaluate_with_context(expr, x, &context)
400}
401
402/// Evaluate an expression at a given value of x with user constants and user functions.
403///
404/// Convenience wrapper that allocates a new workspace.
405pub fn evaluate_with_constants_and_functions(
406    expr: &Expression,
407    x: f64,
408    user_constants: &[UserConstant],
409    user_functions: &[UserFunction],
410) -> Result<EvalResult, EvalError> {
411    let context = EvalContext::from_slices(user_constants, user_functions);
412    evaluate_with_context(expr, x, &context)
413}
414
415/// Evaluate an expression at a given value of x with an explicit evaluation context.
416///
417/// Convenience wrapper that allocates a new workspace.
418pub fn evaluate_with_context(
419    expr: &Expression,
420    x: f64,
421    context: &EvalContext<'_>,
422) -> Result<EvalResult, EvalError> {
423    let mut workspace = EvalWorkspace::new();
424    evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
425}
426
427/// Evaluate an expression using a thread-local workspace (zero allocations after warmup).
428///
429/// This is ideal for parallel code where each thread needs its own workspace.
430/// Note: This version does NOT support user constants. For user constants,
431/// use `evaluate_with_constants()` or `evaluate_with_workspace_and_constants()`.
432///
433/// Note: This is a convenience wrapper for the full `evaluate_fast_with_constants_and_functions`
434/// when you don't need user constants or functions. It's provided as a simpler API for common cases.
435#[inline]
436pub fn evaluate_fast(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
437    evaluate_fast_with_context(expr, x, &EvalContext::new())
438}
439
440/// Evaluate an expression using a thread-local workspace with user constants.
441///
442/// Note: This uses a global thread-local storage, so it's not safe to call recursively
443/// with different user_constants. For recursive calls, use `evaluate_with_workspace_and_constants`.
444///
445/// Note: This is a convenience wrapper for the full `evaluate_fast_with_constants_and_functions`
446/// when you don't need user functions. It's provided as a simpler API for common cases.
447#[inline]
448pub fn evaluate_fast_with_constants(
449    expr: &Expression,
450    x: f64,
451    user_constants: &[UserConstant],
452) -> Result<EvalResult, EvalError> {
453    let context = EvalContext::from_slices(user_constants, &[]);
454    evaluate_fast_with_context(expr, x, &context)
455}
456
457/// Evaluate an expression using a thread-local workspace with user constants and user functions.
458///
459/// # Thread-Local Workspace
460///
461/// This function uses a `thread_local!` static to cache an `EvalWorkspace` for each thread.
462/// The workspace is created on first use and reused for all subsequent calls from the same thread.
463/// This provides zero-allocation evaluation after the initial warmup, making it ideal for:
464///
465/// - Parallel code where each thread needs its own workspace
466/// - Hot loops where allocation overhead matters
467/// - High-throughput evaluation scenarios
468///
469/// # Limitations
470///
471/// - This uses a global thread-local storage, so it's not safe to call recursively
472///   with different `user_constants` or `user_functions`. The same workspace is shared.
473/// - For recursive calls or when user constants/functions vary per-call,
474///   use [`evaluate_with_workspace_and_constants_and_functions`] instead.
475///
476/// # Example
477///
478/// ```no_run
479/// use ries_rs::eval::evaluate_fast_with_constants_and_functions;
480/// use ries_rs::expr::Expression;
481/// let expr = Expression::new();
482/// let x = 1.0_f64;
483/// // First call allocates workspace (warmup)
484/// let result = evaluate_fast_with_constants_and_functions(&expr, x, &[], &[]);
485///
486/// // Subsequent calls reuse the same workspace (no allocations)
487/// for _ in 0..1000 {
488///     let _ = evaluate_fast_with_constants_and_functions(&expr, x, &[], &[]);
489/// }
490/// ```
491#[inline]
492pub fn evaluate_fast_with_constants_and_functions(
493    expr: &Expression,
494    x: f64,
495    user_constants: &[UserConstant],
496    user_functions: &[UserFunction],
497) -> Result<EvalResult, EvalError> {
498    let context = EvalContext::from_slices(user_constants, user_functions);
499    evaluate_fast_with_context(expr, x, &context)
500}
501
502/// Evaluate an expression using a thread-local workspace and explicit context.
503#[inline]
504pub fn evaluate_fast_with_context(
505    expr: &Expression,
506    x: f64,
507    context: &EvalContext<'_>,
508) -> Result<EvalResult, EvalError> {
509    thread_local! {
510        /// Thread-local evaluation workspace.
511        ///
512        /// Each thread gets its own workspace instance that's lazily allocated
513        /// on first use. The workspace maintains internal Vec storage that grows
514        /// as needed but is never deallocated, providing zero-allocation hot paths.
515        static WORKSPACE: std::cell::RefCell<EvalWorkspace> = std::cell::RefCell::new(EvalWorkspace::new());
516    }
517
518    WORKSPACE.with(|ws| {
519        let mut workspace = ws.borrow_mut();
520        evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
521    })
522}
523
524/// Evaluate a constant or variable symbol with user constant lookup.
525fn eval_constant_with_user(
526    sym: Symbol,
527    x: f64,
528    user_constants: &[UserConstant],
529) -> Result<StackEntry, EvalError> {
530    use Symbol::*;
531    match sym {
532        One => Ok(StackEntry::constant(1.0, NumType::Integer)),
533        Two => Ok(StackEntry::constant(2.0, NumType::Integer)),
534        Three => Ok(StackEntry::constant(3.0, NumType::Integer)),
535        Four => Ok(StackEntry::constant(4.0, NumType::Integer)),
536        Five => Ok(StackEntry::constant(5.0, NumType::Integer)),
537        Six => Ok(StackEntry::constant(6.0, NumType::Integer)),
538        Seven => Ok(StackEntry::constant(7.0, NumType::Integer)),
539        Eight => Ok(StackEntry::constant(8.0, NumType::Integer)),
540        Nine => Ok(StackEntry::constant(9.0, NumType::Integer)),
541        Pi => Ok(StackEntry::constant(constants::PI, NumType::Transcendental)),
542        E => Ok(StackEntry::constant(constants::E, NumType::Transcendental)),
543        Phi => Ok(StackEntry::constant(constants::PHI, NumType::Algebraic)),
544        // New constants
545        Gamma => Ok(StackEntry::constant(
546            constants::GAMMA,
547            NumType::Transcendental,
548        )),
549        Plastic => Ok(StackEntry::constant(constants::PLASTIC, NumType::Algebraic)),
550        Apery => Ok(StackEntry::constant(
551            constants::APERY,
552            NumType::Transcendental,
553        )),
554        Catalan => Ok(StackEntry::constant(
555            constants::CATALAN,
556            NumType::Transcendental,
557        )),
558        X => Ok(StackEntry::new(x, 1.0, NumType::Integer)), // x can be any value, including integer
559        // User constants - look up value from the user_constants slice
560        UserConstant0 | UserConstant1 | UserConstant2 | UserConstant3 | UserConstant4
561        | UserConstant5 | UserConstant6 | UserConstant7 | UserConstant8 | UserConstant9
562        | UserConstant10 | UserConstant11 | UserConstant12 | UserConstant13 | UserConstant14
563        | UserConstant15 => {
564            // Get the index from the symbol
565            let idx = sym.user_constant_index().unwrap() as usize;
566            user_constants
567                .get(idx)
568                .map(|uc| StackEntry::constant(uc.value, uc.num_type))
569                .ok_or(EvalError::MissingUserConstant(idx))
570        }
571        _ => Err(EvalError::Invalid),
572    }
573}
574
575/// Evaluate a user-defined function
576///
577/// Takes the input argument and the user_functions slice, looks up the function
578/// definition, executes the body, and returns the result.
579fn eval_user_function(
580    sym: Symbol,
581    input: StackEntry,
582    context: &EvalContext<'_>,
583    x: f64,
584) -> Result<StackEntry, EvalError> {
585    // Get the function index
586    let idx = sym.user_function_index().ok_or(EvalError::Invalid)? as usize;
587
588    // Look up the function definition
589    let udf = context.user_functions.get(idx).ok_or(EvalError::Invalid)?;
590
591    // Reuse a thread-local scratch buffer rather than allocating a fresh Vec on every
592    // call. eval_user_function is invoked in the inner generation loop (potentially
593    // millions of times at high complexity), so avoiding the heap allocation matters.
594    // UDFs do not call other UDFs, so the borrow is never re-entered.
595    thread_local! {
596        static UDF_STACK: std::cell::RefCell<Vec<StackEntry>> =
597            std::cell::RefCell::new(Vec::with_capacity(16));
598    }
599
600    UDF_STACK.with(|cell| -> Result<StackEntry, EvalError> {
601        let mut stack = cell.borrow_mut();
602        stack.clear();
603        stack.push(input);
604
605        // Execute each operation in the function body
606        for op in &udf.body {
607            match op {
608                UdfOp::Symbol(sym) => {
609                    match sym.seft() {
610                        Seft::A => {
611                            // Constant - push onto stack
612                            let entry = eval_constant_with_user(*sym, x, context.user_constants)?;
613                            stack.push(entry);
614                        }
615                        Seft::B => {
616                            // Unary operator - pop one, push result
617                            let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
618                            let result = eval_unary(*sym, a, context.trig_argument_scale)?;
619                            stack.push(result);
620                        }
621                        Seft::C => {
622                            // Binary operator - pop two, push result
623                            let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
624                            let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
625                            let result = eval_binary(*sym, a, b)?;
626                            stack.push(result);
627                        }
628                    }
629                }
630                UdfOp::Dup => {
631                    // Duplicate top of stack. Dereference immediately so the
632                    // immutable borrow ends before the mutable push.
633                    let top = *stack.last().ok_or(EvalError::StackUnderflow)?;
634                    stack.push(top);
635                }
636                UdfOp::Swap => {
637                    // Swap top two elements
638                    let len = stack.len();
639                    if len < 2 {
640                        return Err(EvalError::StackUnderflow);
641                    }
642                    stack.swap(len - 1, len - 2);
643                }
644            }
645        }
646
647        // Function should leave exactly one value on the stack
648        if stack.len() != 1 {
649            return Err(EvalError::Invalid);
650        }
651
652        // SAFETY: len == 1 check above guarantees pop succeeds
653        let result = stack.pop().unwrap();
654
655        // Check for invalid results
656        if result.val.is_nan() || result.val.is_infinite() {
657            return Err(EvalError::Overflow);
658        }
659
660        Ok(result)
661    })
662}
663
664/// Evaluate a unary operator with derivative
665fn eval_unary(
666    sym: Symbol,
667    a: StackEntry,
668    trig_argument_scale: f64,
669) -> Result<StackEntry, EvalError> {
670    use Symbol::*;
671
672    let (val, deriv, num_type) = match sym {
673        // Negation: -a, d(-a)/dx = -da/dx
674        Neg => (-a.val, -a.deriv, a.num_type),
675
676        // Reciprocal: 1/a, d(1/a)/dx = -da/dx / a²
677        Recip => {
678            if a.val.abs() < f64::MIN_POSITIVE {
679                return Err(EvalError::DivisionByZero);
680            }
681            let val = 1.0 / a.val;
682            let deriv = -a.deriv / (a.val * a.val);
683            let num_type = if a.num_type == NumType::Integer {
684                NumType::Rational
685            } else {
686                a.num_type
687            };
688            (val, deriv, num_type)
689        }
690
691        // Square root: sqrt(a), d(sqrt(a))/dx = da/dx / (2*sqrt(a))
692        Sqrt => {
693            if a.val < 0.0 {
694                return Err(EvalError::SqrtDomain);
695            }
696            let val = a.val.sqrt();
697            let deriv = if val.abs() > f64::MIN_POSITIVE {
698                a.deriv / (2.0 * val)
699            } else {
700                0.0
701            };
702            let num_type = if a.num_type >= NumType::Constructible {
703                NumType::Constructible
704            } else {
705                a.num_type
706            };
707            (val, deriv, num_type)
708        }
709
710        // Square: a², d(a²)/dx = 2*a*da/dx
711        Square => {
712            let val = a.val * a.val;
713            let deriv = 2.0 * a.val * a.deriv;
714            (val, deriv, a.num_type)
715        }
716
717        // Natural log: ln(a), d(ln(a))/dx = da/dx / a
718        Ln => {
719            if a.val <= 0.0 {
720                return Err(EvalError::LogDomain);
721            }
722            let val = a.val.ln();
723            let deriv = a.deriv / a.val;
724            (val, deriv, NumType::Transcendental)
725        }
726
727        // Exponential: e^a, d(e^a)/dx = e^a * da/dx
728        Exp => {
729            let val = a.val.exp();
730            if val.is_infinite() {
731                return Err(EvalError::Overflow);
732            }
733            let deriv = val * a.deriv;
734            (val, deriv, NumType::Transcendental)
735        }
736
737        // sin(π*a), d(sin(πa))/dx = π*cos(πa)*da/dx
738        SinPi => {
739            let val = (trig_argument_scale * a.val).sin();
740            let deriv = trig_argument_scale * (trig_argument_scale * a.val).cos() * a.deriv;
741            (val, deriv, NumType::Transcendental)
742        }
743
744        // cos(π*a), d(cos(πa))/dx = -π*sin(πa)*da/dx
745        CosPi => {
746            let val = (trig_argument_scale * a.val).cos();
747            let deriv = -trig_argument_scale * (trig_argument_scale * a.val).sin() * a.deriv;
748            (val, deriv, NumType::Transcendental)
749        }
750
751        // tan(π*a), d(tan(πa))/dx = π*sec²(πa)*da/dx
752        TanPi => {
753            let cos_val = (trig_argument_scale * a.val).cos();
754            if cos_val.abs() < 1e-10 {
755                return Err(EvalError::Overflow);
756            }
757            let val = (trig_argument_scale * a.val).tan();
758            let deriv = trig_argument_scale * a.deriv / (cos_val * cos_val);
759            (val, deriv, NumType::Transcendental)
760        }
761
762        // Lambert W function (principal branch)
763        LambertW => {
764            let val = lambert_w(a.val)?;
765            // d(W(a))/dx = W(a) / (a * (1 + W(a))) * da/dx
766            // Special case: W'(0) = 1 (by L'Hôpital's rule, since W(x) ≈ x near 0)
767            let deriv = if a.val.abs() < 1e-10 {
768                a.deriv // W'(0) = 1
769            } else {
770                let denom = a.val * (1.0 + val);
771                if denom.abs() > f64::MIN_POSITIVE {
772                    val / denom * a.deriv
773                } else {
774                    0.0
775                }
776            };
777            (val, deriv, NumType::Transcendental)
778        }
779
780        // User functions are handled at the main evaluation loop level, not here
781        // If we reach this point, return an error
782        UserFunction0 | UserFunction1 | UserFunction2 | UserFunction3 | UserFunction4
783        | UserFunction5 | UserFunction6 | UserFunction7 | UserFunction8 | UserFunction9
784        | UserFunction10 | UserFunction11 | UserFunction12 | UserFunction13 | UserFunction14
785        | UserFunction15 => {
786            // This indicates a bug in the evaluation loop - user functions should be
787            // handled before calling eval_unary
788            return Err(EvalError::Invalid);
789        }
790
791        // Non-unary symbols should never be passed to this function
792        _ => return Err(EvalError::Invalid),
793    };
794
795    Ok(StackEntry::new(val, deriv, num_type))
796}
797
798/// Evaluate a binary operator with derivative
799fn eval_binary(sym: Symbol, a: StackEntry, b: StackEntry) -> Result<StackEntry, EvalError> {
800    use Symbol::*;
801
802    let (val, deriv, num_type) = match sym {
803        // Addition: a + b
804        Add => {
805            let val = a.val + b.val;
806            let deriv = a.deriv + b.deriv;
807            let num_type = a.num_type.combine(b.num_type);
808            (val, deriv, num_type)
809        }
810
811        // Subtraction: a - b
812        Sub => {
813            let val = a.val - b.val;
814            let deriv = a.deriv - b.deriv;
815            let num_type = a.num_type.combine(b.num_type);
816            (val, deriv, num_type)
817        }
818
819        // Multiplication: a * b, d(ab)/dx = a*db/dx + b*da/dx
820        Mul => {
821            let val = a.val * b.val;
822            let deriv = a.val * b.deriv + b.val * a.deriv;
823            let num_type = a.num_type.combine(b.num_type);
824            (val, deriv, num_type)
825        }
826
827        // Division: a / b, d(a/b)/dx = (b*da/dx - a*db/dx) / b²
828        Div => {
829            if b.val.abs() < f64::MIN_POSITIVE {
830                return Err(EvalError::DivisionByZero);
831            }
832            let val = a.val / b.val;
833            let deriv = (b.val * a.deriv - a.val * b.deriv) / (b.val * b.val);
834            let mut num_type = a.num_type.combine(b.num_type);
835            if num_type == NumType::Integer {
836                num_type = NumType::Rational;
837            }
838            (val, deriv, num_type)
839        }
840
841        // Power: a^b, d(a^b)/dx = a^b * (b*da/dx/a + ln(a)*db/dx)
842        Pow => {
843            if a.val <= 0.0 && b.val.fract() != 0.0 {
844                return Err(EvalError::SqrtDomain);
845            }
846            let val = a.val.powf(b.val);
847            if val.is_infinite() || val.is_nan() {
848                return Err(EvalError::Overflow);
849            }
850            // Guard for near-zero base to avoid numerical issues
851            let deriv = if a.val > f64::MIN_POSITIVE {
852                val * (b.val * a.deriv / a.val + a.val.ln() * b.deriv)
853            } else if a.val.abs() < f64::MIN_POSITIVE && b.val > 0.0 {
854                0.0
855            } else {
856                // Negative base, integer exponent (or near-zero base treated as 0).
857                // Full formula: val * (b * a.deriv/a + ln(a) * b.deriv).
858                // The ln(a) * b.deriv term is intentionally dropped here: ln(negative) is
859                // undefined in the reals (NaN), so it cannot contribute to Newton-Raphson.
860                // Dropping it gives 0 for the derivative w.r.t. x-in-the-exponent path,
861                // which is the correct safe fallback when x appears in the exponent of a
862                // negative base (e.g., (-2)^x is only real-valued at integer x).
863                if a.val.abs() < f64::MIN_POSITIVE {
864                    0.0
865                } else {
866                    val * b.val * a.deriv / a.val
867                }
868            };
869            let num_type = if b.num_type == NumType::Integer {
870                a.num_type
871            } else {
872                NumType::Transcendental
873            };
874            (val, deriv, num_type)
875        }
876
877        // a-th root of b: b^(1/a)
878        Root => {
879            if a.val.abs() < f64::MIN_POSITIVE {
880                return Err(EvalError::DivisionByZero);
881            }
882            let exp = 1.0 / a.val;
883
884            // For negative radicands, we need to check if the index is an odd integer
885            // Non-integer indices of negative numbers have no real value
886            if b.val < 0.0 {
887                // Check if the index is close to an integer
888                let rounded = a.val.round();
889                let is_integer = (a.val - rounded).abs() < 1e-10;
890
891                if !is_integer {
892                    // Non-integer index of negative number - no real value
893                    return Err(EvalError::SqrtDomain);
894                }
895
896                // Check if the integer is odd (odd roots of negatives are real)
897                let int_val = rounded as i64;
898                if int_val % 2 == 0 {
899                    // Even integer root of negative - no real value
900                    return Err(EvalError::SqrtDomain);
901                }
902                // Odd integer root of negative is OK
903            }
904
905            let val = if b.val < 0.0 {
906                // Odd root of negative number
907                -((-b.val).powf(exp))
908            } else {
909                b.val.powf(exp)
910            };
911            if val.is_infinite() || val.is_nan() {
912                return Err(EvalError::Overflow);
913            }
914            // d(b^(1/a))/dx = b^(1/a) * (db/dx/(a*b) - ln(b)*da/dx/a²)
915            let deriv = if b.val.abs() > f64::MIN_POSITIVE {
916                val * (b.deriv / (a.val * b.val) - b.val.abs().ln() * a.deriv / (a.val * a.val))
917            } else {
918                0.0
919            };
920            (val, deriv, NumType::Algebraic)
921        }
922
923        // Logarithm base a of b: ln(b) / ln(a)
924        Log => {
925            if a.val <= 0.0 || a.val == 1.0 || b.val <= 0.0 {
926                return Err(EvalError::LogDomain);
927            }
928            let ln_a = a.val.ln();
929            let ln_b = b.val.ln();
930            let val = ln_b / ln_a;
931            // d(log_a(b))/dx = (db/dx/(b*ln(a)) - ln(b)*da/dx/(a*ln(a)²))
932            let deriv = b.deriv / (b.val * ln_a) - ln_b * a.deriv / (a.val * ln_a * ln_a);
933            (val, deriv, NumType::Transcendental)
934        }
935
936        // atan2(a, b) = angle of point (b, a) from origin
937        Atan2 => {
938            let val = a.val.atan2(b.val);
939            // d(atan2(a,b))/dx = (b*da/dx - a*db/dx) / (a² + b²)
940            let denom = a.val * a.val + b.val * b.val;
941            let deriv = if denom.abs() > f64::MIN_POSITIVE {
942                (b.val * a.deriv - a.val * b.deriv) / denom
943            } else {
944                0.0
945            };
946            (val, deriv, NumType::Transcendental)
947        }
948
949        // Non-binary symbols should never be passed to this function
950        _ => return Err(EvalError::Invalid),
951    };
952
953    Ok(StackEntry::new(val, deriv, num_type))
954}
955
956/// Compute the Lambert W function (principal branch) using Halley's method
957///
958/// The Lambert W function satisfies W(x) * exp(W(x)) = x.
959/// This implementation handles the principal branch (W₀) for x ≥ -1/e.
960fn lambert_w(x: f64) -> Result<f64, EvalError> {
961    // Branch point: x = -1/e gives W = -1
962    const INV_E: f64 = 1.0 / std::f64::consts::E;
963    const NEG_INV_E: f64 = -INV_E; // -0.36787944117144233...
964
965    // Domain check
966    if x < NEG_INV_E {
967        return Err(EvalError::LogDomain);
968    }
969
970    // Special cases
971    if x == 0.0 {
972        return Ok(0.0); // W(0) = 0
973    }
974    if (x - NEG_INV_E).abs() < 1e-15 {
975        return Ok(-1.0); // W(-1/e) = -1
976    }
977    if x == constants::E {
978        return Ok(1.0); // W(e) = 1
979    }
980
981    // Initial guess - different approximations for different regimes
982    let mut w = if x < -0.3 {
983        // Near the branch point, use a series expansion around -1/e
984        // W(x) ≈ -1 + p - p²/3 + 11p³/72 where p = sqrt(2(ex + 1))
985        let p = (2.0 * (constants::E * x + 1.0)).sqrt();
986        -1.0 + p * (1.0 - p / 3.0 * (1.0 - 11.0 * p / 72.0))
987    } else if x < 0.25 {
988        // Near zero, use a polynomial approximation
989        // W(x) ≈ x - x² + 3x³/2 - 8x⁴/3 + ...
990        // For numerical stability, use a rational approximation
991        let x2 = x * x;
992        x * (1.0 - x + x2 * (1.5 - 2.6667 * x))
993    } else if x < 4.0 {
994        // Moderate range: use log-based approximation
995        // W(x) ≈ ln(x) - ln(ln(x)) + ln(ln(x))/ln(x)
996        let lnx = x.ln();
997        if lnx > 0.0 {
998            let lnlnx = lnx.ln().max(0.0);
999            lnx - lnlnx + lnlnx / lnx.max(1.0)
1000        } else {
1001            x // fallback for x near 1
1002        }
1003    } else {
1004        // Large x: W(x) ≈ ln(x) - ln(ln(x)) + ln(ln(x))/ln(x)
1005        let l1 = x.ln();
1006        let l2 = l1.ln();
1007        l1 - l2 + l2 / l1
1008    };
1009
1010    // Halley's method iteration
1011    // For well-chosen initial guesses, 10-15 iterations are usually enough
1012    for _ in 0..25 {
1013        let ew = w.exp();
1014
1015        // Handle potential overflow
1016        if !ew.is_finite() {
1017            // Back off to a more stable approach
1018            w = x.ln() - w.ln().max(1e-10);
1019            continue;
1020        }
1021
1022        let wew = w * ew;
1023        let diff = wew - x;
1024
1025        // Convergence check with relative tolerance
1026        let tol = 1e-15 * (1.0 + w.abs().max(x.abs()));
1027        if diff.abs() < tol {
1028            break;
1029        }
1030
1031        let w1 = w + 1.0;
1032        // Halley's correction
1033        let denom = ew * w1 - 0.5 * (w + 2.0) * diff / w1;
1034        if denom.abs() < f64::MIN_POSITIVE {
1035            break;
1036        }
1037
1038        let delta = diff / denom;
1039
1040        // Damping for stability near branch point
1041        let correction = if w < -0.5 && delta.abs() > 0.5 {
1042            delta * 0.5 // Damped update near branch point
1043        } else {
1044            delta
1045        };
1046
1047        w -= correction;
1048    }
1049
1050    // Final validation
1051    if !w.is_finite() {
1052        return Err(EvalError::Overflow);
1053    }
1054
1055    Ok(w)
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060    use super::*;
1061
1062    fn approx_eq(a: f64, b: f64) -> bool {
1063        (a - b).abs() < 1e-10
1064    }
1065
1066    #[test]
1067    fn test_basic_eval() {
1068        let expr = Expression::parse("32+").unwrap();
1069        let result = evaluate(&expr, 0.0).unwrap();
1070        assert!(approx_eq(result.value, 5.0));
1071        assert!(approx_eq(result.derivative, 0.0));
1072    }
1073
1074    #[test]
1075    fn test_variable() {
1076        let expr = Expression::parse("x").unwrap();
1077        let result = evaluate(&expr, 3.5).unwrap();
1078        assert!(approx_eq(result.value, 3.5));
1079        assert!(approx_eq(result.derivative, 1.0));
1080    }
1081
1082    #[test]
1083    fn test_x_squared() {
1084        let expr = Expression::parse("xs").unwrap(); // x^2
1085        let result = evaluate(&expr, 3.0).unwrap();
1086        assert!(approx_eq(result.value, 9.0));
1087        assert!(approx_eq(result.derivative, 6.0)); // 2x
1088    }
1089
1090    #[test]
1091    fn test_sqrt_pi() {
1092        let expr = Expression::parse("pq").unwrap(); // sqrt(pi)
1093        let result = evaluate(&expr, 0.0).unwrap();
1094        assert!(approx_eq(result.value, constants::PI.sqrt()));
1095    }
1096
1097    #[test]
1098    fn test_e_to_x() {
1099        let expr = Expression::parse("xE").unwrap(); // e^x
1100        let result = evaluate(&expr, 1.0).unwrap();
1101        assert!(approx_eq(result.value, constants::E));
1102        assert!(approx_eq(result.derivative, constants::E)); // d(e^x)/dx = e^x
1103    }
1104
1105    #[test]
1106    fn test_complex_expr() {
1107        // x^2 + 2*x + 1 = (x+1)^2
1108        let expr = Expression::parse("xs2x*+1+").unwrap();
1109        let result = evaluate(&expr, 3.0).unwrap();
1110        assert!(approx_eq(result.value, 16.0)); // (3+1)^2
1111        assert!(approx_eq(result.derivative, 8.0)); // 2x + 2 = 8
1112    }
1113
1114    #[test]
1115    fn test_lambert_w() {
1116        // W(1) ≈ 0.5671432904
1117        let w = lambert_w(1.0).unwrap();
1118        assert!((w - 0.5671432904).abs() < 1e-9);
1119
1120        // W(e) = 1
1121        let w = lambert_w(constants::E).unwrap();
1122        assert!((w - 1.0).abs() < 1e-10);
1123    }
1124
1125    #[test]
1126    fn test_user_constant_evaluation() {
1127        use crate::profile::UserConstant;
1128
1129        // Create a user constant (Euler-Mascheroni gamma ≈ 0.57721)
1130        let user_constants = vec![UserConstant {
1131            weight: 8,
1132            name: "g".to_string(),
1133            description: "gamma".to_string(),
1134            value: 0.5772156649,
1135            num_type: NumType::Transcendental,
1136        }];
1137
1138        // Create expression with UserConstant0 (byte 128)
1139        let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
1140
1141        // Evaluate with user constants
1142        let result = evaluate_with_constants(&expr, 0.0, &user_constants).unwrap();
1143
1144        // Should match the user constant value
1145        assert!(approx_eq(result.value, 0.5772156649));
1146        // Derivative should be 0 (it's a constant)
1147        assert!(approx_eq(result.derivative, 0.0));
1148    }
1149
1150    #[test]
1151    fn test_user_constant_in_expression() {
1152        use crate::profile::UserConstant;
1153
1154        // Create two user constants
1155        let user_constants = vec![
1156            UserConstant {
1157                weight: 8,
1158                name: "a".to_string(),
1159                description: "constant a".to_string(),
1160                value: 2.0,
1161                num_type: NumType::Integer,
1162            },
1163            UserConstant {
1164                weight: 8,
1165                name: "b".to_string(),
1166                description: "constant b".to_string(),
1167                value: 3.0,
1168                num_type: NumType::Integer,
1169            },
1170        ];
1171
1172        // Create expression: u0 * x + u1 (in postfix: u0 x * u1 +)
1173        let expr = Expression::from_symbols(&[
1174            Symbol::UserConstant0,
1175            Symbol::X,
1176            Symbol::Mul,
1177            Symbol::UserConstant1,
1178            Symbol::Add,
1179        ]);
1180
1181        // At x=4, should be 2*4 + 3 = 11
1182        let result = evaluate_with_constants(&expr, 4.0, &user_constants).unwrap();
1183        assert!(approx_eq(result.value, 11.0));
1184        // Derivative should be 2 (from u0 * x)
1185        assert!(approx_eq(result.derivative, 2.0));
1186    }
1187
1188    #[test]
1189    fn test_user_constant_missing_returns_error() {
1190        // Missing user constant slots must fail explicitly instead of silently
1191        // changing the expression's meaning.
1192        let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
1193
1194        let result = evaluate_with_constants(&expr, 0.0, &[]);
1195        assert!(matches!(result, Err(EvalError::MissingUserConstant(0))));
1196    }
1197
1198    #[test]
1199    fn test_user_function_sinh() {
1200        use crate::udf::UserFunction;
1201
1202        // sinh(x) = (e^x - e^-x) / 2
1203        // In postfix: E|r-2/ (exp, dup, recip, subtract, 2, divide)
1204        let user_functions = vec![UserFunction::parse("4:sinh:hyperbolic sine:E|r-2/").unwrap()];
1205
1206        // Create expression: sinh(x) (in postfix: xF0 where F0 = UserFunction0)
1207        let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1208
1209        // sinh(1) = (e - e^-1) / 2 ≈ 1.1752
1210        let result =
1211            evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
1212        let expected = (constants::E - 1.0 / constants::E) / 2.0;
1213        assert!(approx_eq(result.value, expected));
1214
1215        // Derivative: d(sinh(x))/dx = cosh(x) = (e^x + e^-x) / 2
1216        let expected_deriv = (constants::E + 1.0 / constants::E) / 2.0;
1217        assert!((result.derivative - expected_deriv).abs() < 1e-10);
1218    }
1219
1220    #[test]
1221    fn test_user_function_xex() {
1222        use crate::udf::UserFunction;
1223
1224        // XeX(x) = x * e^x
1225        // In postfix: |E* (dup, exp, multiply)
1226        let user_functions = vec![UserFunction::parse("4:XeX:x*exp(x):|E*").unwrap()];
1227
1228        // Create expression: XeX(x) (in postfix: xF0 where F0 = UserFunction0)
1229        let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1230
1231        // XeX(1) = 1 * e^1 = e
1232        let result =
1233            evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
1234        assert!(approx_eq(result.value, constants::E));
1235
1236        // Derivative: d(x*e^x)/dx = e^x + x*e^x = e^x * (1 + x) = e * 2
1237        let expected_deriv = constants::E * 2.0;
1238        assert!((result.derivative - expected_deriv).abs() < 1e-10);
1239    }
1240
1241    #[test]
1242    fn test_user_function_missing_returns_error() {
1243        // When no user functions are provided, user function evaluation should fail
1244        let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1245
1246        let result = evaluate_with_constants_and_functions(&expr, 1.0, &[], &[]);
1247        assert!(result.is_err());
1248    }
1249}