1use crate::error::ExprError;
7use crate::eval::iterative::{EvalEngine, eval_with_engine};
8use crate::types::{BatchParamMap, TryIntoHeaplessString};
9use crate::{AstExpr, EvalContext, Real};
10use alloc::rc::Rc;
11use alloc::string::{String, ToString};
12use alloc::vec::Vec;
13use bumpalo::Bump;
14use core::cell::RefCell;
15
16#[derive(Clone, Debug)]
18pub struct Param {
19 pub name: String,
20 pub value: Real,
21}
22
23pub struct Expression<'arena> {
28 arena: &'arena Bump,
30
31 expressions: Vec<(&'arena str, &'arena AstExpr<'arena>)>,
33
34 params: Vec<Param>,
36
37 results: Vec<Real>,
39
40 engine: EvalEngine<'arena>,
42
43 local_functions: Option<&'arena RefCell<crate::types::ExpressionFunctionMap>>,
45}
46
47#[deprecated(since = "0.2.0", note = "renamed to Expression")]
49pub type ArenaBatchBuilder<'arena> = Expression<'arena>;
50
51impl<'arena> Expression<'arena> {
52 pub fn new(arena: &'arena Bump) -> Self {
54 Expression {
55 arena,
56 expressions: Vec::new(),
57 params: Vec::new(),
58 results: Vec::new(),
59 engine: EvalEngine::new(arena),
60 local_functions: None,
61 }
62 }
63
64 pub fn add_expression(&mut self, expr: &str) -> Result<usize, ExprError> {
69 let ast = crate::engine::parse_expression(expr, self.arena)?;
71
72 let expr_str = self.arena.alloc_str(expr);
74
75 let arena_ast = self.arena.alloc(ast);
77
78 let idx = self.expressions.len();
79 self.expressions.push((expr_str, arena_ast));
80 self.results.push(0.0); Ok(idx)
82 }
83
84 pub fn add_parameter(&mut self, name: &str, initial_value: Real) -> Result<usize, ExprError> {
89 if self.params.iter().any(|p| p.name == name) {
91 return Err(ExprError::DuplicateParameter(name.to_string()));
92 }
93 let idx = self.params.len();
94 self.params.push(Param {
95 name: name.to_string(),
96 value: initial_value,
97 });
98 Ok(idx)
99 }
100
101 pub fn set_param(&mut self, idx: usize, value: Real) -> Result<(), ExprError> {
103 self.params
104 .get_mut(idx)
105 .ok_or(ExprError::InvalidParameterIndex(idx))?
106 .value = value;
107 Ok(())
108 }
109
110 pub fn set_param_by_name(&mut self, name: &str, value: Real) -> Result<(), ExprError> {
112 self.params
113 .iter_mut()
114 .find(|p| p.name == name)
115 .ok_or_else(|| ExprError::UnknownVariable {
116 name: name.to_string(),
117 })?
118 .value = value;
119 Ok(())
120 }
121
122 pub fn eval(&mut self, base_ctx: &Rc<EvalContext>) -> Result<(), ExprError> {
124 let mut param_map = BatchParamMap::new();
126 for param in &self.params {
127 let hname = param.name.as_str().try_into_heapless()?;
128 param_map
129 .insert(hname, param.value)
130 .map_err(|_| ExprError::CapacityExceeded("parameter overrides"))?;
131 }
132
133 self.engine.set_param_overrides(param_map);
135
136 self.engine.set_local_functions(self.local_functions);
138
139 for (i, (_, ast)) in self.expressions.iter().enumerate() {
141 match eval_with_engine(ast, Some(base_ctx.clone()), &mut self.engine) {
142 Ok(value) => self.results[i] = value,
143 Err(e) => {
144 self.engine.clear_param_overrides();
146 return Err(e);
147 }
148 }
149 }
150
151 self.engine.clear_param_overrides();
153
154 Ok(())
155 }
156
157 pub fn get_result(&self, expr_idx: usize) -> Option<Real> {
159 self.results.get(expr_idx).copied()
160 }
161
162 pub fn get_all_results(&self) -> &[Real] {
164 &self.results
165 }
166
167 pub fn param_count(&self) -> usize {
169 self.params.len()
170 }
171
172 pub fn expression_count(&self) -> usize {
174 self.expressions.len()
175 }
176
177 pub fn register_expression_function(
187 &mut self,
188 name: &str,
189 params: &[&str],
190 body: &str,
191 ) -> Result<(), ExprError> {
192 use crate::types::{ExpressionFunction, ExpressionFunctionMap, TryIntoFunctionName};
193
194 if self.local_functions.is_none() {
196 let map = self.arena.alloc(RefCell::new(ExpressionFunctionMap::new()));
197 self.local_functions = Some(map);
198 }
199
200 let param_buffer = if params.is_empty() {
202 None
203 } else {
204 let slice: &mut [(crate::types::HString, crate::Real)] =
206 self.arena.alloc_slice_fill_default(params.len());
207
208 for (i, param_name) in params.iter().enumerate() {
210 slice[i].0 = param_name.try_into_heapless()?;
211 slice[i].1 = 0.0; }
213
214 Some(slice as *mut _)
215 };
216
217 let func_name = name.try_into_function_name()?;
219 let expr_func = ExpressionFunction {
220 name: func_name.clone(),
221 params: params.iter().map(|s| s.to_string()).collect(),
222 expression: body.to_string(),
223 description: None,
224 param_buffer,
225 };
226
227 self.local_functions
229 .unwrap()
230 .borrow_mut()
231 .insert(func_name, expr_func)
232 .map_err(|_| ExprError::Other("Too many expression functions".to_string()))?;
233 Ok(())
234 }
235
236 pub fn unregister_expression_function(&mut self, name: &str) -> Result<bool, ExprError> {
246 use crate::types::TryIntoFunctionName;
247
248 if let Some(map) = self.local_functions {
249 let func_name = name.try_into_function_name()?;
250 Ok(map.borrow_mut().remove(&func_name).is_some())
251 } else {
252 Ok(false)
253 }
254 }
255
256 pub fn arena_allocated_bytes(&self) -> usize {
258 self.arena.allocated_bytes()
259 }
260
261 pub fn clear(&mut self) {
283 self.expressions.clear();
284 self.params.clear();
285 self.results.clear();
286
287 if let Some(funcs) = self.local_functions {
289 funcs.borrow_mut().clear();
290 }
291 }
292
293 pub fn eval_simple(expr: &str, arena: &'arena Bump) -> Result<Real, ExprError> {
309 let ctx = Rc::new(EvalContext::new());
310 Self::eval_with_context(expr, &ctx, arena)
311 }
312
313 pub fn eval_with_context(
331 expr: &str,
332 ctx: &Rc<EvalContext>,
333 arena: &'arena Bump,
334 ) -> Result<Real, ExprError> {
335 let mut builder = Self::new(arena);
336 builder.add_expression(expr)?;
337 builder.eval(ctx)?;
338 builder
339 .get_result(0)
340 .ok_or(ExprError::Other("No result".to_string()))
341 }
342
343 pub fn eval_with_params(
361 expr: &str,
362 params: &[(&str, Real)],
363 ctx: &Rc<EvalContext>,
364 arena: &'arena Bump,
365 ) -> Result<Real, ExprError> {
366 let mut builder = Self::new(arena);
367
368 for (name, value) in params {
370 builder.add_parameter(name, *value)?;
371 }
372
373 builder.add_expression(expr)?;
374 builder.eval(ctx)?;
375 builder
376 .get_result(0)
377 .ok_or(ExprError::Other("No result".to_string()))
378 }
379
380 pub fn set(&mut self, name: &str, value: Real) -> Result<(), ExprError> {
395 self.set_param_by_name(name, value)
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use bumpalo::Bump;
403
404 #[test]
407 fn test_arena_batch_eval_simple() {
408 let arena = Bump::new();
409
410 assert_eq!(Expression::eval_simple("2 + 3 * 4", &arena).unwrap(), 14.0);
412 assert_eq!(
413 Expression::eval_simple("(2 + 3) * 4", &arena).unwrap(),
414 20.0
415 );
416 assert_eq!(Expression::eval_simple("10 / 2 - 3", &arena).unwrap(), 2.0);
417
418 #[cfg(feature = "libm")]
420 {
421 assert!(Expression::eval_simple("pi", &arena).unwrap() - std::f64::consts::PI < 0.0001);
422 assert!(Expression::eval_simple("e", &arena).unwrap() - std::f64::consts::E < 0.0001);
423 }
424 }
425
426 #[test]
427 fn test_arena_batch_eval_with_context() {
428 let arena = Bump::new();
429 let mut ctx = EvalContext::new();
430
431 let _ = ctx.set_parameter("x", 10.0);
433 let _ = ctx.set_parameter("y", 20.0);
434
435 let ctx_rc = Rc::new(ctx);
436
437 assert_eq!(
439 Expression::eval_with_context("x + y", &ctx_rc, &arena).unwrap(),
440 30.0
441 );
442 assert_eq!(
443 Expression::eval_with_context("x * 2 + y / 2", &ctx_rc, &arena).unwrap(),
444 30.0
445 );
446
447 #[cfg(feature = "libm")]
449 {
450 assert_eq!(
451 Expression::eval_with_context("sin(0)", &ctx_rc, &arena).unwrap(),
452 0.0
453 );
454 assert_eq!(
455 Expression::eval_with_context("cos(0)", &ctx_rc, &arena).unwrap(),
456 1.0
457 );
458 }
459 }
460
461 #[test]
462 fn test_arena_batch_eval_with_params() {
463 let arena = Bump::new();
464 let ctx = Rc::new(EvalContext::new());
465
466 let params = [("x", 3.0), ("y", 4.0)];
468 assert_eq!(
469 Expression::eval_with_params("x + y", ¶ms, &ctx, &arena).unwrap(),
470 7.0
471 );
472
473 assert_eq!(
475 Expression::eval_with_params("x^2 + y^2", ¶ms, &ctx, &arena).unwrap(),
476 25.0
477 );
478
479 let params3 = [("a", 2.0), ("b", 3.0), ("c", 5.0)];
481 assert_eq!(
482 Expression::eval_with_params("a * b + c", ¶ms3, &ctx, &arena).unwrap(),
483 11.0
484 );
485 }
486
487 #[test]
488 fn test_arena_batch_set_convenience_method() {
489 let arena = Bump::new();
490 let ctx = Rc::new(EvalContext::new());
491
492 let mut builder = Expression::new(&arena);
493 builder.add_parameter("a", 1.0).unwrap();
494 builder.add_parameter("b", 2.0).unwrap();
495 builder.add_expression("a + b").unwrap();
496
497 builder.eval(&ctx).unwrap();
499 assert_eq!(builder.get_result(0), Some(3.0));
500
501 builder.set("a", 5.0).unwrap();
503 builder.eval(&ctx).unwrap();
504 assert_eq!(builder.get_result(0), Some(7.0));
505
506 builder.set("b", 10.0).unwrap();
507 builder.eval(&ctx).unwrap();
508 assert_eq!(builder.get_result(0), Some(15.0));
509
510 assert!(builder.set("c", 100.0).is_err());
512 }
513
514 #[test]
515 fn test_arena_batch_local_expression_functions() {
516 let arena = Bump::new();
517 let mut builder = Expression::new(&arena);
518
519 builder
521 .register_expression_function("double", &["x"], "x * 2")
522 .unwrap();
523 builder
524 .register_expression_function("add_one", &["x"], "x + 1")
525 .unwrap();
526
527 builder.add_expression("double(5)").unwrap();
529 builder.add_expression("add_one(10)").unwrap();
530 builder.add_expression("double(add_one(3))").unwrap(); let ctx = Rc::new(EvalContext::new());
534 builder.eval(&ctx).unwrap();
535
536 assert_eq!(builder.get_result(0), Some(10.0)); assert_eq!(builder.get_result(1), Some(11.0)); assert_eq!(builder.get_result(2), Some(8.0)); assert!(builder.unregister_expression_function("double").unwrap());
543 assert!(!builder.unregister_expression_function("double").unwrap()); }
545
546 #[test]
547 fn test_arena_batch_local_functions() {
548 let arena = Bump::new();
549
550 let ctx = Rc::new(EvalContext::new());
552
553 {
555 let mut builder = Expression::new(&arena);
556 builder
558 .register_expression_function("calc", &["x"], "x * 3")
559 .unwrap();
560 builder.add_expression("calc(5)").unwrap();
561 builder.eval(&ctx).unwrap();
562 assert_eq!(builder.get_result(0), Some(15.0)); }
564 }
565}
566
567impl<'arena> Drop for Expression<'arena> {
570 fn drop(&mut self) {
571 if let Some(funcs) = self.local_functions {
575 funcs.borrow_mut().clear();
576 }
577 }
578}