1extern crate alloc;
2
3#[cfg(test)]
4use crate::Real;
5#[cfg(not(test))]
6use crate::{Real, String, ToString, Vec};
7#[cfg(not(test))]
8use alloc::rc::Rc;
9#[cfg(not(test))]
10use hashbrown::HashMap;
11#[cfg(test)]
12use std::collections::HashMap;
13#[cfg(test)]
14use std::rc::Rc;
15#[cfg(test)]
16use std::string::{String, ToString};
17#[cfg(test)]
18use std::vec::Vec;
19
20#[allow(dead_code)]
29#[derive(Default, Clone)]
30pub struct FunctionRegistry<'a> {
31 pub native_functions: HashMap<String, crate::types::NativeFunction<'a>>,
33 pub expression_functions: HashMap<String, crate::types::ExpressionFunction>,
35 pub user_functions: HashMap<String, UserFunction>,
37}
38
39use core::cell::RefCell;
40
41#[derive(Default)]
87pub struct EvalContext<'a> {
88 pub variables: HashMap<String, Real>,
90 pub constants: HashMap<String, Real>,
92 pub arrays: HashMap<String, Vec<Real>>,
94 pub attributes: HashMap<String, HashMap<String, Real>>,
96 pub nested_arrays: HashMap<String, HashMap<usize, Vec<Real>>>,
98 pub function_registry: Rc<FunctionRegistry<'a>>,
100 pub parent: Option<Rc<EvalContext<'a>>>,
102 pub ast_cache: Option<RefCell<HashMap<String, Rc<crate::types::AstExpr>>>>,
104}
105
106impl<'a> EvalContext<'a> {
107 pub fn new() -> Self {
121 Self {
122 variables: HashMap::new(),
123 constants: HashMap::new(),
124 arrays: HashMap::new(),
125 attributes: HashMap::new(),
126 nested_arrays: HashMap::new(),
127 function_registry: Rc::new(FunctionRegistry::default()),
128 parent: None,
129 ast_cache: None,
130 }
131 }
132
133 pub fn set_parameter(&mut self, name: &str, value: Real) -> Option<Real> {
161 self.variables.insert(name.to_string(), value)
162 }
163
164 pub fn register_native_function<F>(&mut self, name: &str, arity: usize, implementation: F)
211 where
212 F: Fn(&[Real]) -> Real + 'static,
213 {
214 Rc::make_mut(&mut self.function_registry)
215 .native_functions
216 .insert(
217 name.to_string(),
218 crate::types::NativeFunction {
219 arity,
220 implementation: Rc::new(implementation),
221 name: name.to_string().into(),
222 description: None,
223 },
224 );
225 }
226
227 pub fn register_expression_function(
283 &mut self,
284 name: &str,
285 params: &[&str],
286 expression: &str,
287 ) -> Result<(), crate::error::ExprError> {
288 let param_names: Vec<String> = params.iter().map(|&s| s.to_string()).collect();
290 let ast = crate::engine::parse_expression_with_reserved(expression, Some(¶m_names))?;
291
292 Rc::make_mut(&mut self.function_registry)
294 .expression_functions
295 .insert(
296 name.to_string(),
297 crate::types::ExpressionFunction {
298 name: name.to_string(),
299 params: param_names,
300 expression: expression.to_string(),
301 compiled_ast: ast,
302 description: None,
303 },
304 );
305
306 Ok(())
307 }
308
309 pub fn enable_ast_cache(&self) {
339 if self.ast_cache.is_none() {
340 let cache = RefCell::new(HashMap::new());
342 unsafe {
345 let self_mut = self as *const _ as *mut Self;
346 (*self_mut).ast_cache = Some(cache);
347 }
348 }
349 }
350
351 pub fn disable_ast_cache(&self) {
367 unsafe {
369 let self_mut = self as *const _ as *mut Self;
370 (*self_mut).ast_cache = None;
371 }
372 }
373
374 pub fn clear_ast_cache(&self) {
376 if let Some(cache) = self.ast_cache.as_ref() {
377 cache.borrow_mut().clear();
378 }
379 }
380
381 #[cfg(not(feature = "no-builtin-math"))]
401 pub fn register_default_math_functions(&mut self) {
402 self.register_native_function("abs", 1, |args| crate::functions::abs(args[0], 0.0));
403 self.register_native_function("acos", 1, |args| crate::functions::acos(args[0], 0.0));
404 self.register_native_function("asin", 1, |args| crate::functions::asin(args[0], 0.0));
405 self.register_native_function("atan", 1, |args| crate::functions::atan(args[0], 0.0));
406 self.register_native_function("atan2", 2, |args| crate::functions::atan2(args[0], args[1]));
407 self.register_native_function("ceil", 1, |args| crate::functions::ceil(args[0], 0.0));
408 self.register_native_function("cos", 1, |args| crate::functions::cos(args[0], 0.0));
409 self.register_native_function("cosh", 1, |args| crate::functions::cosh(args[0], 0.0));
410 self.register_native_function("e", 0, |_args| crate::functions::e(0.0, 0.0));
411 self.register_native_function("exp", 1, |args| crate::functions::exp(args[0], 0.0));
412 self.register_native_function("floor", 1, |args| crate::functions::floor(args[0], 0.0));
413 self.register_native_function("ln", 1, |args| crate::functions::ln(args[0], 0.0));
414 self.register_native_function("log", 1, |args| crate::functions::log(args[0], 0.0));
415 self.register_native_function("log10", 1, |args| crate::functions::log10(args[0], 0.0));
416 self.register_native_function("max", 2, |args| crate::functions::max(args[0], args[1]));
417 self.register_native_function("min", 2, |args| crate::functions::min(args[0], args[1]));
418 self.register_native_function("pi", 0, |_args| crate::functions::pi(0.0, 0.0));
419 self.register_native_function("pow", 2, |args| crate::functions::pow(args[0], args[1]));
420 self.register_native_function("^", 2, |args| crate::functions::pow(args[0], args[1]));
421 self.register_native_function("sin", 1, |args| crate::functions::sin(args[0], 0.0));
422 self.register_native_function("sinh", 1, |args| crate::functions::sinh(args[0], 0.0));
423 self.register_native_function("sqrt", 1, |args| crate::functions::sqrt(args[0], 0.0));
424 self.register_native_function("tan", 1, |args| crate::functions::tan(args[0], 0.0));
425 self.register_native_function("tanh", 1, |args| crate::functions::tanh(args[0], 0.0));
426 self.register_native_function("sign", 1, |args| crate::functions::sign(args[0], 0.0));
427 self.register_native_function("add", 2, |args| crate::functions::add(args[0], args[1]));
428 self.register_native_function("sub", 2, |args| crate::functions::sub(args[0], args[1]));
429 self.register_native_function("mul", 2, |args| crate::functions::mul(args[0], args[1]));
430 self.register_native_function("div", 2, |args| crate::functions::div(args[0], args[1]));
431 self.register_native_function("fmod", 2, |args| crate::functions::fmod(args[0], args[1]));
432 self.register_native_function("neg", 1, |args| crate::functions::neg(args[0], 0.0));
433 self.register_native_function("comma", 2, |args| crate::functions::comma(args[0], args[1]));
434 }
436
437 pub fn get_variable(&self, name: &str) -> Option<Real> {
460 if let Some(val) = self.variables.get(name) {
461 Some(*val)
462 } else if let Some(parent) = &self.parent {
463 parent.get_variable(name)
464 } else {
465 None
466 }
467 }
468
469 pub fn get_constant(&self, name: &str) -> Option<Real> {
470 if let Some(val) = self.constants.get(name) {
471 Some(*val)
472 } else if let Some(parent) = &self.parent {
473 parent.get_constant(name)
474 } else {
475 None
476 }
477 }
478
479 pub fn get_array(&self, name: &str) -> Option<&Vec<Real>> {
480 if let Some(arr) = self.arrays.get(name) {
481 Some(arr)
482 } else if let Some(parent) = &self.parent {
483 parent.get_array(name)
484 } else {
485 None
486 }
487 }
488
489 pub fn get_attribute_map(&self, base: &str) -> Option<&HashMap<String, Real>> {
490 if let Some(attr_map) = self.attributes.get(base) {
491 Some(attr_map)
492 } else if let Some(parent) = &self.parent {
493 parent.get_attribute_map(base)
494 } else {
495 None
496 }
497 }
498
499 pub fn get_native_function(&self, name: &str) -> Option<&crate::types::NativeFunction> {
500 if let Some(f) = self.function_registry.native_functions.get(name) {
501 Some(f)
502 } else if let Some(parent) = &self.parent {
503 parent.get_native_function(name)
504 } else {
505 None
506 }
507 }
508
509 pub fn get_user_function(&self, name: &str) -> Option<&crate::context::UserFunction> {
510 if let Some(f) = self.function_registry.user_functions.get(name) {
511 Some(f)
512 } else if let Some(parent) = &self.parent {
513 parent.get_user_function(name)
514 } else {
515 None
516 }
517 }
518
519 pub fn get_expression_function(&self, name: &str) -> Option<&crate::types::ExpressionFunction> {
520 if let Some(f) = self.function_registry.expression_functions.get(name) {
521 Some(f)
522 } else if let Some(parent) = &self.parent {
523 parent.get_expression_function(name)
524 } else {
525 None
526 }
527 }
528}
529
530impl<'a> Clone for EvalContext<'a> {
531 fn clone(&self) -> Self {
532 Self {
533 variables: self.variables.clone(),
534 constants: self.constants.clone(),
535 arrays: self.arrays.clone(),
536 attributes: self.attributes.clone(),
537 nested_arrays: self.nested_arrays.clone(),
538 function_registry: self.function_registry.clone(),
539 parent: self.parent.clone(),
540 ast_cache: self.ast_cache.clone(),
541 }
542 }
543}
544
545pub trait CloneShallowNativeFunctions<'a> {
547 fn clone_shallow(&self) -> HashMap<Cow<'a, str>, crate::types::NativeFunction<'a>>;
548}
549
550impl<'a> CloneShallowNativeFunctions<'a>
552 for HashMap<Cow<'a, str>, crate::types::NativeFunction<'a>>
553{
554 fn clone_shallow(&self) -> HashMap<Cow<'a, str>, crate::types::NativeFunction<'a>> {
555 self.iter()
559 .map(|(k, v)| {
560 (
561 k.clone(),
562 crate::types::NativeFunction {
563 arity: v.arity,
564 implementation: v.implementation.clone(), name: v.name.clone(),
566 description: v.description.clone(),
567 },
568 )
569 })
570 .collect()
571 }
572}
573
574use alloc::borrow::Cow;
575
576#[derive(Clone)]
578#[allow(dead_code)]
579pub struct UserFunction {
580 pub params: Vec<String>,
581 pub body: String,
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use crate::engine;
588 use crate::types::AstExpr;
589 use std::rc::Rc;
590
591 #[test]
592 fn test_get_variable_parent_chain() {
593 let mut parent_ctx = EvalContext::new();
595 parent_ctx.set_parameter("parent_only", 1.0);
596 parent_ctx.set_parameter("shadowed", 2.0);
597
598 let mut child_ctx = EvalContext::new();
600 child_ctx.set_parameter("child_only", 3.0);
601 child_ctx.set_parameter("shadowed", 4.0); child_ctx.parent = Some(Rc::new(parent_ctx));
603
604 assert_eq!(child_ctx.get_variable("parent_only"), Some(1.0));
606
607 assert_eq!(child_ctx.get_variable("child_only"), Some(3.0));
609
610 assert_eq!(child_ctx.get_variable("shadowed"), Some(4.0));
612
613 assert_eq!(child_ctx.get_variable("nonexistent"), None);
615 }
616
617 #[test]
618 fn test_get_variable_deep_chain() {
619 let mut grandparent_ctx = EvalContext::new();
621 grandparent_ctx.set_parameter("grandparent_var", 1.0);
622 grandparent_ctx.set_parameter("shadowed", 2.0);
623
624 let mut parent_ctx = EvalContext::new();
626 parent_ctx.set_parameter("parent_var", 3.0);
627 parent_ctx.set_parameter("shadowed", 4.0);
628 parent_ctx.parent = Some(Rc::new(grandparent_ctx));
629
630 let mut child_ctx = EvalContext::new();
632 child_ctx.set_parameter("child_var", 5.0);
633 child_ctx.set_parameter("shadowed", 6.0);
634 child_ctx.parent = Some(Rc::new(parent_ctx));
635
636 assert_eq!(child_ctx.get_variable("child_var"), Some(5.0));
638 assert_eq!(child_ctx.get_variable("parent_var"), Some(3.0));
639 assert_eq!(child_ctx.get_variable("grandparent_var"), Some(1.0));
640
641 assert_eq!(child_ctx.get_variable("shadowed"), Some(6.0));
643 }
644
645 #[test]
646 fn test_get_variable_null_parent() {
647 let mut ctx = EvalContext::new();
648 ctx.set_parameter("x", 1.0);
649 ctx.parent = None;
650
651 assert_eq!(ctx.get_variable("x"), Some(1.0));
652 assert_eq!(ctx.get_variable("nonexistent"), None);
653 }
654
655 #[test]
656 fn test_get_variable_cyclic_reference_safety() {
657 let mut ctx1 = EvalContext::new();
659 let mut ctx2 = EvalContext::new();
660
661 ctx1.set_parameter("var1", 1.0);
662 ctx2.set_parameter("var2", 2.0);
663
664 let ctx1_rc = Rc::new(ctx1);
667 ctx2.parent = Some(Rc::clone(&ctx1_rc));
668
669 assert_eq!(ctx2.get_variable("var2"), Some(2.0));
671 assert_eq!(ctx2.get_variable("var1"), Some(1.0));
672 }
673
674 #[test]
675 fn test_get_variable_in_function_scope() {
676 let mut ctx = EvalContext::new();
677
678 ctx.set_parameter("x", 100.0);
680
681 let mut func_ctx = EvalContext::new();
683 func_ctx.set_parameter("x", 5.0); func_ctx.parent = Some(Rc::new(ctx.clone()));
685
686 assert_eq!(func_ctx.get_variable("x"), Some(5.0),
688 "Function parameter should shadow parent variable");
689
690 println!("Parent context x = {:?}", ctx.get_variable("x"));
692 println!("Function context x = {:?}", func_ctx.get_variable("x"));
693 println!("Function context variables: {:?}", func_ctx.variables);
694 println!("Function context parent variables: {:?}",
695 func_ctx.parent.as_ref().map(|p| &p.variables));
696 }
697
698 #[test]
699 fn test_get_variable_nested_scopes() {
700 let mut root_ctx = EvalContext::new();
701 root_ctx.set_parameter("x", 1.0);
702 root_ctx.set_parameter("y", 1.0);
703
704 let mut mid_ctx = EvalContext::new();
705 mid_ctx.set_parameter("x", 2.0);
706 mid_ctx.parent = Some(Rc::new(root_ctx));
707
708 let mut leaf_ctx = EvalContext::new();
709 leaf_ctx.set_parameter("x", 3.0);
710 leaf_ctx.parent = Some(Rc::new(mid_ctx));
711
712 assert_eq!(leaf_ctx.get_variable("x"), Some(3.0),
714 "Should get leaf context value");
715 assert_eq!(leaf_ctx.get_variable("y"), Some(1.0),
716 "Should get root context value when not shadowed");
717
718 println!("Variable lookup in nested scopes:");
719 println!("leaf x = {:?}", leaf_ctx.get_variable("x"));
720 println!("leaf y = {:?}", leaf_ctx.get_variable("y"));
721 println!("leaf variables: {:?}", leaf_ctx.variables);
722 println!("mid variables: {:?}",
723 leaf_ctx.parent.as_ref().map(|p| &p.variables));
724 println!("root variables: {:?}",
725 leaf_ctx.parent.as_ref().and_then(|p| p.parent.as_ref()).map(|p| &p.variables));
726 }
727
728 #[test]
729 fn test_get_variable_function_parameter_precedence() {
730 let mut ctx = EvalContext::new();
731
732 ctx.register_expression_function("f", &["x"], "x * 2").unwrap();
734
735 ctx.set_parameter("x", 100.0);
737
738 let mut func_ctx = EvalContext::new();
740 func_ctx.set_parameter("x", 5.0); func_ctx.parent = Some(Rc::new(ctx));
742
743 println!("Function parameter context:");
744 println!("func_ctx x = {:?}", func_ctx.get_variable("x"));
745 println!("func_ctx variables: {:?}", func_ctx.variables);
746 println!("parent variables: {:?}",
747 func_ctx.parent.as_ref().map(|p| &p.variables));
748
749 assert_eq!(func_ctx.get_variable("x"), Some(5.0),
750 "Function parameter should take precedence over global x");
751 }
752
753 #[test]
754 fn test_get_variable_temporary_scope() {
755 let mut ctx = EvalContext::new();
756 ctx.set_parameter("x", 1.0);
757
758 let mut temp_ctx = EvalContext::new();
760 temp_ctx.parent = Some(Rc::new(ctx));
761
762 assert_eq!(temp_ctx.get_variable("x"), Some(1.0),
764 "Should find variable in parent scope");
765
766 temp_ctx.set_parameter("x", 2.0);
768
769 assert_eq!(temp_ctx.get_variable("x"), Some(2.0),
771 "Should find shadowed variable in local scope");
772
773 println!("Temporary scope variable lookup:");
774 println!("temp x = {:?}", temp_ctx.get_variable("x"));
775 println!("temp variables: {:?}", temp_ctx.variables);
776 println!("parent variables: {:?}",
777 temp_ctx.parent.as_ref().map(|p| &p.variables));
778 }
779
780 #[test]
781 fn test_native_function() {
782 let mut ctx = EvalContext::new();
783
784 ctx.register_native_function("add_all", 3, |args| args.iter().sum());
785
786 let val = engine::interp("add_all(1, 2, 3)", Some(Rc::new(ctx.clone()))).unwrap();
787 assert_eq!(val, 6.0);
788 }
789
790 #[test]
791 fn test_expression_function() {
792 let mut ctx = EvalContext::new();
793
794 ctx.register_expression_function("double", &["x"], "x * 2")
795 .unwrap();
796
797 ctx.variables.insert("value".to_string().into(), 5.0);
798
799 let val = engine::interp("double(value)", Some(Rc::new(ctx.clone()))).unwrap();
800 assert_eq!(val, 10.0);
801
802 let val2 = engine::interp("double(7)", Some(Rc::new(ctx.clone()))).unwrap();
803 assert_eq!(val2, 14.0);
804 }
805
806 #[test]
807 fn test_array_access() {
808 let mut ctx = EvalContext::new();
809 ctx.arrays.insert(
810 "climb_wave_wait_time".to_string().into(),
811 vec![10.0, 20.0, 30.0],
812 );
813 let val = engine::interp("climb_wave_wait_time[1]", Some(Rc::new(ctx.clone()))).unwrap();
814 assert_eq!(val, 20.0);
815 }
816
817 #[test]
818 fn test_array_access_ast_structure() {
819 let mut ctx = EvalContext::new();
820 ctx.arrays.insert(
821 "climb_wave_wait_time".to_string().into(),
822 vec![10.0, 20.0, 30.0],
823 );
824 let ast = engine::parse_expression("climb_wave_wait_time[1]").unwrap();
825 match ast {
826 AstExpr::Array { name, index } => {
827 assert_eq!(name, "climb_wave_wait_time");
828 match *index {
829 AstExpr::Constant(val) => assert_eq!(val, 1.0),
830 _ => panic!("Expected constant index"),
831 }
832 }
833 _ => panic!("Expected array AST node"),
834 }
835 }
836
837 #[test]
838 fn test_attribute_access() {
839 let mut ctx = EvalContext::new();
840 let mut foo_map = HashMap::new();
841 foo_map.insert("bar".to_string().into(), 42.0);
842 ctx.attributes.insert("foo".to_string().into(), foo_map);
843
844 let ast = engine::parse_expression("foo.bar").unwrap();
845 println!("AST for foo.bar: {:?}", ast);
846
847 let ctx_copy = ctx.clone();
848 let eval_result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx_copy)));
849 println!("Direct eval_ast result: {:?}", eval_result);
850
851 let ctx_copy2 = ctx.clone();
852 let val = engine::interp("foo.bar", Some(Rc::new(ctx_copy2))).unwrap();
853 assert_eq!(val, 42.0);
854
855 let ctx_copy3 = ctx.clone();
856 let err = engine::interp("foo.baz", Some(Rc::new(ctx_copy3))).unwrap_err();
857 println!("Error for foo.baz: {:?}", err);
858
859 let ctx_copy4 = ctx.clone();
860 let err2 = engine::interp("nope.bar", Some(Rc::new(ctx_copy4))).unwrap_err();
861 println!("Error for nope.bar: {:?}", err2);
862
863 let err3 = engine::interp("foo.bar", None).unwrap_err();
864 println!("Error for foo.bar with None context: {:?}", err3);
865 }
866
867 #[test]
868 fn test_set_parameter() {
869 let mut ctx = EvalContext::new();
870
871 let prev = ctx.set_parameter("x", 10.0);
872 assert_eq!(prev, None);
873
874 let val = engine::interp("x", Some(Rc::new(ctx.clone()))).unwrap();
875 assert_eq!(val, 10.0);
876
877 let prev = ctx.set_parameter("x", 20.0);
878 assert_eq!(prev, Some(10.0));
879
880 let val = engine::interp("x", Some(Rc::new(ctx.clone()))).unwrap();
881 assert_eq!(val, 20.0);
882
883 let val = engine::interp("x * 2", Some(Rc::new(ctx.clone()))).unwrap();
884 assert_eq!(val, 40.0);
885 }
886}