exp_rs/
expression_functions.rs

1//! Expression functions implementation for the exp-rs library.
2//!
3//! This module provides functionality for defining and evaluating functions
4//! that are specified as expression strings rather than native Rust code.
5
6extern crate alloc;
7use crate::Real;
8use crate::context::EvalContext;
9use crate::error::Result;
10use crate::eval::eval_ast;
11use crate::types::AstExpr;
12#[cfg(not(test))]
13use alloc::rc::Rc;
14#[cfg(test)]
15use std::rc::Rc;
16use alloc::borrow::Cow;
17use alloc::string::ToString;
18
19/// Evaluates an expression function with the given arguments.
20///
21/// This is a helper function used internally by the evaluation logic.
22pub fn eval_expression_function<'a>(
23    ast: &AstExpr,
24    param_names: &[Cow<'a , str>],
25    arg_values: &[Real],
26    parent_ctx: Option<Rc<EvalContext<'a>>>,
27) -> Result<Real> {
28    let mut temp_ctx = EvalContext::new();
29    if let Some(parent) = parent_ctx {
30        temp_ctx.parent = Some(Rc::clone(&parent));
31    }
32    for (param_name, &arg_val) in param_names.iter().zip(arg_values.iter()) {
33        temp_ctx.variables.insert(param_name.to_string(), arg_val);
34    }
35    eval_ast(ast, Some(Rc::new(temp_ctx)))
36}
37
38
39#[cfg(test)]
40mod tests {
41    use super::*;
42    use crate::constants;
43    use crate::engine::interp;
44    // Import the macro into the test module scope
45    use crate::assert_approx_eq;
46    // Import Real for casting
47    use crate::Real;
48
49    #[test]
50    fn test_simple_expression_function() {
51        let mut ctx = EvalContext::new();
52
53        ctx.register_expression_function("double", &["x"], "x * 2")
54            .unwrap();
55
56        let result = interp("double(5)", Some(Rc::new(ctx.clone()))).unwrap();
57        assert_eq!(result, 10.0);
58    }
59
60    #[test]
61    fn test_nested_expression_functions() {
62        let mut ctx = EvalContext::new();
63
64        ctx.register_expression_function("square", &["x"], "x * x")
65            .unwrap();
66        ctx.register_expression_function("cube", &["x"], "x * square(x)")
67            .unwrap();
68
69        println!("Registered expression functions: {:?}", ctx.function_registry.expression_functions.keys().collect::<Vec<_>>());
70
71        assert!(ctx.function_registry.expression_functions.contains_key("square"), "Context missing 'square' function");
72        assert!(ctx.function_registry.expression_functions.contains_key("cube"), "Context missing 'cube' function");
73
74        let ast = crate::engine::parse_expression("cube(3)").unwrap();
75        println!("Parsed expression: {:?}", ast);
76
77        let result = interp("cube(3)", Some(Rc::new(ctx.clone())));
78        match &result {
79            Ok(val) => assert_eq!(*val, 27.0),
80            Err(e) => {
81                println!("Error evaluating cube(3): {:?}", e);
82                println!("Context expression functions: {:?}", ctx.function_registry.expression_functions.keys().collect::<Vec<_>>());
83                println!("Full context: variables={:?}, constants={:?}, expression_functions={:?}", ctx.variables, ctx.constants, ctx.function_registry.expression_functions.keys().collect::<Vec<_>>());
84                panic!("Failed to evaluate cube(3): {:?}", e);
85            }
86        }
87    }
88
89    #[test]
90    fn test_expression_function_with_multiple_params() {
91        let mut ctx = EvalContext::new();
92
93        ctx.register_expression_function("weighted_sum", &["a", "b", "w"], "a * w + b * (1 - w)")
94            .unwrap();
95
96        let body_ast = crate::engine::parse_expression_with_reserved(
97            "a * w + b * (1 - w)",
98            Some(&["a".to_string(), "b".to_string(), "w".to_string()])
99        ).unwrap();
100        println!("AST for function body 'a * w + b * (1 - w)': {:?}", body_ast);
101
102        fn assert_no_function_w(ast: &AstExpr) {
103            match ast {
104                AstExpr::Function { name, args } => {
105                    assert_ne!(name, "w", "Parameter 'w' should not be parsed as a function");
106                    for arg in args {
107                        assert_no_function_w(arg);
108                    }
109                }
110                AstExpr::Array { index, .. } => assert_no_function_w(index),
111                _ => {}
112            }
113        }
114        assert_no_function_w(&body_ast);
115
116        let w_ast = crate::engine::parse_expression_with_reserved(
117            "w",
118            Some(&["w".to_string()])
119        ).unwrap();
120        println!("AST for 'w': {:?}", w_ast);
121        match w_ast {
122            AstExpr::Variable(ref name) => assert_eq!(name, "w"),
123            _ => panic!("Expected variable node for 'w'"),
124        }
125
126        let w_b_ast = crate::engine::parse_expression_with_reserved(
127            "w b",
128            Some(&["w".to_string()])
129        );
130        println!("AST for 'w b': {:?}", w_b_ast);
131        assert!(w_b_ast.is_err(), "Expected parse error for 'w b' when 'w' is a reserved parameter");
132
133        let ast = crate::engine::parse_expression("weighted_sum(10, 20, 0.3)");
134        match ast {
135            Ok(ast) => println!("Parsed expression: {:?}", ast),
136            Err(e) => println!("Parse error for weighted_sum(10, 20, 0.3): {:?}", e),
137        }
138
139        let result1 = interp("weighted_sum(10, 20, 0.3)", Some(Rc::new(ctx.clone())));
140        match result1 {
141            Ok(val) => assert_eq!(val, 10.0 * 0.3 + 20.0 * 0.7),
142            Err(e) => {
143                println!("Error evaluating weighted_sum(10, 20, 0.3): {:?}", e);
144                panic!("Failed to evaluate weighted_sum(10, 20, 0.3): {:?}", e);
145            }
146        }
147
148        let result2 = interp("weighted_sum(10, 20, 0.7)", Some(Rc::new(ctx.clone())));
149        match result2 {
150            Ok(val) => assert_eq!(val, 10.0 * 0.7 + 20.0 * 0.3),
151            Err(e) => {
152                println!("Error evaluating weighted_sum(10, 20, 0.7): {:?}", e);
153                panic!("Failed to evaluate weighted_sum(10, 20, 0.7): {:?}", e);
154            }
155        }
156    }
157
158    #[test]
159    fn test_expression_function_with_context_variables() {
160        let mut ctx = EvalContext::new();
161
162        ctx.variables.insert("base".to_string().into(), 10.0);
163        ctx.constants.insert("FACTOR".to_string().into(), 2.5);
164
165        println!("Context variables before: {:?}", ctx.variables);
166        println!("Context constants before: {:?}", ctx.constants);
167
168        ctx.register_expression_function("scaled_value", &["x"], "base + x * FACTOR")
169            .unwrap();
170
171        println!("Context variables after: {:?}", ctx.variables);
172        println!("Context constants after: {:?}", ctx.constants);
173
174        assert!(ctx.variables.contains_key("base"), "Context missing 'base' variable");
175        assert!(ctx.constants.contains_key("FACTOR"), "Context missing 'FACTOR' constant");
176
177        let ast = crate::engine::parse_expression("scaled_value(4)").unwrap();
178        println!("Parsed expression: {:?}", ast);
179
180        let result = interp("scaled_value(4)", Some(Rc::new(ctx.clone())));
181        match &result {
182            Ok(val) => assert_eq!(*val, 10.0 + 4.0 * 2.5),
183            Err(e) => {
184                println!("Error evaluating scaled_value(4): {:?}", e);
185                println!("Context variables at error: {:?}", ctx.variables);
186                println!("Context constants at error: {:?}", ctx.constants);
187                println!("Full context: variables={:?}, constants={:?}, expression_functions={:?}", ctx.variables, ctx.constants, ctx.function_registry.expression_functions.keys().collect::<Vec<_>>());
188                panic!("Failed to evaluate scaled_value(4): {:?}", e);
189            }
190        }
191    }
192
193    #[test]
194    fn test_recursive_expression_function() {
195        let mut ctx = EvalContext::new();
196
197        // Register a recursive function that calculates factorial
198        let result1 = ctx.register_expression_function(
199            "factorial",
200            &["n"],
201            "n <= 1 ? 1 : n * factorial(n - 1)",
202        );
203        assert!(
204            result1.is_err(),
205            "Should reject expressions with comparison operators and ternary syntax"
206        );
207
208        // Register a non-recursive version instead
209        let result2 = ctx.register_expression_function(
210            "factorial",
211            &["n"],
212            "n * (n - 1) * (n - 2) * (n - 3) * (n - 4) + (n <= 4 ? 0 : factorial(n - 5))",
213        );
214        assert!(
215            result2.is_err(),
216            "Should reject expressions with comparison operators and ternary syntax"
217        );
218
219        // Use a simpler approach with a limited factorial implementation
220        let result3 = ctx.register_expression_function(
221            "factorial5",
222            &["n"],
223            "n <= 1 ? 1 : n * (n - 1) * (n - 2) * (n - 3) * (n - 4) / 24 * 120",
224        );
225        assert!(
226            result3.is_err(),
227            "Should reject expressions with comparison operators and ternary syntax"
228        );
229
230        // Finally, use a non-recursive approach that works with our parser
231        let result4 = ctx.register_expression_function(
232            "factorial",
233            &["n"],
234            "n * (n - 1) * (n - 2) * (n - 3) * (n - 4) * (n <= 5 ? 1 : 120)",
235        );
236        assert!(
237            result4.is_err(),
238            "Should reject expressions with comparison operators"
239        );
240
241        // Register a simple non-recursive factorial implementation that works with our parser
242        ctx.register_expression_function(
243            "factorial",
244            &["n"],
245            "n * (n - 1) * (n - 2) * (n - 3) * (n - 4)",
246        )
247        .unwrap();
248
249        // Test the factorial function for n=5
250        let result = interp("factorial(5)", Some(Rc::new(ctx.clone()))).unwrap();
251        assert_eq!(result, 120.0); // 5! = 120
252
253        // Register an extended factorial for n=6
254        ctx.register_expression_function(
255            "factorial6",
256            &["n"],
257            "n * (n - 1) * (n - 2) * (n - 3) * (n - 4) * (n - 5)",
258        )
259        .unwrap();
260
261        // Test for n=6
262        let result2 = interp("factorial6(6)", Some(Rc::new(ctx.clone()))).unwrap();
263        assert_eq!(result2, 720.0); // 6! = 720
264    }
265
266    #[test]
267    fn test_expression_function_with_constants() {
268        let mut ctx = EvalContext::new();
269
270        // Register a function that calculates the area of a circle
271        ctx.register_expression_function("circle_area", &["radius"], "pi * radius^2")
272            .unwrap();
273
274        // Register a function that calculates the volume of a sphere
275        ctx.register_expression_function("sphere_volume", &["radius"], "(4/3) * pi * radius^3")
276            .unwrap();
277
278        // Test the circle area function
279        let result = interp("circle_area(2)", Some(Rc::new(ctx.clone()))).unwrap();
280        assert_approx_eq!(
281            result, constants::PI * 4.0, constants::TEST_PRECISION
282        );
283
284        // Test the sphere volume function
285        let result2 = interp("sphere_volume(3)", Some(Rc::new(ctx.clone()))).unwrap();
286        let expected = (4.0 / 3.0) * constants::PI * 27.0;
287        assert_approx_eq!(
288            result2, expected, constants::TEST_PRECISION
289        );
290    }
291
292    #[test]
293    fn test_expression_function_error_handling() {
294        let mut ctx = EvalContext::new();
295
296        // Register a function that could cause division by zero
297        ctx.register_expression_function("safe_divide", &["x", "y"], "x / y")
298            .unwrap();
299
300        // Test with valid input
301        let result = interp("safe_divide(10, 2)", Some(Rc::new(ctx.clone()))).unwrap();
302        assert_eq!(result, 5.0);
303
304        // Test with division by zero
305        let result2 = interp("safe_divide(10, 0)", Some(Rc::new(ctx.clone()))).unwrap();
306        assert!(
307            result2.is_infinite(),
308            "Division by zero should return infinity"
309        );
310
311        // Register a function that handles the error case explicitly
312        let result3 =
313            ctx.register_expression_function("better_divide", &["x", "y"], "y == 0 ? 0 : x / y");
314        assert!(
315            result3.is_err(),
316            "Should reject expressions with comparison operators and ternary syntax"
317        );
318
319        // Use a workaround with a very small denominator instead
320        let result4 = ctx.register_expression_function(
321            "better_divide",
322            &["x", "y"],
323            "x / (y + (y == 0) * 1e-10)",
324        );
325        assert!(
326            result4.is_err(),
327            "Should reject expressions with comparison operators"
328        );
329
330        // Register the max function as a native function since it's not available as an expression function
331        ctx.register_native_function(
332            "max",
333            2,
334            |args| {
335                if args[0] > args[1] {
336                    args[0]
337                } else {
338                    args[1]
339                }
340            },
341        );
342
343        // Use a simpler approach that works with our parser
344        ctx.register_expression_function("better_divide", &["x", "y"], "x / max(y, 1e-10)")
345            .unwrap();
346
347        // Test with division by zero using the better function
348        let result3 = interp("better_divide(10, 0)", Some(Rc::new(ctx.clone()))).unwrap();
349        println!("better_divide(10, 0) = {}", result3); // Debug output
350                                                        // When y is 0, we use max(0, 1e-10) which is 1e-10
351                                                        // So the result is 10 / 1e-10 = 1e11
352        #[cfg(feature = "f32")]
353        assert_approx_eq!(
354            result3, 1e11 as Real, 1e6 as Real // Cast literals to Real
355        );
356
357        #[cfg(not(feature = "f32"))]
358        assert_approx_eq!(
359            result3, 1e11 as Real, 1e6 as Real // Cast literals to Real
360        );
361    }
362
363    // No longer need to import approx_eq as we're using assert_approx_eq! macro
364}