use crate::expr::Expression;
use crate::profile::UserConstant;
use crate::symbol::{NumType, Seft, Symbol};
use crate::udf::{UdfOp, UserFunction};
#[derive(Debug, Clone, Copy)]
pub struct EvalResult {
pub value: f64,
pub derivative: f64,
pub num_type: NumType,
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum EvalError {
#[error("Stack underflow: not enough operands on stack")]
StackUnderflow,
#[error("Missing user constant: slot u{0} is not configured")]
MissingUserConstant(usize),
#[error("Division by zero: divisor was zero or near-zero")]
DivisionByZero,
#[error("Logarithm domain error: argument was non-positive")]
LogDomain,
#[error("Square root domain error: argument was negative")]
SqrtDomain,
#[error("Overflow: result is infinite or NaN")]
Overflow,
#[error("Invalid expression: malformed or incomplete")]
Invalid,
#[error("{err} at position {pos}")]
WithPosition {
#[source]
err: Box<EvalError>,
pos: usize,
},
#[error("{err} (value: {val})")]
WithValue {
#[source]
err: Box<EvalError>,
val: ordered_float::OrderedFloat<f64>,
},
#[error("{err} in expression '{expr}'")]
WithExpression {
#[source]
err: Box<EvalError>,
expr: String,
},
}
impl EvalError {
pub fn with_context(self, position: Option<usize>, value: Option<f64>) -> Self {
let mut err = self;
if let Some(pos) = position {
err = EvalError::WithPosition {
err: Box::new(err),
pos,
};
}
if let Some(val) = value {
err = EvalError::WithValue {
err: Box::new(err),
val: ordered_float::OrderedFloat(val),
};
}
err
}
pub fn with_expression(self, expr: String) -> Self {
EvalError::WithExpression {
err: Box::new(self),
expr,
}
}
}
pub mod constants {
pub const PI: f64 = std::f64::consts::PI;
pub const E: f64 = std::f64::consts::E;
pub const PHI: f64 = 1.618_033_988_749_895; pub const GAMMA: f64 = 0.577_215_664_901_532_9;
pub const PLASTIC: f64 = 1.324_717_957_244_746;
pub const APERY: f64 = 1.202_056_903_159_594_2;
pub const CATALAN: f64 = 0.915_965_594_177_219;
}
pub const DEFAULT_TRIG_ARGUMENT_SCALE: f64 = std::f64::consts::PI;
#[derive(Clone, Copy, Debug)]
pub struct EvalContext<'a> {
pub trig_argument_scale: f64,
pub user_constants: &'a [UserConstant],
pub user_functions: &'a [UserFunction],
}
impl Default for EvalContext<'static> {
fn default() -> Self {
Self {
trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
user_constants: &[],
user_functions: &[],
}
}
}
impl EvalContext<'static> {
pub fn new() -> Self {
Self::default()
}
}
impl<'a> EvalContext<'a> {
pub fn from_slices(
user_constants: &'a [UserConstant],
user_functions: &'a [UserFunction],
) -> Self {
Self {
trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
user_constants,
user_functions,
}
}
pub fn with_trig_argument_scale(mut self, scale: f64) -> Self {
if scale.is_finite() && scale != 0.0 {
self.trig_argument_scale = scale;
}
self
}
}
#[derive(Debug, Clone, Copy)]
struct StackEntry {
val: f64,
deriv: f64,
num_type: NumType,
}
impl StackEntry {
fn new(val: f64, deriv: f64, num_type: NumType) -> Self {
Self {
val,
deriv,
num_type,
}
}
fn constant(val: f64, num_type: NumType) -> Self {
Self {
val,
deriv: 0.0,
num_type,
}
}
}
pub struct EvalWorkspace {
stack: Vec<StackEntry>,
}
impl EvalWorkspace {
pub fn new() -> Self {
Self {
stack: Vec::with_capacity(32),
}
}
#[inline]
fn clear(&mut self) {
self.stack.clear();
}
}
impl Default for EvalWorkspace {
fn default() -> Self {
Self::new()
}
}
#[inline]
pub fn evaluate_with_workspace(
expr: &Expression,
x: f64,
workspace: &mut EvalWorkspace,
) -> Result<EvalResult, EvalError> {
evaluate_with_workspace_and_context(expr, x, workspace, &EvalContext::new())
}
#[inline]
pub fn evaluate_with_workspace_and_constants(
expr: &Expression,
x: f64,
workspace: &mut EvalWorkspace,
user_constants: &[UserConstant],
) -> Result<EvalResult, EvalError> {
let context = EvalContext::from_slices(user_constants, &[]);
evaluate_with_workspace_and_context(expr, x, workspace, &context)
}
#[inline]
pub fn evaluate_with_workspace_and_constants_and_functions(
expr: &Expression,
x: f64,
workspace: &mut EvalWorkspace,
user_constants: &[UserConstant],
user_functions: &[UserFunction],
) -> Result<EvalResult, EvalError> {
let context = EvalContext::from_slices(user_constants, user_functions);
evaluate_with_workspace_and_context(expr, x, workspace, &context)
}
#[inline]
pub fn evaluate_with_workspace_and_context(
expr: &Expression,
x: f64,
workspace: &mut EvalWorkspace,
context: &EvalContext<'_>,
) -> Result<EvalResult, EvalError> {
workspace.clear();
let stack = &mut workspace.stack;
for &sym in expr.symbols() {
match sym.seft() {
Seft::A => {
let entry = eval_constant_with_user(sym, x, context.user_constants)?;
stack.push(entry);
}
Seft::B => {
if matches!(
sym,
Symbol::UserFunction0
| Symbol::UserFunction1
| Symbol::UserFunction2
| Symbol::UserFunction3
| Symbol::UserFunction4
| Symbol::UserFunction5
| Symbol::UserFunction6
| Symbol::UserFunction7
| Symbol::UserFunction8
| Symbol::UserFunction9
| Symbol::UserFunction10
| Symbol::UserFunction11
| Symbol::UserFunction12
| Symbol::UserFunction13
| Symbol::UserFunction14
| Symbol::UserFunction15
) {
let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
let result = eval_user_function(sym, a, context, x)?;
stack.push(result);
} else {
let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
let result = eval_unary(sym, a, context.trig_argument_scale)?;
stack.push(result);
}
}
Seft::C => {
let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
let result = eval_binary(sym, a, b)?;
stack.push(result);
}
}
}
if stack.len() != 1 {
return Err(EvalError::Invalid);
}
let result = stack.pop().unwrap();
if result.val.is_nan() || result.val.is_infinite() {
return Err(EvalError::Overflow);
}
Ok(EvalResult {
value: result.val,
derivative: result.deriv,
num_type: result.num_type,
})
}
pub fn evaluate(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
evaluate_with_context(expr, x, &EvalContext::new())
}
pub fn evaluate_with_constants(
expr: &Expression,
x: f64,
user_constants: &[UserConstant],
) -> Result<EvalResult, EvalError> {
let context = EvalContext::from_slices(user_constants, &[]);
evaluate_with_context(expr, x, &context)
}
pub fn evaluate_with_constants_and_functions(
expr: &Expression,
x: f64,
user_constants: &[UserConstant],
user_functions: &[UserFunction],
) -> Result<EvalResult, EvalError> {
let context = EvalContext::from_slices(user_constants, user_functions);
evaluate_with_context(expr, x, &context)
}
pub fn evaluate_with_context(
expr: &Expression,
x: f64,
context: &EvalContext<'_>,
) -> Result<EvalResult, EvalError> {
let mut workspace = EvalWorkspace::new();
evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
}
#[inline]
pub fn evaluate_fast(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
evaluate_fast_with_context(expr, x, &EvalContext::new())
}
#[inline]
pub fn evaluate_fast_with_constants(
expr: &Expression,
x: f64,
user_constants: &[UserConstant],
) -> Result<EvalResult, EvalError> {
let context = EvalContext::from_slices(user_constants, &[]);
evaluate_fast_with_context(expr, x, &context)
}
#[inline]
pub fn evaluate_fast_with_constants_and_functions(
expr: &Expression,
x: f64,
user_constants: &[UserConstant],
user_functions: &[UserFunction],
) -> Result<EvalResult, EvalError> {
let context = EvalContext::from_slices(user_constants, user_functions);
evaluate_fast_with_context(expr, x, &context)
}
#[inline]
pub fn evaluate_fast_with_context(
expr: &Expression,
x: f64,
context: &EvalContext<'_>,
) -> Result<EvalResult, EvalError> {
thread_local! {
static WORKSPACE: std::cell::RefCell<EvalWorkspace> = std::cell::RefCell::new(EvalWorkspace::new());
}
WORKSPACE.with(|ws| {
let mut workspace = ws.borrow_mut();
evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
})
}
fn eval_constant_with_user(
sym: Symbol,
x: f64,
user_constants: &[UserConstant],
) -> Result<StackEntry, EvalError> {
use Symbol::*;
match sym {
One => Ok(StackEntry::constant(1.0, NumType::Integer)),
Two => Ok(StackEntry::constant(2.0, NumType::Integer)),
Three => Ok(StackEntry::constant(3.0, NumType::Integer)),
Four => Ok(StackEntry::constant(4.0, NumType::Integer)),
Five => Ok(StackEntry::constant(5.0, NumType::Integer)),
Six => Ok(StackEntry::constant(6.0, NumType::Integer)),
Seven => Ok(StackEntry::constant(7.0, NumType::Integer)),
Eight => Ok(StackEntry::constant(8.0, NumType::Integer)),
Nine => Ok(StackEntry::constant(9.0, NumType::Integer)),
Pi => Ok(StackEntry::constant(constants::PI, NumType::Transcendental)),
E => Ok(StackEntry::constant(constants::E, NumType::Transcendental)),
Phi => Ok(StackEntry::constant(constants::PHI, NumType::Algebraic)),
Gamma => Ok(StackEntry::constant(
constants::GAMMA,
NumType::Transcendental,
)),
Plastic => Ok(StackEntry::constant(constants::PLASTIC, NumType::Algebraic)),
Apery => Ok(StackEntry::constant(
constants::APERY,
NumType::Transcendental,
)),
Catalan => Ok(StackEntry::constant(
constants::CATALAN,
NumType::Transcendental,
)),
X => Ok(StackEntry::new(x, 1.0, NumType::Integer)), UserConstant0 | UserConstant1 | UserConstant2 | UserConstant3 | UserConstant4
| UserConstant5 | UserConstant6 | UserConstant7 | UserConstant8 | UserConstant9
| UserConstant10 | UserConstant11 | UserConstant12 | UserConstant13 | UserConstant14
| UserConstant15 => {
let idx = sym.user_constant_index().unwrap() as usize;
user_constants
.get(idx)
.map(|uc| StackEntry::constant(uc.value, uc.num_type))
.ok_or(EvalError::MissingUserConstant(idx))
}
_ => Err(EvalError::Invalid),
}
}
fn eval_user_function(
sym: Symbol,
input: StackEntry,
context: &EvalContext<'_>,
x: f64,
) -> Result<StackEntry, EvalError> {
let idx = sym.user_function_index().ok_or(EvalError::Invalid)? as usize;
let udf = context.user_functions.get(idx).ok_or(EvalError::Invalid)?;
thread_local! {
static UDF_STACK: std::cell::RefCell<Vec<StackEntry>> =
std::cell::RefCell::new(Vec::with_capacity(16));
}
UDF_STACK.with(|cell| -> Result<StackEntry, EvalError> {
let mut stack = cell.borrow_mut();
stack.clear();
stack.push(input);
for op in &udf.body {
match op {
UdfOp::Symbol(sym) => {
match sym.seft() {
Seft::A => {
let entry = eval_constant_with_user(*sym, x, context.user_constants)?;
stack.push(entry);
}
Seft::B => {
let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
let result = eval_unary(*sym, a, context.trig_argument_scale)?;
stack.push(result);
}
Seft::C => {
let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
let result = eval_binary(*sym, a, b)?;
stack.push(result);
}
}
}
UdfOp::Dup => {
let top = *stack.last().ok_or(EvalError::StackUnderflow)?;
stack.push(top);
}
UdfOp::Swap => {
let len = stack.len();
if len < 2 {
return Err(EvalError::StackUnderflow);
}
stack.swap(len - 1, len - 2);
}
}
}
if stack.len() != 1 {
return Err(EvalError::Invalid);
}
let result = stack.pop().unwrap();
if result.val.is_nan() || result.val.is_infinite() {
return Err(EvalError::Overflow);
}
Ok(result)
})
}
fn eval_unary(
sym: Symbol,
a: StackEntry,
trig_argument_scale: f64,
) -> Result<StackEntry, EvalError> {
use Symbol::*;
let (val, deriv, num_type) = match sym {
Neg => (-a.val, -a.deriv, a.num_type),
Recip => {
if a.val.abs() < f64::MIN_POSITIVE {
return Err(EvalError::DivisionByZero);
}
let val = 1.0 / a.val;
let deriv = -a.deriv / (a.val * a.val);
let num_type = if a.num_type == NumType::Integer {
NumType::Rational
} else {
a.num_type
};
(val, deriv, num_type)
}
Sqrt => {
if a.val < 0.0 {
return Err(EvalError::SqrtDomain);
}
let val = a.val.sqrt();
let deriv = if val.abs() > f64::MIN_POSITIVE {
a.deriv / (2.0 * val)
} else {
0.0
};
let num_type = if a.num_type >= NumType::Constructible {
NumType::Constructible
} else {
a.num_type
};
(val, deriv, num_type)
}
Square => {
let val = a.val * a.val;
let deriv = 2.0 * a.val * a.deriv;
(val, deriv, a.num_type)
}
Ln => {
if a.val <= 0.0 {
return Err(EvalError::LogDomain);
}
let val = a.val.ln();
let deriv = a.deriv / a.val;
(val, deriv, NumType::Transcendental)
}
Exp => {
let val = a.val.exp();
if val.is_infinite() {
return Err(EvalError::Overflow);
}
let deriv = val * a.deriv;
(val, deriv, NumType::Transcendental)
}
SinPi => {
let val = (trig_argument_scale * a.val).sin();
let deriv = trig_argument_scale * (trig_argument_scale * a.val).cos() * a.deriv;
(val, deriv, NumType::Transcendental)
}
CosPi => {
let val = (trig_argument_scale * a.val).cos();
let deriv = -trig_argument_scale * (trig_argument_scale * a.val).sin() * a.deriv;
(val, deriv, NumType::Transcendental)
}
TanPi => {
let cos_val = (trig_argument_scale * a.val).cos();
if cos_val.abs() < 1e-10 {
return Err(EvalError::Overflow);
}
let val = (trig_argument_scale * a.val).tan();
let deriv = trig_argument_scale * a.deriv / (cos_val * cos_val);
(val, deriv, NumType::Transcendental)
}
LambertW => {
let val = lambert_w(a.val)?;
let deriv = if a.val.abs() < 1e-10 {
a.deriv } else {
let denom = a.val * (1.0 + val);
if denom.abs() > f64::MIN_POSITIVE {
val / denom * a.deriv
} else {
0.0
}
};
(val, deriv, NumType::Transcendental)
}
UserFunction0 | UserFunction1 | UserFunction2 | UserFunction3 | UserFunction4
| UserFunction5 | UserFunction6 | UserFunction7 | UserFunction8 | UserFunction9
| UserFunction10 | UserFunction11 | UserFunction12 | UserFunction13 | UserFunction14
| UserFunction15 => {
return Err(EvalError::Invalid);
}
_ => return Err(EvalError::Invalid),
};
Ok(StackEntry::new(val, deriv, num_type))
}
fn eval_binary(sym: Symbol, a: StackEntry, b: StackEntry) -> Result<StackEntry, EvalError> {
use Symbol::*;
let (val, deriv, num_type) = match sym {
Add => {
let val = a.val + b.val;
let deriv = a.deriv + b.deriv;
let num_type = a.num_type.combine(b.num_type);
(val, deriv, num_type)
}
Sub => {
let val = a.val - b.val;
let deriv = a.deriv - b.deriv;
let num_type = a.num_type.combine(b.num_type);
(val, deriv, num_type)
}
Mul => {
let val = a.val * b.val;
let deriv = a.val * b.deriv + b.val * a.deriv;
let num_type = a.num_type.combine(b.num_type);
(val, deriv, num_type)
}
Div => {
if b.val.abs() < f64::MIN_POSITIVE {
return Err(EvalError::DivisionByZero);
}
let val = a.val / b.val;
let deriv = (b.val * a.deriv - a.val * b.deriv) / (b.val * b.val);
let mut num_type = a.num_type.combine(b.num_type);
if num_type == NumType::Integer {
num_type = NumType::Rational;
}
(val, deriv, num_type)
}
Pow => {
if a.val <= 0.0 && b.val.fract() != 0.0 {
return Err(EvalError::SqrtDomain);
}
let val = a.val.powf(b.val);
if val.is_infinite() || val.is_nan() {
return Err(EvalError::Overflow);
}
let deriv = if a.val > f64::MIN_POSITIVE {
val * (b.val * a.deriv / a.val + a.val.ln() * b.deriv)
} else if a.val.abs() < f64::MIN_POSITIVE && b.val > 0.0 {
0.0
} else {
if a.val.abs() < f64::MIN_POSITIVE {
0.0
} else {
val * b.val * a.deriv / a.val
}
};
let num_type = if b.num_type == NumType::Integer {
a.num_type
} else {
NumType::Transcendental
};
(val, deriv, num_type)
}
Root => {
if a.val.abs() < f64::MIN_POSITIVE {
return Err(EvalError::DivisionByZero);
}
let exp = 1.0 / a.val;
if b.val < 0.0 {
let rounded = a.val.round();
let is_integer = (a.val - rounded).abs() < 1e-10;
if !is_integer {
return Err(EvalError::SqrtDomain);
}
let int_val = rounded as i64;
if int_val % 2 == 0 {
return Err(EvalError::SqrtDomain);
}
}
let val = if b.val < 0.0 {
-((-b.val).powf(exp))
} else {
b.val.powf(exp)
};
if val.is_infinite() || val.is_nan() {
return Err(EvalError::Overflow);
}
let deriv = if b.val.abs() > f64::MIN_POSITIVE {
val * (b.deriv / (a.val * b.val) - b.val.abs().ln() * a.deriv / (a.val * a.val))
} else {
0.0
};
(val, deriv, NumType::Algebraic)
}
Log => {
if a.val <= 0.0 || a.val == 1.0 || b.val <= 0.0 {
return Err(EvalError::LogDomain);
}
let ln_a = a.val.ln();
let ln_b = b.val.ln();
let val = ln_b / ln_a;
let deriv = b.deriv / (b.val * ln_a) - ln_b * a.deriv / (a.val * ln_a * ln_a);
(val, deriv, NumType::Transcendental)
}
Atan2 => {
let val = a.val.atan2(b.val);
let denom = a.val * a.val + b.val * b.val;
let deriv = if denom.abs() > f64::MIN_POSITIVE {
(b.val * a.deriv - a.val * b.deriv) / denom
} else {
0.0
};
(val, deriv, NumType::Transcendental)
}
_ => return Err(EvalError::Invalid),
};
Ok(StackEntry::new(val, deriv, num_type))
}
fn lambert_w(x: f64) -> Result<f64, EvalError> {
const INV_E: f64 = 1.0 / std::f64::consts::E;
const NEG_INV_E: f64 = -INV_E;
if x < NEG_INV_E {
return Err(EvalError::LogDomain);
}
if x == 0.0 {
return Ok(0.0); }
if (x - NEG_INV_E).abs() < 1e-15 {
return Ok(-1.0); }
if x == constants::E {
return Ok(1.0); }
let mut w = if x < -0.3 {
let p = (2.0 * (constants::E * x + 1.0)).sqrt();
-1.0 + p * (1.0 - p / 3.0 * (1.0 - 11.0 * p / 72.0))
} else if x < 0.25 {
let x2 = x * x;
x * (1.0 - x + x2 * (1.5 - 2.6667 * x))
} else if x < 4.0 {
let lnx = x.ln();
if lnx > 0.0 {
let lnlnx = lnx.ln().max(0.0);
lnx - lnlnx + lnlnx / lnx.max(1.0)
} else {
x }
} else {
let l1 = x.ln();
let l2 = l1.ln();
l1 - l2 + l2 / l1
};
for _ in 0..25 {
let ew = w.exp();
if !ew.is_finite() {
w = x.ln() - w.ln().max(1e-10);
continue;
}
let wew = w * ew;
let diff = wew - x;
let tol = 1e-15 * (1.0 + w.abs().max(x.abs()));
if diff.abs() < tol {
break;
}
let w1 = w + 1.0;
let denom = ew * w1 - 0.5 * (w + 2.0) * diff / w1;
if denom.abs() < f64::MIN_POSITIVE {
break;
}
let delta = diff / denom;
let correction = if w < -0.5 && delta.abs() > 0.5 {
delta * 0.5 } else {
delta
};
w -= correction;
}
if !w.is_finite() {
return Err(EvalError::Overflow);
}
Ok(w)
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10
}
#[test]
fn test_basic_eval() {
let expr = Expression::parse("32+").unwrap();
let result = evaluate(&expr, 0.0).unwrap();
assert!(approx_eq(result.value, 5.0));
assert!(approx_eq(result.derivative, 0.0));
}
#[test]
fn test_variable() {
let expr = Expression::parse("x").unwrap();
let result = evaluate(&expr, 3.5).unwrap();
assert!(approx_eq(result.value, 3.5));
assert!(approx_eq(result.derivative, 1.0));
}
#[test]
fn test_x_squared() {
let expr = Expression::parse("xs").unwrap(); let result = evaluate(&expr, 3.0).unwrap();
assert!(approx_eq(result.value, 9.0));
assert!(approx_eq(result.derivative, 6.0)); }
#[test]
fn test_sqrt_pi() {
let expr = Expression::parse("pq").unwrap(); let result = evaluate(&expr, 0.0).unwrap();
assert!(approx_eq(result.value, constants::PI.sqrt()));
}
#[test]
fn test_e_to_x() {
let expr = Expression::parse("xE").unwrap(); let result = evaluate(&expr, 1.0).unwrap();
assert!(approx_eq(result.value, constants::E));
assert!(approx_eq(result.derivative, constants::E)); }
#[test]
fn test_complex_expr() {
let expr = Expression::parse("xs2x*+1+").unwrap();
let result = evaluate(&expr, 3.0).unwrap();
assert!(approx_eq(result.value, 16.0)); assert!(approx_eq(result.derivative, 8.0)); }
#[test]
fn test_lambert_w() {
let w = lambert_w(1.0).unwrap();
assert!((w - 0.5671432904).abs() < 1e-9);
let w = lambert_w(constants::E).unwrap();
assert!((w - 1.0).abs() < 1e-10);
}
#[test]
fn test_user_constant_evaluation() {
use crate::profile::UserConstant;
let user_constants = vec![UserConstant {
weight: 8,
name: "g".to_string(),
description: "gamma".to_string(),
value: 0.5772156649,
num_type: NumType::Transcendental,
}];
let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
let result = evaluate_with_constants(&expr, 0.0, &user_constants).unwrap();
assert!(approx_eq(result.value, 0.5772156649));
assert!(approx_eq(result.derivative, 0.0));
}
#[test]
fn test_user_constant_in_expression() {
use crate::profile::UserConstant;
let user_constants = vec![
UserConstant {
weight: 8,
name: "a".to_string(),
description: "constant a".to_string(),
value: 2.0,
num_type: NumType::Integer,
},
UserConstant {
weight: 8,
name: "b".to_string(),
description: "constant b".to_string(),
value: 3.0,
num_type: NumType::Integer,
},
];
let expr = Expression::from_symbols(&[
Symbol::UserConstant0,
Symbol::X,
Symbol::Mul,
Symbol::UserConstant1,
Symbol::Add,
]);
let result = evaluate_with_constants(&expr, 4.0, &user_constants).unwrap();
assert!(approx_eq(result.value, 11.0));
assert!(approx_eq(result.derivative, 2.0));
}
#[test]
fn test_user_constant_missing_returns_error() {
let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
let result = evaluate_with_constants(&expr, 0.0, &[]);
assert!(matches!(result, Err(EvalError::MissingUserConstant(0))));
}
#[test]
fn test_user_function_sinh() {
use crate::udf::UserFunction;
let user_functions = vec![UserFunction::parse("4:sinh:hyperbolic sine:E|r-2/").unwrap()];
let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
let result =
evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
let expected = (constants::E - 1.0 / constants::E) / 2.0;
assert!(approx_eq(result.value, expected));
let expected_deriv = (constants::E + 1.0 / constants::E) / 2.0;
assert!((result.derivative - expected_deriv).abs() < 1e-10);
}
#[test]
fn test_user_function_xex() {
use crate::udf::UserFunction;
let user_functions = vec![UserFunction::parse("4:XeX:x*exp(x):|E*").unwrap()];
let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
let result =
evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
assert!(approx_eq(result.value, constants::E));
let expected_deriv = constants::E * 2.0;
assert!((result.derivative - expected_deriv).abs() < 1e-10);
}
#[test]
fn test_user_function_missing_returns_error() {
let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
let result = evaluate_with_constants_and_functions(&expr, 1.0, &[], &[]);
assert!(result.is_err());
}
}