1extern crate alloc;
7use crate::Real;
8use crate::context::EvalContext;
9use crate::error::Result;
10use crate::eval::eval_ast;
11use crate::types::AstExpr;
12#[cfg(not(test))]
13use alloc::rc::Rc;
14#[cfg(test)]
15use std::rc::Rc;
16use alloc::borrow::Cow;
17use alloc::string::ToString;
18
19pub fn eval_expression_function<'a>(
23 ast: &AstExpr,
24 param_names: &[Cow<'a , str>],
25 arg_values: &[Real],
26 parent_ctx: Option<Rc<EvalContext<'a>>>,
27) -> Result<Real> {
28 let mut temp_ctx = EvalContext::new();
29 if let Some(parent) = parent_ctx {
30 temp_ctx.parent = Some(Rc::clone(&parent));
31 }
32 for (param_name, &arg_val) in param_names.iter().zip(arg_values.iter()) {
33 temp_ctx.variables.insert(param_name.to_string(), arg_val);
34 }
35 eval_ast(ast, Some(Rc::new(temp_ctx)))
36}
37
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42 use crate::constants;
43 use crate::engine::interp;
44 use crate::assert_approx_eq;
46 use crate::Real;
48
49 #[test]
50 fn test_simple_expression_function() {
51 let mut ctx = EvalContext::new();
52
53 ctx.register_expression_function("double", &["x"], "x * 2")
54 .unwrap();
55
56 let result = interp("double(5)", Some(Rc::new(ctx.clone()))).unwrap();
57 assert_eq!(result, 10.0);
58 }
59
60 #[test]
61 fn test_nested_expression_functions() {
62 let mut ctx = EvalContext::new();
63
64 ctx.register_expression_function("square", &["x"], "x * x")
65 .unwrap();
66 ctx.register_expression_function("cube", &["x"], "x * square(x)")
67 .unwrap();
68
69 println!("Registered expression functions: {:?}", ctx.function_registry.expression_functions.keys().collect::<Vec<_>>());
70
71 assert!(ctx.function_registry.expression_functions.contains_key("square"), "Context missing 'square' function");
72 assert!(ctx.function_registry.expression_functions.contains_key("cube"), "Context missing 'cube' function");
73
74 let ast = crate::engine::parse_expression("cube(3)").unwrap();
75 println!("Parsed expression: {:?}", ast);
76
77 let result = interp("cube(3)", Some(Rc::new(ctx.clone())));
78 match &result {
79 Ok(val) => assert_eq!(*val, 27.0),
80 Err(e) => {
81 println!("Error evaluating cube(3): {:?}", e);
82 println!("Context expression functions: {:?}", ctx.function_registry.expression_functions.keys().collect::<Vec<_>>());
83 println!("Full context: variables={:?}, constants={:?}, expression_functions={:?}", ctx.variables, ctx.constants, ctx.function_registry.expression_functions.keys().collect::<Vec<_>>());
84 panic!("Failed to evaluate cube(3): {:?}", e);
85 }
86 }
87 }
88
89 #[test]
90 fn test_expression_function_with_multiple_params() {
91 let mut ctx = EvalContext::new();
92
93 ctx.register_expression_function("weighted_sum", &["a", "b", "w"], "a * w + b * (1 - w)")
94 .unwrap();
95
96 let body_ast = crate::engine::parse_expression_with_reserved(
97 "a * w + b * (1 - w)",
98 Some(&["a".to_string(), "b".to_string(), "w".to_string()])
99 ).unwrap();
100 println!("AST for function body 'a * w + b * (1 - w)': {:?}", body_ast);
101
102 fn assert_no_function_w(ast: &AstExpr) {
103 match ast {
104 AstExpr::Function { name, args } => {
105 assert_ne!(name, "w", "Parameter 'w' should not be parsed as a function");
106 for arg in args {
107 assert_no_function_w(arg);
108 }
109 }
110 AstExpr::Array { index, .. } => assert_no_function_w(index),
111 _ => {}
112 }
113 }
114 assert_no_function_w(&body_ast);
115
116 let w_ast = crate::engine::parse_expression_with_reserved(
117 "w",
118 Some(&["w".to_string()])
119 ).unwrap();
120 println!("AST for 'w': {:?}", w_ast);
121 match w_ast {
122 AstExpr::Variable(ref name) => assert_eq!(name, "w"),
123 _ => panic!("Expected variable node for 'w'"),
124 }
125
126 let w_b_ast = crate::engine::parse_expression_with_reserved(
127 "w b",
128 Some(&["w".to_string()])
129 );
130 println!("AST for 'w b': {:?}", w_b_ast);
131 assert!(w_b_ast.is_err(), "Expected parse error for 'w b' when 'w' is a reserved parameter");
132
133 let ast = crate::engine::parse_expression("weighted_sum(10, 20, 0.3)");
134 match ast {
135 Ok(ast) => println!("Parsed expression: {:?}", ast),
136 Err(e) => println!("Parse error for weighted_sum(10, 20, 0.3): {:?}", e),
137 }
138
139 let result1 = interp("weighted_sum(10, 20, 0.3)", Some(Rc::new(ctx.clone())));
140 match result1 {
141 Ok(val) => assert_eq!(val, 10.0 * 0.3 + 20.0 * 0.7),
142 Err(e) => {
143 println!("Error evaluating weighted_sum(10, 20, 0.3): {:?}", e);
144 panic!("Failed to evaluate weighted_sum(10, 20, 0.3): {:?}", e);
145 }
146 }
147
148 let result2 = interp("weighted_sum(10, 20, 0.7)", Some(Rc::new(ctx.clone())));
149 match result2 {
150 Ok(val) => assert_eq!(val, 10.0 * 0.7 + 20.0 * 0.3),
151 Err(e) => {
152 println!("Error evaluating weighted_sum(10, 20, 0.7): {:?}", e);
153 panic!("Failed to evaluate weighted_sum(10, 20, 0.7): {:?}", e);
154 }
155 }
156 }
157
158 #[test]
159 fn test_expression_function_with_context_variables() {
160 let mut ctx = EvalContext::new();
161
162 ctx.variables.insert("base".to_string().into(), 10.0);
163 ctx.constants.insert("FACTOR".to_string().into(), 2.5);
164
165 println!("Context variables before: {:?}", ctx.variables);
166 println!("Context constants before: {:?}", ctx.constants);
167
168 ctx.register_expression_function("scaled_value", &["x"], "base + x * FACTOR")
169 .unwrap();
170
171 println!("Context variables after: {:?}", ctx.variables);
172 println!("Context constants after: {:?}", ctx.constants);
173
174 assert!(ctx.variables.contains_key("base"), "Context missing 'base' variable");
175 assert!(ctx.constants.contains_key("FACTOR"), "Context missing 'FACTOR' constant");
176
177 let ast = crate::engine::parse_expression("scaled_value(4)").unwrap();
178 println!("Parsed expression: {:?}", ast);
179
180 let result = interp("scaled_value(4)", Some(Rc::new(ctx.clone())));
181 match &result {
182 Ok(val) => assert_eq!(*val, 10.0 + 4.0 * 2.5),
183 Err(e) => {
184 println!("Error evaluating scaled_value(4): {:?}", e);
185 println!("Context variables at error: {:?}", ctx.variables);
186 println!("Context constants at error: {:?}", ctx.constants);
187 println!("Full context: variables={:?}, constants={:?}, expression_functions={:?}", ctx.variables, ctx.constants, ctx.function_registry.expression_functions.keys().collect::<Vec<_>>());
188 panic!("Failed to evaluate scaled_value(4): {:?}", e);
189 }
190 }
191 }
192
193 #[test]
194 fn test_recursive_expression_function() {
195 let mut ctx = EvalContext::new();
196
197 let result1 = ctx.register_expression_function(
199 "factorial",
200 &["n"],
201 "n <= 1 ? 1 : n * factorial(n - 1)",
202 );
203 assert!(
204 result1.is_err(),
205 "Should reject expressions with comparison operators and ternary syntax"
206 );
207
208 let result2 = ctx.register_expression_function(
210 "factorial",
211 &["n"],
212 "n * (n - 1) * (n - 2) * (n - 3) * (n - 4) + (n <= 4 ? 0 : factorial(n - 5))",
213 );
214 assert!(
215 result2.is_err(),
216 "Should reject expressions with comparison operators and ternary syntax"
217 );
218
219 let result3 = ctx.register_expression_function(
221 "factorial5",
222 &["n"],
223 "n <= 1 ? 1 : n * (n - 1) * (n - 2) * (n - 3) * (n - 4) / 24 * 120",
224 );
225 assert!(
226 result3.is_err(),
227 "Should reject expressions with comparison operators and ternary syntax"
228 );
229
230 let result4 = ctx.register_expression_function(
232 "factorial",
233 &["n"],
234 "n * (n - 1) * (n - 2) * (n - 3) * (n - 4) * (n <= 5 ? 1 : 120)",
235 );
236 assert!(
237 result4.is_err(),
238 "Should reject expressions with comparison operators"
239 );
240
241 ctx.register_expression_function(
243 "factorial",
244 &["n"],
245 "n * (n - 1) * (n - 2) * (n - 3) * (n - 4)",
246 )
247 .unwrap();
248
249 let result = interp("factorial(5)", Some(Rc::new(ctx.clone()))).unwrap();
251 assert_eq!(result, 120.0); ctx.register_expression_function(
255 "factorial6",
256 &["n"],
257 "n * (n - 1) * (n - 2) * (n - 3) * (n - 4) * (n - 5)",
258 )
259 .unwrap();
260
261 let result2 = interp("factorial6(6)", Some(Rc::new(ctx.clone()))).unwrap();
263 assert_eq!(result2, 720.0); }
265
266 #[test]
267 fn test_expression_function_with_constants() {
268 let mut ctx = EvalContext::new();
269
270 ctx.register_expression_function("circle_area", &["radius"], "pi * radius^2")
272 .unwrap();
273
274 ctx.register_expression_function("sphere_volume", &["radius"], "(4/3) * pi * radius^3")
276 .unwrap();
277
278 let result = interp("circle_area(2)", Some(Rc::new(ctx.clone()))).unwrap();
280 assert_approx_eq!(
281 result, constants::PI * 4.0, constants::TEST_PRECISION
282 );
283
284 let result2 = interp("sphere_volume(3)", Some(Rc::new(ctx.clone()))).unwrap();
286 let expected = (4.0 / 3.0) * constants::PI * 27.0;
287 assert_approx_eq!(
288 result2, expected, constants::TEST_PRECISION
289 );
290 }
291
292 #[test]
293 fn test_expression_function_error_handling() {
294 let mut ctx = EvalContext::new();
295
296 ctx.register_expression_function("safe_divide", &["x", "y"], "x / y")
298 .unwrap();
299
300 let result = interp("safe_divide(10, 2)", Some(Rc::new(ctx.clone()))).unwrap();
302 assert_eq!(result, 5.0);
303
304 let result2 = interp("safe_divide(10, 0)", Some(Rc::new(ctx.clone()))).unwrap();
306 assert!(
307 result2.is_infinite(),
308 "Division by zero should return infinity"
309 );
310
311 let result3 =
313 ctx.register_expression_function("better_divide", &["x", "y"], "y == 0 ? 0 : x / y");
314 assert!(
315 result3.is_err(),
316 "Should reject expressions with comparison operators and ternary syntax"
317 );
318
319 let result4 = ctx.register_expression_function(
321 "better_divide",
322 &["x", "y"],
323 "x / (y + (y == 0) * 1e-10)",
324 );
325 assert!(
326 result4.is_err(),
327 "Should reject expressions with comparison operators"
328 );
329
330 ctx.register_native_function(
332 "max",
333 2,
334 |args| {
335 if args[0] > args[1] {
336 args[0]
337 } else {
338 args[1]
339 }
340 },
341 );
342
343 ctx.register_expression_function("better_divide", &["x", "y"], "x / max(y, 1e-10)")
345 .unwrap();
346
347 let result3 = interp("better_divide(10, 0)", Some(Rc::new(ctx.clone()))).unwrap();
349 println!("better_divide(10, 0) = {}", result3); #[cfg(feature = "f32")]
353 assert_approx_eq!(
354 result3, 1e11 as Real, 1e6 as Real );
356
357 #[cfg(not(feature = "f32"))]
358 assert_approx_eq!(
359 result3, 1e11 as Real, 1e6 as Real );
361 }
362
363 }