use crate::error::ExprError;
use crate::eval::iterative::{EvalEngine, eval_with_engine};
use crate::types::{BatchParamMap, TryIntoHeaplessString};
use crate::{AstExpr, EvalContext, Real};
use alloc::rc::Rc;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use bumpalo::Bump;
use core::cell::RefCell;
#[derive(Clone, Debug)]
pub struct Param {
pub name: String,
pub value: Real,
}
pub struct Expression<'arena> {
arena: &'arena Bump,
expressions: Vec<(&'arena str, &'arena AstExpr<'arena>)>,
params: Vec<Param>,
results: Vec<Real>,
engine: EvalEngine<'arena>,
local_functions: Option<&'arena RefCell<crate::types::ExpressionFunctionMap>>,
}
#[deprecated(since = "0.2.0", note = "renamed to Expression")]
pub type ArenaBatchBuilder<'arena> = Expression<'arena>;
impl<'arena> Expression<'arena> {
pub fn new(arena: &'arena Bump) -> Self {
Expression {
arena,
expressions: Vec::new(),
params: Vec::new(),
results: Vec::new(),
engine: EvalEngine::new(arena),
local_functions: None,
}
}
pub fn add_expression(&mut self, expr: &str) -> Result<usize, ExprError> {
let ast = crate::engine::parse_expression(expr, self.arena)?;
let expr_str = self.arena.alloc_str(expr);
let arena_ast = self.arena.alloc(ast);
let idx = self.expressions.len();
self.expressions.push((expr_str, arena_ast));
self.results.push(0.0); Ok(idx)
}
pub fn add_parameter(&mut self, name: &str, initial_value: Real) -> Result<usize, ExprError> {
if self.params.iter().any(|p| p.name == name) {
return Err(ExprError::DuplicateParameter(name.to_string()));
}
let idx = self.params.len();
self.params.push(Param {
name: name.to_string(),
value: initial_value,
});
Ok(idx)
}
pub fn set_param(&mut self, idx: usize, value: Real) -> Result<(), ExprError> {
self.params
.get_mut(idx)
.ok_or(ExprError::InvalidParameterIndex(idx))?
.value = value;
Ok(())
}
pub fn set_param_by_name(&mut self, name: &str, value: Real) -> Result<(), ExprError> {
self.params
.iter_mut()
.find(|p| p.name == name)
.ok_or_else(|| ExprError::UnknownVariable {
name: name.to_string(),
})?
.value = value;
Ok(())
}
pub fn eval(&mut self, base_ctx: &Rc<EvalContext>) -> Result<(), ExprError> {
let mut param_map = BatchParamMap::new();
for param in &self.params {
let hname = param.name.as_str().try_into_heapless()?;
param_map
.insert(hname, param.value)
.map_err(|_| ExprError::CapacityExceeded("parameter overrides"))?;
}
self.engine.set_param_overrides(param_map);
self.engine.set_local_functions(self.local_functions);
for (i, (_, ast)) in self.expressions.iter().enumerate() {
match eval_with_engine(ast, Some(base_ctx.clone()), &mut self.engine) {
Ok(value) => self.results[i] = value,
Err(e) => {
self.engine.clear_param_overrides();
return Err(e);
}
}
}
self.engine.clear_param_overrides();
Ok(())
}
pub fn get_result(&self, expr_idx: usize) -> Option<Real> {
self.results.get(expr_idx).copied()
}
pub fn get_all_results(&self) -> &[Real] {
&self.results
}
pub fn param_count(&self) -> usize {
self.params.len()
}
pub fn expression_count(&self) -> usize {
self.expressions.len()
}
pub fn register_expression_function(
&mut self,
name: &str,
params: &[&str],
body: &str,
) -> Result<(), ExprError> {
use crate::types::{ExpressionFunction, ExpressionFunctionMap, TryIntoFunctionName};
if self.local_functions.is_none() {
let map = self.arena.alloc(RefCell::new(ExpressionFunctionMap::new()));
self.local_functions = Some(map);
}
let param_buffer = if params.is_empty() {
None
} else {
let slice: &mut [(crate::types::HString, crate::Real)] =
self.arena.alloc_slice_fill_default(params.len());
for (i, param_name) in params.iter().enumerate() {
slice[i].0 = param_name.try_into_heapless()?;
slice[i].1 = 0.0; }
Some(slice as *mut _)
};
let func_name = name.try_into_function_name()?;
let expr_func = ExpressionFunction {
name: func_name.clone(),
params: params.iter().map(|s| s.to_string()).collect(),
expression: body.to_string(),
description: None,
param_buffer,
};
self.local_functions
.unwrap()
.borrow_mut()
.insert(func_name, expr_func)
.map_err(|_| ExprError::Other("Too many expression functions".to_string()))?;
Ok(())
}
pub fn unregister_expression_function(&mut self, name: &str) -> Result<bool, ExprError> {
use crate::types::TryIntoFunctionName;
if let Some(map) = self.local_functions {
let func_name = name.try_into_function_name()?;
Ok(map.borrow_mut().remove(&func_name).is_some())
} else {
Ok(false)
}
}
pub fn arena_allocated_bytes(&self) -> usize {
self.arena.allocated_bytes()
}
pub fn clear(&mut self) {
self.expressions.clear();
self.params.clear();
self.results.clear();
if let Some(funcs) = self.local_functions {
funcs.borrow_mut().clear();
}
}
pub fn eval_simple(expr: &str, arena: &'arena Bump) -> Result<Real, ExprError> {
let ctx = Rc::new(EvalContext::new());
Self::eval_with_context(expr, &ctx, arena)
}
pub fn eval_with_context(
expr: &str,
ctx: &Rc<EvalContext>,
arena: &'arena Bump,
) -> Result<Real, ExprError> {
let mut builder = Self::new(arena);
builder.add_expression(expr)?;
builder.eval(ctx)?;
builder
.get_result(0)
.ok_or(ExprError::Other("No result".to_string()))
}
pub fn eval_with_params(
expr: &str,
params: &[(&str, Real)],
ctx: &Rc<EvalContext>,
arena: &'arena Bump,
) -> Result<Real, ExprError> {
let mut builder = Self::new(arena);
for (name, value) in params {
builder.add_parameter(name, *value)?;
}
builder.add_expression(expr)?;
builder.eval(ctx)?;
builder
.get_result(0)
.ok_or(ExprError::Other("No result".to_string()))
}
pub fn set(&mut self, name: &str, value: Real) -> Result<(), ExprError> {
self.set_param_by_name(name, value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bumpalo::Bump;
#[test]
fn test_arena_batch_eval_simple() {
let arena = Bump::new();
assert_eq!(Expression::eval_simple("2 + 3 * 4", &arena).unwrap(), 14.0);
assert_eq!(
Expression::eval_simple("(2 + 3) * 4", &arena).unwrap(),
20.0
);
assert_eq!(Expression::eval_simple("10 / 2 - 3", &arena).unwrap(), 2.0);
#[cfg(feature = "libm")]
{
assert!(Expression::eval_simple("pi", &arena).unwrap() - std::f64::consts::PI < 0.0001);
assert!(Expression::eval_simple("e", &arena).unwrap() - std::f64::consts::E < 0.0001);
}
}
#[test]
fn test_arena_batch_eval_with_context() {
let arena = Bump::new();
let mut ctx = EvalContext::new();
let _ = ctx.set_parameter("x", 10.0);
let _ = ctx.set_parameter("y", 20.0);
let ctx_rc = Rc::new(ctx);
assert_eq!(
Expression::eval_with_context("x + y", &ctx_rc, &arena).unwrap(),
30.0
);
assert_eq!(
Expression::eval_with_context("x * 2 + y / 2", &ctx_rc, &arena).unwrap(),
30.0
);
#[cfg(feature = "libm")]
{
assert_eq!(
Expression::eval_with_context("sin(0)", &ctx_rc, &arena).unwrap(),
0.0
);
assert_eq!(
Expression::eval_with_context("cos(0)", &ctx_rc, &arena).unwrap(),
1.0
);
}
}
#[test]
fn test_arena_batch_eval_with_params() {
let arena = Bump::new();
let ctx = Rc::new(EvalContext::new());
let params = [("x", 3.0), ("y", 4.0)];
assert_eq!(
Expression::eval_with_params("x + y", ¶ms, &ctx, &arena).unwrap(),
7.0
);
assert_eq!(
Expression::eval_with_params("x^2 + y^2", ¶ms, &ctx, &arena).unwrap(),
25.0
);
let params3 = [("a", 2.0), ("b", 3.0), ("c", 5.0)];
assert_eq!(
Expression::eval_with_params("a * b + c", ¶ms3, &ctx, &arena).unwrap(),
11.0
);
}
#[test]
fn test_arena_batch_set_convenience_method() {
let arena = Bump::new();
let ctx = Rc::new(EvalContext::new());
let mut builder = Expression::new(&arena);
builder.add_parameter("a", 1.0).unwrap();
builder.add_parameter("b", 2.0).unwrap();
builder.add_expression("a + b").unwrap();
builder.eval(&ctx).unwrap();
assert_eq!(builder.get_result(0), Some(3.0));
builder.set("a", 5.0).unwrap();
builder.eval(&ctx).unwrap();
assert_eq!(builder.get_result(0), Some(7.0));
builder.set("b", 10.0).unwrap();
builder.eval(&ctx).unwrap();
assert_eq!(builder.get_result(0), Some(15.0));
assert!(builder.set("c", 100.0).is_err());
}
#[test]
fn test_arena_batch_local_expression_functions() {
let arena = Bump::new();
let mut builder = Expression::new(&arena);
builder
.register_expression_function("double", &["x"], "x * 2")
.unwrap();
builder
.register_expression_function("add_one", &["x"], "x + 1")
.unwrap();
builder.add_expression("double(5)").unwrap();
builder.add_expression("add_one(10)").unwrap();
builder.add_expression("double(add_one(3))").unwrap();
let ctx = Rc::new(EvalContext::new());
builder.eval(&ctx).unwrap();
assert_eq!(builder.get_result(0), Some(10.0)); assert_eq!(builder.get_result(1), Some(11.0)); assert_eq!(builder.get_result(2), Some(8.0));
assert!(builder.unregister_expression_function("double").unwrap());
assert!(!builder.unregister_expression_function("double").unwrap()); }
#[test]
fn test_arena_batch_local_functions() {
let arena = Bump::new();
let ctx = Rc::new(EvalContext::new());
{
let mut builder = Expression::new(&arena);
builder
.register_expression_function("calc", &["x"], "x * 3")
.unwrap();
builder.add_expression("calc(5)").unwrap();
builder.eval(&ctx).unwrap();
assert_eq!(builder.get_result(0), Some(15.0)); }
}
}
impl<'arena> Drop for Expression<'arena> {
fn drop(&mut self) {
if let Some(funcs) = self.local_functions {
funcs.borrow_mut().clear();
}
}
}