1use crate::expr::Expression;
11use crate::profile::UserConstant;
12use crate::symbol::{NumType, Seft, Symbol};
13use crate::udf::{UdfOp, UserFunction};
14
15#[derive(Debug, Clone, Copy)]
17pub struct EvalResult {
18 pub value: f64,
20 pub derivative: f64,
22 pub num_type: NumType,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
31pub enum EvalError {
32 #[error("Stack underflow: not enough operands on stack")]
34 StackUnderflow,
35 #[error("Missing user constant: slot u{0} is not configured")]
37 MissingUserConstant(usize),
38 #[error("Division by zero: divisor was zero or near-zero")]
40 DivisionByZero,
41 #[error("Logarithm domain error: argument was non-positive")]
43 LogDomain,
44 #[error("Square root domain error: argument was negative")]
46 SqrtDomain,
47 #[error("Overflow: result is infinite or NaN")]
49 Overflow,
50 #[error("Invalid expression: malformed or incomplete")]
52 Invalid,
53 #[error("{err} at position {pos}")]
55 WithPosition {
56 #[source]
57 err: Box<EvalError>,
58 pos: usize,
59 },
60 #[error("{err} (value: {val})")]
62 WithValue {
63 #[source]
64 err: Box<EvalError>,
65 val: ordered_float::OrderedFloat<f64>,
66 },
67 #[error("{err} in expression '{expr}'")]
69 WithExpression {
70 #[source]
71 err: Box<EvalError>,
72 expr: String,
73 },
74}
75
76impl EvalError {
77 pub fn with_context(self, position: Option<usize>, value: Option<f64>) -> Self {
79 let mut err = self;
80 if let Some(pos) = position {
81 err = EvalError::WithPosition {
82 err: Box::new(err),
83 pos,
84 };
85 }
86 if let Some(val) = value {
87 err = EvalError::WithValue {
88 err: Box::new(err),
89 val: ordered_float::OrderedFloat(val),
90 };
91 }
92 err
93 }
94
95 pub fn with_expression(self, expr: String) -> Self {
97 EvalError::WithExpression {
98 err: Box::new(self),
99 expr,
100 }
101 }
102}
103
104pub mod constants {
106 pub const PI: f64 = std::f64::consts::PI;
107 pub const E: f64 = std::f64::consts::E;
108 pub const PHI: f64 = 1.618_033_988_749_895; pub const GAMMA: f64 = 0.577_215_664_901_532_9;
111 pub const PLASTIC: f64 = 1.324_717_957_244_746;
113 pub const APERY: f64 = 1.202_056_903_159_594_2;
115 pub const CATALAN: f64 = 0.915_965_594_177_219;
117}
118
119pub const DEFAULT_TRIG_ARGUMENT_SCALE: f64 = std::f64::consts::PI;
123
124#[derive(Clone, Copy, Debug)]
129pub struct EvalContext<'a> {
130 pub trig_argument_scale: f64,
132 pub user_constants: &'a [UserConstant],
134 pub user_functions: &'a [UserFunction],
136}
137
138impl Default for EvalContext<'static> {
139 fn default() -> Self {
140 Self {
141 trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
142 user_constants: &[],
143 user_functions: &[],
144 }
145 }
146}
147
148impl EvalContext<'static> {
149 pub fn new() -> Self {
151 Self::default()
152 }
153}
154
155impl<'a> EvalContext<'a> {
156 pub fn from_slices(
158 user_constants: &'a [UserConstant],
159 user_functions: &'a [UserFunction],
160 ) -> Self {
161 Self {
162 trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
163 user_constants,
164 user_functions,
165 }
166 }
167
168 pub fn with_trig_argument_scale(mut self, scale: f64) -> Self {
170 if scale.is_finite() && scale != 0.0 {
171 self.trig_argument_scale = scale;
172 }
173 self
174 }
175}
176
177#[derive(Debug, Clone, Copy)]
179struct StackEntry {
180 val: f64,
181 deriv: f64,
182 num_type: NumType,
183}
184
185impl StackEntry {
186 fn new(val: f64, deriv: f64, num_type: NumType) -> Self {
187 Self {
188 val,
189 deriv,
190 num_type,
191 }
192 }
193
194 fn constant(val: f64, num_type: NumType) -> Self {
195 Self {
196 val,
197 deriv: 0.0,
198 num_type,
199 }
200 }
201}
202
203pub struct EvalWorkspace {
223 stack: Vec<StackEntry>,
224}
225
226impl EvalWorkspace {
227 pub fn new() -> Self {
231 Self {
232 stack: Vec::with_capacity(32),
233 }
234 }
235
236 #[inline]
238 fn clear(&mut self) {
239 self.stack.clear();
240 }
241}
242
243impl Default for EvalWorkspace {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[inline]
257pub fn evaluate_with_workspace(
258 expr: &Expression,
259 x: f64,
260 workspace: &mut EvalWorkspace,
261) -> Result<EvalResult, EvalError> {
262 evaluate_with_workspace_and_context(expr, x, workspace, &EvalContext::new())
263}
264
265#[inline]
273pub fn evaluate_with_workspace_and_constants(
274 expr: &Expression,
275 x: f64,
276 workspace: &mut EvalWorkspace,
277 user_constants: &[UserConstant],
278) -> Result<EvalResult, EvalError> {
279 let context = EvalContext::from_slices(user_constants, &[]);
280 evaluate_with_workspace_and_context(expr, x, workspace, &context)
281}
282
283#[inline]
289pub fn evaluate_with_workspace_and_constants_and_functions(
290 expr: &Expression,
291 x: f64,
292 workspace: &mut EvalWorkspace,
293 user_constants: &[UserConstant],
294 user_functions: &[UserFunction],
295) -> Result<EvalResult, EvalError> {
296 let context = EvalContext::from_slices(user_constants, user_functions);
297 evaluate_with_workspace_and_context(expr, x, workspace, &context)
298}
299
300#[inline]
305pub fn evaluate_with_workspace_and_context(
306 expr: &Expression,
307 x: f64,
308 workspace: &mut EvalWorkspace,
309 context: &EvalContext<'_>,
310) -> Result<EvalResult, EvalError> {
311 workspace.clear();
312 let stack = &mut workspace.stack;
313
314 for &sym in expr.symbols() {
315 match sym.seft() {
316 Seft::A => {
317 let entry = eval_constant_with_user(sym, x, context.user_constants)?;
318 stack.push(entry);
319 }
320 Seft::B => {
321 if matches!(
323 sym,
324 Symbol::UserFunction0
325 | Symbol::UserFunction1
326 | Symbol::UserFunction2
327 | Symbol::UserFunction3
328 | Symbol::UserFunction4
329 | Symbol::UserFunction5
330 | Symbol::UserFunction6
331 | Symbol::UserFunction7
332 | Symbol::UserFunction8
333 | Symbol::UserFunction9
334 | Symbol::UserFunction10
335 | Symbol::UserFunction11
336 | Symbol::UserFunction12
337 | Symbol::UserFunction13
338 | Symbol::UserFunction14
339 | Symbol::UserFunction15
340 ) {
341 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
343 let result = eval_user_function(sym, a, context, x)?;
344 stack.push(result);
345 } else {
346 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
347 let result = eval_unary(sym, a, context.trig_argument_scale)?;
348 stack.push(result);
349 }
350 }
351 Seft::C => {
352 let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
353 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
354 let result = eval_binary(sym, a, b)?;
355 stack.push(result);
356 }
357 }
358 }
359
360 if stack.len() != 1 {
361 return Err(EvalError::Invalid);
362 }
363
364 let result = stack.pop().unwrap();
365
366 if result.val.is_nan() || result.val.is_infinite() {
368 return Err(EvalError::Overflow);
369 }
370
371 Ok(EvalResult {
372 value: result.val,
373 derivative: result.deriv,
374 num_type: result.num_type,
375 })
376}
377
378pub fn evaluate(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
386 evaluate_with_context(expr, x, &EvalContext::new())
387}
388
389pub fn evaluate_with_constants(
393 expr: &Expression,
394 x: f64,
395 user_constants: &[UserConstant],
396) -> Result<EvalResult, EvalError> {
397 let context = EvalContext::from_slices(user_constants, &[]);
398 evaluate_with_context(expr, x, &context)
399}
400
401pub fn evaluate_with_constants_and_functions(
405 expr: &Expression,
406 x: f64,
407 user_constants: &[UserConstant],
408 user_functions: &[UserFunction],
409) -> Result<EvalResult, EvalError> {
410 let context = EvalContext::from_slices(user_constants, user_functions);
411 evaluate_with_context(expr, x, &context)
412}
413
414pub fn evaluate_with_context(
418 expr: &Expression,
419 x: f64,
420 context: &EvalContext<'_>,
421) -> Result<EvalResult, EvalError> {
422 let mut workspace = EvalWorkspace::new();
423 evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
424}
425
426#[inline]
435pub fn evaluate_fast(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
436 evaluate_fast_with_context(expr, x, &EvalContext::new())
437}
438
439#[inline]
447pub fn evaluate_fast_with_constants(
448 expr: &Expression,
449 x: f64,
450 user_constants: &[UserConstant],
451) -> Result<EvalResult, EvalError> {
452 let context = EvalContext::from_slices(user_constants, &[]);
453 evaluate_fast_with_context(expr, x, &context)
454}
455
456#[inline]
491pub fn evaluate_fast_with_constants_and_functions(
492 expr: &Expression,
493 x: f64,
494 user_constants: &[UserConstant],
495 user_functions: &[UserFunction],
496) -> Result<EvalResult, EvalError> {
497 let context = EvalContext::from_slices(user_constants, user_functions);
498 evaluate_fast_with_context(expr, x, &context)
499}
500
501#[inline]
503pub fn evaluate_fast_with_context(
504 expr: &Expression,
505 x: f64,
506 context: &EvalContext<'_>,
507) -> Result<EvalResult, EvalError> {
508 thread_local! {
509 static WORKSPACE: std::cell::RefCell<EvalWorkspace> = std::cell::RefCell::new(EvalWorkspace::new());
515 }
516
517 WORKSPACE.with(|ws| {
518 let mut workspace = ws.borrow_mut();
519 evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
520 })
521}
522
523fn eval_constant_with_user(
525 sym: Symbol,
526 x: f64,
527 user_constants: &[UserConstant],
528) -> Result<StackEntry, EvalError> {
529 use Symbol::*;
530 match sym {
531 One => Ok(StackEntry::constant(1.0, NumType::Integer)),
532 Two => Ok(StackEntry::constant(2.0, NumType::Integer)),
533 Three => Ok(StackEntry::constant(3.0, NumType::Integer)),
534 Four => Ok(StackEntry::constant(4.0, NumType::Integer)),
535 Five => Ok(StackEntry::constant(5.0, NumType::Integer)),
536 Six => Ok(StackEntry::constant(6.0, NumType::Integer)),
537 Seven => Ok(StackEntry::constant(7.0, NumType::Integer)),
538 Eight => Ok(StackEntry::constant(8.0, NumType::Integer)),
539 Nine => Ok(StackEntry::constant(9.0, NumType::Integer)),
540 Pi => Ok(StackEntry::constant(constants::PI, NumType::Transcendental)),
541 E => Ok(StackEntry::constant(constants::E, NumType::Transcendental)),
542 Phi => Ok(StackEntry::constant(constants::PHI, NumType::Algebraic)),
543 Gamma => Ok(StackEntry::constant(
545 constants::GAMMA,
546 NumType::Transcendental,
547 )),
548 Plastic => Ok(StackEntry::constant(constants::PLASTIC, NumType::Algebraic)),
549 Apery => Ok(StackEntry::constant(
550 constants::APERY,
551 NumType::Transcendental,
552 )),
553 Catalan => Ok(StackEntry::constant(
554 constants::CATALAN,
555 NumType::Transcendental,
556 )),
557 X => Ok(StackEntry::new(x, 1.0, NumType::Integer)), UserConstant0 | UserConstant1 | UserConstant2 | UserConstant3 | UserConstant4
560 | UserConstant5 | UserConstant6 | UserConstant7 | UserConstant8 | UserConstant9
561 | UserConstant10 | UserConstant11 | UserConstant12 | UserConstant13 | UserConstant14
562 | UserConstant15 => {
563 let idx = sym.user_constant_index().unwrap() as usize;
565 user_constants
566 .get(idx)
567 .map(|uc| StackEntry::constant(uc.value, uc.num_type))
568 .ok_or(EvalError::MissingUserConstant(idx))
569 }
570 _ => Err(EvalError::Invalid),
571 }
572}
573
574fn eval_user_function(
579 sym: Symbol,
580 input: StackEntry,
581 context: &EvalContext<'_>,
582 x: f64,
583) -> Result<StackEntry, EvalError> {
584 let idx = sym.user_function_index().ok_or(EvalError::Invalid)? as usize;
586
587 let udf = context.user_functions.get(idx).ok_or(EvalError::Invalid)?;
589
590 thread_local! {
595 static UDF_STACK: std::cell::RefCell<Vec<StackEntry>> =
596 std::cell::RefCell::new(Vec::with_capacity(16));
597 }
598
599 UDF_STACK.with(|cell| -> Result<StackEntry, EvalError> {
600 let mut stack = cell.borrow_mut();
601 stack.clear();
602 stack.push(input);
603
604 for op in &udf.body {
606 match op {
607 UdfOp::Symbol(sym) => {
608 match sym.seft() {
609 Seft::A => {
610 let entry = eval_constant_with_user(*sym, x, context.user_constants)?;
612 stack.push(entry);
613 }
614 Seft::B => {
615 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
617 let result = eval_unary(*sym, a, context.trig_argument_scale)?;
618 stack.push(result);
619 }
620 Seft::C => {
621 let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
623 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
624 let result = eval_binary(*sym, a, b)?;
625 stack.push(result);
626 }
627 }
628 }
629 UdfOp::Dup => {
630 let top = *stack.last().ok_or(EvalError::StackUnderflow)?;
633 stack.push(top);
634 }
635 UdfOp::Swap => {
636 let len = stack.len();
638 if len < 2 {
639 return Err(EvalError::StackUnderflow);
640 }
641 stack.swap(len - 1, len - 2);
642 }
643 }
644 }
645
646 if stack.len() != 1 {
648 return Err(EvalError::Invalid);
649 }
650
651 let result = stack.pop().unwrap();
652
653 if result.val.is_nan() || result.val.is_infinite() {
655 return Err(EvalError::Overflow);
656 }
657
658 Ok(result)
659 })
660}
661
662fn eval_unary(
664 sym: Symbol,
665 a: StackEntry,
666 trig_argument_scale: f64,
667) -> Result<StackEntry, EvalError> {
668 use Symbol::*;
669
670 let (val, deriv, num_type) = match sym {
671 Neg => (-a.val, -a.deriv, a.num_type),
673
674 Recip => {
676 if a.val.abs() < f64::MIN_POSITIVE {
677 return Err(EvalError::DivisionByZero);
678 }
679 let val = 1.0 / a.val;
680 let deriv = -a.deriv / (a.val * a.val);
681 let num_type = if a.num_type == NumType::Integer {
682 NumType::Rational
683 } else {
684 a.num_type
685 };
686 (val, deriv, num_type)
687 }
688
689 Sqrt => {
691 if a.val < 0.0 {
692 return Err(EvalError::SqrtDomain);
693 }
694 let val = a.val.sqrt();
695 let deriv = if val.abs() > f64::MIN_POSITIVE {
696 a.deriv / (2.0 * val)
697 } else {
698 0.0
699 };
700 let num_type = if a.num_type >= NumType::Constructible {
701 NumType::Constructible
702 } else {
703 a.num_type
704 };
705 (val, deriv, num_type)
706 }
707
708 Square => {
710 let val = a.val * a.val;
711 let deriv = 2.0 * a.val * a.deriv;
712 (val, deriv, a.num_type)
713 }
714
715 Ln => {
717 if a.val <= 0.0 {
718 return Err(EvalError::LogDomain);
719 }
720 let val = a.val.ln();
721 let deriv = a.deriv / a.val;
722 (val, deriv, NumType::Transcendental)
723 }
724
725 Exp => {
727 let val = a.val.exp();
728 if val.is_infinite() {
729 return Err(EvalError::Overflow);
730 }
731 let deriv = val * a.deriv;
732 (val, deriv, NumType::Transcendental)
733 }
734
735 SinPi => {
737 let val = (trig_argument_scale * a.val).sin();
738 let deriv = trig_argument_scale * (trig_argument_scale * a.val).cos() * a.deriv;
739 (val, deriv, NumType::Transcendental)
740 }
741
742 CosPi => {
744 let val = (trig_argument_scale * a.val).cos();
745 let deriv = -trig_argument_scale * (trig_argument_scale * a.val).sin() * a.deriv;
746 (val, deriv, NumType::Transcendental)
747 }
748
749 TanPi => {
751 let cos_val = (trig_argument_scale * a.val).cos();
752 if cos_val.abs() < 1e-10 {
753 return Err(EvalError::Overflow);
754 }
755 let val = (trig_argument_scale * a.val).tan();
756 let deriv = trig_argument_scale * a.deriv / (cos_val * cos_val);
757 (val, deriv, NumType::Transcendental)
758 }
759
760 LambertW => {
762 let val = lambert_w(a.val)?;
763 let deriv = if a.val.abs() < 1e-10 {
766 a.deriv } else {
768 let denom = a.val * (1.0 + val);
769 if denom.abs() > f64::MIN_POSITIVE {
770 val / denom * a.deriv
771 } else {
772 0.0
773 }
774 };
775 (val, deriv, NumType::Transcendental)
776 }
777
778 UserFunction0 | UserFunction1 | UserFunction2 | UserFunction3 | UserFunction4
781 | UserFunction5 | UserFunction6 | UserFunction7 | UserFunction8 | UserFunction9
782 | UserFunction10 | UserFunction11 | UserFunction12 | UserFunction13 | UserFunction14
783 | UserFunction15 => {
784 return Err(EvalError::Invalid);
787 }
788
789 _ => return Err(EvalError::Invalid),
791 };
792
793 Ok(StackEntry::new(val, deriv, num_type))
794}
795
796fn eval_binary(sym: Symbol, a: StackEntry, b: StackEntry) -> Result<StackEntry, EvalError> {
798 use Symbol::*;
799
800 let (val, deriv, num_type) = match sym {
801 Add => {
803 let val = a.val + b.val;
804 let deriv = a.deriv + b.deriv;
805 let num_type = a.num_type.combine(b.num_type);
806 (val, deriv, num_type)
807 }
808
809 Sub => {
811 let val = a.val - b.val;
812 let deriv = a.deriv - b.deriv;
813 let num_type = a.num_type.combine(b.num_type);
814 (val, deriv, num_type)
815 }
816
817 Mul => {
819 let val = a.val * b.val;
820 let deriv = a.val * b.deriv + b.val * a.deriv;
821 let num_type = a.num_type.combine(b.num_type);
822 (val, deriv, num_type)
823 }
824
825 Div => {
827 if b.val.abs() < f64::MIN_POSITIVE {
828 return Err(EvalError::DivisionByZero);
829 }
830 let val = a.val / b.val;
831 let deriv = (b.val * a.deriv - a.val * b.deriv) / (b.val * b.val);
832 let mut num_type = a.num_type.combine(b.num_type);
833 if num_type == NumType::Integer {
834 num_type = NumType::Rational;
835 }
836 (val, deriv, num_type)
837 }
838
839 Pow => {
841 if a.val <= 0.0 && b.val.fract() != 0.0 {
842 return Err(EvalError::SqrtDomain);
843 }
844 let val = a.val.powf(b.val);
845 if val.is_infinite() || val.is_nan() {
846 return Err(EvalError::Overflow);
847 }
848 let deriv = if a.val > f64::MIN_POSITIVE {
850 val * (b.val * a.deriv / a.val + a.val.ln() * b.deriv)
851 } else if a.val.abs() < f64::MIN_POSITIVE && b.val > 0.0 {
852 0.0
853 } else {
854 if a.val.abs() < f64::MIN_POSITIVE {
862 0.0
863 } else {
864 val * b.val * a.deriv / a.val
865 }
866 };
867 let num_type = if b.num_type == NumType::Integer {
868 a.num_type
869 } else {
870 NumType::Transcendental
871 };
872 (val, deriv, num_type)
873 }
874
875 Root => {
877 if a.val.abs() < f64::MIN_POSITIVE {
878 return Err(EvalError::DivisionByZero);
879 }
880 let exp = 1.0 / a.val;
881
882 if b.val < 0.0 {
885 let rounded = a.val.round();
887 let is_integer = (a.val - rounded).abs() < 1e-10;
888
889 if !is_integer {
890 return Err(EvalError::SqrtDomain);
892 }
893
894 let int_val = rounded as i64;
896 if int_val % 2 == 0 {
897 return Err(EvalError::SqrtDomain);
899 }
900 }
902
903 let val = if b.val < 0.0 {
904 -((-b.val).powf(exp))
906 } else {
907 b.val.powf(exp)
908 };
909 if val.is_infinite() || val.is_nan() {
910 return Err(EvalError::Overflow);
911 }
912 let deriv = if b.val.abs() > f64::MIN_POSITIVE {
914 val * (b.deriv / (a.val * b.val) - b.val.abs().ln() * a.deriv / (a.val * a.val))
915 } else {
916 0.0
917 };
918 (val, deriv, NumType::Algebraic)
919 }
920
921 Log => {
923 if a.val <= 0.0 || a.val == 1.0 || b.val <= 0.0 {
924 return Err(EvalError::LogDomain);
925 }
926 let ln_a = a.val.ln();
927 let ln_b = b.val.ln();
928 let val = ln_b / ln_a;
929 let deriv = b.deriv / (b.val * ln_a) - ln_b * a.deriv / (a.val * ln_a * ln_a);
931 (val, deriv, NumType::Transcendental)
932 }
933
934 Atan2 => {
936 let val = a.val.atan2(b.val);
937 let denom = a.val * a.val + b.val * b.val;
939 let deriv = if denom.abs() > f64::MIN_POSITIVE {
940 (b.val * a.deriv - a.val * b.deriv) / denom
941 } else {
942 0.0
943 };
944 (val, deriv, NumType::Transcendental)
945 }
946
947 _ => return Err(EvalError::Invalid),
949 };
950
951 Ok(StackEntry::new(val, deriv, num_type))
952}
953
954fn lambert_w(x: f64) -> Result<f64, EvalError> {
959 const INV_E: f64 = 1.0 / std::f64::consts::E;
961 const NEG_INV_E: f64 = -INV_E; if x < NEG_INV_E {
965 return Err(EvalError::LogDomain);
966 }
967
968 if x == 0.0 {
970 return Ok(0.0); }
972 if (x - NEG_INV_E).abs() < 1e-15 {
973 return Ok(-1.0); }
975 if x == constants::E {
976 return Ok(1.0); }
978
979 let mut w = if x < -0.3 {
981 let p = (2.0 * (constants::E * x + 1.0)).sqrt();
984 -1.0 + p * (1.0 - p / 3.0 * (1.0 - 11.0 * p / 72.0))
985 } else if x < 0.25 {
986 let x2 = x * x;
990 x * (1.0 - x + x2 * (1.5 - 2.6667 * x))
991 } else if x < 4.0 {
992 let lnx = x.ln();
995 if lnx > 0.0 {
996 let lnlnx = lnx.ln().max(0.0);
997 lnx - lnlnx + lnlnx / lnx.max(1.0)
998 } else {
999 x }
1001 } else {
1002 let l1 = x.ln();
1004 let l2 = l1.ln();
1005 l1 - l2 + l2 / l1
1006 };
1007
1008 for _ in 0..25 {
1011 let ew = w.exp();
1012
1013 if !ew.is_finite() {
1015 w = x.ln() - w.ln().max(1e-10);
1017 continue;
1018 }
1019
1020 let wew = w * ew;
1021 let diff = wew - x;
1022
1023 let tol = 1e-15 * (1.0 + w.abs().max(x.abs()));
1025 if diff.abs() < tol {
1026 break;
1027 }
1028
1029 let w1 = w + 1.0;
1030 let denom = ew * w1 - 0.5 * (w + 2.0) * diff / w1;
1032 if denom.abs() < f64::MIN_POSITIVE {
1033 break;
1034 }
1035
1036 let delta = diff / denom;
1037
1038 let correction = if w < -0.5 && delta.abs() > 0.5 {
1040 delta * 0.5 } else {
1042 delta
1043 };
1044
1045 w -= correction;
1046 }
1047
1048 if !w.is_finite() {
1050 return Err(EvalError::Overflow);
1051 }
1052
1053 Ok(w)
1054}
1055
1056#[cfg(test)]
1057mod tests {
1058 use super::*;
1059
1060 fn approx_eq(a: f64, b: f64) -> bool {
1061 (a - b).abs() < 1e-10
1062 }
1063
1064 #[test]
1065 fn test_basic_eval() {
1066 let expr = Expression::parse("32+").unwrap();
1067 let result = evaluate(&expr, 0.0).unwrap();
1068 assert!(approx_eq(result.value, 5.0));
1069 assert!(approx_eq(result.derivative, 0.0));
1070 }
1071
1072 #[test]
1073 fn test_variable() {
1074 let expr = Expression::parse("x").unwrap();
1075 let result = evaluate(&expr, 3.5).unwrap();
1076 assert!(approx_eq(result.value, 3.5));
1077 assert!(approx_eq(result.derivative, 1.0));
1078 }
1079
1080 #[test]
1081 fn test_x_squared() {
1082 let expr = Expression::parse("xs").unwrap(); let result = evaluate(&expr, 3.0).unwrap();
1084 assert!(approx_eq(result.value, 9.0));
1085 assert!(approx_eq(result.derivative, 6.0)); }
1087
1088 #[test]
1089 fn test_sqrt_pi() {
1090 let expr = Expression::parse("pq").unwrap(); let result = evaluate(&expr, 0.0).unwrap();
1092 assert!(approx_eq(result.value, constants::PI.sqrt()));
1093 }
1094
1095 #[test]
1096 fn test_e_to_x() {
1097 let expr = Expression::parse("xE").unwrap(); let result = evaluate(&expr, 1.0).unwrap();
1099 assert!(approx_eq(result.value, constants::E));
1100 assert!(approx_eq(result.derivative, constants::E)); }
1102
1103 #[test]
1104 fn test_complex_expr() {
1105 let expr = Expression::parse("xs2x*+1+").unwrap();
1107 let result = evaluate(&expr, 3.0).unwrap();
1108 assert!(approx_eq(result.value, 16.0)); assert!(approx_eq(result.derivative, 8.0)); }
1111
1112 #[test]
1113 fn test_lambert_w() {
1114 let w = lambert_w(1.0).unwrap();
1116 assert!((w - 0.5671432904).abs() < 1e-9);
1117
1118 let w = lambert_w(constants::E).unwrap();
1120 assert!((w - 1.0).abs() < 1e-10);
1121 }
1122
1123 #[test]
1124 fn test_user_constant_evaluation() {
1125 use crate::profile::UserConstant;
1126
1127 let user_constants = vec![UserConstant {
1129 weight: 8,
1130 name: "g".to_string(),
1131 description: "gamma".to_string(),
1132 value: 0.5772156649,
1133 num_type: NumType::Transcendental,
1134 }];
1135
1136 let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
1138
1139 let result = evaluate_with_constants(&expr, 0.0, &user_constants).unwrap();
1141
1142 assert!(approx_eq(result.value, 0.5772156649));
1144 assert!(approx_eq(result.derivative, 0.0));
1146 }
1147
1148 #[test]
1149 fn test_user_constant_in_expression() {
1150 use crate::profile::UserConstant;
1151
1152 let user_constants = vec![
1154 UserConstant {
1155 weight: 8,
1156 name: "a".to_string(),
1157 description: "constant a".to_string(),
1158 value: 2.0,
1159 num_type: NumType::Integer,
1160 },
1161 UserConstant {
1162 weight: 8,
1163 name: "b".to_string(),
1164 description: "constant b".to_string(),
1165 value: 3.0,
1166 num_type: NumType::Integer,
1167 },
1168 ];
1169
1170 let expr = Expression::from_symbols(&[
1172 Symbol::UserConstant0,
1173 Symbol::X,
1174 Symbol::Mul,
1175 Symbol::UserConstant1,
1176 Symbol::Add,
1177 ]);
1178
1179 let result = evaluate_with_constants(&expr, 4.0, &user_constants).unwrap();
1181 assert!(approx_eq(result.value, 11.0));
1182 assert!(approx_eq(result.derivative, 2.0));
1184 }
1185
1186 #[test]
1187 fn test_user_constant_missing_returns_error() {
1188 let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
1191
1192 let result = evaluate_with_constants(&expr, 0.0, &[]);
1193 assert!(matches!(result, Err(EvalError::MissingUserConstant(0))));
1194 }
1195
1196 #[test]
1197 fn test_user_function_sinh() {
1198 use crate::udf::UserFunction;
1199
1200 let user_functions = vec![UserFunction::parse("4:sinh:hyperbolic sine:E|r-2/").unwrap()];
1203
1204 let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1206
1207 let result =
1209 evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
1210 let expected = (constants::E - 1.0 / constants::E) / 2.0;
1211 assert!(approx_eq(result.value, expected));
1212
1213 let expected_deriv = (constants::E + 1.0 / constants::E) / 2.0;
1215 assert!((result.derivative - expected_deriv).abs() < 1e-10);
1216 }
1217
1218 #[test]
1219 fn test_user_function_xex() {
1220 use crate::udf::UserFunction;
1221
1222 let user_functions = vec![UserFunction::parse("4:XeX:x*exp(x):|E*").unwrap()];
1225
1226 let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1228
1229 let result =
1231 evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
1232 assert!(approx_eq(result.value, constants::E));
1233
1234 let expected_deriv = constants::E * 2.0;
1236 assert!((result.derivative - expected_deriv).abs() < 1e-10);
1237 }
1238
1239 #[test]
1240 fn test_user_function_missing_returns_error() {
1241 let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1243
1244 let result = evaluate_with_constants_and_functions(&expr, 1.0, &[], &[]);
1245 assert!(result.is_err());
1246 }
1247}