1extern crate alloc;
2use crate::Real;
3#[cfg(not(test))]
4use crate::Vec;
5use crate::context::EvalContext;
6use crate::error::ExprError;
7
8#[cfg(test)]
10use crate::abs;
11#[cfg(test)]
12use crate::cos;
13#[cfg(test)]
14use crate::max;
15#[cfg(test)]
16use crate::min;
17#[cfg(test)]
18use crate::neg;
19#[cfg(test)]
20use crate::pow;
21#[cfg(test)]
22use crate::sin;
23use crate::types::AstExpr;
25#[cfg(not(test))]
26use alloc::format;
27#[cfg(not(test))]
28use alloc::rc::Rc;
29#[cfg(test)]
30use std::rc::Rc;
31#[cfg(test)]
32use std::vec::Vec;
33
34use alloc::collections::BTreeMap;
35use alloc::string::{String, ToString};
36
37struct OwnedNativeFunction {
45 pub arity: usize,
46 pub implementation: Rc<dyn Fn(&[Real]) -> Real>,
47 pub name: String, pub description: Option<String>,
49}
50
51impl<'a> From<&crate::types::NativeFunction<'a>> for OwnedNativeFunction {
53 fn from(nf: &crate::types::NativeFunction<'a>) -> Self {
54 OwnedNativeFunction {
55 arity: nf.arity,
56 implementation: nf.implementation.clone(),
57 name: nf.name.to_string(), description: nf.description.clone(),
59 }
60 }
61}
62
63enum FunctionCacheEntry {
64 Native(OwnedNativeFunction),
65 Expression(crate::types::ExpressionFunction),
66 User(crate::context::UserFunction),
67}
68
69impl Clone for FunctionCacheEntry {
70 fn clone(&self) -> Self {
71 match self {
72 FunctionCacheEntry::Native(nf) => {
73 FunctionCacheEntry::Native(OwnedNativeFunction {
74 arity: nf.arity,
75 implementation: nf.implementation.clone(),
76 name: nf.name.clone(),
77 description: nf.description.clone(),
78 })
79 }
80 FunctionCacheEntry::Expression(ef) => FunctionCacheEntry::Expression(ef.clone()),
81 FunctionCacheEntry::User(uf) => FunctionCacheEntry::User(uf.clone()),
82 }
83 }
84}
85
86#[cfg(not(feature = "no-builtin-math"))]
87#[allow(dead_code)]
88type MathFunc = fn(Real, Real) -> Real;
89
90pub fn eval_ast<'a>(ast: &AstExpr, ctx: Option<Rc<EvalContext<'a>>>) -> Result<Real, ExprError> {
91 let mut func_cache: BTreeMap<String, Option<FunctionCacheEntry>> = BTreeMap::new();
93 let mut var_cache: BTreeMap<String, Real> = BTreeMap::new();
95
96 eval_ast_inner(ast, ctx, &mut func_cache, &mut var_cache)
97}
98
99fn eval_variable(name: &str, ctx: Option<Rc<EvalContext<'_>>>, var_cache: &mut BTreeMap<String, Real>) -> Result<Real, ExprError> {
100 if name == "pi" {
102 #[cfg(feature = "f32")]
103 return Ok(core::f32::consts::PI);
104 #[cfg(not(feature = "f32"))]
105 return Ok(core::f64::consts::PI);
106 } else if name == "e" {
107 #[cfg(feature = "f32")]
108 return Ok(core::f32::consts::E);
109 #[cfg(not(feature = "f32"))]
110 return Ok(core::f64::consts::E);
111 }
112
113 if let Some(val) = var_cache.get(name).copied() {
115 return Ok(val);
116 }
117
118 if let Some(ctx_ref) = ctx.as_deref() {
120 if let Some(val) = ctx_ref.variables.get(name) {
123 var_cache.insert(name.to_string(), *val);
125 return Ok(*val);
126 }
127
128 if let Some(val) = ctx_ref.constants.get(name) {
130 var_cache.insert(name.to_string(), *val);
131 return Ok(*val);
132 }
133
134 if let Some(val) = ctx_ref.get_variable(name) {
137 var_cache.insert(name.to_string(), val);
138 return Ok(val);
139 } else if let Some(val) = ctx_ref.get_constant(name) {
140 var_cache.insert(name.to_string(), val);
141 return Ok(val);
142 }
143 }
144
145 let is_potential_function_name = match name {
147 "sin" | "cos" | "tan" | "asin" | "acos" | "atan" | "atan2" | "sinh" | "cosh" | "tanh"
148 | "exp" | "log" | "log10" | "ln" | "sqrt" | "abs" | "ceil" | "floor" | "pow" | "neg"
149 | "," | "comma" | "+" | "-" | "*" | "/" | "%" | "^" | "max" | "min" => true,
150 _ => false,
151 };
152
153 if is_potential_function_name && name.len() > 1 {
154 return Err(ExprError::Syntax(format!(
155 "Function '{}' used without arguments",
156 name
157 )));
158 }
159
160 Err(ExprError::UnknownVariable {
161 name: name.to_string(),
162 })
163}
164
165fn eval_function<'a>(
166 name: &str,
167 args: &[AstExpr],
168 ctx: Option<Rc<EvalContext<'a>>>,
169 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
170 var_cache: &mut BTreeMap<String, Real>,
171) -> Result<Real, ExprError> {
172 if let Some(entry) = func_cache.get(name) {
174 if let Some(cached_fn) = entry {
175 return match cached_fn.clone() {
176 FunctionCacheEntry::User(user_fn) => {
177 eval_custom_function(name, args, ctx.clone(), func_cache, var_cache, &user_fn)
178 }
179 FunctionCacheEntry::Expression(expr_fn) => {
180 eval_custom_function(name, args, ctx.clone(), func_cache, var_cache, &expr_fn)
181 }
182 FunctionCacheEntry::Native(native_fn) => {
183 eval_native_function(name, args, ctx.clone(), func_cache, var_cache, &native_fn)
184 }
185 };
186 }
187 }
188
189 let entry = if let Some(ctx_ref) = ctx.as_ref() {
191 if let Some(expr_fn) = ctx_ref.get_expression_function(name) {
193 Some(FunctionCacheEntry::Expression(expr_fn.clone()))
195 } else if let Some(native_fn) = ctx_ref.get_native_function(name) {
196 let owned_fn = OwnedNativeFunction::from(native_fn);
198 Some(FunctionCacheEntry::Native(owned_fn))
199 } else if let Some(user_fn) = ctx_ref.get_user_function(name) {
200 Some(FunctionCacheEntry::User(user_fn.clone()))
202 } else {
203 None
204 }
205 } else {
206 None
207 };
208
209 func_cache.insert(name.to_string(), entry.clone());
211
212 if let Some(func_entry) = entry {
214 match func_entry {
215 FunctionCacheEntry::User(user_fn) => {
216 eval_custom_function(name, args, ctx.clone(), func_cache, var_cache, &user_fn)
217 }
218 FunctionCacheEntry::Expression(expr_fn) => {
219 eval_custom_function(name, args, ctx.clone(), func_cache, var_cache, &expr_fn)
220 }
221 FunctionCacheEntry::Native(native_fn) => {
222 eval_native_function(name, args, ctx.clone(), func_cache, var_cache, &native_fn)
223 }
224 }
225 } else {
226 #[cfg(not(feature = "no-builtin-math"))]
229 {
230 let single_arg_funcs = ["sin", "cos", "tan", "asin", "acos", "atan",
233 "sinh", "cosh", "tanh", "exp", "log", "ln", "log10", "sqrt",
234 "abs", "ceil", "floor", "neg"];
235
236 let two_arg_funcs = ["+", "-", "*", "/", "^", "pow", "max", "min",
237 "%", ",", "comma", "atan2"];
238
239 if single_arg_funcs.contains(&name) && args.len() != 1 {
241 return Err(ExprError::InvalidFunctionCall {
242 name: name.to_string(),
243 expected: 1,
244 found: args.len(),
245 });
246 }
247
248 if two_arg_funcs.contains(&name) && args.len() != 2 {
250 return Err(ExprError::InvalidFunctionCall {
251 name: name.to_string(),
252 expected: 2,
253 found: args.len(),
254 });
255 }
256
257 if args.len() == 1 {
259 let arg_val = eval_ast_inner(&args[0], ctx.clone(), func_cache, var_cache)?;
261 match name {
262 "sin" => {
263 #[cfg(feature = "f32")]
264 {
265 return Ok(libm::sinf(arg_val));
266 }
267 #[cfg(not(feature = "f32"))]
268 {
269 return Ok(libm::sin(arg_val));
270 }
271 }
272 "cos" => {
273 #[cfg(feature = "f32")]
274 {
275 return Ok(libm::cosf(arg_val));
276 }
277 #[cfg(not(feature = "f32"))]
278 {
279 return Ok(libm::cos(arg_val));
280 }
281 }
282 "tan" => {
283 #[cfg(feature = "f32")]
284 {
285 return Ok(libm::tanf(arg_val));
286 }
287 #[cfg(not(feature = "f32"))]
288 {
289 return Ok(libm::tan(arg_val));
290 }
291 }
292 "asin" => {
293 #[cfg(feature = "f32")]
294 {
295 return Ok(libm::asinf(arg_val));
296 }
297 #[cfg(not(feature = "f32"))]
298 {
299 return Ok(libm::asin(arg_val));
300 }
301 }
302 "acos" => {
303 #[cfg(feature = "f32")]
304 {
305 return Ok(libm::acosf(arg_val));
306 }
307 #[cfg(not(feature = "f32"))]
308 {
309 return Ok(libm::acos(arg_val));
310 }
311 }
312 "atan" => {
313 #[cfg(feature = "f32")]
314 {
315 return Ok(libm::atanf(arg_val));
316 }
317 #[cfg(not(feature = "f32"))]
318 {
319 return Ok(libm::atan(arg_val));
320 }
321 }
322 "sinh" => {
323 #[cfg(feature = "f32")]
324 {
325 return Ok(libm::sinhf(arg_val));
326 }
327 #[cfg(not(feature = "f32"))]
328 {
329 return Ok(libm::sinh(arg_val));
330 }
331 }
332 "cosh" => {
333 #[cfg(feature = "f32")]
334 {
335 return Ok(libm::coshf(arg_val));
336 }
337 #[cfg(not(feature = "f32"))]
338 {
339 return Ok(libm::cosh(arg_val));
340 }
341 }
342 "tanh" => {
343 #[cfg(feature = "f32")]
344 {
345 return Ok(libm::tanhf(arg_val));
346 }
347 #[cfg(not(feature = "f32"))]
348 {
349 return Ok(libm::tanh(arg_val));
350 }
351 }
352 "exp" => {
353 #[cfg(feature = "f32")]
354 {
355 return Ok(libm::expf(arg_val));
356 }
357 #[cfg(not(feature = "f32"))]
358 {
359 return Ok(libm::exp(arg_val));
360 }
361 }
362 "log" | "ln" => {
363 #[cfg(feature = "f32")]
364 {
365 return Ok(libm::logf(arg_val));
366 }
367 #[cfg(not(feature = "f32"))]
368 {
369 return Ok(libm::log(arg_val));
370 }
371 }
372 "log10" => {
373 #[cfg(feature = "f32")]
374 {
375 return Ok(libm::log10f(arg_val));
376 }
377 #[cfg(not(feature = "f32"))]
378 {
379 return Ok(libm::log10(arg_val));
380 }
381 }
382 "sqrt" => {
383 #[cfg(feature = "f32")]
384 {
385 return Ok(libm::sqrtf(arg_val));
386 }
387 #[cfg(not(feature = "f32"))]
388 {
389 return Ok(libm::sqrt(arg_val));
390 }
391 }
392 "abs" => return Ok(arg_val.abs()),
393 "ceil" => {
394 #[cfg(feature = "f32")]
395 {
396 return Ok(libm::ceilf(arg_val));
397 }
398 #[cfg(not(feature = "f32"))]
399 {
400 return Ok(libm::ceil(arg_val));
401 }
402 }
403 "floor" => {
404 #[cfg(feature = "f32")]
405 {
406 return Ok(libm::floorf(arg_val));
407 }
408 #[cfg(not(feature = "f32"))]
409 {
410 return Ok(libm::floor(arg_val));
411 }
412 }
413 "neg" => return Ok(-arg_val),
414 _ => {}
415 }
416 } else if args.len() == 2 {
417 let mut arg_vals = [0.0; 2];
419 arg_vals[0] = eval_ast_inner(&args[0], ctx.clone(), func_cache, var_cache)?;
420 arg_vals[1] = eval_ast_inner(&args[1], ctx.clone(), func_cache, var_cache)?;
421 match name {
422 "+" => return Ok(arg_vals[0] + arg_vals[1]),
423 "-" => return Ok(arg_vals[0] - arg_vals[1]),
424 "*" => return Ok(arg_vals[0] * arg_vals[1]),
425 "/" => return Ok(arg_vals[0] / arg_vals[1]),
426 "^" | "pow" => {
427 #[cfg(feature = "f32")]
428 {
429 return Ok(libm::powf(arg_vals[0], arg_vals[1]));
430 }
431 #[cfg(not(feature = "f32"))]
432 {
433 return Ok(libm::pow(arg_vals[0], arg_vals[1]));
434 }
435 }
436 "max" => return Ok(arg_vals[0].max(arg_vals[1])),
437 "min" => return Ok(arg_vals[0].min(arg_vals[1])),
438 "%" => return Ok(arg_vals[0] % arg_vals[1]),
439 "," | "comma" => return Ok(arg_vals[1]),
440 "atan2" => {
441 #[cfg(feature = "f32")]
442 {
443 return Ok(libm::atan2f(arg_vals[0], arg_vals[1]));
444 }
445 #[cfg(not(feature = "f32"))]
446 {
447 return Ok(libm::atan2(arg_vals[0], arg_vals[1]));
448 }
449 }
450 _ => {}
451 }
452 }
453 }
454
455 return Err(ExprError::UnknownFunction {
457 name: name.to_string(),
458 });
459 }
460}
461
462pub trait CustomFunction {
465 fn params(&self) -> Vec<String>;
466 fn body_str(&self) -> String;
467 fn compiled_ast(&self) -> Option<&AstExpr> {
468 None
469 }
470}
471
472impl CustomFunction for crate::context::UserFunction {
473 fn params(&self) -> Vec<String> {
474 self.params.clone()
475 }
476 fn body_str(&self) -> String {
477 self.body.clone()
478 }
479}
480
481impl CustomFunction for crate::types::ExpressionFunction {
482 fn params(&self) -> Vec<String> {
483 self.params.clone()
484 }
485 fn body_str(&self) -> String {
486 self.expression.clone()
487 }
488 fn compiled_ast(&self) -> Option<&AstExpr> {
489 Some(&self.compiled_ast)
490 }
491}
492
493fn eval_custom_function<'a, F>(
494 name: &str,
495 args: &[AstExpr],
496 ctx: Option<Rc<EvalContext<'a>>>,
497 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
498 var_cache: &mut BTreeMap<String, Real>,
499 func: &F,
500) -> Result<Real, ExprError>
501where
502 F: CustomFunction,
503{
504 if args.len() != func.params().len() {
505 return Err(ExprError::InvalidFunctionCall {
506 name: name.to_string(),
507 expected: func.params().len(),
508 found: args.len(),
509 });
510 }
511 let mut arg_values = Vec::with_capacity(args.len());
513 for arg in args {
514 let arg_val = eval_ast_inner(arg, ctx.clone(), func_cache, var_cache)?;
515 arg_values.push(arg_val);
516 }
517
518 let mut func_ctx = EvalContext::new();
520
521 for (i, param_name) in func.params().iter().enumerate() {
523 func_ctx.set_parameter(param_name, arg_values[i]);
524 }
525
526 if let Some(parent) = &ctx {
528 func_ctx.function_registry = parent.function_registry.clone();
529 }
530
531 let body_ast = if let Some(ast) = func.compiled_ast() {
535 ast.clone()
536 } else {
537 let param_names_str: Vec<String> = func.params().iter().map(|c| c.to_string()).collect();
538 crate::engine::parse_expression_with_reserved(&func.body_str(), Some(¶m_names_str))?
539 };
540
541 fn eval_custom_function_ast<'b>(
546 ast: &AstExpr,
547 func_ctx: &EvalContext<'b>,
548 global_ctx: Option<&Rc<EvalContext<'b>>>,
549 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
550 var_cache: &mut BTreeMap<String, Real>,
551 ) -> Result<Real, ExprError> {
552 match ast {
553 AstExpr::Constant(val) => Ok(*val),
554 AstExpr::Variable(name) => {
555 if let Some(val) = func_ctx.variables.get(name) {
557 return Ok(*val);
558 }
559
560 if let Some(ctx) = global_ctx {
562 return eval_variable(name, Some(ctx.clone()), var_cache);
563 }
564
565 Err(ExprError::UnknownVariable { name: name.to_string() })
567 },
568 AstExpr::Function { name, args } => {
569 let mut arg_values = Vec::with_capacity(args.len());
571
572 for arg in args {
574 let arg_val = eval_custom_function_ast(arg, func_ctx, global_ctx, func_cache, var_cache)?;
576 arg_values.push(arg_val);
577 }
578
579 if let Some(native_fn) = func_ctx.get_native_function(name) {
581 if arg_values.len() != native_fn.arity {
583 return Err(ExprError::InvalidFunctionCall {
584 name: name.to_string(),
585 expected: native_fn.arity,
586 found: arg_values.len(),
587 });
588 }
589 let owned_fn = OwnedNativeFunction::from(native_fn);
591 return Ok((owned_fn.implementation)(&arg_values));
592 } else if let Some(expr_fn) = func_ctx.get_expression_function(name) {
593 return eval_expression_function(name, &arg_values, expr_fn, global_ctx.cloned(), func_cache, var_cache);
595 } else {
596 let ast_args: Vec<AstExpr> = args.iter()
599 .zip(arg_values.iter())
600 .map(|(_, val)| AstExpr::Constant(*val))
601 .collect();
602
603 return eval_function(name, &ast_args, global_ctx.cloned(), func_cache, var_cache);
604 }
605 },
606 AstExpr::Array { name, index } => {
607 let idx_val = eval_custom_function_ast(index, func_ctx, global_ctx, func_cache, var_cache)? as usize;
608
609 if let Some(arr) = func_ctx.get_array(name) {
611 if idx_val < arr.len() {
612 return Ok(arr[idx_val]);
613 } else {
614 return Err(ExprError::ArrayIndexOutOfBounds {
615 name: name.to_string(),
616 index: idx_val,
617 len: arr.len(),
618 });
619 }
620 }
621
622 if let Some(global) = global_ctx {
624 if let Some(arr) = global.get_array(name) {
625 if idx_val < arr.len() {
626 return Ok(arr[idx_val]);
627 } else {
628 return Err(ExprError::ArrayIndexOutOfBounds {
629 name: name.to_string(),
630 index: idx_val,
631 len: arr.len(),
632 });
633 }
634 }
635 }
636
637 Err(ExprError::UnknownVariable { name: name.to_string() })
638 },
639 AstExpr::Attribute { base, attr } => {
640 if let Some(global) = global_ctx {
642 return eval_attribute(base, attr, Some(global.clone()));
643 }
644
645 Err(ExprError::AttributeNotFound {
646 base: base.to_string(),
647 attr: attr.to_string(),
648 })
649 }
650 }
651 }
652
653 fn eval_expression_function<'b>(
655 name: &str,
656 arg_values: &[Real],
657 expr_fn: &crate::types::ExpressionFunction,
658 global_ctx: Option<Rc<EvalContext<'b>>>,
659 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
660 var_cache: &mut BTreeMap<String, Real>,
661 ) -> Result<Real, ExprError> {
662 if arg_values.len() != expr_fn.params.len() {
663 return Err(ExprError::InvalidFunctionCall {
664 name: name.to_string(),
665 expected: expr_fn.params.len(),
666 found: arg_values.len(),
667 });
668 }
669
670 let mut nested_func_ctx = EvalContext::new();
672
673 for (i, param_name) in expr_fn.params.iter().enumerate() {
675 nested_func_ctx.set_parameter(param_name, arg_values[i]);
676 }
677
678 if let Some(parent) = &global_ctx {
680 nested_func_ctx.function_registry = parent.function_registry.clone();
681 }
682
683 eval_custom_function_ast(
685 &expr_fn.compiled_ast,
686 &nested_func_ctx,
687 global_ctx.as_ref(),
688 func_cache,
689 var_cache
690 )
691 }
692
693 #[cfg(test)]
695 if name == "polynomial" && arg_values.len() == 1 {
696 let x = arg_values[0];
697
698 let x_cubed = x * x * x;
700 let two_x_squared = 2.0 * x * x;
701 let three_x = 3.0 * x;
702 let expected = x_cubed + two_x_squared + three_x + 4.0;
703
704 #[cfg(test)]
706 {
707 eprintln!("Polynomial calculation breakdown for x={}:", x);
708 eprintln!(" x^3 = {}", x_cubed);
709 eprintln!(" 2*x^2 = {}", two_x_squared);
710 eprintln!(" 3*x = {}", three_x);
711 eprintln!(" 4 = 4");
712 eprintln!(" Total expected: {}", expected);
713
714 eprintln!("Function body string: {}", func.body_str());
716
717 eprintln!("AST structure for polynomial body:");
719 match &body_ast {
720 AstExpr::Function { name, args } => {
721 eprintln!("Top-level function: {}", name);
722 for (i, arg) in args.iter().enumerate() {
723 eprintln!(" Arg {}: {:?}", i, arg);
724
725 if let AstExpr::Function { name: inner_name, args: inner_args } = arg {
727 eprintln!(" Inner function: {}", inner_name);
728 for (j, inner_arg) in inner_args.iter().enumerate() {
729 eprintln!(" Inner arg {}: {:?}", j, inner_arg);
730 }
731 }
732 }
733 },
734 _ => eprintln!("Not a function at top level: {:?}", body_ast),
735 }
736
737 if let Some(x_val) = func_ctx.variables.get("x") {
739 eprintln!("Value of 'x' in function context: {}", x_val);
740 } else {
741 eprintln!("ERROR: 'x' not found in function context!");
742 }
743 }
744
745 let x_var = AstExpr::Variable("x".to_string());
747 let result_x = eval_custom_function_ast(&x_var, &func_ctx, ctx.as_ref(), func_cache, &mut BTreeMap::new());
748
749 #[cfg(test)]
750 eprintln!("Custom evaluating 'x': {:?}", result_x);
751
752 let x_cubed_ast = AstExpr::Function {
754 name: "^".to_string(),
755 args: alloc::vec![x_var.clone(), AstExpr::Constant(3.0)],
756 };
757 let result_x_cubed = eval_custom_function_ast(&x_cubed_ast, &func_ctx, ctx.as_ref(), func_cache, &mut BTreeMap::new());
758
759 #[cfg(test)]
760 eprintln!("Custom evaluating 'x^3': {:?}", result_x_cubed);
761 }
762
763 let result = eval_custom_function_ast(&body_ast, &func_ctx, ctx.as_ref(), func_cache, &mut BTreeMap::new());
765
766 #[cfg(test)]
768 if name == "polynomial" && arg_values.len() == 1 {
769 let x = arg_values[0];
770 let expected = x*x*x + 2.0*x*x + 3.0*x + 4.0;
771 eprintln!("polynomial({}) = {} (expected {})", x, result.as_ref().unwrap_or(&0.0), expected);
772 }
773
774 result
775}
776
777fn eval_native_function<'a>(
778 name: &str,
779 args: &[AstExpr],
780 ctx: Option<Rc<EvalContext<'a>>>,
781 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
782 var_cache: &mut BTreeMap<String, Real>,
783 native_fn: &OwnedNativeFunction,
784) -> Result<Real, ExprError> {
785 if args.len() != native_fn.arity {
786 return Err(ExprError::InvalidFunctionCall {
787 name: name.to_string(),
788 expected: native_fn.arity,
789 found: args.len(),
790 });
791 }
792 let mut arg_values = Vec::with_capacity(args.len());
793 for arg in args.iter() {
794 let arg_val = eval_ast_inner(arg, ctx.clone(), func_cache, var_cache)?;
795 arg_values.push(arg_val);
796 }
797 Ok((native_fn.implementation)(&arg_values))
798}
799
800fn eval_array<'a>(
801 name: &str,
802 index: &AstExpr,
803 ctx: Option<Rc<EvalContext<'a>>>,
804 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
805 var_cache: &mut BTreeMap<String, Real>,
806) -> Result<Real, ExprError> {
807 let idx = eval_ast_inner(index, ctx.clone(), func_cache, var_cache)? as usize;
808
809 if let Some(ctx_ref) = ctx.as_ref() {
810 if let Some(arr) = ctx_ref.get_array(name) {
811 if idx < arr.len() {
812 return Ok(arr[idx]);
813 } else {
814 return Err(ExprError::ArrayIndexOutOfBounds {
815 name: name.to_string(),
816 index: idx,
817 len: arr.len(),
818 });
819 }
820 }
821 }
822 Err(ExprError::UnknownVariable {
823 name: name.to_string(),
824 })
825}
826
827fn eval_attribute(
828 base: &str,
829 attr: &str,
830 ctx: Option<Rc<EvalContext<'_>>>,
831) -> Result<Real, ExprError> {
832 if let Some(ctx_ref) = ctx.as_ref() {
833 if let Some(attr_map) = ctx_ref.get_attribute_map(base) {
834 if let Some(val) = attr_map.get(attr) {
835 return Ok(*val);
836 } else {
837 return Err(ExprError::AttributeNotFound {
838 base: base.to_string(),
839 attr: attr.to_string(),
840 });
841 }
842 }
843 }
844 Err(ExprError::AttributeNotFound {
845 base: base.to_string(),
846 attr: attr.to_string(),
847 })
848}
849
850fn eval_ast_inner<'a>(
851 ast: &AstExpr,
852 ctx: Option<Rc<EvalContext<'a>>>,
853 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
854 var_cache: &mut BTreeMap<String, Real>,
855) -> Result<Real, ExprError> {
856 match ast {
857 AstExpr::Constant(val) => Ok(*val),
858 AstExpr::Variable(name) => eval_variable(name, ctx.clone(), var_cache),
859 AstExpr::Function { name, args } => {
860 eval_function(name, args, ctx.clone(), func_cache, var_cache)
861 }
862 AstExpr::Array { name, index } => eval_array(name, index, ctx.clone(), func_cache, var_cache),
863 AstExpr::Attribute { base, attr } => eval_attribute(base, attr, ctx),
864 }
865}
866
867#[cfg(test)]
868mod tests {
869 use super::*;
870 use crate::engine::{interp, parse_expression};
871
872 fn test_eval_variable(name: &str, ctx: Option<Rc<EvalContext>>) -> Result<Real, ExprError> {
874 let mut var_cache = BTreeMap::new();
875 super::eval_variable(name, ctx, &mut var_cache)
876 }
877
878 fn test_eval_function(
879 name: &str,
880 args: &[AstExpr],
881 ctx: Option<Rc<EvalContext>>,
882 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>
883 ) -> Result<Real, ExprError> {
884 let mut var_cache = BTreeMap::new();
885 super::eval_function(name, args, ctx, func_cache, &mut var_cache)
886 }
887
888 fn test_eval_array(
889 name: &str,
890 index: &AstExpr,
891 ctx: Option<Rc<EvalContext>>,
892 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>
893 ) -> Result<Real, ExprError> {
894 let mut var_cache = BTreeMap::new();
895 super::eval_array(name, index, ctx, func_cache, &mut var_cache)
896 }
897
898 fn test_eval_custom_function<F>(
899 name: &str,
900 args: &[AstExpr],
901 ctx: Option<Rc<EvalContext>>,
902 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
903 func: &F
904 ) -> Result<Real, ExprError>
905 where
906 F: super::CustomFunction
907 {
908 let mut var_cache = BTreeMap::new();
909 super::eval_custom_function(name, args, ctx, func_cache, &mut var_cache, func)
910 }
911
912 fn test_eval_native_function(
913 name: &str,
914 args: &[AstExpr],
915 ctx: Option<Rc<EvalContext>>,
916 func_cache: &mut BTreeMap<String, Option<FunctionCacheEntry>>,
917 native_fn: &OwnedNativeFunction
918 ) -> Result<Real, ExprError> {
919 let mut var_cache = BTreeMap::new();
920 super::eval_native_function(name, args, ctx, func_cache, &mut var_cache, native_fn)
921 }
922 use crate::error::ExprError;
923
924 #[test]
925 fn test_eval_user_function_polynomial() {
926 let mut ctx = EvalContext::new();
927 ctx.register_expression_function("polynomial", &["x"], "x^3 + 2*x^2 + 3*x + 4")
928 .unwrap();
929 let mut func_cache = std::collections::BTreeMap::new();
930 let _ast = AstExpr::Function {
931 name: "polynomial".to_string(),
932 args: vec![AstExpr::Constant(3.0)],
933 };
934 let expr_fn = ctx.get_expression_function("polynomial").unwrap().clone();
936 let val = test_eval_custom_function(
937 "polynomial",
938 &[AstExpr::Constant(3.0)],
939 Some(Rc::new(ctx.clone())),
940 &mut func_cache,
941 &expr_fn,
942 )
943 .unwrap();
944 assert_eq!(val, 58.0); }
946
947 #[test]
948 fn test_eval_expression_function_simple() {
949 let mut ctx = EvalContext::new();
950 ctx.register_expression_function("double", &["x"], "x*2")
951 .unwrap();
952 let mut func_cache = std::collections::BTreeMap::new();
953 let expr_fn = ctx.get_expression_function("double").unwrap().clone();
955 let val = test_eval_custom_function(
956 "double",
957 &[AstExpr::Constant(7.0)],
958 Some(Rc::new(ctx.clone())),
959 &mut func_cache,
960 &expr_fn,
961 )
962 .unwrap();
963 assert_eq!(val, 14.0);
964 }
965
966 #[test]
967 fn test_eval_native_function_simple() {
968 let mut ctx = EvalContext::new();
969 ctx.register_native_function("triple", 1, |args| args[0] * 3.0);
970 let mut func_cache = std::collections::BTreeMap::new();
971 let native_fn = {
973 let nf = ctx.function_registry.native_functions.get("triple").unwrap();
975 OwnedNativeFunction {
976 arity: nf.arity,
977 implementation: nf.implementation.clone(),
978 name: nf.name.to_string(), description: nf.description.clone(),
980 }
981 };
982 let val = test_eval_native_function(
985 "triple",
986 &[AstExpr::Constant(4.0)],
987 Some(Rc::new(ctx.clone())),
988 &mut func_cache,
989 &native_fn,
990 )
991 .unwrap();
992 assert_eq!(val, 12.0);
993 }
994
995 fn create_test_context<'a>() -> EvalContext<'a> {
997 let mut ctx = EvalContext::new();
998 #[cfg(not(feature = "no-builtin-math"))]
1000 {
1001 ctx.register_native_function("sin", 1, |args| sin(args[0], 0.0));
1004 ctx.register_native_function("cos", 1, |args| cos(args[0], 0.0));
1005 ctx.register_native_function("pow", 2, |args| pow(args[0], args[1]));
1006 ctx.register_native_function("^", 2, |args| pow(args[0], args[1]));
1007 ctx.register_native_function("min", 2, |args| min(args[0], args[1]));
1008 ctx.register_native_function("max", 2, |args| max(args[0], args[1]));
1009 ctx.register_native_function("neg", 1, |args| neg(args[0], 0.0));
1010 ctx.register_native_function("abs", 1, |args| abs(args[0], 0.0));
1011 }
1013 ctx
1014 }
1015
1016 #[test]
1017 fn test_eval_variable_builtin_constants() {
1018 #[cfg(feature = "f32")]
1020 {
1021 assert!((test_eval_variable("pi", None).unwrap() - std::f32::consts::PI).abs() < 1e-5);
1022 assert!((test_eval_variable("e", None).unwrap() - std::f32::consts::E).abs() < 1e-5);
1023 }
1024 #[cfg(not(feature = "f32"))]
1025 {
1026 assert!((test_eval_variable("pi", None).unwrap() - std::f64::consts::PI).abs() < 1e-10);
1027 assert!((test_eval_variable("e", None).unwrap() - std::f64::consts::E).abs() < 1e-10);
1028 }
1029 }
1030
1031 #[test]
1032 fn test_eval_variable_context_lookup() {
1033 let mut ctx = EvalContext::new();
1034 ctx.set_parameter("x", 42.0);
1035 ctx.constants.insert("y".into(), 3.14);
1036 assert_eq!(test_eval_variable("x", Some(Rc::new(ctx.clone()))).unwrap(), 42.0);
1037 assert_eq!(test_eval_variable("y", Some(Rc::new(ctx.clone()))).unwrap(), 3.14);
1038 }
1039
1040 #[test]
1041 fn test_eval_variable_unknown_and_function_name() {
1042 let err = test_eval_variable("nosuchvar", None).unwrap_err();
1043 assert!(matches!(err, ExprError::UnknownVariable { .. }));
1044 let err2 = test_eval_variable("sin", None).unwrap_err();
1045 assert!(matches!(err2, ExprError::Syntax(_)));
1046 }
1047
1048 #[test]
1049 fn test_eval_function_native_and_expression() {
1050 let mut ctx = create_test_context();
1051 let mut func_cache = std::collections::BTreeMap::new();
1054 let val = test_eval_function(
1055 "sin",
1056 &[AstExpr::Constant(0.0)],
1057 Some(Rc::new(ctx.clone())),
1058 &mut func_cache,
1059 )
1060 .unwrap();
1061 assert!((val - 0.0).abs() < 1e-10);
1062
1063 ctx.register_expression_function("double", &["x"], "x*2")
1065 .unwrap();
1066 let val2 = test_eval_function(
1068 "double",
1069 &[AstExpr::Constant(5.0)],
1070 Some(Rc::new(ctx.clone())),
1071 &mut func_cache,
1072 )
1073 .unwrap();
1074 assert_eq!(val2, 10.0);
1075 }
1076
1077 #[test]
1078 fn test_eval_function_user_function() {
1079 let mut ctx = create_test_context();
1080 ctx.register_expression_function("inc", &["x"], "x+1")
1081 .unwrap();
1082 let mut func_cache = std::collections::BTreeMap::new();
1083 let val = test_eval_function(
1084 "inc",
1085 &[AstExpr::Constant(41.0)],
1086 Some(Rc::new(ctx.clone())),
1087 &mut func_cache,
1088 )
1089 .unwrap();
1090 assert_eq!(val, 42.0);
1091 }
1092
1093 #[test]
1094 fn test_eval_function_builtin_fallback() {
1095 let ctx = create_test_context();
1096 let mut func_cache = std::collections::BTreeMap::new();
1097 let val = test_eval_function(
1099 "pow",
1100 &[AstExpr::Constant(2.0), AstExpr::Constant(3.0)],
1101 Some(Rc::new(ctx.clone())),
1102 &mut func_cache,
1103 )
1104 .unwrap();
1105 assert_eq!(val, 8.0);
1106 let val2 = test_eval_function(
1108 "abs",
1109 &[AstExpr::Constant(-5.0)],
1110 Some(Rc::new(ctx.clone())),
1111 &mut func_cache,
1112 )
1113 .unwrap();
1114 assert_eq!(val2, 5.0);
1115 }
1116
1117 #[test]
1118 fn test_eval_array_success_and_out_of_bounds() {
1119 let mut ctx = EvalContext::new();
1120 ctx.arrays.insert("arr".into(), vec![1.0, 2.0, 3.0]);
1121
1122 let mut func_cache1 = std::collections::BTreeMap::new();
1124 let mut func_cache2 = std::collections::BTreeMap::new();
1125
1126 let idx_expr = AstExpr::Constant(1.0);
1127 let val = test_eval_array("arr", &idx_expr, Some(Rc::new(ctx.clone())), &mut func_cache1).unwrap();
1128 assert_eq!(val, 2.0);
1129
1130 let idx_expr2 = AstExpr::Constant(10.0);
1132 let err =
1133 test_eval_array("arr", &idx_expr2, Some(Rc::new(ctx.clone())), &mut func_cache2).unwrap_err();
1134 assert!(matches!(err, ExprError::ArrayIndexOutOfBounds { .. }));
1135 }
1136
1137 #[test]
1138 fn test_eval_array_unknown() {
1139 let ctx = EvalContext::new();
1140 let mut func_cache = std::collections::BTreeMap::new();
1141 let idx_expr = AstExpr::Constant(0.0);
1142 let err =
1143 test_eval_array("nosucharr", &idx_expr, Some(Rc::new(ctx.clone())), &mut func_cache).unwrap_err();
1144 assert!(matches!(err, ExprError::UnknownVariable { .. }));
1145 }
1146
1147 #[test]
1148 fn test_eval_attribute_success_and_not_found() {
1149 let mut ctx = EvalContext::new();
1150 let mut map = std::collections::HashMap::new();
1151 map.insert("foo".to_string(), 123.0);
1152 ctx.attributes.insert("bar".to_string(), map);
1153 let val = super::eval_attribute("bar", "foo", Some(Rc::new(ctx.clone()))).unwrap();
1154 assert_eq!(val, 123.0);
1155 let err = super::eval_attribute("bar", "baz", Some(Rc::new(ctx.clone()))).unwrap_err();
1156 assert!(matches!(err, ExprError::AttributeNotFound { .. }));
1157 }
1158
1159 #[test]
1160 fn test_eval_attribute_unknown_base() {
1161 let ctx = EvalContext::new();
1162 let err = super::eval_attribute("nosuch", "foo", Some(Rc::new(ctx.clone()))).unwrap_err();
1163 assert!(matches!(err, ExprError::AttributeNotFound { .. }));
1164 }
1165
1166 #[test]
1167 fn test_neg_pow_ast() {
1168 let ast = parse_expression("-2^2").unwrap_or_else(|e| panic!("Parse error: {}", e));
1170 match ast {
1172 AstExpr::Function { ref name, ref args } if name == "neg" => {
1173 assert_eq!(args.len(), 1);
1174 match &args[0] {
1175 AstExpr::Function {
1176 name: pow_name,
1177 args: pow_args,
1178 } if pow_name == "^" => {
1179 assert_eq!(pow_args.len(), 2);
1180 match (&pow_args[0], &pow_args[1]) {
1181 (AstExpr::Constant(a), AstExpr::Constant(b)) => {
1182 assert_eq!(*a, 2.0);
1183 assert_eq!(*b, 2.0);
1184 }
1185 _ => panic!("Expected constants as pow args"),
1186 }
1187 }
1188 _ => panic!("Expected pow as argument to neg"),
1189 }
1190 }
1191 _ => panic!("Expected neg as top-level function"),
1192 }
1193 }
1194
1195 #[test]
1196 #[cfg(not(feature = "no-builtin-math"))] fn test_neg_pow_eval() {
1198 let val = interp("-2^2", None).unwrap();
1200 assert_eq!(val, -4.0); let val2 = interp("(-2)^2", None).unwrap();
1202 assert_eq!(val2, 4.0); }
1204
1205 #[test]
1206 #[cfg(feature = "no-builtin-math")] fn test_neg_pow_eval_no_builtins() {
1208 let mut ctx = EvalContext::new();
1210 ctx.register_native_function("neg", 1, |args| -args[0]);
1211 ctx.register_native_function("^", 2, |args| args[0].powf(args[1])); let val = interp("-2^2", Some(&mut ctx)).unwrap();
1214 assert_eq!(val, -4.0);
1215 let val2 = interp("(-2)^2", Some(&mut ctx)).unwrap();
1216 assert_eq!(val2, 4.0);
1217
1218 let err = interp("-2^2", None).unwrap_err();
1220 assert!(matches!(err, ExprError::UnknownFunction { .. }));
1221 }
1222
1223 #[test]
1224 fn test_paren_neg_pow_ast() {
1225 let ast = parse_expression("(-2)^2").unwrap_or_else(|e| panic!("Parse error: {}", e));
1227 match ast {
1229 AstExpr::Function { ref name, ref args } if name == "^" => {
1230 assert_eq!(args.len(), 2);
1231 match &args[0] {
1232 AstExpr::Function {
1233 name: neg_name,
1234 args: neg_args,
1235 } if neg_name == "neg" => {
1236 assert_eq!(neg_args.len(), 1);
1237 match &neg_args[0] {
1238 AstExpr::Constant(a) => assert_eq!(*a, 2.0),
1239 _ => panic!("Expected constant as neg arg"),
1240 }
1241 }
1242 _ => panic!("Expected neg as left arg to pow"),
1243 }
1244 match &args[1] {
1245 AstExpr::Constant(b) => assert_eq!(*b, 2.0),
1246 _ => panic!("Expected constant as right arg to pow"),
1247 }
1248 }
1249 _ => panic!("Expected pow as top-level function"),
1250 }
1251 }
1252
1253 #[test]
1254 fn test_function_application_juxtaposition_ast() {
1255 let sin_x_ast = AstExpr::Function {
1258 name: "sin".to_string(),
1259 args: vec![AstExpr::Variable("x".to_string())],
1260 };
1261
1262 match sin_x_ast {
1263 AstExpr::Function { ref name, ref args } if name == "sin" => {
1264 assert_eq!(args.len(), 1);
1265 match &args[0] {
1266 AstExpr::Variable(var) => assert_eq!(var, "x"),
1267 _ => panic!("Expected variable as argument"),
1268 }
1269 }
1270 _ => panic!("Expected function node for sin x"),
1271 }
1272
1273 let neg_42_ast = AstExpr::Function {
1275 name: "neg".to_string(),
1276 args: vec![AstExpr::Constant(42.0)],
1277 };
1278
1279 let abs_neg_42_ast = AstExpr::Function {
1280 name: "abs".to_string(),
1281 args: vec![neg_42_ast],
1282 };
1283
1284 println!("AST for 'abs -42': {:?}", abs_neg_42_ast);
1285
1286 match abs_neg_42_ast {
1287 AstExpr::Function { ref name, ref args } if name == "abs" => {
1288 assert_eq!(args.len(), 1);
1289 match &args[0] {
1290 AstExpr::Function {
1291 name: n2,
1292 args: args2,
1293 } if n2 == "neg" => {
1294 assert_eq!(args2.len(), 1);
1295 match &args2[0] {
1296 AstExpr::Constant(c) => assert_eq!(*c, 42.0),
1297 _ => panic!("Expected constant as neg arg"),
1298 }
1299 }
1300 _ => panic!("Expected neg as argument to abs"),
1301 }
1302 }
1303 _ => panic!("Expected function node for abs -42"),
1304 }
1305 }
1306
1307 #[test]
1308 fn test_function_application_juxtaposition_eval() {
1309 let ctx = create_test_context(); #[cfg(feature = "no-builtin-math")]
1315 {
1316 ctx.register_native_function("abs", 1, |args| args[0].abs());
1317 ctx.register_native_function("neg", 1, |args| -args[0]);
1318 }
1319
1320 let ast = AstExpr::Function {
1322 name: "abs".to_string(),
1323 args: vec![AstExpr::Function {
1324 name: "neg".to_string(),
1325 args: vec![AstExpr::Constant(42.0)],
1326 }],
1327 };
1328
1329 let val = eval_ast(&ast, Some(Rc::new(ctx.clone()))).unwrap();
1330 assert_eq!(val, 42.0);
1331 }
1332
1333 #[test]
1334 fn test_pow_arity_ast() {
1335 let ast = parse_expression("pow(2)").unwrap_or_else(|e| panic!("Parse error: {}", e));
1339 match ast {
1340 AstExpr::Function { ref name, ref args } if name == "pow" => {
1341 assert!(args.len() == 1 || args.len() == 2);
1344 match &args[0] {
1345 AstExpr::Constant(c) => assert_eq!(*c, 2.0),
1346 _ => panic!("Expected constant as pow arg"),
1347 }
1348 if args.len() == 2 {
1350 match &args[1] {
1351 AstExpr::Constant(c) => assert_eq!(*c, 2.0),
1352 _ => panic!("Expected constant as pow second arg"),
1353 }
1354 }
1355 }
1356 _ => panic!("Expected function node for pow(2)"),
1357 }
1358 }
1359
1360 #[test]
1361 #[cfg(not(feature = "no-builtin-math"))] fn test_pow_arity_eval() {
1363 let result = interp("pow(2)", None).unwrap();
1365 assert_eq!(result, 4.0); let result2 = interp("pow(2, 3)", None).unwrap();
1368 assert_eq!(result2, 8.0);
1369 }
1370
1371 #[test]
1372 #[cfg(feature = "no-builtin-math")] fn test_pow_arity_eval_no_builtins() {
1374 let mut ctx = EvalContext::new();
1375 ctx.register_native_function("pow", 2, |args| args[0].powf(args[1]));
1377 let err = interp("pow(2)", Some(&mut ctx)).unwrap_err(); assert!(matches!(err, ExprError::InvalidFunctionCall { .. }));
1383
1384 let result2 = interp("pow(2, 3)", Some(&mut ctx)).unwrap();
1385 assert_eq!(result2, 8.0);
1386 }
1387
1388 #[test]
1389 fn test_unknown_variable_and_function_ast() {
1390 let ast = parse_expression("sin").unwrap_or_else(|e| panic!("Parse error: {}", e));
1392 match ast {
1393 AstExpr::Variable(ref name) => assert_eq!(name, "sin"),
1394 _ => panic!("Expected variable node for sin"),
1395 }
1396 let ast2 = parse_expression("abs").unwrap_or_else(|e| panic!("Parse error: {}", e));
1397 match ast2 {
1398 AstExpr::Variable(ref name) => assert_eq!(name, "abs"),
1399 _ => panic!("Expected variable node for abs"),
1400 }
1401 }
1402
1403 #[test]
1404 fn test_unknown_variable_and_function_eval() {
1405 let ctx = create_test_context(); #[cfg(feature = "no-builtin-math")]
1410 {
1411 ctx.register_native_function("sin", 1, |args| args[0].sin());
1412 ctx.register_native_function("abs", 1, |args| args[0].abs());
1413 }
1414
1415 let sin_var_ast = AstExpr::Variable("sin".to_string());
1417 let err = eval_ast(&sin_var_ast, Some(Rc::new(ctx.clone()))).unwrap_err();
1418 match err {
1419 ExprError::Syntax(msg) => {
1420 assert!(msg.contains("Function 'sin' used without arguments"));
1421 }
1422 _ => panic!("Expected Syntax error, got {:?}", err),
1423 }
1424
1425 let abs_var_ast = AstExpr::Variable("abs".to_string());
1427 let err2 = eval_ast(&abs_var_ast, Some(Rc::new(ctx.clone()))).unwrap_err();
1428 match err2 {
1429 ExprError::Syntax(msg) => {
1430 assert!(msg.contains("Function 'abs' used without arguments"));
1431 }
1432 _ => panic!("Expected Syntax error, got {:?}", err2),
1433 }
1434
1435 let unknown_var_ast = AstExpr::Variable("nosuchvar".to_string());
1437 let err3 = eval_ast(&unknown_var_ast, Some(Rc::new(ctx.clone()))).unwrap_err();
1438 assert!(matches!(err3, ExprError::UnknownVariable { name } if name == "nosuchvar"));
1439 }
1440
1441 #[test]
1442 fn test_override_builtin_native() {
1443 let mut ctx = create_test_context(); ctx.register_native_function("sin", 1, |_args| 100.0);
1447 ctx.register_native_function("pow", 2, |args| args[0] + args[1]);
1449 ctx.register_native_function("^", 2, |args| args[0] + args[1]);
1451
1452 let val_sin = interp("sin(0.5)", Some(Rc::new(ctx.clone()))).unwrap();
1454 assert_eq!(val_sin, 100.0, "Native 'sin' override failed");
1455
1456 let val_pow = interp("pow(3, 4)", Some(Rc::new(ctx.clone()))).unwrap();
1458 assert_eq!(val_pow, 7.0, "Native 'pow' override failed");
1459
1460 let val_pow_op = interp("3^4", Some(Rc::new(ctx.clone()))).unwrap();
1462 assert_eq!(val_pow_op, 7.0, "Native '^' override failed");
1463
1464 #[cfg(feature = "no-builtin-math")]
1467 {
1468 ctx.register_native_function("cos", 1, |args| args[0].cos()); }
1470 if ctx.function_registry.native_functions.contains_key("cos")
1472 || cfg!(not(feature = "no-builtin-math"))
1473 {
1474 let val_cos = interp("cos(0)", Some(Rc::new(ctx.clone()))).unwrap();
1475 let expected_cos = 1.0;
1477 assert!(
1478 (val_cos - expected_cos).abs() < 1e-9,
1479 "Built-in/default 'cos' failed after override. Got {}",
1480 val_cos
1481 );
1482 } else {
1483 let err = interp("cos(0)", Some(Rc::new(ctx.clone()))).unwrap_err();
1485 assert!(matches!(err, ExprError::UnknownFunction { .. }));
1486 }
1487 }
1488
1489 #[test]
1490 fn test_override_builtin_expression() {
1491 let mut ctx = create_test_context(); ctx.register_expression_function("cos", &["x"], "x * 10")
1495 .unwrap();
1496
1497 #[cfg(feature = "no-builtin-math")]
1500 {
1501 ctx.register_native_function("min", 2, |args| args[0].min(args[1]));
1502 }
1503 if ctx.function_registry.native_functions.contains_key("min")
1505 || cfg!(not(feature = "no-builtin-math"))
1506 {
1507 ctx.register_expression_function("max", &["a", "b"], "min(a, b)")
1508 .unwrap();
1509
1510 let val_max = interp("max(10, 2)", Some(Rc::new(ctx.clone()))).unwrap();
1512 assert_eq!(val_max, 2.0, "Expression 'max' override failed");
1513 } else {
1514 let reg_err = ctx.register_expression_function("max", &["a", "b"], "min(a, b)");
1516 if reg_err.is_ok() {
1519 let eval_err = interp("max(10, 2)", Some(Rc::new(ctx.clone()))).unwrap_err();
1520 assert!(matches!(eval_err, ExprError::UnknownFunction { name } if name == "min"));
1521 }
1522 }
1523
1524 let val_cos = interp("cos(5)", Some(Rc::new(ctx.clone()))).unwrap();
1526 assert_eq!(val_cos, 50.0, "Expression 'cos' override failed");
1527
1528 #[cfg(feature = "no-builtin-math")]
1530 {
1531 ctx.register_native_function("sin", 1, |args| args[0].sin());
1532 }
1533 if ctx.function_registry.native_functions.contains_key("sin")
1534 || cfg!(not(feature = "no-builtin-math"))
1535 {
1536 let val_sin = interp("sin(0)", Some(Rc::new(ctx.clone()))).unwrap();
1537 assert!(
1538 (val_sin - 0.0).abs() < 1e-9,
1539 "Built-in/default 'sin' failed after override"
1540 );
1541 } else {
1542 let err = interp("sin(0)", Some(Rc::new(ctx.clone()))).unwrap_err();
1543 assert!(matches!(err, ExprError::UnknownFunction { .. }));
1544 }
1545 }
1546
1547 #[test]
1548 fn test_expression_function_uses_correct_context() {
1549 let mut ctx = create_test_context(); ctx.set_parameter("a", 10.0); ctx.constants.insert("my_const".to_string().into(), 100.0); ctx.register_expression_function("func1_const", &["x"], "x + my_const")
1556 .unwrap();
1557 let val1 = interp("func1_const(5)", Some(Rc::new(ctx.clone()))).unwrap();
1558 assert_eq!(val1, 105.0, "func1_const should use constant from context");
1559
1560 ctx.register_expression_function("func_uses_outer_var", &["x"], "x + a")
1562 .unwrap();
1563
1564 let result = interp("func_uses_outer_var(5)", Some(Rc::new(ctx.clone())));
1566 match result {
1567 Ok(val) => {
1568 assert_eq!(
1569 val, 15.0,
1570 "func_uses_outer_var should use variable 'a' from context"
1571 );
1572 }
1573 Err(e) => {
1574 println!("Error evaluating func_uses_outer_var(5): {:?}", e);
1575 panic!(
1576 "Expected Ok(15.0) for func_uses_outer_var(5), got error: {:?}",
1577 e
1578 );
1579 }
1580 }
1581
1582 ctx.register_expression_function("shadow_test", &["a"], "a + 1")
1584 .unwrap();
1585 let val_shadow = interp("shadow_test(7)", Some(Rc::new(ctx.clone()))).unwrap();
1586 assert_eq!(
1587 val_shadow, 8.0,
1588 "Parameter 'a' should shadow context variable 'a'"
1589 );
1590
1591 let val_a = interp("a", Some(Rc::new(ctx.clone()))).unwrap();
1593 assert_eq!(val_a, 10.0, "Context 'a' should remain unchanged");
1594 }
1595
1596 #[test]
1599 fn test_polynomial_expression_function_direct() {
1600 let mut ctx = EvalContext::new();
1601 ctx.register_expression_function("polynomial", &["x"], "x^3 + 2*x^2 + 3*x + 4")
1602 .unwrap();
1603
1604 let ast = crate::engine::parse_expression("polynomial(2)").unwrap();
1606 let result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx.clone()))).unwrap();
1607 assert!(
1608 (result - 26.0).abs() < 1e-10,
1609 "Expected 26.0, got {}",
1610 result
1611 );
1612
1613 let ast = crate::engine::parse_expression("polynomial(3)").unwrap();
1615 let result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx.clone()))).unwrap();
1616 assert!(
1617 (result - 58.0).abs() < 1e-10,
1618 "Expected 58.0, got {}",
1619 result
1620 );
1621 }
1622
1623 #[test]
1624 fn test_polynomial_subexpressions() {
1625 let mut ctx = EvalContext::new();
1626 ctx.set_parameter("x", 2.0);
1627
1628 let ast = crate::engine::parse_expression("x^3").unwrap();
1630 let result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx.clone()))).unwrap();
1631 assert_eq!(result, 8.0);
1632
1633 let ast = crate::engine::parse_expression("2*x^2").unwrap();
1635 let result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx.clone()))).unwrap();
1636 assert_eq!(result, 8.0);
1637
1638 let ast = crate::engine::parse_expression("3*x").unwrap();
1640 let result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx.clone()))).unwrap();
1641 assert_eq!(result, 6.0);
1642
1643 let ast = crate::engine::parse_expression("4").unwrap();
1645 let result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx.clone()))).unwrap();
1646 assert_eq!(result, 4.0);
1647 }
1648
1649 #[test]
1650 fn test_operator_precedence() {
1651 let ast = crate::engine::parse_expression("2 + 3 * 4 ^ 2").unwrap();
1652 let result = crate::eval::eval_ast(&ast, None).unwrap();
1653 assert_eq!(result, 2.0 + 3.0 * 16.0); }
1655
1656 #[test]
1657 fn test_polynomial_ast_structure() {
1658 let ast = crate::engine::parse_expression("x^3 + 2*x^2 + 3*x + 4").unwrap();
1659 println!("{:?}", ast);
1661 }
1663
1664 #[test]
1666 fn test_polynomial_integration_debug() {
1667 let mut ctx = EvalContext::new();
1668 ctx.register_expression_function("polynomial", &["x"], "x^3 + 2*x^2 + 3*x + 4")
1669 .unwrap();
1670
1671 let body_ast = crate::engine::parse_expression_with_reserved(
1673 "x^3 + 2*x^2 + 3*x + 4",
1674 Some(&vec!["x".to_string()]),
1675 )
1676 .unwrap();
1677 println!("AST for polynomial body: {:?}", body_ast);
1678
1679 let call_ast = crate::engine::parse_expression("polynomial(2)").unwrap();
1681 println!("AST for polynomial(2): {:?}", call_ast);
1682
1683 let result = crate::eval::eval_ast(&call_ast, Some(Rc::new(ctx.clone()))).unwrap();
1685 println!("polynomial(2) = {}", result);
1686
1687 ctx.set_parameter("x", 2.0);
1689 let direct_result = crate::eval::eval_ast(&body_ast, Some(Rc::new(ctx.clone()))).unwrap();
1690 println!("Direct eval with x=2: {}", direct_result);
1691 }
1692
1693 #[test]
1695 fn test_polynomial_argument_mapping_debug() {
1696 let mut ctx = EvalContext::new();
1697 ctx.register_expression_function("polynomial", &["x"], "x^3 + 2*x^2 + 3*x + 4")
1698 .unwrap();
1699
1700 let ast_lit = crate::engine::parse_expression("polynomial(10)").unwrap();
1702 let result_lit = crate::eval::eval_ast(&ast_lit, Some(Rc::new(ctx.clone()))).unwrap();
1703 println!("polynomial(10) = {}", result_lit);
1704 assert_eq!(result_lit, 1234.0);
1705
1706 ctx.set_parameter("z", 10.0);
1708 let ast_var = crate::engine::parse_expression("polynomial(z)").unwrap();
1709 let result_var = crate::eval::eval_ast(&ast_var, Some(Rc::new(ctx.clone()))).unwrap();
1710 println!("polynomial(z) = {}", result_var);
1711 assert_eq!(result_var, 1234.0);
1712
1713 ctx.set_parameter("a", 5.0);
1715 ctx.set_parameter("b", 10.0);
1716 let ast_sub = crate::engine::parse_expression("polynomial(a + b / 2)").unwrap();
1717 let result_sub = crate::eval::eval_ast(&ast_sub, Some(Rc::new(ctx.clone()))).unwrap();
1718 println!("polynomial(a + b / 2) = {}", result_sub);
1719 assert_eq!(result_sub, 1234.0);
1720
1721 let ast_nested = crate::engine::parse_expression("polynomial(polynomial(2))").unwrap();
1723 let result_nested = crate::eval::eval_ast(&ast_nested, Some(Rc::new(ctx.clone()))).unwrap();
1724 println!("polynomial(polynomial(2)) = {}", result_nested);
1725 }
1726 #[test]
1727 fn test_polynomial_shadowing_variable() {
1728 let mut ctx = EvalContext::new();
1729 ctx.set_parameter("x", 100.0); ctx.register_expression_function("polynomial", &["x"], "x^3 + 2*x^2 + 3*x + 4")
1731 .unwrap();
1732
1733 let call_ast = crate::engine::parse_expression("polynomial(2)").unwrap();
1734 let result = crate::eval::eval_ast(&call_ast, Some(Rc::new(ctx.clone()))).unwrap();
1735
1736 assert!(
1737 (result - 26.0).abs() < 1e-10,
1738 "Expected 26.0, got {}",
1739 result
1740 );
1741 }
1742
1743 #[test]
1745 fn test_polynomial_ast_cache_effect() {
1746 use std::cell::RefCell;
1747 use std::collections::HashMap;
1748 use std::rc::Rc;
1749
1750 let mut ctx = EvalContext::new();
1751 ctx.ast_cache = Some(RefCell::new(HashMap::<String, Rc<crate::types::AstExpr>>::new()));
1752 ctx.register_expression_function("polynomial", &["x"], "x^3 + 2*x^2 + 3*x + 4")
1753 .unwrap();
1754
1755 let expr = "polynomial(2)";
1756
1757 let result1 = crate::engine::interp(expr, Some(Rc::new(ctx.clone()))).unwrap();
1759 println!("First eval with cache: {}", result1);
1760
1761 let result2 = crate::engine::interp(expr, Some(Rc::new(ctx.clone()))).unwrap();
1763 println!("Second eval with cache: {}", result2);
1764
1765 assert_eq!(result1, result2);
1766 assert!((result1 - 26.0).abs() < 1e-10);
1767 }
1768
1769 #[test]
1771 fn test_polynomial_function_overriding() {
1772 let mut ctx = EvalContext::new();
1773 ctx.register_expression_function("polynomial", &["x"], "x + 1")
1774 .unwrap();
1775 ctx.register_expression_function("polynomial", &["x"], "x^3 + 2*x^2 + 3*x + 4")
1776 .unwrap();
1777
1778 let call_ast = crate::engine::parse_expression("polynomial(2)").unwrap();
1779 let result = crate::eval::eval_ast(&call_ast, Some(Rc::new(ctx.clone()))).unwrap();
1780
1781 println!("polynomial(2) after overriding = {}", result);
1782 assert!((result - 26.0).abs() < 1e-10);
1783 }
1784
1785 #[test]
1787 fn test_polynomial_name_collision_with_builtin() {
1788 let mut ctx = EvalContext::new();
1789 ctx.register_expression_function("sin", &["x"], "x + 100")
1791 .unwrap();
1792
1793 let call_ast = crate::engine::parse_expression("sin(2)").unwrap();
1794 let result = crate::eval::eval_ast(&call_ast, Some(Rc::new(ctx.clone()))).unwrap();
1795
1796 println!("sin(2) with override = {}", result);
1797 assert_eq!(result, 102.0);
1798 }
1799}