extern crate alloc;
#[cfg(test)]
use crate::Real;
#[cfg(not(test))]
use crate::{Real, String, ToString, Vec};
#[cfg(not(test))]
use alloc::rc::Rc;
use crate::types::{TryIntoFunctionName, TryIntoHeaplessString};
#[cfg(test)]
use std::rc::Rc;
#[cfg(test)]
use std::string::{String, ToString};
#[cfg(test)]
use std::vec::Vec;
#[allow(dead_code)]
#[derive(Default, Clone)]
pub struct FunctionRegistry {
pub native_functions: crate::types::NativeFunctionMap,
pub expression_functions: crate::types::ExpressionFunctionMap,
}
pub struct EvalContext {
pub variables: crate::types::VariableMap,
pub constants: crate::types::ConstantMap,
pub arrays: crate::types::ArrayMap,
pub attributes: crate::types::AttributeMap,
pub nested_arrays: crate::types::NestedArrayMap,
pub native_functions: Rc<crate::types::NativeFunctionMap>,
pub parent: Option<Rc<EvalContext>>,
}
impl EvalContext {
pub fn new() -> Self {
let mut ctx = Self {
variables: crate::types::VariableMap::new(),
constants: crate::types::ConstantMap::new(),
arrays: crate::types::ArrayMap::new(),
attributes: crate::types::AttributeMap::new(),
nested_arrays: crate::types::NestedArrayMap::new(),
native_functions: Rc::new(crate::types::NativeFunctionMap::new()),
parent: None,
};
ctx.register_default_math_functions();
ctx
}
pub fn with_default_functions() -> Self {
Self::new()
}
pub fn empty() -> Self {
Self {
variables: crate::types::VariableMap::new(),
constants: crate::types::ConstantMap::new(),
arrays: crate::types::ArrayMap::new(),
attributes: crate::types::AttributeMap::new(),
nested_arrays: crate::types::NestedArrayMap::new(),
native_functions: Rc::new(crate::types::NativeFunctionMap::new()),
parent: None,
}
}
pub fn set_parameter(
&mut self,
name: &str,
value: Real,
) -> Result<Option<Real>, crate::error::ExprError> {
let key = name.try_into_heapless()?;
match self.variables.insert(key, value) {
Ok(old_value) => Ok(old_value),
Err(_) => Err(crate::error::ExprError::CapacityExceeded("variables")),
}
}
pub fn register_native_function<F>(
&mut self,
name: &str,
arity: usize,
implementation: F,
) -> Result<(), crate::error::ExprError>
where
F: Fn(&[Real]) -> Real + 'static,
{
let key = name.try_into_function_name()?;
let function = crate::types::NativeFunction {
arity,
implementation: Rc::new(implementation),
name: key.clone(),
description: None,
};
match Rc::make_mut(&mut self.native_functions).insert(key, function) {
Ok(_) => Ok(()),
Err(_) => Err(crate::error::ExprError::CapacityExceeded(
"native_functions",
)),
}
}
pub fn enable_default_functions(&mut self) {
self.register_default_math_functions();
}
pub fn register_default_math_functions(&mut self) {
let _ = self.register_native_function("+", 2, |args| args[0] + args[1]);
let _ = self.register_native_function("-", 2, |args| args[0] - args[1]);
let _ = self.register_native_function("*", 2, |args| args[0] * args[1]);
let _ = self.register_native_function("/", 2, |args| args[0] / args[1]);
let _ = self.register_native_function("%", 2, |args| args[0] % args[1]);
let _ =
self.register_native_function("<", 2, |args| if args[0] < args[1] { 1.0 } else { 0.0 });
let _ =
self.register_native_function(">", 2, |args| if args[0] > args[1] { 1.0 } else { 0.0 });
let _ = self.register_native_function(
"<=",
2,
|args| if args[0] <= args[1] { 1.0 } else { 0.0 },
);
let _ = self.register_native_function(
">=",
2,
|args| if args[0] >= args[1] { 1.0 } else { 0.0 },
);
let _ = self.register_native_function(
"==",
2,
|args| if args[0] == args[1] { 1.0 } else { 0.0 },
);
let _ = self.register_native_function(
"!=",
2,
|args| if args[0] != args[1] { 1.0 } else { 0.0 },
);
let _ = self.register_native_function("&&", 2, |args| {
if args[0] != 0.0 && args[1] != 0.0 {
1.0
} else {
0.0
}
});
let _ = self.register_native_function("||", 2, |args| {
if args[0] != 0.0 || args[1] != 0.0 {
1.0
} else {
0.0
}
});
let _ = self.register_native_function("add", 2, |args| args[0] + args[1]);
let _ = self.register_native_function("sub", 2, |args| args[0] - args[1]);
let _ = self.register_native_function("mul", 2, |args| args[0] * args[1]);
let _ = self.register_native_function("div", 2, |args| args[0] / args[1]);
let _ = self.register_native_function("fmod", 2, |args| args[0] % args[1]);
let _ = self.register_native_function("neg", 1, |args| -args[0]);
let _ = self.register_native_function(",", 2, |args| args[1]); let _ = self.register_native_function("comma", 2, |args| args[1]);
let _ = self.register_native_function("abs", 1, |args| args[0].abs());
let _ = self.register_native_function("max", 2, |args| args[0].max(args[1]));
let _ = self.register_native_function("min", 2, |args| args[0].min(args[1]));
let _ = self.register_native_function("sign", 1, |args| {
if args[0] > 0.0 {
1.0
} else if args[0] < 0.0 {
-1.0
} else {
0.0
}
});
#[cfg(feature = "f32")]
let _ = self.register_native_function("e", 0, |_| core::f32::consts::E);
#[cfg(not(feature = "f32"))]
let _ = self.register_native_function("e", 0, |_| core::f64::consts::E);
#[cfg(feature = "f32")]
let _ = self.register_native_function("pi", 0, |_| core::f32::consts::PI);
#[cfg(not(feature = "f32"))]
let _ = self.register_native_function("pi", 0, |_| core::f64::consts::PI);
#[cfg(feature = "libm")]
{
let _ = self
.register_native_function("acos", 1, |args| crate::functions::acos(args[0], 0.0));
let _ = self
.register_native_function("asin", 1, |args| crate::functions::asin(args[0], 0.0));
let _ = self
.register_native_function("atan", 1, |args| crate::functions::atan(args[0], 0.0));
let _ = self.register_native_function("atan2", 2, |args| {
crate::functions::atan2(args[0], args[1])
});
let _ = self
.register_native_function("ceil", 1, |args| crate::functions::ceil(args[0], 0.0));
let _ =
self.register_native_function("cos", 1, |args| crate::functions::cos(args[0], 0.0));
let _ = self
.register_native_function("cosh", 1, |args| crate::functions::cosh(args[0], 0.0));
let _ =
self.register_native_function("exp", 1, |args| crate::functions::exp(args[0], 0.0));
let _ = self
.register_native_function("floor", 1, |args| crate::functions::floor(args[0], 0.0));
let _ = self
.register_native_function("round", 1, |args| crate::functions::round(args[0], 0.0));
let _ =
self.register_native_function("ln", 1, |args| crate::functions::ln(args[0], 0.0));
let _ =
self.register_native_function("log", 1, |args| crate::functions::log(args[0], 0.0));
let _ = self
.register_native_function("log10", 1, |args| crate::functions::log10(args[0], 0.0));
let _ = self
.register_native_function("pow", 2, |args| crate::functions::pow(args[0], args[1]));
let _ = self
.register_native_function("^", 2, |args| crate::functions::pow(args[0], args[1]));
let _ =
self.register_native_function("sin", 1, |args| crate::functions::sin(args[0], 0.0));
let _ = self
.register_native_function("sinh", 1, |args| crate::functions::sinh(args[0], 0.0));
let _ = self
.register_native_function("sqrt", 1, |args| crate::functions::sqrt(args[0], 0.0));
let _ =
self.register_native_function("tan", 1, |args| crate::functions::tan(args[0], 0.0));
let _ = self
.register_native_function("tanh", 1, |args| crate::functions::tanh(args[0], 0.0));
}
#[cfg(all(not(feature = "libm"), test))]
{
let _ = self.register_native_function("acos", 1, |args| args[0].acos());
let _ = self.register_native_function("asin", 1, |args| args[0].asin());
let _ = self.register_native_function("atan", 1, |args| args[0].atan());
let _ = self.register_native_function("atan2", 2, |args| args[0].atan2(args[1]));
let _ = self.register_native_function("ceil", 1, |args| args[0].ceil());
let _ = self.register_native_function("cos", 1, |args| args[0].cos());
let _ = self.register_native_function("cosh", 1, |args| args[0].cosh());
let _ = self.register_native_function("exp", 1, |args| args[0].exp());
let _ = self.register_native_function("floor", 1, |args| args[0].floor());
let _ = self.register_native_function("round", 1, |args| args[0].round());
let _ = self.register_native_function("ln", 1, |args| args[0].ln());
let _ = self.register_native_function("log", 1, |args| args[0].log10());
let _ = self.register_native_function("log10", 1, |args| args[0].log10());
let _ = self.register_native_function("pow", 2, |args| args[0].powf(args[1]));
let _ = self.register_native_function("^", 2, |args| args[0].powf(args[1]));
let _ = self.register_native_function("sin", 1, |args| args[0].sin());
let _ = self.register_native_function("sinh", 1, |args| args[0].sinh());
let _ = self.register_native_function("sqrt", 1, |args| args[0].sqrt());
let _ = self.register_native_function("tan", 1, |args| args[0].tan());
let _ = self.register_native_function("tanh", 1, |args| args[0].tanh());
}
}
pub fn get_variable(&self, name: &str) -> Option<Real> {
if let Ok(key) = name.try_into_heapless() {
if let Some(val) = self.variables.get(&key) {
return Some(*val);
}
}
if let Some(parent) = &self.parent {
parent.get_variable(name)
} else {
None
}
}
pub fn get_constant(&self, name: &str) -> Option<Real> {
if let Ok(key) = name.try_into_heapless() {
if let Some(val) = self.constants.get(&key) {
return Some(*val);
}
}
if let Some(parent) = &self.parent {
parent.get_constant(name)
} else {
None
}
}
pub fn get_array(&self, name: &str) -> Option<&alloc::vec::Vec<crate::Real>> {
if let Ok(key) = name.try_into_heapless() {
if let Some(arr) = self.arrays.get(&key) {
return Some(arr);
}
}
if let Some(parent) = &self.parent {
parent.get_array(name)
} else {
None
}
}
pub fn set_attribute(
&mut self,
object_name: &str,
attr_name: &str,
value: Real,
) -> Result<Option<Real>, crate::error::ExprError> {
let obj_key = object_name.try_into_heapless()?;
let attr_key = attr_name.try_into_heapless()?;
if !self.attributes.contains_key(&obj_key) {
let attr_map = heapless::FnvIndexMap::<
crate::types::HString,
Real,
{ crate::types::EXP_RS_MAX_ATTR_KEYS },
>::new();
self.attributes
.insert(obj_key.clone(), attr_map)
.map_err(|_| crate::error::ExprError::CapacityExceeded("attributes"))?;
}
if let Some(attr_map) = self.attributes.get_mut(&obj_key) {
attr_map
.insert(attr_key, value)
.map_err(|_| crate::error::ExprError::CapacityExceeded("object attributes"))
} else {
unreachable!("Just inserted the object")
}
}
pub fn get_attribute_map(
&self,
base: &str,
) -> Option<
&heapless::FnvIndexMap<crate::types::HString, Real, { crate::types::EXP_RS_MAX_ATTR_KEYS }>,
> {
if let Ok(key) = base.try_into_heapless() {
if let Some(attr_map) = self.attributes.get(&key) {
return Some(attr_map);
}
}
if let Some(parent) = &self.parent {
parent.get_attribute_map(base)
} else {
None
}
}
pub fn get_native_function(&self, name: &str) -> Option<&crate::types::NativeFunction> {
if let Ok(key) = name.try_into_function_name() {
if let Some(f) = self.native_functions.get(&key) {
return Some(f);
}
}
if let Some(parent) = &self.parent {
parent.get_native_function(name)
} else {
None
}
}
pub fn list_native_functions(&self) -> Vec<String> {
let mut functions = Vec::new();
let mut seen = alloc::collections::BTreeSet::new();
for (name, _) in self.native_functions.iter() {
let name_str = name.to_string();
if seen.insert(name_str.clone()) {
functions.push(name_str);
}
}
if let Some(parent) = &self.parent {
for name in parent.list_native_functions() {
if seen.insert(name.clone()) {
functions.push(name);
}
}
}
functions.sort();
functions
}
pub fn list_expression_functions(&self) -> Vec<String> {
let mut functions = Vec::new();
let mut seen = alloc::collections::BTreeSet::new();
if let Some(parent) = &self.parent {
for name in parent.list_expression_functions() {
if seen.insert(name.clone()) {
functions.push(name);
}
}
}
functions.sort();
functions
}
}
impl Clone for EvalContext {
fn clone(&self) -> Self {
Self {
variables: self.variables.clone(),
constants: self.constants.clone(),
arrays: self.arrays.clone(),
attributes: self.attributes.clone(),
nested_arrays: self.nested_arrays.clone(),
native_functions: self.native_functions.clone(),
parent: self.parent.clone(),
}
}
}
impl Default for EvalContext {
fn default() -> Self {
EvalContext::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine;
use crate::types::AstExpr;
use crate::types::TryIntoHeaplessString;
use std::rc::Rc;
#[test]
fn test_get_variable_parent_chain() {
let mut parent_ctx = EvalContext::new();
let _ = parent_ctx.set_parameter("parent_only", 1.0);
let _ = parent_ctx.set_parameter("shadowed", 2.0);
let mut child_ctx = EvalContext::new();
let _ = child_ctx.set_parameter("child_only", 3.0);
let _ = child_ctx.set_parameter("shadowed", 4.0); child_ctx.parent = Some(Rc::new(parent_ctx));
assert_eq!(child_ctx.get_variable("parent_only"), Some(1.0));
assert_eq!(child_ctx.get_variable("child_only"), Some(3.0));
assert_eq!(child_ctx.get_variable("shadowed"), Some(4.0));
assert_eq!(child_ctx.get_variable("nonexistent"), None);
}
#[test]
fn test_get_variable_deep_chain() {
let mut grandparent_ctx = EvalContext::new();
let _ = grandparent_ctx.set_parameter("grandparent_var", 1.0);
let _ = grandparent_ctx.set_parameter("shadowed", 2.0);
let mut parent_ctx = EvalContext::new();
let _ = parent_ctx.set_parameter("parent_var", 3.0);
let _ = parent_ctx.set_parameter("shadowed", 4.0);
parent_ctx.parent = Some(Rc::new(grandparent_ctx));
let mut child_ctx = EvalContext::new();
let _ = child_ctx.set_parameter("child_var", 5.0);
let _ = child_ctx.set_parameter("shadowed", 6.0);
child_ctx.parent = Some(Rc::new(parent_ctx));
assert_eq!(child_ctx.get_variable("child_var"), Some(5.0));
assert_eq!(child_ctx.get_variable("parent_var"), Some(3.0));
assert_eq!(child_ctx.get_variable("grandparent_var"), Some(1.0));
assert_eq!(child_ctx.get_variable("shadowed"), Some(6.0));
}
#[test]
fn test_get_variable_null_parent() {
let mut ctx = EvalContext::new();
let _ = ctx.set_parameter("x", 1.0);
ctx.parent = None;
assert_eq!(ctx.get_variable("x"), Some(1.0));
assert_eq!(ctx.get_variable("nonexistent"), None);
}
#[test]
fn test_get_variable_cyclic_reference_safety() {
let mut ctx1 = EvalContext::new();
let mut ctx2 = EvalContext::new();
let _ = ctx1.set_parameter("var1", 1.0);
let _ = ctx2.set_parameter("var2", 2.0);
let ctx1_rc = Rc::new(ctx1);
ctx2.parent = Some(Rc::clone(&ctx1_rc));
assert_eq!(ctx2.get_variable("var2"), Some(2.0));
assert_eq!(ctx2.get_variable("var1"), Some(1.0));
}
#[test]
fn test_get_variable_in_function_scope() {
let mut ctx = EvalContext::new();
let _ = ctx.set_parameter("x", 100.0);
let mut func_ctx = EvalContext::new();
let _ = func_ctx.set_parameter("x", 5.0); func_ctx.parent = Some(Rc::new(ctx.clone()));
assert_eq!(
func_ctx.get_variable("x"),
Some(5.0),
"Function parameter should shadow parent variable"
);
println!("Parent context x = {:?}", ctx.get_variable("x"));
println!("Function context x = {:?}", func_ctx.get_variable("x"));
println!("Function context variables: {:?}", func_ctx.variables);
println!(
"Function context parent variables: {:?}",
func_ctx.parent.as_ref().map(|p| &p.variables)
);
}
#[test]
fn test_get_variable_nested_scopes() {
let mut root_ctx = EvalContext::new();
let _ = root_ctx.set_parameter("x", 1.0);
let _ = root_ctx.set_parameter("y", 1.0);
let mut mid_ctx = EvalContext::new();
let _ = mid_ctx.set_parameter("x", 2.0);
mid_ctx.parent = Some(Rc::new(root_ctx));
let mut leaf_ctx = EvalContext::new();
let _ = leaf_ctx.set_parameter("x", 3.0);
leaf_ctx.parent = Some(Rc::new(mid_ctx));
assert_eq!(
leaf_ctx.get_variable("x"),
Some(3.0),
"Should get leaf context value"
);
assert_eq!(
leaf_ctx.get_variable("y"),
Some(1.0),
"Should get root context value when not shadowed"
);
println!("Variable lookup in nested scopes:");
println!("leaf x = {:?}", leaf_ctx.get_variable("x"));
println!("leaf y = {:?}", leaf_ctx.get_variable("y"));
println!("leaf variables: {:?}", leaf_ctx.variables);
println!(
"mid variables: {:?}",
leaf_ctx.parent.as_ref().map(|p| &p.variables)
);
println!(
"root variables: {:?}",
leaf_ctx
.parent
.as_ref()
.and_then(|p| p.parent.as_ref())
.map(|p| &p.variables)
);
}
#[test]
fn test_get_variable_function_parameter_precedence() {
let mut ctx = EvalContext::new();
let arena = bumpalo::Bump::new();
let mut batch = crate::expression::Expression::new(&arena);
batch
.register_expression_function("f", &["x"], "x * 2")
.unwrap();
let _ = ctx.set_parameter("x", 100.0);
let mut func_ctx = EvalContext::new();
let _ = func_ctx.set_parameter("x", 5.0); func_ctx.parent = Some(Rc::new(ctx));
println!("Function parameter context:");
println!("func_ctx x = {:?}", func_ctx.get_variable("x"));
println!("func_ctx variables: {:?}", func_ctx.variables);
println!(
"parent variables: {:?}",
func_ctx.parent.as_ref().map(|p| &p.variables)
);
assert_eq!(
func_ctx.get_variable("x"),
Some(5.0),
"Function parameter should take precedence over global x"
);
}
#[test]
fn test_get_variable_temporary_scope() {
let mut ctx = EvalContext::new();
let _ = ctx.set_parameter("x", 1.0);
let mut temp_ctx = EvalContext::new();
temp_ctx.parent = Some(Rc::new(ctx));
assert_eq!(
temp_ctx.get_variable("x"),
Some(1.0),
"Should find variable in parent scope"
);
let _ = temp_ctx.set_parameter("x", 2.0);
assert_eq!(
temp_ctx.get_variable("x"),
Some(2.0),
"Should find shadowed variable in local scope"
);
println!("Temporary scope variable lookup:");
println!("temp x = {:?}", temp_ctx.get_variable("x"));
println!("temp variables: {:?}", temp_ctx.variables);
println!(
"parent variables: {:?}",
temp_ctx.parent.as_ref().map(|p| &p.variables)
);
}
#[test]
fn test_native_function() {
let mut ctx = EvalContext::new();
ctx.register_native_function("add_all", 3, |args| args.iter().sum())
.unwrap();
let val = engine::interp("add_all(1, 2, 3)", Some(Rc::new(ctx))).unwrap();
assert_eq!(val, 6.0);
}
#[test]
fn test_array_access() {
let mut ctx = EvalContext::new();
ctx.arrays
.insert(
"climb_wave_wait_time".try_into_heapless().unwrap(),
vec![10.0, 20.0, 30.0],
)
.expect("Failed to insert array");
let val = engine::interp("climb_wave_wait_time[1]", Some(Rc::new(ctx.clone()))).unwrap();
assert_eq!(val, 20.0);
}
#[test]
fn test_array_access_ast_structure() {
let mut ctx = EvalContext::new();
ctx.arrays
.insert(
"climb_wave_wait_time".try_into_heapless().unwrap(),
vec![10.0, 20.0, 30.0],
)
.expect("Failed to insert array");
use bumpalo::Bump;
let arena = Bump::new();
let ast = engine::parse_expression("climb_wave_wait_time[1]", &arena).unwrap();
match ast {
AstExpr::Array { name, index } => {
assert_eq!(name, "climb_wave_wait_time");
match *index {
AstExpr::Constant(val) => assert_eq!(val, 1.0),
_ => panic!("Expected constant index"),
}
}
_ => panic!("Expected array AST node"),
}
}
#[test]
fn test_attribute_access() {
let mut ctx = EvalContext::new();
let mut foo_map = heapless::FnvIndexMap::<
crate::types::HString,
crate::Real,
{ crate::types::EXP_RS_MAX_ATTR_KEYS },
>::new();
foo_map
.insert("bar".try_into_heapless().unwrap(), 42.0)
.unwrap();
ctx.attributes
.insert("foo".try_into_heapless().unwrap(), foo_map)
.unwrap();
use bumpalo::Bump;
let arena = Bump::new();
let ast = engine::parse_expression("foo.bar", &arena).unwrap();
println!("AST for foo.bar: {:?}", ast);
let ctx_copy = ctx.clone();
let eval_result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx_copy)), &arena);
println!("Direct eval_ast result: {:?}", eval_result);
let ctx_copy2 = ctx.clone();
let val = engine::interp("foo.bar", Some(Rc::new(ctx_copy2))).unwrap();
assert_eq!(val, 42.0);
let ctx_copy3 = ctx.clone();
let err = engine::interp("foo.baz", Some(Rc::new(ctx_copy3))).unwrap_err();
println!("Error for foo.baz: {:?}", err);
let ctx_copy4 = ctx.clone();
let err2 = engine::interp("nope.bar", Some(Rc::new(ctx_copy4))).unwrap_err();
println!("Error for nope.bar: {:?}", err2);
let err3 = engine::interp("foo.bar", None).unwrap_err();
println!("Error for foo.bar with None context: {:?}", err3);
}
#[test]
fn test_set_parameter() {
let mut ctx = EvalContext::new();
let prev = ctx.set_parameter("x", 10.0);
assert_eq!(prev.unwrap(), None);
let val = engine::interp("x", Some(Rc::new(ctx.clone()))).unwrap();
assert_eq!(val, 10.0);
let prev = ctx.set_parameter("x", 20.0);
assert_eq!(prev.unwrap(), Some(10.0));
let val = engine::interp("x", Some(Rc::new(ctx.clone()))).unwrap();
assert_eq!(val, 20.0);
let val = engine::interp("x * 2", Some(Rc::new(ctx.clone()))).unwrap();
assert_eq!(val, 40.0);
}
}