pub mod ast;
pub mod context_stack;
pub mod iterative;
pub mod recursion;
pub mod stack_ops;
pub mod types;
pub use ast::*;
pub use recursion::*;
pub use types::*;
pub use recursion::{
check_and_increment_recursion_depth, decrement_recursion_depth, get_recursion_depth,
reset_recursion_depth, set_max_recursion_depth,
};
#[cfg(test)]
mod tests {
use crate::types::{TryIntoFunctionName, TryIntoHeaplessString};
use super::*;
use crate::AstExpr;
use crate::Real;
use crate::context::EvalContext;
use crate::engine::interp;
use crate::error::ExprError;
use crate::parse_expression;
use std::rc::Rc;
use std::sync::atomic::Ordering;
#[cfg(feature = "libm")]
use crate::{abs, cos, max, min, neg, pow, sin};
fn test_eval_variable(name: &str, ctx: Option<Rc<EvalContext>>) -> Result<Real, ExprError> {
interp(name, ctx)
}
#[test]
fn test_eval_native_function_simple() {
let mut ctx = EvalContext::new();
let _ = ctx.register_native_function("triple", 1, |args| args[0] * 3.0);
let val = interp("triple(4)", Some(Rc::new(ctx))).unwrap();
assert_eq!(val, 12.0);
}
fn create_test_context<'a>() -> EvalContext {
let mut ctx = EvalContext::new();
#[cfg(all(test, not(feature = "libm")))]
{
let _ = ctx.register_native_function("sin", 1, |args| args[0].sin());
let _ = ctx.register_native_function("cos", 1, |args| args[0].cos());
let _ = ctx.register_native_function("tan", 1, |args| args[0].tan());
let _ = ctx.register_native_function("asin", 1, |args| args[0].asin());
let _ = ctx.register_native_function("acos", 1, |args| args[0].acos());
let _ = ctx.register_native_function("atan", 1, |args| args[0].atan());
let _ = ctx.register_native_function("atan2", 2, |args| args[0].atan2(args[1]));
let _ = ctx.register_native_function("sinh", 1, |args| args[0].sinh());
let _ = ctx.register_native_function("cosh", 1, |args| args[0].cosh());
let _ = ctx.register_native_function("tanh", 1, |args| args[0].tanh());
let _ = ctx.register_native_function("exp", 1, |args| args[0].exp());
let _ = ctx.register_native_function("ln", 1, |args| args[0].ln());
let _ = ctx.register_native_function("log", 1, |args| args[0].ln());
let _ = ctx.register_native_function("log10", 1, |args| args[0].log10());
let _ = ctx.register_native_function("log2", 1, |args| args[0].log2());
let _ = ctx.register_native_function("sqrt", 1, |args| args[0].sqrt());
let _ = ctx.register_native_function("abs", 1, |args| args[0].abs());
let _ = ctx.register_native_function("floor", 1, |args| args[0].floor());
let _ = ctx.register_native_function("ceil", 1, |args| args[0].ceil());
let _ = ctx.register_native_function("round", 1, |args| args[0].round());
let _ = ctx.register_native_function("pow", 2, |args| args[0].powf(args[1]));
let _ = ctx.register_native_function("^", 2, |args| args[0].powf(args[1]));
let _ = ctx.register_native_function("min", 2, |args| args[0].min(args[1]));
let _ = ctx.register_native_function("max", 2, |args| args[0].max(args[1]));
let _ = ctx.register_native_function("neg", 1, |args| -args[0]);
let _ = ctx.register_native_function("sign", 1, |args| {
if args[0] > 0.0 {
1.0
} else if args[0] < 0.0 {
-1.0
} else {
0.0
}
});
}
#[cfg(feature = "libm")]
{
let _ = ctx.register_native_function("sin", 1, |args| sin(args[0], 0.0));
let _ = ctx.register_native_function("cos", 1, |args| cos(args[0], 0.0));
let _ = ctx.register_native_function("pow", 2, |args| pow(args[0], args[1]));
let _ = ctx.register_native_function("^", 2, |args| pow(args[0], args[1]));
let _ = ctx.register_native_function("min", 2, |args| min(args[0], args[1]));
let _ = ctx.register_native_function("max", 2, |args| max(args[0], args[1]));
let _ = ctx.register_native_function("neg", 1, |args| neg(args[0], 0.0));
let _ = ctx.register_native_function("abs", 1, |args| abs(args[0], 0.0));
}
ctx
}
#[test]
fn test_eval_variable_builtin_constants() {
#[cfg(feature = "f32")]
{
assert!((test_eval_variable("pi", None).unwrap() - std::f32::consts::PI).abs() < 1e-5);
assert!((test_eval_variable("e", None).unwrap() - std::f32::consts::E).abs() < 1e-5);
}
#[cfg(not(feature = "f32"))]
{
assert!((test_eval_variable("pi", None).unwrap() - std::f64::consts::PI).abs() < 1e-10);
assert!((test_eval_variable("e", None).unwrap() - std::f64::consts::E).abs() < 1e-10);
}
}
#[test]
fn test_eval_variable_context_lookup() {
let mut ctx = EvalContext::new();
let _ = ctx.set_parameter("x", 42.0);
ctx.constants
.insert("y".try_into_heapless().unwrap(), crate::constants::PI)
.expect("Failed to insert constant");
assert_eq!(
test_eval_variable("x", Some(Rc::new(ctx.clone()))).unwrap(),
42.0
);
assert_eq!(
test_eval_variable("y", Some(Rc::new(ctx.clone()))).unwrap(),
crate::constants::PI
);
}
#[test]
fn test_eval_variable_unknown_and_function_name() {
let err = test_eval_variable("nosuchvar", None).unwrap_err();
assert!(matches!(err, ExprError::UnknownVariable { .. }));
let err2 = test_eval_variable("sin", None).unwrap_err();
assert!(matches!(err2, ExprError::Syntax(_)));
}
#[test]
fn test_eval_function_native() {
let ctx = create_test_context();
let val = interp("sin(0)", Some(Rc::new(ctx))).unwrap();
assert!((val - 0.0).abs() < 1e-10);
}
#[test]
fn test_eval_function_builtin_fallback() {
let ctx = create_test_context();
let val = interp("pow(2,3)", Some(Rc::new(ctx.clone()))).unwrap();
assert_eq!(val, 8.0);
let val2 = interp("abs(-5)", Some(Rc::new(ctx))).unwrap();
assert_eq!(val2, 5.0);
}
#[test]
fn test_eval_array_success_and_out_of_bounds() {
let mut ctx = EvalContext::new();
ctx.arrays
.insert("arr".try_into_heapless().unwrap(), vec![1.0, 2.0, 3.0])
.expect("Failed to insert array");
let val = interp("arr[1]", Some(Rc::new(ctx.clone()))).unwrap();
assert_eq!(val, 2.0);
let err = interp("arr[10]", Some(Rc::new(ctx))).unwrap_err();
assert!(matches!(err, ExprError::ArrayIndexOutOfBounds { .. }));
}
#[test]
fn test_eval_array_unknown() {
let ctx = EvalContext::new();
let err = interp("nosucharr[0]", Some(Rc::new(ctx))).unwrap_err();
assert!(matches!(err, ExprError::UnknownVariable { .. }));
}
#[test]
fn test_eval_attribute_success_and_not_found() {
let mut ctx = EvalContext::new();
ctx.set_attribute("bar", "foo", 123.0)
.expect("Failed to set attribute");
let val = interp("bar.foo", Some(Rc::new(ctx.clone()))).unwrap();
assert_eq!(val, 123.0);
let err = interp("bar.baz", Some(Rc::new(ctx.clone()))).unwrap_err();
assert!(matches!(err, ExprError::AttributeNotFound { .. }));
}
#[test]
fn test_eval_attribute_unknown_base() {
let ctx = EvalContext::new();
let err = interp("nosuch.foo", Some(Rc::new(ctx.clone()))).unwrap_err();
assert!(matches!(err, ExprError::AttributeNotFound { .. }));
}
#[test]
fn test_neg_pow_ast() {
use bumpalo::Bump;
let arena = Bump::new();
let ast = parse_expression("-2^2", &arena).unwrap_or_else(|e| panic!("Parse error: {}", e));
match ast {
AstExpr::Function { ref name, ref args } if *name == "neg" => {
assert_eq!(args.len(), 1);
match &args[0] {
AstExpr::Function {
name: pow_name,
args: pow_args,
} if *pow_name == "^" => {
assert_eq!(pow_args.len(), 2);
match (&pow_args[0], &pow_args[1]) {
(AstExpr::Constant(a), AstExpr::Constant(b)) => {
assert_eq!(*a, 2.0);
assert_eq!(*b, 2.0);
}
_ => panic!("Expected constants as pow args"),
}
}
_ => panic!("Expected pow as argument to neg"),
}
}
_ => panic!("Expected neg as top-level function"),
}
}
#[test]
#[cfg(feature = "libm")] fn test_neg_pow_eval() {
let val = interp("-2^2", None).unwrap();
assert_eq!(val, -4.0); let val2 = interp("(-2)^2", None).unwrap();
assert_eq!(val2, 4.0); }
#[test]
#[cfg(not(feature = "libm"))] fn test_neg_pow_eval_no_builtins() {
let mut ctx = EvalContext {
variables: Default::default(),
constants: Default::default(),
arrays: Default::default(),
attributes: Default::default(),
nested_arrays: Default::default(),
function_registry: Rc::new(FunctionRegistry::default()),
parent: None,
ast_cache: None,
};
let _ = ctx.register_native_function("neg", 1, |args| -args[0]);
ctx.register_native_function("^", 2, |args| args[0].powf(args[1]));
let ctx_rc = Rc::new(ctx.clone());
let val = interp("-2^2", Some(ctx_rc.clone())).unwrap();
assert_eq!(val, -4.0);
let val2 = interp("(-2)^2", Some(ctx_rc)).unwrap();
assert_eq!(val2, 4.0);
let empty_ctx = Rc::new(EvalContext {
variables: Default::default(),
constants: Default::default(),
arrays: Default::default(),
attributes: Default::default(),
nested_arrays: Default::default(),
function_registry: Rc::new(FunctionRegistry::default()),
parent: None,
ast_cache: None,
});
let err = interp("-2^2", Some(empty_ctx)).unwrap_err();
assert!(matches!(err, ExprError::UnknownFunction { .. }));
}
#[test]
fn test_paren_neg_pow_ast() {
use bumpalo::Bump;
let arena = Bump::new();
let ast =
parse_expression("(-2)^2", &arena).unwrap_or_else(|e| panic!("Parse error: {}", e));
match ast {
AstExpr::Function { ref name, ref args } if *name == "^" => {
assert_eq!(args.len(), 2);
match &args[0] {
AstExpr::Function {
name: neg_name,
args: neg_args,
} if *neg_name == "neg" => {
assert_eq!(neg_args.len(), 1);
match &neg_args[0] {
AstExpr::Constant(a) => assert_eq!(*a, 2.0),
_ => panic!("Expected constant as neg arg"),
}
}
_ => panic!("Expected neg as left arg to pow"),
}
match &args[1] {
AstExpr::Constant(b) => assert_eq!(*b, 2.0),
_ => panic!("Expected constant as right arg to pow"),
}
}
_ => panic!("Expected pow as top-level function"),
}
}
#[test]
fn test_function_application_juxtaposition_ast() {
use bumpalo::Bump;
let arena = Bump::new();
let sin_x_ast = crate::engine::parse_expression("sin(x)", &arena).unwrap();
match sin_x_ast {
AstExpr::Function { ref name, ref args } if *name == "sin" => {
assert_eq!(args.len(), 1);
match &args[0] {
AstExpr::Variable(var) => assert_eq!(*var, "x"),
_ => panic!("Expected variable as argument"),
}
}
_ => panic!("Expected function node for sin x"),
}
let abs_neg_42_ast = crate::engine::parse_expression("abs(-42)", &arena).unwrap();
println!("AST for 'abs(-42)': {:?}", abs_neg_42_ast);
match abs_neg_42_ast {
AstExpr::Function { ref name, ref args } if *name == "abs" => {
assert_eq!(args.len(), 1);
match &args[0] {
AstExpr::Function {
name: n2,
args: args2,
} if *n2 == "neg" => {
assert_eq!(args2.len(), 1);
match &args2[0] {
AstExpr::Constant(c) => assert_eq!(*c, 42.0),
_ => panic!("Expected constant as neg arg"),
}
}
_ => panic!("Expected neg as argument to abs"),
}
}
_ => panic!("Expected function node for abs -42"),
}
}
#[test]
fn test_function_application_juxtaposition_eval() {
let ctx = create_test_context();
#[cfg(not(feature = "libm"))]
{
let _ = ctx.register_native_function("abs", 1, |args| args[0].abs());
ctx.register_native_function("neg", 1, |args| -args[0]);
}
use bumpalo::Bump;
let arena = Bump::new();
let ast = crate::engine::parse_expression("abs(-42)", &arena).unwrap();
let val = crate::eval::ast::eval_ast(&ast, Some(Rc::new(ctx)), &arena).unwrap();
assert_eq!(val, 42.0);
}
#[test]
fn test_pow_arity_ast() {
use bumpalo::Bump;
let arena = Bump::new();
let ast =
parse_expression("pow(2)", &arena).unwrap_or_else(|e| panic!("Parse error: {}", e));
match ast {
AstExpr::Function { ref name, ref args } if *name == "pow" => {
assert!(args.len() == 1 || args.len() == 2);
match &args[0] {
AstExpr::Constant(c) => assert_eq!(*c, 2.0),
_ => panic!("Expected constant as pow arg"),
}
if args.len() == 2 {
match &args[1] {
AstExpr::Constant(c) => assert_eq!(*c, 2.0),
_ => panic!("Expected constant as pow second arg"),
}
}
}
_ => panic!("Expected function node for pow(2)"),
}
}
#[test]
#[cfg(feature = "libm")] fn test_pow_arity_eval() {
let result = interp("pow(2)", None).unwrap();
assert_eq!(result, 4.0);
let result2 = interp("pow(2, 3)", None).unwrap();
assert_eq!(result2, 8.0);
}
#[test]
#[cfg(not(feature = "libm"))] fn test_pow_arity_eval_no_builtins() {
let mut ctx = EvalContext {
variables: Default::default(),
constants: Default::default(),
arrays: Default::default(),
attributes: Default::default(),
nested_arrays: Default::default(),
function_registry: Rc::new(FunctionRegistry::default()),
parent: None,
ast_cache: None,
};
let _ = ctx.register_native_function("pow", 2, |args| args[0].powf(args[1]));
let ctx_rc = Rc::new(ctx);
let ast = crate::engine::parse_expression("pow(2)").unwrap();
println!("Parsed expression: {:?}", ast);
let result = interp("pow(2)", Some(ctx_rc.clone())).unwrap();
assert_eq!(result, 4.0, "pow(2) should be interpreted as pow(2,2) = 4");
let result2 = interp("pow(2, 3)", Some(ctx_rc)).unwrap();
assert_eq!(result2, 8.0);
}
#[test]
fn test_unknown_variable_and_function_ast() {
use bumpalo::Bump;
let arena = Bump::new();
let ast = parse_expression("sin", &arena).unwrap_or_else(|e| panic!("Parse error: {}", e));
match ast {
AstExpr::Variable(ref name) => assert_eq!(*name, "sin"),
_ => panic!("Expected variable node for sin"),
}
let ast2 = parse_expression("abs", &arena).unwrap_or_else(|e| panic!("Parse error: {}", e));
match ast2 {
AstExpr::Variable(ref name) => assert_eq!(*name, "abs"),
_ => panic!("Expected variable node for abs"),
}
}
#[test]
fn test_unknown_variable_and_function_eval() {
let ctx = create_test_context();
#[cfg(not(feature = "libm"))]
{
let _ = ctx.register_native_function("sin", 1, |args| args[0].sin());
ctx.register_native_function("abs", 1, |args| args[0].abs());
}
let ctx_rc = Rc::new(ctx);
let err = interp("sin", Some(ctx_rc.clone())).unwrap_err();
match err {
ExprError::Syntax(msg) => {
assert!(
msg.contains("Unexpected token")
|| msg.contains("Function 'sin' used without arguments")
);
}
_ => panic!("Expected Syntax error, got {:?}", err),
}
let err2 = interp("abs", Some(ctx_rc.clone())).unwrap_err();
match err2 {
ExprError::Syntax(msg) => {
assert!(
msg.contains("Unexpected token")
|| msg.contains("Function 'abs' used without arguments")
);
}
_ => panic!("Expected Syntax error, got {:?}", err2),
}
let err3 = interp("nosuchvar", Some(ctx_rc)).unwrap_err();
assert!(matches!(err3, ExprError::UnknownVariable { name } if name == "nosuchvar"));
}
#[test]
fn test_override_builtin_native() {
let mut ctx = create_test_context();
let _ = ctx.register_native_function("sin", 1, |_args| 100.0);
let _ = ctx.register_native_function("pow", 2, |args| args[0] + args[1]);
let _ = ctx.register_native_function("^", 2, |args| args[0] + args[1]);
let ctx_rc = Rc::new(ctx.clone());
let val_sin = interp("sin(0.5)", Some(ctx_rc.clone())).unwrap();
assert_eq!(val_sin, 100.0, "Native 'sin' override failed");
let val_pow = interp("pow(3, 4)", Some(ctx_rc.clone())).unwrap();
assert_eq!(val_pow, 7.0, "Native 'pow' override failed");
let val_pow_op = interp("3^4", Some(ctx_rc.clone())).unwrap();
assert_eq!(val_pow_op, 7.0, "Native '^' override failed");
#[cfg(not(feature = "libm"))]
{
ctx.register_native_function("cos", 1, |args| args[0].cos()); let ctx_rc = Rc::new(ctx.clone());
}
if ctx
.native_functions
.contains_key(&"cos".try_into_function_name().unwrap())
|| cfg!(feature = "libm")
{
let val_cos = interp("cos(0)", Some(ctx_rc.clone())).unwrap();
let expected_cos = 1.0;
assert!(
(val_cos - expected_cos).abs() < 1e-9,
"Built-in/default 'cos' failed after override. Got {}",
val_cos
);
} else {
let err = interp("cos(0)", Some(ctx_rc)).unwrap_err();
assert!(matches!(err, ExprError::UnknownFunction { .. }));
}
}
#[test]
fn test_polynomial_subexpressions() {
let mut ctx = EvalContext::new();
let _ = ctx.set_parameter("x", 2.0);
let ctx_rc = Rc::new(ctx);
use bumpalo::Bump;
let arena = Bump::new();
let ast = crate::engine::parse_expression("x^3", &arena).unwrap();
let result = crate::eval::ast::eval_ast(&ast, Some(ctx_rc.clone()), &arena).unwrap();
assert_eq!(result, 8.0);
let ast = crate::engine::parse_expression("2*x^2", &arena).unwrap();
let result = crate::eval::ast::eval_ast(&ast, Some(ctx_rc.clone()), &arena).unwrap();
assert_eq!(result, 8.0);
let ast = crate::engine::parse_expression("3*x", &arena).unwrap();
let result = crate::eval::ast::eval_ast(&ast, Some(ctx_rc.clone()), &arena).unwrap();
assert_eq!(result, 6.0);
let ast = crate::engine::parse_expression("4", &arena).unwrap();
let result = crate::eval::ast::eval_ast(&ast, Some(ctx_rc), &arena).unwrap();
assert_eq!(result, 4.0);
}
#[test]
fn test_operator_precedence() {
let mut ctx = EvalContext::new();
ctx.native_functions = Rc::new(crate::types::NativeFunctionMap::new());
let _ = ctx.register_native_function("+", 2, |args| args[0] + args[1]);
let _ = ctx.register_native_function("*", 2, |args| args[0] * args[1]);
let _ = ctx.register_native_function("^", 2, |args| args[0].powf(args[1]));
use bumpalo::Bump;
let arena = Bump::new();
let ast = crate::engine::parse_expression("2 + 3 * 4 ^ 2", &arena).unwrap();
let result = crate::eval::eval_ast(&ast, Some(Rc::new(ctx)), &arena).unwrap();
assert_eq!(result, 2.0 + 3.0 * 16.0); }
#[test]
fn test_polynomial_ast_structure() {
use bumpalo::Bump;
let arena = Bump::new();
let ast = crate::engine::parse_expression("x^3 + 2*x^2 + 3*x + 4", &arena).unwrap();
println!("{:?}", ast);
}
#[test]
fn test_recursion_depth_tracking_reset() {
let arena = bumpalo::Bump::new();
let ast = AstExpr::Constant(42.0);
let result = crate::eval::ast::eval_ast(&ast, None, &arena);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42.0);
}
#[test]
fn test_recursion_depth_with_non_recursive_expressions() {
RECURSION_DEPTH.store(0, Ordering::Relaxed);
let expr = "1 + 2 * 3 + 4 * 5 + 6 * 7 + 8 * 9 + 10";
let result = interp(expr, None);
assert!(
result.is_ok(),
"Failed to evaluate non-recursive expression: {:?}",
result.err()
);
assert_eq!(
result.unwrap(),
1.0 + 2.0 * 3.0 + 4.0 * 5.0 + 6.0 * 7.0 + 8.0 * 9.0 + 10.0
);
let depth = RECURSION_DEPTH.load(Ordering::Relaxed);
assert!(
depth < 10,
"Unexpectedly high recursion depth for non-recursive expr: {}",
depth
);
}
#[test]
fn test_recursion_tracking_function_specific() {
RECURSION_DEPTH.store(0, Ordering::Relaxed);
let expr = "(1 + 2) * (3 + 4) * (5 + 6) * (7 + 8) * (9 + 10) * (11 + 12) * (13 + 14)";
let ctx = EvalContext::new();
#[cfg(not(feature = "libm"))]
{
let _ = ctx.register_native_function("+", 2, |args| args[0] + args[1]);
ctx.register_native_function("*", 2, |args| args[0] * args[1]);
}
let result = interp(expr, Some(Rc::new(ctx)));
assert!(result.is_ok());
#[cfg(feature = "libm")]
{
let depth = RECURSION_DEPTH.load(Ordering::Relaxed);
assert!(
depth < 5,
"Recursion tracking shouldn't count non-function AST nodes, got depth: {}",
depth
);
}
#[cfg(not(feature = "libm"))]
{
let depth = RECURSION_DEPTH.load(Ordering::Relaxed);
println!(
"Without libm, recursion depth is higher due to operator functions: {}",
depth
);
}
}
}