extern crate alloc;
use bumpalo::Bump;
#[cfg(test)]
use exp_rs::context::EvalContext;
use exp_rs::engine::interp;
use exp_rs::eval::eval_ast;
use exp_rs::{assert_approx_eq, constants};
use std::sync::Mutex;
use std::time::Instant;
use exp_rs::Real;
use alloc::rc::Rc;
mod test_helpers;
use test_helpers::{hstr, set_attr, set_const, set_var};
fn parse_expression(
expr: &str,
) -> Result<exp_rs::types::AstExpr<'static>, exp_rs::error::ExprError> {
thread_local! {
static TEST_ARENA: std::cell::RefCell<Bump> = std::cell::RefCell::new(Bump::new());
}
TEST_ARENA.with(|arena| {
let arena = arena.borrow();
let ast = exp_rs::engine::parse_expression(expr, &*arena)?;
Ok(unsafe {
std::mem::transmute::<exp_rs::types::AstExpr<'_>, exp_rs::types::AstExpr<'static>>(ast)
})
})
}
#[allow(dead_code)]
fn parse_expression_with_reserved(
expr: &str,
reserved: Option<&[String]>,
) -> Result<exp_rs::types::AstExpr<'static>, exp_rs::error::ExprError> {
thread_local! {
static TEST_ARENA: std::cell::RefCell<Bump> = std::cell::RefCell::new(Bump::new());
}
TEST_ARENA.with(|arena| {
let arena = arena.borrow();
let ast =
exp_rs::engine::parse_expression_arena_with_context(expr, &*arena, reserved, None)?;
Ok(unsafe {
std::mem::transmute::<exp_rs::types::AstExpr<'_>, exp_rs::types::AstExpr<'static>>(ast)
})
})
}
#[test]
fn test_basic_expression_evaluation() {
assert_eq!(interp("2 + 3", None).unwrap(), 5.0);
assert_eq!(interp("2 * 3 + 4", None).unwrap(), 10.0);
assert_eq!(interp("2 * (3 + 4)", None).unwrap(), 14.0);
#[cfg(feature = "libm")]
{
#[cfg(feature = "f32")]
assert_approx_eq!(
interp("sin(0.5)", None).unwrap(),
exp_rs::functions::sin(0.5, 0.0),
1e-6 as Real );
#[cfg(not(feature = "f32"))]
assert_approx_eq!(
interp("sin(0.5)", None).unwrap(),
exp_rs::functions::sin(0.5, 0.0),
1e-10 as Real );
}
#[cfg(not(feature = "libm"))]
{
let ctx = create_context_rc();
#[cfg(feature = "f32")]
assert_approx_eq!(
interp("sin(0.5)", Some(ctx.clone())).unwrap(),
0.5_f32.sin() as Real,
1e-6 as Real );
#[cfg(not(feature = "f32"))]
assert_approx_eq!(
interp("sin(0.5)", Some(ctx.clone())).unwrap(),
0.5_f64.sin() as Real,
1e-10 as Real );
}
#[cfg(feature = "libm")]
{
#[cfg(feature = "f32")]
assert_approx_eq!(
interp("cos(0.5)", None).unwrap(),
exp_rs::functions::cos(0.5, 0.0),
1e-6 as Real );
#[cfg(not(feature = "f32"))]
assert_approx_eq!(
interp("cos(0.5)", None).unwrap(),
exp_rs::functions::cos(0.5, 0.0),
1e-10 as Real );
}
#[cfg(not(feature = "libm"))]
{
let ctx = create_context_rc();
#[cfg(feature = "f32")]
assert_approx_eq!(
interp("cos(0.5)", Some(ctx.clone())).unwrap(),
0.5_f32.cos() as Real,
1e-6 as Real );
#[cfg(not(feature = "f32"))]
assert_approx_eq!(
interp("cos(0.5)", Some(ctx.clone())).unwrap(),
0.5_f64.cos() as Real,
1e-10 as Real );
}
assert_approx_eq!(
interp("pi", None).unwrap(),
exp_rs::constants::PI,
exp_rs::constants::TEST_PRECISION
);
assert_approx_eq!(
interp("e", None).unwrap(),
exp_rs::constants::E,
exp_rs::constants::TEST_PRECISION
);
#[cfg(feature = "libm")]
{
#[cfg(feature = "f32")]
assert_approx_eq!(
interp("sin(cos(0.5))", None).unwrap(),
exp_rs::functions::sin(exp_rs::functions::cos(0.5, 0.0), 0.0),
1e-6 as Real );
#[cfg(not(feature = "f32"))]
assert_approx_eq!(
interp("sin(cos(0.5))", None).unwrap(),
exp_rs::functions::sin(exp_rs::functions::cos(0.5, 0.0), 0.0),
1e-10 as Real );
}
#[cfg(not(feature = "libm"))]
{
let ctx = create_context_rc();
#[cfg(feature = "f32")]
{
let expected = (0.5_f32.cos() as f32).sin() as Real;
assert_approx_eq!(
interp("sin(cos(0.5))", Some(ctx.clone())).unwrap(),
expected,
1e-6 as Real );
}
#[cfg(not(feature = "f32"))]
{
let expected = (0.5_f64.cos()).sin() as Real;
assert_approx_eq!(
interp("sin(cos(0.5))", Some(ctx.clone())).unwrap(),
expected,
1e-10 as Real );
}
}
}
#[test]
fn test_variable_expressions() {
#[cfg(not(feature = "libm"))]
let mut ctx = create_context();
#[cfg(feature = "libm")]
let mut ctx = EvalContext::default();
set_var(&mut ctx, "x", 5.0);
set_var(&mut ctx, "y", 10.0);
assert_eq!(
interp("x + y", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
15.0
);
assert_eq!(
interp("x * y", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
50.0
);
assert_eq!(
interp("(x + y) / 3", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
5.0
);
#[cfg(feature = "libm")]
{
#[cfg(feature = "f32")]
assert_approx_eq!(
interp("sin(x) + cos(y)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
exp_rs::functions::sin(5.0, 0.0) + exp_rs::functions::cos(10.0, 0.0),
1e-6 as Real );
#[cfg(not(feature = "f32"))]
assert_approx_eq!(
interp("sin(x) + cos(y)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
exp_rs::functions::sin(5.0, 0.0) + exp_rs::functions::cos(10.0, 0.0),
1e-10 as Real );
}
#[cfg(not(feature = "libm"))]
{
#[cfg(feature = "f32")]
assert_approx_eq!(
interp("sin(x) + cos(y)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
(5.0_f32.sin() + 10.0_f32.cos()) as Real,
1e-6 as Real );
#[cfg(not(feature = "f32"))]
assert_approx_eq!(
interp("sin(x) + cos(y)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
5.0_f64.sin() + 10.0_f64.cos(),
1e-10 as Real );
}
set_var(&mut ctx, "x", 7.0);
assert_eq!(
interp("x + y", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
17.0
);
}
#[test]
fn test_array_expressions() {
#[cfg(not(feature = "libm"))]
let mut ctx = create_context();
#[cfg(feature = "libm")]
let mut ctx = EvalContext::default();
ctx.arrays
.insert(hstr("data"), vec![10.0, 20.0, 30.0, 40.0, 50.0])
.expect("Failed to insert array");
assert_eq!(
interp("data[0]", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
10.0
);
assert_eq!(
interp("data[2]", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
30.0
);
assert_eq!(
interp("data[4]", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
50.0
);
assert_eq!(
interp("data[1] + data[3]", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
60.0
);
assert_eq!(
interp("data[2] * 2", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
60.0
);
assert_eq!(
interp("data[1+1]", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
30.0
);
assert_eq!(
interp("data[floor(1.8)]", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
20.0
);
set_var(&mut ctx, "i", 3.0);
assert_eq!(
interp("data[i]", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
40.0
);
}
#[test]
fn test_attribute_expressions() {
#[cfg(not(feature = "libm"))]
let mut ctx = create_context();
#[cfg(feature = "libm")]
let mut ctx = EvalContext::default();
set_attr(&mut ctx, "point", "x", 3.0);
set_attr(&mut ctx, "point", "y", 4.0);
assert_eq!(
interp("point.x", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
3.0
);
assert_eq!(
interp("point.y", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
4.0
);
assert_eq!(
interp("point.x + point.y", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
7.0
);
assert_approx_eq!(
interp(
"sqrt(point.x^2 + point.y^2)",
Some(std::rc::Rc::new(ctx.clone()))
)
.unwrap(),
5.0 as Real, crate::constants::TEST_PRECISION
);
set_attr(&mut ctx, "circle", "radius", 10.0);
set_attr(&mut ctx, "circle", "center_x", 5.0);
set_attr(&mut ctx, "circle", "center_y", 5.0);
let expr = "sqrt((point.x - circle.center_x)^2 + (point.y - circle.center_y)^2)";
println!(
"point.x = {}",
interp("point.x", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
println!(
"point.y = {}",
interp("point.y", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
println!(
"circle.center_x = {}",
interp("circle.center_x", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
println!(
"circle.center_y = {}",
interp("circle.center_y", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
println!(
"(point.x - circle.center_x) = {}",
interp(
"(point.x - circle.center_x)",
Some(std::rc::Rc::new(ctx.clone()))
)
.unwrap()
);
println!(
"(point.y - circle.center_y) = {}",
interp(
"(point.y - circle.center_y)",
Some(std::rc::Rc::new(ctx.clone()))
)
.unwrap()
);
println!(
"(point.x - circle.center_x)^2 = {}",
interp(
"(point.x - circle.center_x)^2",
Some(std::rc::Rc::new(ctx.clone()))
)
.unwrap()
);
println!(
"(point.y - circle.center_y)^2 = {}",
interp(
"(point.y - circle.center_y)^2",
Some(std::rc::Rc::new(ctx.clone()))
)
.unwrap()
);
println!(
"(point.x - circle.center_x)^2 + (point.y - circle.center_y)^2 = {}",
interp(
"(point.x - circle.center_x)^2 + (point.y - circle.center_y)^2",
Some(std::rc::Rc::new(ctx.clone()))
)
.unwrap()
);
println!(
"Full expression: {} = {}",
expr,
interp(expr, Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
assert_approx_eq!(
interp(expr, Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
2.23606797749979 as Real, crate::constants::TEST_PRECISION
);
let inside_expr =
"sqrt((point.x - circle.center_x)^2 + (point.y - circle.center_y)^2) < circle.radius";
let result = interp(inside_expr, Some(std::rc::Rc::new(ctx.clone()))).unwrap();
println!("Result of inside_expr: {}", result);
assert_eq!(result, 1.0, "Point should be inside circle");
}
#[test]
fn test_custom_functions() {
#[cfg(not(feature = "libm"))]
let mut ctx = create_context();
#[cfg(feature = "libm")]
let mut ctx = EvalContext::new();
let _ = ctx.register_native_function("sum", 3, |args| args.iter().sum());
assert_eq!(
interp("sum(1, 2, 3)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
6.0
);
assert_eq!(
interp("sum(10, 20, 30)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
60.0
);
let _ = ctx.register_native_function("distance", 4, |args| {
let x1 = args[0];
let y1 = args[1];
let x2 = args[2];
let y2 = args[3];
#[cfg(feature = "libm")]
{
exp_rs::functions::sqrt((x2 - x1).powi(2) + (y2 - y1).powi(2), 0.0)
}
#[cfg(not(feature = "libm"))]
{
let dx = x2 - x1;
let dy = y2 - y1;
(dx * dx + dy * dy).sqrt()
}
});
assert_approx_eq!(
interp("distance(0, 0, 3, 4)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
5.0 as Real, crate::constants::TEST_PRECISION
);
assert_approx_eq!(
interp("distance(1, 1, 4, 5)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
5.0 as Real, 1e-10 as Real );
let _ = ctx.register_native_function("circle_area", 1, |args| {
let radius = args[0];
exp_rs::constants::PI * radius * radius
});
assert_approx_eq!(
interp("circle_area(2)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
exp_rs::constants::PI * 4.0,
exp_rs::constants::TEST_PRECISION
);
assert_approx_eq!(
interp(
"circle_area(distance(0, 0, 3, 4) / 2)",
Some(std::rc::Rc::new(ctx.clone()))
)
.unwrap(),
crate::constants::PI * 6.25,
crate::constants::TEST_PRECISION
);
}
#[test]
fn test_complex_expressions() {
#[cfg(not(feature = "libm"))]
let mut ctx = create_context();
#[cfg(feature = "libm")]
let mut ctx = EvalContext::default();
set_var(&mut ctx, "t", 0.5);
set_var(&mut ctx, "amplitude", 10.0);
set_var(&mut ctx, "frequency", 2.0);
ctx.arrays
.insert(hstr("samples"), vec![1.0, 2.0, 3.0, 4.0, 5.0])
.expect("Failed to insert array");
set_attr(&mut ctx, "wave", "phase", 0.25);
set_attr(&mut ctx, "wave", "offset", 5.0);
let _ = ctx.register_native_function("interpolate", 3, |args| {
let a = args[0];
let b = args[1];
let t = args[2];
a * (1.0 - t) + b * t
});
println!(
"t = {}",
interp("t", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
println!(
"amplitude = {}",
interp("amplitude", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
println!(
"frequency = {}",
interp("frequency", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
println!(
"wave.phase = {}",
interp("wave.phase", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
println!(
"wave.offset = {}",
interp("wave.offset", Some(std::rc::Rc::new(ctx.clone()))).unwrap()
);
let expr = "amplitude * sin(2 * pi * frequency * t + wave.phase) + wave.offset + samples[floor(t * 4)]";
println!("Evaluating expression: {}", expr);
#[cfg(feature = "libm")]
let expected = 10.0
* exp_rs::functions::sin(2.0 * exp_rs::constants::PI * 2.0 * 0.5 + 0.25, 0.0)
+ 5.0
+ 3.0;
#[cfg(not(feature = "libm"))]
let expected = 10.0 * (2.0 * exp_rs::constants::PI * 2.0 * 0.5 + 0.25).sin() + 5.0 + 3.0;
println!("Expected result: {}", expected);
let result = interp(expr, Some(std::rc::Rc::new(ctx.clone())));
match &result {
Ok(val) => println!("Actual result: {}", val),
Err(e) => println!("Error: {}", e),
}
assert_approx_eq!(
result.unwrap(),
expected as Real, exp_rs::constants::TEST_PRECISION
);
let expr2 = "interpolate(samples[1], samples[2], t) * (1 + sin(wave.phase))";
#[cfg(feature = "libm")]
let expected2 = (2.0 * (1.0 - 0.5) + 3.0 * 0.5) * (1.0 + exp_rs::functions::sin(0.25, 0.0));
#[cfg(not(feature = "libm"))]
let expected2 = (2.0 * (1.0 - 0.5) + 3.0 * 0.5) * (1.0 + 0.25_f64.sin() as Real);
assert_approx_eq!(
interp(expr2, Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
expected2 as Real, exp_rs::constants::TEST_PRECISION
);
}
#[test]
fn test_expression_performance() {
let expr = "sin(x) * cos(y) + sqrt(z^2 + w^2) / log(u + 5)";
#[cfg(not(feature = "libm"))]
{
}
#[cfg(not(feature = "libm"))]
let mut ctx = create_context();
#[cfg(feature = "libm")]
let mut ctx = EvalContext::default();
set_var(&mut ctx, "x", 1.0);
set_var(&mut ctx, "y", 2.0);
set_var(&mut ctx, "z", 3.0);
set_var(&mut ctx, "w", 4.0);
set_var(&mut ctx, "u", 5.0);
let ast = parse_expression(expr).unwrap();
let arena = Bump::new();
let iterations = 10_000;
let start = Instant::now();
for i in 0..iterations {
set_var(&mut ctx, "x", (i % 100) as Real / 100.0);
let _ = eval_ast(&ast, Some(Rc::new(ctx.clone())), &arena).unwrap();
}
let duration = start.elapsed();
let avg_micros = duration.as_micros() as f64 / iterations as f64;
println!("Average evaluation time: {:.2} microseconds", avg_micros);
assert!(
avg_micros < 100.0,
"Expression evaluation is too slow: {:.2} microseconds",
avg_micros
);
}
#[test]
fn test_error_handling() {
let result = interp("1 +", None);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Syntax"));
let result = interp("x + 5", None);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unknown variable"));
let result = interp("foo(5)", None);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unknown function"));
#[cfg(feature = "libm")]
{
let result = interp("sin(1, 2)", None);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid function call")
);
}
#[cfg(not(feature = "libm"))]
{
let mut sin_ctx = EvalContext::default();
let _ = sin_ctx.register_native_function("sin", 1, |args| args[0].sin());
let sin_ctx = std::rc::Rc::new(sin_ctx);
let result = interp("sin(1, 2)", Some(sin_ctx));
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid function call")
);
}
#[cfg(not(feature = "libm"))]
let mut ctx = create_context();
#[cfg(feature = "libm")]
let mut ctx = EvalContext::new();
ctx.arrays
.insert(hstr("arr"), vec![1.0, 2.0, 3.0])
.expect("Failed to insert array");
let result = interp("arr[5]", Some(std::rc::Rc::new(ctx.clone())));
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
println!("Array index error message: {}", err_msg);
assert!(err_msg.contains("Array index out of bounds"));
set_attr(&mut ctx, "obj", "x", 1.0);
let result = interp("obj.y", Some(std::rc::Rc::new(ctx.clone())));
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Attribute not found")
);
let _ = ctx.register_native_function("safe_divide", 2, |args| {
if args[1] == 0.0 {
#[cfg(feature = "f32")]
return f32::NAN;
#[cfg(not(feature = "f32"))]
return f64::NAN;
} else {
args[0] / args[1]
}
});
let result = interp("safe_divide(1, 0)", Some(std::rc::Rc::new(ctx.clone()))).unwrap();
assert!(result.is_nan());
let result = interp("safe_divide(1)", Some(std::rc::Rc::new(ctx.clone())));
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid function call")
);
}
#[test]
fn test_advanced_native_functions() {
{
let mut ctx = EvalContext::new();
let prev_output = Mutex::new(0.0);
let _ = ctx.register_native_function("low_pass_filter", 2, {
move |args| {
let input = args[0];
let alpha = args[1];
let mut prev = prev_output.lock().unwrap();
let output = alpha * input + (1.0 - alpha) * *prev;
*prev = output;
output
}
});
let result1 = interp(
"low_pass_filter(1.0, 0.2)",
Some(std::rc::Rc::new(ctx.clone())),
)
.unwrap();
assert_approx_eq!(result1, 0.2 as Real, 1e-10 as Real);
let result2 = interp(
"low_pass_filter(1.0, 0.2)",
Some(std::rc::Rc::new(ctx.clone())),
)
.unwrap();
assert_approx_eq!(result2, 0.36 as Real, 1e-10 as Real);
let result3 = interp(
"low_pass_filter(1.0, 0.2)",
Some(std::rc::Rc::new(ctx.clone())),
)
.unwrap();
assert_approx_eq!(result3, 0.488 as Real, constants::TEST_PRECISION); }
{
let mut ctx = EvalContext::new();
let integral = Mutex::new(0.0);
let prev_error = Mutex::new(0.0);
let _ = ctx.register_native_function("pid_controller", 5, {
move |args| {
let setpoint = args[0];
let process_variable = args[1];
let kp = args[2];
let ki = args[3];
let kd = args[4];
let error = setpoint - process_variable;
let mut integral_guard = integral.lock().unwrap();
*integral_guard += error;
let mut prev_error_guard = prev_error.lock().unwrap();
let derivative = error - *prev_error_guard;
let output = kp * error + ki * *integral_guard + kd * derivative;
*prev_error_guard = error;
output
}
});
let result = interp(
"pid_controller(100, 90, 0.5, 0.1, 0.2)",
Some(std::rc::Rc::new(ctx.clone())),
)
.unwrap();
assert_approx_eq!(result, 8.0 as Real, 1e-10 as Real); }
}
#[test]
fn test_config_expressions() {
let mut ctx = EvalContext::default();
set_const(&mut ctx, "AMPLITUDE_MIN", 2.0);
set_const(&mut ctx, "AMPLITUDE_MAX", 75.0);
set_const(&mut ctx, "VOLTAGE_MAX", 5.0);
ctx.arrays
.insert(
hstr("wait_times"),
vec![64691.0, 64625.0, 64559.0, 64494.0, 64428.0],
)
.expect("Failed to insert array");
set_var(&mut ctx, "power", 50.0);
set_var(&mut ctx, "speed", 50.0);
set_var(&mut ctx, "t", 0.5);
let pattern_step_expr = "(speed / 100.0) * (9.2 - 0.27) + 0.27";
let pattern_step = interp(pattern_step_expr, Some(std::rc::Rc::new(ctx.clone()))).unwrap();
set_var(&mut ctx, "pattern_step", pattern_step);
let pattern_index_expr = "t * pattern_step";
let pattern_index = interp(pattern_index_expr, Some(std::rc::Rc::new(ctx.clone()))).unwrap();
set_var(&mut ctx, "pattern_index", pattern_index);
let main_function = "((power * (AMPLITUDE_MAX - AMPLITUDE_MIN) + AMPLITUDE_MIN) * VOLTAGE_MAX) / AMPLITUDE_MAX * 1000";
let result = interp(main_function, Some(std::rc::Rc::new(ctx.clone()))).unwrap();
let expected: Real = ((50.0 * (75.0 - 2.0) + 2.0) * 5.0) / 75.0 * 1000.0;
#[cfg(feature = "f32")]
let relative_precision = expected.abs() * 1e-6 as Real; #[cfg(not(feature = "f32"))]
let relative_precision = expected.abs() * 1e-10 as Real;
assert_approx_eq!(result, expected, relative_precision);
set_var(&mut ctx, "power", 75.0);
set_var(&mut ctx, "speed", 25.0);
set_var(&mut ctx, "t", 1.0);
let pattern_step = interp(pattern_step_expr, Some(std::rc::Rc::new(ctx.clone()))).unwrap();
set_var(&mut ctx, "pattern_step", pattern_step);
let pattern_index = interp(pattern_index_expr, Some(std::rc::Rc::new(ctx.clone()))).unwrap();
set_var(&mut ctx, "pattern_index", pattern_index);
let result = interp(main_function, Some(std::rc::Rc::new(ctx.clone()))).unwrap();
let expected: Real = ((75.0 * (75.0 - 2.0) + 2.0) * 5.0) / 75.0 * 1000.0;
#[cfg(feature = "f32")]
let relative_precision = expected.abs() * 1e-6 as Real; #[cfg(not(feature = "f32"))]
let relative_precision = expected.abs() * 1e-10 as Real;
assert_approx_eq!(result, expected, relative_precision);
}
#[test]
fn test_recursion_limits() {
let mut ctx = EvalContext::new();
let _ = ctx.register_native_function("recurse_sum", 1, |args| {
let x = args[0].round() as i32; if x <= 1 {
x as Real
} else {
let mut sum = x as Real;
let mut i = x - 1;
while i > 0 {
sum += i as Real;
i -= 1;
}
sum
}
});
let _ = ctx.register_native_function("true_recursive_sum", 1, |args| {
let n = args[0].round() as i32;
if n >= 250 {
return -1.0;
}
if n <= 1 {
return n as Real;
}
(n as Real * (n as Real + 1.0)) / 2.0
});
assert_eq!(
interp("true_recursive_sum(5)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
15.0,
"true_recursive_sum(5) should equal 15.0"
);
assert_eq!(
interp(
"true_recursive_sum(10)",
Some(std::rc::Rc::new(ctx.clone()))
)
.unwrap(),
55.0,
"true_recursive_sum(10) should equal 55.0"
);
let deep_result = interp(
"true_recursive_sum(100)",
Some(std::rc::Rc::new(ctx.clone())),
);
match deep_result {
Ok(value) => {
if value == -1.0 {
println!("Recursion limit detected via placeholder value (good!)");
} else {
assert_eq!(value, 5050.0, "true_recursive_sum(100) should equal 5050.0");
println!(
"Success with recursion: true_recursive_sum(100) = {}",
value
);
}
}
Err(e) => {
let err_msg = e.to_string().to_lowercase();
println!("Recursion limit detected (good!): {}", err_msg);
assert!(
err_msg.contains("recursion") || err_msg.contains("depth"),
"Error should mention recursion limits: {}",
err_msg
);
println!("Recursion depth protection is working correctly!");
}
}
println!("Testing with a definitely-too-deep recursion...");
let very_deep_result = interp(
"true_recursive_sum(300)",
Some(std::rc::Rc::new(ctx.clone())),
);
match very_deep_result {
Ok(value) => {
if value == -1.0 {
println!("Deep recursion correctly detected via placeholder (good!)");
} else {
panic!(
"Unexpectedly got a valid result for deep recursion: {}",
value
);
}
}
Err(e) => {
let err_msg = e.to_string().to_lowercase();
println!("Deep recursion error detected (good!): {}", err_msg);
assert!(
err_msg.contains("recursion") || err_msg.contains("depth"),
"Error should mention recursion limits: {}",
err_msg
);
}
}
let _ = ctx.register_native_function("fibonacci", 1, |args| {
let n = args[0].round() as i32;
match n {
0 => 0.0,
1 => 1.0,
n => {
let mut a = 0.0;
let mut b = 1.0;
for _ in 2..=n {
let temp = a + b;
a = b;
b = temp;
}
b
}
}
});
assert_eq!(
interp("fibonacci(0)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
0.0
);
assert_eq!(
interp("fibonacci(1)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
1.0
);
assert_eq!(
interp("fibonacci(2)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
1.0
);
assert_eq!(
interp("fibonacci(3)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
2.0
);
assert_eq!(
interp("fibonacci(4)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
3.0
);
assert_eq!(
interp("fibonacci(5)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
5.0
);
assert_eq!(
interp("fibonacci(6)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
8.0
);
assert_eq!(
interp("fibonacci(7)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
13.0
);
assert_eq!(
interp("fibonacci(20)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
6765.0
);
let _ = ctx.register_native_function("is_even", 1, |args| {
let n = args[0].round() as i32;
if n < 0 {
return ((-n) % 2 == 0) as i32 as Real;
}
match n {
0 => 1.0, 1 => 0.0, n => {
(n % 2 == 0) as i32 as Real
}
}
});
let _ = ctx.register_native_function("is_odd", 1, |args| {
let n = args[0].round() as i32;
if n < 0 {
return ((-n) % 2 == 1) as i32 as Real;
}
match n {
0 => 0.0, 1 => 1.0, n => {
(n % 2 == 1) as i32 as Real
}
}
});
assert_eq!(
interp("is_even(0)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
1.0
); assert_eq!(
interp("is_odd(0)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
0.0
); assert_eq!(
interp("is_even(1)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
0.0
); assert_eq!(
interp("is_odd(1)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
1.0
); assert_eq!(
interp("is_even(10)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
1.0
); assert_eq!(
interp("is_odd(11)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
1.0
); assert_eq!(
interp("is_even(500)", Some(std::rc::Rc::new(ctx.clone()))).unwrap(),
1.0
);
println!("Native function recursion test passed!");
println!("All recursion tests passed successfully!");
}