1use crate::expr::Expression;
11use crate::profile::UserConstant;
12use crate::symbol::{NumType, Seft, Symbol};
13use crate::udf::{UdfOp, UserFunction};
14
15#[derive(Debug, Clone, Copy)]
17pub struct EvalResult {
18 pub value: f64,
20 pub derivative: f64,
22 pub num_type: NumType,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
31pub enum EvalError {
32 #[error("Stack underflow: not enough operands on stack")]
34 StackUnderflow,
35 #[error("Missing user constant: slot u{0} is not configured")]
37 MissingUserConstant(usize),
38 #[error("Division by zero: divisor was zero or near-zero")]
40 DivisionByZero,
41 #[error("Logarithm domain error: argument was non-positive")]
43 LogDomain,
44 #[error("Square root domain error: argument was negative")]
46 SqrtDomain,
47 #[error("Overflow: result is infinite or NaN")]
49 Overflow,
50 #[error("Invalid expression: malformed or incomplete")]
52 Invalid,
53 #[error("{err} at position {pos}")]
55 WithPosition {
56 #[source]
57 err: Box<EvalError>,
58 pos: usize,
59 },
60 #[error("{err} (value: {val})")]
62 WithValue {
63 #[source]
64 err: Box<EvalError>,
65 val: ordered_float::OrderedFloat<f64>,
66 },
67 #[error("{err} in expression '{expr}'")]
69 WithExpression {
70 #[source]
71 err: Box<EvalError>,
72 expr: String,
73 },
74}
75
76impl EvalError {
77 pub fn with_context(self, position: Option<usize>, value: Option<f64>) -> Self {
79 let mut err = self;
80 if let Some(pos) = position {
81 err = EvalError::WithPosition {
82 err: Box::new(err),
83 pos,
84 };
85 }
86 if let Some(val) = value {
87 err = EvalError::WithValue {
88 err: Box::new(err),
89 val: ordered_float::OrderedFloat(val),
90 };
91 }
92 err
93 }
94
95 pub fn with_expression(self, expr: String) -> Self {
97 EvalError::WithExpression {
98 err: Box::new(self),
99 expr,
100 }
101 }
102}
103
104pub mod constants {
106 pub const PI: f64 = std::f64::consts::PI;
107 pub const E: f64 = std::f64::consts::E;
108 pub const PHI: f64 = 1.618_033_988_749_895; pub const GAMMA: f64 = 0.577_215_664_901_532_9;
111 pub const PLASTIC: f64 = 1.324_717_957_244_746;
113 pub const APERY: f64 = 1.202_056_903_159_594_2;
115 pub const CATALAN: f64 = 0.915_965_594_177_219;
117}
118
119pub const DEFAULT_TRIG_ARGUMENT_SCALE: f64 = std::f64::consts::PI;
123
124#[derive(Clone, Copy, Debug)]
129pub struct EvalContext<'a> {
130 pub trig_argument_scale: f64,
132 pub user_constants: &'a [UserConstant],
134 pub user_functions: &'a [UserFunction],
136}
137
138impl Default for EvalContext<'static> {
139 fn default() -> Self {
140 Self {
141 trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
142 user_constants: &[],
143 user_functions: &[],
144 }
145 }
146}
147
148impl EvalContext<'static> {
149 pub fn new() -> Self {
151 Self::default()
152 }
153}
154
155impl<'a> EvalContext<'a> {
156 pub fn from_slices(
158 user_constants: &'a [UserConstant],
159 user_functions: &'a [UserFunction],
160 ) -> Self {
161 Self {
162 trig_argument_scale: DEFAULT_TRIG_ARGUMENT_SCALE,
163 user_constants,
164 user_functions,
165 }
166 }
167
168 pub fn with_trig_argument_scale(mut self, scale: f64) -> Self {
170 if scale.is_finite() && scale != 0.0 {
171 self.trig_argument_scale = scale;
172 }
173 self
174 }
175}
176
177#[derive(Debug, Clone, Copy)]
179struct StackEntry {
180 val: f64,
181 deriv: f64,
182 num_type: NumType,
183}
184
185impl StackEntry {
186 fn new(val: f64, deriv: f64, num_type: NumType) -> Self {
187 Self {
188 val,
189 deriv,
190 num_type,
191 }
192 }
193
194 fn constant(val: f64, num_type: NumType) -> Self {
195 Self {
196 val,
197 deriv: 0.0,
198 num_type,
199 }
200 }
201}
202
203pub struct EvalWorkspace {
223 stack: Vec<StackEntry>,
224}
225
226impl EvalWorkspace {
227 pub fn new() -> Self {
231 Self {
232 stack: Vec::with_capacity(32),
233 }
234 }
235
236 #[inline]
238 fn clear(&mut self) {
239 self.stack.clear();
240 }
241}
242
243impl Default for EvalWorkspace {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[inline]
257pub fn evaluate_with_workspace(
258 expr: &Expression,
259 x: f64,
260 workspace: &mut EvalWorkspace,
261) -> Result<EvalResult, EvalError> {
262 evaluate_with_workspace_and_context(expr, x, workspace, &EvalContext::new())
263}
264
265#[inline]
273pub fn evaluate_with_workspace_and_constants(
274 expr: &Expression,
275 x: f64,
276 workspace: &mut EvalWorkspace,
277 user_constants: &[UserConstant],
278) -> Result<EvalResult, EvalError> {
279 let context = EvalContext::from_slices(user_constants, &[]);
280 evaluate_with_workspace_and_context(expr, x, workspace, &context)
281}
282
283#[inline]
289pub fn evaluate_with_workspace_and_constants_and_functions(
290 expr: &Expression,
291 x: f64,
292 workspace: &mut EvalWorkspace,
293 user_constants: &[UserConstant],
294 user_functions: &[UserFunction],
295) -> Result<EvalResult, EvalError> {
296 let context = EvalContext::from_slices(user_constants, user_functions);
297 evaluate_with_workspace_and_context(expr, x, workspace, &context)
298}
299
300#[inline]
305pub fn evaluate_with_workspace_and_context(
306 expr: &Expression,
307 x: f64,
308 workspace: &mut EvalWorkspace,
309 context: &EvalContext<'_>,
310) -> Result<EvalResult, EvalError> {
311 workspace.clear();
312 let stack = &mut workspace.stack;
313
314 for &sym in expr.symbols() {
315 match sym.seft() {
316 Seft::A => {
317 let entry = eval_constant_with_user(sym, x, context.user_constants)?;
318 stack.push(entry);
319 }
320 Seft::B => {
321 if matches!(
323 sym,
324 Symbol::UserFunction0
325 | Symbol::UserFunction1
326 | Symbol::UserFunction2
327 | Symbol::UserFunction3
328 | Symbol::UserFunction4
329 | Symbol::UserFunction5
330 | Symbol::UserFunction6
331 | Symbol::UserFunction7
332 | Symbol::UserFunction8
333 | Symbol::UserFunction9
334 | Symbol::UserFunction10
335 | Symbol::UserFunction11
336 | Symbol::UserFunction12
337 | Symbol::UserFunction13
338 | Symbol::UserFunction14
339 | Symbol::UserFunction15
340 ) {
341 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
343 let result = eval_user_function(sym, a, context, x)?;
344 stack.push(result);
345 } else {
346 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
347 let result = eval_unary(sym, a, context.trig_argument_scale)?;
348 stack.push(result);
349 }
350 }
351 Seft::C => {
352 let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
353 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
354 let result = eval_binary(sym, a, b)?;
355 stack.push(result);
356 }
357 }
358 }
359
360 if stack.len() != 1 {
361 return Err(EvalError::Invalid);
362 }
363
364 let result = stack.pop().unwrap();
366
367 if result.val.is_nan() || result.val.is_infinite() {
369 return Err(EvalError::Overflow);
370 }
371
372 Ok(EvalResult {
373 value: result.val,
374 derivative: result.deriv,
375 num_type: result.num_type,
376 })
377}
378
379pub fn evaluate(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
387 evaluate_with_context(expr, x, &EvalContext::new())
388}
389
390pub fn evaluate_with_constants(
394 expr: &Expression,
395 x: f64,
396 user_constants: &[UserConstant],
397) -> Result<EvalResult, EvalError> {
398 let context = EvalContext::from_slices(user_constants, &[]);
399 evaluate_with_context(expr, x, &context)
400}
401
402pub fn evaluate_with_constants_and_functions(
406 expr: &Expression,
407 x: f64,
408 user_constants: &[UserConstant],
409 user_functions: &[UserFunction],
410) -> Result<EvalResult, EvalError> {
411 let context = EvalContext::from_slices(user_constants, user_functions);
412 evaluate_with_context(expr, x, &context)
413}
414
415pub fn evaluate_with_context(
419 expr: &Expression,
420 x: f64,
421 context: &EvalContext<'_>,
422) -> Result<EvalResult, EvalError> {
423 let mut workspace = EvalWorkspace::new();
424 evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
425}
426
427#[inline]
436pub fn evaluate_fast(expr: &Expression, x: f64) -> Result<EvalResult, EvalError> {
437 evaluate_fast_with_context(expr, x, &EvalContext::new())
438}
439
440#[inline]
448pub fn evaluate_fast_with_constants(
449 expr: &Expression,
450 x: f64,
451 user_constants: &[UserConstant],
452) -> Result<EvalResult, EvalError> {
453 let context = EvalContext::from_slices(user_constants, &[]);
454 evaluate_fast_with_context(expr, x, &context)
455}
456
457#[inline]
492pub fn evaluate_fast_with_constants_and_functions(
493 expr: &Expression,
494 x: f64,
495 user_constants: &[UserConstant],
496 user_functions: &[UserFunction],
497) -> Result<EvalResult, EvalError> {
498 let context = EvalContext::from_slices(user_constants, user_functions);
499 evaluate_fast_with_context(expr, x, &context)
500}
501
502#[inline]
504pub fn evaluate_fast_with_context(
505 expr: &Expression,
506 x: f64,
507 context: &EvalContext<'_>,
508) -> Result<EvalResult, EvalError> {
509 thread_local! {
510 static WORKSPACE: std::cell::RefCell<EvalWorkspace> = std::cell::RefCell::new(EvalWorkspace::new());
516 }
517
518 WORKSPACE.with(|ws| {
519 let mut workspace = ws.borrow_mut();
520 evaluate_with_workspace_and_context(expr, x, &mut workspace, context)
521 })
522}
523
524fn eval_constant_with_user(
526 sym: Symbol,
527 x: f64,
528 user_constants: &[UserConstant],
529) -> Result<StackEntry, EvalError> {
530 use Symbol::*;
531 match sym {
532 One => Ok(StackEntry::constant(1.0, NumType::Integer)),
533 Two => Ok(StackEntry::constant(2.0, NumType::Integer)),
534 Three => Ok(StackEntry::constant(3.0, NumType::Integer)),
535 Four => Ok(StackEntry::constant(4.0, NumType::Integer)),
536 Five => Ok(StackEntry::constant(5.0, NumType::Integer)),
537 Six => Ok(StackEntry::constant(6.0, NumType::Integer)),
538 Seven => Ok(StackEntry::constant(7.0, NumType::Integer)),
539 Eight => Ok(StackEntry::constant(8.0, NumType::Integer)),
540 Nine => Ok(StackEntry::constant(9.0, NumType::Integer)),
541 Pi => Ok(StackEntry::constant(constants::PI, NumType::Transcendental)),
542 E => Ok(StackEntry::constant(constants::E, NumType::Transcendental)),
543 Phi => Ok(StackEntry::constant(constants::PHI, NumType::Algebraic)),
544 Gamma => Ok(StackEntry::constant(
546 constants::GAMMA,
547 NumType::Transcendental,
548 )),
549 Plastic => Ok(StackEntry::constant(constants::PLASTIC, NumType::Algebraic)),
550 Apery => Ok(StackEntry::constant(
551 constants::APERY,
552 NumType::Transcendental,
553 )),
554 Catalan => Ok(StackEntry::constant(
555 constants::CATALAN,
556 NumType::Transcendental,
557 )),
558 X => Ok(StackEntry::new(x, 1.0, NumType::Integer)), UserConstant0 | UserConstant1 | UserConstant2 | UserConstant3 | UserConstant4
561 | UserConstant5 | UserConstant6 | UserConstant7 | UserConstant8 | UserConstant9
562 | UserConstant10 | UserConstant11 | UserConstant12 | UserConstant13 | UserConstant14
563 | UserConstant15 => {
564 let idx = sym.user_constant_index().unwrap() as usize;
566 user_constants
567 .get(idx)
568 .map(|uc| StackEntry::constant(uc.value, uc.num_type))
569 .ok_or(EvalError::MissingUserConstant(idx))
570 }
571 _ => Err(EvalError::Invalid),
572 }
573}
574
575fn eval_user_function(
580 sym: Symbol,
581 input: StackEntry,
582 context: &EvalContext<'_>,
583 x: f64,
584) -> Result<StackEntry, EvalError> {
585 let idx = sym.user_function_index().ok_or(EvalError::Invalid)? as usize;
587
588 let udf = context.user_functions.get(idx).ok_or(EvalError::Invalid)?;
590
591 thread_local! {
596 static UDF_STACK: std::cell::RefCell<Vec<StackEntry>> =
597 std::cell::RefCell::new(Vec::with_capacity(16));
598 }
599
600 UDF_STACK.with(|cell| -> Result<StackEntry, EvalError> {
601 let mut stack = cell.borrow_mut();
602 stack.clear();
603 stack.push(input);
604
605 for op in &udf.body {
607 match op {
608 UdfOp::Symbol(sym) => {
609 match sym.seft() {
610 Seft::A => {
611 let entry = eval_constant_with_user(*sym, x, context.user_constants)?;
613 stack.push(entry);
614 }
615 Seft::B => {
616 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
618 let result = eval_unary(*sym, a, context.trig_argument_scale)?;
619 stack.push(result);
620 }
621 Seft::C => {
622 let b = stack.pop().ok_or(EvalError::StackUnderflow)?;
624 let a = stack.pop().ok_or(EvalError::StackUnderflow)?;
625 let result = eval_binary(*sym, a, b)?;
626 stack.push(result);
627 }
628 }
629 }
630 UdfOp::Dup => {
631 let top = *stack.last().ok_or(EvalError::StackUnderflow)?;
634 stack.push(top);
635 }
636 UdfOp::Swap => {
637 let len = stack.len();
639 if len < 2 {
640 return Err(EvalError::StackUnderflow);
641 }
642 stack.swap(len - 1, len - 2);
643 }
644 }
645 }
646
647 if stack.len() != 1 {
649 return Err(EvalError::Invalid);
650 }
651
652 let result = stack.pop().unwrap();
654
655 if result.val.is_nan() || result.val.is_infinite() {
657 return Err(EvalError::Overflow);
658 }
659
660 Ok(result)
661 })
662}
663
664fn eval_unary(
666 sym: Symbol,
667 a: StackEntry,
668 trig_argument_scale: f64,
669) -> Result<StackEntry, EvalError> {
670 use Symbol::*;
671
672 let (val, deriv, num_type) = match sym {
673 Neg => (-a.val, -a.deriv, a.num_type),
675
676 Recip => {
678 if a.val.abs() < f64::MIN_POSITIVE {
679 return Err(EvalError::DivisionByZero);
680 }
681 let val = 1.0 / a.val;
682 let deriv = -a.deriv / (a.val * a.val);
683 let num_type = if a.num_type == NumType::Integer {
684 NumType::Rational
685 } else {
686 a.num_type
687 };
688 (val, deriv, num_type)
689 }
690
691 Sqrt => {
693 if a.val < 0.0 {
694 return Err(EvalError::SqrtDomain);
695 }
696 let val = a.val.sqrt();
697 let deriv = if val.abs() > f64::MIN_POSITIVE {
698 a.deriv / (2.0 * val)
699 } else {
700 0.0
701 };
702 let num_type = if a.num_type >= NumType::Constructible {
703 NumType::Constructible
704 } else {
705 a.num_type
706 };
707 (val, deriv, num_type)
708 }
709
710 Square => {
712 let val = a.val * a.val;
713 let deriv = 2.0 * a.val * a.deriv;
714 (val, deriv, a.num_type)
715 }
716
717 Ln => {
719 if a.val <= 0.0 {
720 return Err(EvalError::LogDomain);
721 }
722 let val = a.val.ln();
723 let deriv = a.deriv / a.val;
724 (val, deriv, NumType::Transcendental)
725 }
726
727 Exp => {
729 let val = a.val.exp();
730 if val.is_infinite() {
731 return Err(EvalError::Overflow);
732 }
733 let deriv = val * a.deriv;
734 (val, deriv, NumType::Transcendental)
735 }
736
737 SinPi => {
739 let val = (trig_argument_scale * a.val).sin();
740 let deriv = trig_argument_scale * (trig_argument_scale * a.val).cos() * a.deriv;
741 (val, deriv, NumType::Transcendental)
742 }
743
744 CosPi => {
746 let val = (trig_argument_scale * a.val).cos();
747 let deriv = -trig_argument_scale * (trig_argument_scale * a.val).sin() * a.deriv;
748 (val, deriv, NumType::Transcendental)
749 }
750
751 TanPi => {
753 let cos_val = (trig_argument_scale * a.val).cos();
754 if cos_val.abs() < 1e-10 {
755 return Err(EvalError::Overflow);
756 }
757 let val = (trig_argument_scale * a.val).tan();
758 let deriv = trig_argument_scale * a.deriv / (cos_val * cos_val);
759 (val, deriv, NumType::Transcendental)
760 }
761
762 LambertW => {
764 let val = lambert_w(a.val)?;
765 let deriv = if a.val.abs() < 1e-10 {
768 a.deriv } else {
770 let denom = a.val * (1.0 + val);
771 if denom.abs() > f64::MIN_POSITIVE {
772 val / denom * a.deriv
773 } else {
774 0.0
775 }
776 };
777 (val, deriv, NumType::Transcendental)
778 }
779
780 UserFunction0 | UserFunction1 | UserFunction2 | UserFunction3 | UserFunction4
783 | UserFunction5 | UserFunction6 | UserFunction7 | UserFunction8 | UserFunction9
784 | UserFunction10 | UserFunction11 | UserFunction12 | UserFunction13 | UserFunction14
785 | UserFunction15 => {
786 return Err(EvalError::Invalid);
789 }
790
791 _ => return Err(EvalError::Invalid),
793 };
794
795 Ok(StackEntry::new(val, deriv, num_type))
796}
797
798fn eval_binary(sym: Symbol, a: StackEntry, b: StackEntry) -> Result<StackEntry, EvalError> {
800 use Symbol::*;
801
802 let (val, deriv, num_type) = match sym {
803 Add => {
805 let val = a.val + b.val;
806 let deriv = a.deriv + b.deriv;
807 let num_type = a.num_type.combine(b.num_type);
808 (val, deriv, num_type)
809 }
810
811 Sub => {
813 let val = a.val - b.val;
814 let deriv = a.deriv - b.deriv;
815 let num_type = a.num_type.combine(b.num_type);
816 (val, deriv, num_type)
817 }
818
819 Mul => {
821 let val = a.val * b.val;
822 let deriv = a.val * b.deriv + b.val * a.deriv;
823 let num_type = a.num_type.combine(b.num_type);
824 (val, deriv, num_type)
825 }
826
827 Div => {
829 if b.val.abs() < f64::MIN_POSITIVE {
830 return Err(EvalError::DivisionByZero);
831 }
832 let val = a.val / b.val;
833 let deriv = (b.val * a.deriv - a.val * b.deriv) / (b.val * b.val);
834 let mut num_type = a.num_type.combine(b.num_type);
835 if num_type == NumType::Integer {
836 num_type = NumType::Rational;
837 }
838 (val, deriv, num_type)
839 }
840
841 Pow => {
843 if a.val <= 0.0 && b.val.fract() != 0.0 {
844 return Err(EvalError::SqrtDomain);
845 }
846 let val = a.val.powf(b.val);
847 if val.is_infinite() || val.is_nan() {
848 return Err(EvalError::Overflow);
849 }
850 let deriv = if a.val > f64::MIN_POSITIVE {
852 val * (b.val * a.deriv / a.val + a.val.ln() * b.deriv)
853 } else if a.val.abs() < f64::MIN_POSITIVE && b.val > 0.0 {
854 0.0
855 } else {
856 if a.val.abs() < f64::MIN_POSITIVE {
864 0.0
865 } else {
866 val * b.val * a.deriv / a.val
867 }
868 };
869 let num_type = if b.num_type == NumType::Integer {
870 a.num_type
871 } else {
872 NumType::Transcendental
873 };
874 (val, deriv, num_type)
875 }
876
877 Root => {
879 if a.val.abs() < f64::MIN_POSITIVE {
880 return Err(EvalError::DivisionByZero);
881 }
882 let exp = 1.0 / a.val;
883
884 if b.val < 0.0 {
887 let rounded = a.val.round();
889 let is_integer = (a.val - rounded).abs() < 1e-10;
890
891 if !is_integer {
892 return Err(EvalError::SqrtDomain);
894 }
895
896 let int_val = rounded as i64;
898 if int_val % 2 == 0 {
899 return Err(EvalError::SqrtDomain);
901 }
902 }
904
905 let val = if b.val < 0.0 {
906 -((-b.val).powf(exp))
908 } else {
909 b.val.powf(exp)
910 };
911 if val.is_infinite() || val.is_nan() {
912 return Err(EvalError::Overflow);
913 }
914 let deriv = if b.val.abs() > f64::MIN_POSITIVE {
916 val * (b.deriv / (a.val * b.val) - b.val.abs().ln() * a.deriv / (a.val * a.val))
917 } else {
918 0.0
919 };
920 (val, deriv, NumType::Algebraic)
921 }
922
923 Log => {
925 if a.val <= 0.0 || a.val == 1.0 || b.val <= 0.0 {
926 return Err(EvalError::LogDomain);
927 }
928 let ln_a = a.val.ln();
929 let ln_b = b.val.ln();
930 let val = ln_b / ln_a;
931 let deriv = b.deriv / (b.val * ln_a) - ln_b * a.deriv / (a.val * ln_a * ln_a);
933 (val, deriv, NumType::Transcendental)
934 }
935
936 Atan2 => {
938 let val = a.val.atan2(b.val);
939 let denom = a.val * a.val + b.val * b.val;
941 let deriv = if denom.abs() > f64::MIN_POSITIVE {
942 (b.val * a.deriv - a.val * b.deriv) / denom
943 } else {
944 0.0
945 };
946 (val, deriv, NumType::Transcendental)
947 }
948
949 _ => return Err(EvalError::Invalid),
951 };
952
953 Ok(StackEntry::new(val, deriv, num_type))
954}
955
956fn lambert_w(x: f64) -> Result<f64, EvalError> {
961 const INV_E: f64 = 1.0 / std::f64::consts::E;
963 const NEG_INV_E: f64 = -INV_E; if x < NEG_INV_E {
967 return Err(EvalError::LogDomain);
968 }
969
970 if x == 0.0 {
972 return Ok(0.0); }
974 if (x - NEG_INV_E).abs() < 1e-15 {
975 return Ok(-1.0); }
977 if x == constants::E {
978 return Ok(1.0); }
980
981 let mut w = if x < -0.3 {
983 let p = (2.0 * (constants::E * x + 1.0)).sqrt();
986 -1.0 + p * (1.0 - p / 3.0 * (1.0 - 11.0 * p / 72.0))
987 } else if x < 0.25 {
988 let x2 = x * x;
992 x * (1.0 - x + x2 * (1.5 - 2.6667 * x))
993 } else if x < 4.0 {
994 let lnx = x.ln();
997 if lnx > 0.0 {
998 let lnlnx = lnx.ln().max(0.0);
999 lnx - lnlnx + lnlnx / lnx.max(1.0)
1000 } else {
1001 x }
1003 } else {
1004 let l1 = x.ln();
1006 let l2 = l1.ln();
1007 l1 - l2 + l2 / l1
1008 };
1009
1010 for _ in 0..25 {
1013 let ew = w.exp();
1014
1015 if !ew.is_finite() {
1017 w = x.ln() - w.ln().max(1e-10);
1019 continue;
1020 }
1021
1022 let wew = w * ew;
1023 let diff = wew - x;
1024
1025 let tol = 1e-15 * (1.0 + w.abs().max(x.abs()));
1027 if diff.abs() < tol {
1028 break;
1029 }
1030
1031 let w1 = w + 1.0;
1032 let denom = ew * w1 - 0.5 * (w + 2.0) * diff / w1;
1034 if denom.abs() < f64::MIN_POSITIVE {
1035 break;
1036 }
1037
1038 let delta = diff / denom;
1039
1040 let correction = if w < -0.5 && delta.abs() > 0.5 {
1042 delta * 0.5 } else {
1044 delta
1045 };
1046
1047 w -= correction;
1048 }
1049
1050 if !w.is_finite() {
1052 return Err(EvalError::Overflow);
1053 }
1054
1055 Ok(w)
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060 use super::*;
1061
1062 fn approx_eq(a: f64, b: f64) -> bool {
1063 (a - b).abs() < 1e-10
1064 }
1065
1066 #[test]
1067 fn test_basic_eval() {
1068 let expr = Expression::parse("32+").unwrap();
1069 let result = evaluate(&expr, 0.0).unwrap();
1070 assert!(approx_eq(result.value, 5.0));
1071 assert!(approx_eq(result.derivative, 0.0));
1072 }
1073
1074 #[test]
1075 fn test_variable() {
1076 let expr = Expression::parse("x").unwrap();
1077 let result = evaluate(&expr, 3.5).unwrap();
1078 assert!(approx_eq(result.value, 3.5));
1079 assert!(approx_eq(result.derivative, 1.0));
1080 }
1081
1082 #[test]
1083 fn test_x_squared() {
1084 let expr = Expression::parse("xs").unwrap(); let result = evaluate(&expr, 3.0).unwrap();
1086 assert!(approx_eq(result.value, 9.0));
1087 assert!(approx_eq(result.derivative, 6.0)); }
1089
1090 #[test]
1091 fn test_sqrt_pi() {
1092 let expr = Expression::parse("pq").unwrap(); let result = evaluate(&expr, 0.0).unwrap();
1094 assert!(approx_eq(result.value, constants::PI.sqrt()));
1095 }
1096
1097 #[test]
1098 fn test_e_to_x() {
1099 let expr = Expression::parse("xE").unwrap(); let result = evaluate(&expr, 1.0).unwrap();
1101 assert!(approx_eq(result.value, constants::E));
1102 assert!(approx_eq(result.derivative, constants::E)); }
1104
1105 #[test]
1106 fn test_complex_expr() {
1107 let expr = Expression::parse("xs2x*+1+").unwrap();
1109 let result = evaluate(&expr, 3.0).unwrap();
1110 assert!(approx_eq(result.value, 16.0)); assert!(approx_eq(result.derivative, 8.0)); }
1113
1114 #[test]
1115 fn test_lambert_w() {
1116 let w = lambert_w(1.0).unwrap();
1118 assert!((w - 0.5671432904).abs() < 1e-9);
1119
1120 let w = lambert_w(constants::E).unwrap();
1122 assert!((w - 1.0).abs() < 1e-10);
1123 }
1124
1125 #[test]
1126 fn test_user_constant_evaluation() {
1127 use crate::profile::UserConstant;
1128
1129 let user_constants = vec![UserConstant {
1131 weight: 8,
1132 name: "g".to_string(),
1133 description: "gamma".to_string(),
1134 value: 0.5772156649,
1135 num_type: NumType::Transcendental,
1136 }];
1137
1138 let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
1140
1141 let result = evaluate_with_constants(&expr, 0.0, &user_constants).unwrap();
1143
1144 assert!(approx_eq(result.value, 0.5772156649));
1146 assert!(approx_eq(result.derivative, 0.0));
1148 }
1149
1150 #[test]
1151 fn test_user_constant_in_expression() {
1152 use crate::profile::UserConstant;
1153
1154 let user_constants = vec![
1156 UserConstant {
1157 weight: 8,
1158 name: "a".to_string(),
1159 description: "constant a".to_string(),
1160 value: 2.0,
1161 num_type: NumType::Integer,
1162 },
1163 UserConstant {
1164 weight: 8,
1165 name: "b".to_string(),
1166 description: "constant b".to_string(),
1167 value: 3.0,
1168 num_type: NumType::Integer,
1169 },
1170 ];
1171
1172 let expr = Expression::from_symbols(&[
1174 Symbol::UserConstant0,
1175 Symbol::X,
1176 Symbol::Mul,
1177 Symbol::UserConstant1,
1178 Symbol::Add,
1179 ]);
1180
1181 let result = evaluate_with_constants(&expr, 4.0, &user_constants).unwrap();
1183 assert!(approx_eq(result.value, 11.0));
1184 assert!(approx_eq(result.derivative, 2.0));
1186 }
1187
1188 #[test]
1189 fn test_user_constant_missing_returns_error() {
1190 let expr = Expression::from_symbols(&[Symbol::UserConstant0]);
1193
1194 let result = evaluate_with_constants(&expr, 0.0, &[]);
1195 assert!(matches!(result, Err(EvalError::MissingUserConstant(0))));
1196 }
1197
1198 #[test]
1199 fn test_user_function_sinh() {
1200 use crate::udf::UserFunction;
1201
1202 let user_functions = vec![UserFunction::parse("4:sinh:hyperbolic sine:E|r-2/").unwrap()];
1205
1206 let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1208
1209 let result =
1211 evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
1212 let expected = (constants::E - 1.0 / constants::E) / 2.0;
1213 assert!(approx_eq(result.value, expected));
1214
1215 let expected_deriv = (constants::E + 1.0 / constants::E) / 2.0;
1217 assert!((result.derivative - expected_deriv).abs() < 1e-10);
1218 }
1219
1220 #[test]
1221 fn test_user_function_xex() {
1222 use crate::udf::UserFunction;
1223
1224 let user_functions = vec![UserFunction::parse("4:XeX:x*exp(x):|E*").unwrap()];
1227
1228 let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1230
1231 let result =
1233 evaluate_with_constants_and_functions(&expr, 1.0, &[], &user_functions).unwrap();
1234 assert!(approx_eq(result.value, constants::E));
1235
1236 let expected_deriv = constants::E * 2.0;
1238 assert!((result.derivative - expected_deriv).abs() < 1e-10);
1239 }
1240
1241 #[test]
1242 fn test_user_function_missing_returns_error() {
1243 let expr = Expression::from_symbols(&[Symbol::X, Symbol::UserFunction0]);
1245
1246 let result = evaluate_with_constants_and_functions(&expr, 1.0, &[], &[]);
1247 assert!(result.is_err());
1248 }
1249}