exp_rs/
expression.rs

1//! Batch expression evaluation builder for efficient real-time evaluation
2//!
3//! This module provides a builder pattern for evaluating multiple expressions
4//! with a shared set of parameters, optimized for real-time use cases.
5
6use crate::error::ExprError;
7use crate::eval::iterative::{EvalEngine, eval_with_engine};
8use crate::types::{BatchParamMap, TryIntoHeaplessString};
9use crate::{AstExpr, EvalContext, Real};
10use alloc::rc::Rc;
11use alloc::string::{String, ToString};
12use alloc::vec::Vec;
13use bumpalo::Bump;
14use core::cell::RefCell;
15
16/// A parameter with its name and current value
17#[derive(Clone, Debug)]
18pub struct Param {
19    pub name: String,
20    pub value: Real,
21}
22
23/// Arena-aware batch builder for zero-allocation expression evaluation
24///
25/// This structure is similar to BatchBuilder but uses an arena for all
26/// AST allocations, eliminating dynamic memory allocation during evaluation.
27pub struct Expression<'arena> {
28    /// The arena for all allocations
29    arena: &'arena Bump,
30
31    /// Pre-parsed expressions with their original strings
32    expressions: Vec<(&'arena str, &'arena AstExpr<'arena>)>,
33
34    /// Parameters with names and values together
35    params: Vec<Param>,
36
37    /// Results for each expression
38    results: Vec<Real>,
39
40    /// Reusable evaluation engine
41    engine: EvalEngine<'arena>,
42
43    /// Optional arena-allocated expression functions (lazy-initialized)
44    local_functions: Option<&'arena RefCell<crate::types::ExpressionFunctionMap>>,
45}
46
47/// Deprecated: Use `Expression` instead
48#[deprecated(since = "0.2.0", note = "renamed to Expression")]
49pub type ArenaBatchBuilder<'arena> = Expression<'arena>;
50
51impl<'arena> Expression<'arena> {
52    /// Create a new empty batch builder with arena
53    pub fn new(arena: &'arena Bump) -> Self {
54        Expression {
55            arena,
56            expressions: Vec::new(),
57            params: Vec::new(),
58            results: Vec::new(),
59            engine: EvalEngine::new(arena),
60            local_functions: None,
61        }
62    }
63
64    /// Add an expression to be evaluated
65    ///
66    /// The expression is parsed immediately into the arena.
67    /// Returns the index of the added expression.
68    pub fn add_expression(&mut self, expr: &str) -> Result<usize, ExprError> {
69        // Parse the expression into the arena
70        let ast = crate::engine::parse_expression(expr, self.arena)?;
71
72        // Allocate expression string in arena
73        let expr_str = self.arena.alloc_str(expr);
74
75        // Allocate the AST in the arena
76        let arena_ast = self.arena.alloc(ast);
77
78        let idx = self.expressions.len();
79        self.expressions.push((expr_str, arena_ast));
80        self.results.push(0.0); // Pre-allocate result slot
81        Ok(idx)
82    }
83
84    /// Add a parameter with an initial value
85    ///
86    /// Returns an error if a parameter with the same name already exists.
87    /// Returns the index of the added parameter.
88    pub fn add_parameter(&mut self, name: &str, initial_value: Real) -> Result<usize, ExprError> {
89        // Check for duplicates
90        if self.params.iter().any(|p| p.name == name) {
91            return Err(ExprError::DuplicateParameter(name.to_string()));
92        }
93        let idx = self.params.len();
94        self.params.push(Param {
95            name: name.to_string(),
96            value: initial_value,
97        });
98        Ok(idx)
99    }
100
101    /// Update a parameter value by index (fastest method)
102    pub fn set_param(&mut self, idx: usize, value: Real) -> Result<(), ExprError> {
103        self.params
104            .get_mut(idx)
105            .ok_or(ExprError::InvalidParameterIndex(idx))?
106            .value = value;
107        Ok(())
108    }
109
110    /// Update a parameter value by name (convenient but slower)
111    pub fn set_param_by_name(&mut self, name: &str, value: Real) -> Result<(), ExprError> {
112        self.params
113            .iter_mut()
114            .find(|p| p.name == name)
115            .ok_or_else(|| ExprError::UnknownVariable {
116                name: name.to_string(),
117            })?
118            .value = value;
119        Ok(())
120    }
121
122    /// Evaluate all expressions with current parameter values
123    pub fn eval(&mut self, base_ctx: &Rc<EvalContext>) -> Result<(), ExprError> {
124        // Build parameter override map
125        let mut param_map = BatchParamMap::new();
126        for param in &self.params {
127            let hname = param.name.as_str().try_into_heapless()?;
128            param_map
129                .insert(hname, param.value)
130                .map_err(|_| ExprError::CapacityExceeded("parameter overrides"))?;
131        }
132
133        // Set parameter overrides in engine
134        self.engine.set_param_overrides(param_map);
135
136        // Set local functions in engine
137        self.engine.set_local_functions(self.local_functions);
138
139        // Evaluate each expression with the original context
140        for (i, (_, ast)) in self.expressions.iter().enumerate() {
141            match eval_with_engine(ast, Some(base_ctx.clone()), &mut self.engine) {
142                Ok(value) => self.results[i] = value,
143                Err(e) => {
144                    // Clear overrides on error
145                    self.engine.clear_param_overrides();
146                    return Err(e);
147                }
148            }
149        }
150
151        // Clear parameter overrides when done
152        self.engine.clear_param_overrides();
153
154        Ok(())
155    }
156
157    /// Get the result of a specific expression by index
158    pub fn get_result(&self, expr_idx: usize) -> Option<Real> {
159        self.results.get(expr_idx).copied()
160    }
161
162    /// Get all results as a slice
163    pub fn get_all_results(&self) -> &[Real] {
164        &self.results
165    }
166
167    /// Get the number of parameters
168    pub fn param_count(&self) -> usize {
169        self.params.len()
170    }
171
172    /// Get the number of expressions
173    pub fn expression_count(&self) -> usize {
174        self.expressions.len()
175    }
176
177    /// Register a local expression function for this batch
178    ///
179    /// Expression functions are mathematical expressions that can call other functions.
180    /// They are specific to this batch and take precedence over context functions.
181    ///
182    /// # Arguments
183    /// * `name` - Function name
184    /// * `params` - Parameter names
185    /// * `body` - Expression string defining the function
186    pub fn register_expression_function(
187        &mut self,
188        name: &str,
189        params: &[&str],
190        body: &str,
191    ) -> Result<(), ExprError> {
192        use crate::types::{ExpressionFunction, ExpressionFunctionMap, TryIntoFunctionName};
193
194        // Lazy initialization - only allocate map when first function is added
195        if self.local_functions.is_none() {
196            let map = self.arena.alloc(RefCell::new(ExpressionFunctionMap::new()));
197            self.local_functions = Some(map);
198        }
199
200        // Pre-allocate parameter buffer in arena for zero-allocation evaluation
201        let param_buffer = if params.is_empty() {
202            None
203        } else {
204            // Pre-allocate parameter slice in arena
205            let slice: &mut [(crate::types::HString, crate::Real)] =
206                self.arena.alloc_slice_fill_default(params.len());
207
208            // Pre-fill parameter names (they never change)
209            for (i, param_name) in params.iter().enumerate() {
210                slice[i].0 = param_name.try_into_heapless()?;
211                slice[i].1 = 0.0; // Default value
212            }
213
214            Some(slice as *mut _)
215        };
216
217        // Create the function
218        let func_name = name.try_into_function_name()?;
219        let expr_func = ExpressionFunction {
220            name: func_name.clone(),
221            params: params.iter().map(|s| s.to_string()).collect(),
222            expression: body.to_string(),
223            description: None,
224            param_buffer,
225        };
226
227        // Add to map through RefCell
228        self.local_functions
229            .unwrap()
230            .borrow_mut()
231            .insert(func_name, expr_func)
232            .map_err(|_| ExprError::Other("Too many expression functions".to_string()))?;
233        Ok(())
234    }
235
236    /// Remove a local expression function from this batch
237    ///
238    /// # Arguments
239    /// * `name` - Function name to remove
240    ///
241    /// # Returns
242    /// * `Ok(true)` if the function was removed
243    /// * `Ok(false)` if the function didn't exist
244    /// * `Err` if the name is invalid
245    pub fn unregister_expression_function(&mut self, name: &str) -> Result<bool, ExprError> {
246        use crate::types::TryIntoFunctionName;
247
248        if let Some(map) = self.local_functions {
249            let func_name = name.try_into_function_name()?;
250            Ok(map.borrow_mut().remove(&func_name).is_some())
251        } else {
252            Ok(false)
253        }
254    }
255
256    /// Get the current number of bytes allocated in the arena
257    pub fn arena_allocated_bytes(&self) -> usize {
258        self.arena.allocated_bytes()
259    }
260
261    /// Clear all expressions, parameters, results, and local functions from this batch
262    ///
263    /// This allows the batch to be reused without recreating it. The arena memory
264    /// used by previous expressions remains allocated but unused until the arena
265    /// is reset. The evaluation engine is retained for reuse.
266    ///
267    /// # Example
268    /// ```
269    /// use bumpalo::Bump;
270    /// use exp_rs::expression::Expression;
271    ///
272    /// let arena = Bump::new();
273    /// let mut batch = Expression::new(&arena);
274    /// batch.add_expression("x + 1").unwrap();
275    /// batch.add_parameter("x", 5.0).unwrap();
276    ///
277    /// // Clear and reuse
278    /// batch.clear();
279    /// assert_eq!(batch.expression_count(), 0);
280    /// assert_eq!(batch.param_count(), 0);
281    /// ```
282    pub fn clear(&mut self) {
283        self.expressions.clear();
284        self.params.clear();
285        self.results.clear();
286
287        // Clear local functions if they exist
288        if let Some(funcs) = self.local_functions {
289            funcs.borrow_mut().clear();
290        }
291    }
292
293    // === Convenience Methods ===
294
295    /// Evaluate a single expression without parameters
296    ///
297    /// This is the simplest way to evaluate an expression that doesn't need variables.
298    ///
299    /// # Example
300    /// ```
301    /// use bumpalo::Bump;
302    /// use exp_rs::expression::Expression;
303    ///
304    /// let arena = Bump::new();
305    /// let result = Expression::eval_simple("2 + 3 * 4", &arena).unwrap();
306    /// assert_eq!(result, 14.0);
307    /// ```
308    pub fn eval_simple(expr: &str, arena: &'arena Bump) -> Result<Real, ExprError> {
309        let ctx = Rc::new(EvalContext::new());
310        Self::eval_with_context(expr, &ctx, arena)
311    }
312
313    /// Evaluate a single expression with context
314    ///
315    /// Use this when you have a context with pre-defined variables, constants, or functions.
316    ///
317    /// # Example
318    /// ```
319    /// use bumpalo::Bump;
320    /// use exp_rs::{expression::Expression, EvalContext};
321    /// use std::rc::Rc;
322    ///
323    /// let arena = Bump::new();
324    /// let mut ctx = EvalContext::new();
325    /// ctx.set_parameter("x", 5.0);
326    ///
327    /// let result = Expression::eval_with_context("x * 2", &Rc::new(ctx), &arena).unwrap();
328    /// assert_eq!(result, 10.0);
329    /// ```
330    pub fn eval_with_context(
331        expr: &str,
332        ctx: &Rc<EvalContext>,
333        arena: &'arena Bump,
334    ) -> Result<Real, ExprError> {
335        let mut builder = Self::new(arena);
336        builder.add_expression(expr)?;
337        builder.eval(ctx)?;
338        builder
339            .get_result(0)
340            .ok_or(ExprError::Other("No result".to_string()))
341    }
342
343    /// Evaluate a single expression with parameters
344    ///
345    /// This is convenient when you want to provide parameters inline without creating a context.
346    ///
347    /// # Example
348    /// ```
349    /// use bumpalo::Bump;
350    /// use exp_rs::{expression::Expression, EvalContext};
351    /// use std::rc::Rc;
352    ///
353    /// let arena = Bump::new();
354    /// let params = [("x", 3.0), ("y", 4.0)];
355    /// let ctx = Rc::new(EvalContext::new());
356    ///
357    /// let result = Expression::eval_with_params("x^2 + y^2", &params, &ctx, &arena).unwrap();
358    /// assert_eq!(result, 25.0); // 3^2 + 4^2 = 25
359    /// ```
360    pub fn eval_with_params(
361        expr: &str,
362        params: &[(&str, Real)],
363        ctx: &Rc<EvalContext>,
364        arena: &'arena Bump,
365    ) -> Result<Real, ExprError> {
366        let mut builder = Self::new(arena);
367
368        // Add all parameters
369        for (name, value) in params {
370            builder.add_parameter(name, *value)?;
371        }
372
373        builder.add_expression(expr)?;
374        builder.eval(ctx)?;
375        builder
376            .get_result(0)
377            .ok_or(ExprError::Other("No result".to_string()))
378    }
379
380    /// Convenience setter using string slices
381    ///
382    /// This is an alias for set_param_by_name with a shorter name for convenience.
383    ///
384    /// # Example
385    /// ```
386    /// use bumpalo::Bump;
387    /// use exp_rs::expression::Expression;
388    ///
389    /// let arena = Bump::new();
390    /// let mut builder = Expression::new(&arena);
391    /// builder.add_parameter("x", 0.0).unwrap();
392    /// builder.set("x", 5.0).unwrap();
393    /// ```
394    pub fn set(&mut self, name: &str, value: Real) -> Result<(), ExprError> {
395        self.set_param_by_name(name, value)
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use bumpalo::Bump;
403
404    // === Tests for Expression Convenience Methods ===
405
406    #[test]
407    fn test_arena_batch_eval_simple() {
408        let arena = Bump::new();
409
410        // Test basic arithmetic
411        assert_eq!(Expression::eval_simple("2 + 3 * 4", &arena).unwrap(), 14.0);
412        assert_eq!(
413            Expression::eval_simple("(2 + 3) * 4", &arena).unwrap(),
414            20.0
415        );
416        assert_eq!(Expression::eval_simple("10 / 2 - 3", &arena).unwrap(), 2.0);
417
418        // Test with constants
419        #[cfg(feature = "libm")]
420        {
421            assert!(Expression::eval_simple("pi", &arena).unwrap() - std::f64::consts::PI < 0.0001);
422            assert!(Expression::eval_simple("e", &arena).unwrap() - std::f64::consts::E < 0.0001);
423        }
424    }
425
426    #[test]
427    fn test_arena_batch_eval_with_context() {
428        let arena = Bump::new();
429        let mut ctx = EvalContext::new();
430
431        // Add some variables to context
432        let _ = ctx.set_parameter("x", 10.0);
433        let _ = ctx.set_parameter("y", 20.0);
434
435        let ctx_rc = Rc::new(ctx);
436
437        // Test evaluation with context variables
438        assert_eq!(
439            Expression::eval_with_context("x + y", &ctx_rc, &arena).unwrap(),
440            30.0
441        );
442        assert_eq!(
443            Expression::eval_with_context("x * 2 + y / 2", &ctx_rc, &arena).unwrap(),
444            30.0
445        );
446
447        // Test with functions if available
448        #[cfg(feature = "libm")]
449        {
450            assert_eq!(
451                Expression::eval_with_context("sin(0)", &ctx_rc, &arena).unwrap(),
452                0.0
453            );
454            assert_eq!(
455                Expression::eval_with_context("cos(0)", &ctx_rc, &arena).unwrap(),
456                1.0
457            );
458        }
459    }
460
461    #[test]
462    fn test_arena_batch_eval_with_params() {
463        let arena = Bump::new();
464        let ctx = Rc::new(EvalContext::new());
465
466        // Test with simple parameters
467        let params = [("x", 3.0), ("y", 4.0)];
468        assert_eq!(
469            Expression::eval_with_params("x + y", &params, &ctx, &arena).unwrap(),
470            7.0
471        );
472
473        // Test with complex expression
474        assert_eq!(
475            Expression::eval_with_params("x^2 + y^2", &params, &ctx, &arena).unwrap(),
476            25.0
477        );
478
479        // Test with multiple parameters
480        let params3 = [("a", 2.0), ("b", 3.0), ("c", 5.0)];
481        assert_eq!(
482            Expression::eval_with_params("a * b + c", &params3, &ctx, &arena).unwrap(),
483            11.0
484        );
485    }
486
487    #[test]
488    fn test_arena_batch_set_convenience_method() {
489        let arena = Bump::new();
490        let ctx = Rc::new(EvalContext::new());
491
492        let mut builder = Expression::new(&arena);
493        builder.add_parameter("a", 1.0).unwrap();
494        builder.add_parameter("b", 2.0).unwrap();
495        builder.add_expression("a + b").unwrap();
496
497        // Test initial evaluation
498        builder.eval(&ctx).unwrap();
499        assert_eq!(builder.get_result(0), Some(3.0));
500
501        // Test using set method
502        builder.set("a", 5.0).unwrap();
503        builder.eval(&ctx).unwrap();
504        assert_eq!(builder.get_result(0), Some(7.0));
505
506        builder.set("b", 10.0).unwrap();
507        builder.eval(&ctx).unwrap();
508        assert_eq!(builder.get_result(0), Some(15.0));
509
510        // Test error on unknown parameter
511        assert!(builder.set("c", 100.0).is_err());
512    }
513
514    #[test]
515    fn test_arena_batch_local_expression_functions() {
516        let arena = Bump::new();
517        let mut builder = Expression::new(&arena);
518
519        // Register a local function
520        builder
521            .register_expression_function("double", &["x"], "x * 2")
522            .unwrap();
523        builder
524            .register_expression_function("add_one", &["x"], "x + 1")
525            .unwrap();
526
527        // Use the functions in expressions
528        builder.add_expression("double(5)").unwrap();
529        builder.add_expression("add_one(10)").unwrap();
530        builder.add_expression("double(add_one(3))").unwrap(); // Nested
531
532        // Evaluate
533        let ctx = Rc::new(EvalContext::new());
534        builder.eval(&ctx).unwrap();
535
536        // Check results
537        assert_eq!(builder.get_result(0), Some(10.0)); // double(5) = 10
538        assert_eq!(builder.get_result(1), Some(11.0)); // add_one(10) = 11
539        assert_eq!(builder.get_result(2), Some(8.0)); // double(add_one(3)) = double(4) = 8
540
541        // Test removing a function
542        assert!(builder.unregister_expression_function("double").unwrap());
543        assert!(!builder.unregister_expression_function("double").unwrap()); // Already removed
544    }
545
546    #[test]
547    fn test_arena_batch_local_functions() {
548        let arena = Bump::new();
549
550        // Create basic context
551        let ctx = Rc::new(EvalContext::new());
552
553        // Test: Local arena function
554        {
555            let mut builder = Expression::new(&arena);
556            // Register local function
557            builder
558                .register_expression_function("calc", &["x"], "x * 3")
559                .unwrap();
560            builder.add_expression("calc(5)").unwrap();
561            builder.eval(&ctx).unwrap();
562            assert_eq!(builder.get_result(0), Some(15.0)); // x * 3 = 15
563        }
564    }
565}
566
567// Implement Drop to manually free heap-allocated strings in ExpressionFunction objects
568// This prevents memory leaks when the batch contains expression functions
569impl<'arena> Drop for Expression<'arena> {
570    fn drop(&mut self) {
571        // Manually clear local functions to ensure String objects are dropped
572        // This is important because if the arena is dropped before this builder,
573        // the String objects won't get their destructors called
574        if let Some(funcs) = self.local_functions {
575            funcs.borrow_mut().clear();
576        }
577    }
578}