Skip to main content

oximo_expr/
eval.rs

1use thiserror::Error;
2
3use crate::arena::{ExprArena, ExprId, ExprNode, ParamId, VarId};
4
5#[derive(Debug, Error)]
6pub enum EvalError {
7    #[error("variable {0:?} has no value bound in the evaluation context")]
8    UnboundVar(VarId),
9    #[error("parameter {0:?} has no value bound in the evaluation context")]
10    UnboundParam(ParamId),
11}
12
13/// Source of variable and parameter values during expression evaluation.
14pub trait EvalContext {
15    fn var(&self, v: VarId) -> Option<f64>;
16    fn param(&self, p: ParamId) -> Option<f64>;
17}
18
19impl EvalContext for &[f64] {
20    fn var(&self, v: VarId) -> Option<f64> {
21        self.get(v.index()).copied()
22    }
23    fn param(&self, _p: ParamId) -> Option<f64> {
24        None
25    }
26}
27
28/// Evaluate `id` to an `f64`, pulling variable / parameter values from `ctx`.
29///
30/// # Errors
31///
32/// Returns an [`EvalError`] if a needed variable or parameter is missing from the context.
33pub fn evaluate<C: EvalContext>(arena: &ExprArena, id: ExprId, ctx: &C) -> Result<f64, EvalError> {
34    Ok(match arena.get(id) {
35        ExprNode::Const(c) => *c,
36        ExprNode::Var(v) => ctx.var(*v).ok_or(EvalError::UnboundVar(*v))?,
37        ExprNode::Param(p) => ctx
38            .param(*p)
39            .or_else(|| arena.try_param_value(*p))
40            .ok_or(EvalError::UnboundParam(*p))?,
41        ExprNode::Add(children) => children
42            .iter()
43            .try_fold(0.0, |acc, c| Ok::<_, EvalError>(acc + evaluate(arena, *c, ctx)?))?,
44        ExprNode::Mul(children) => children
45            .iter()
46            .try_fold(1.0, |acc, c| Ok::<_, EvalError>(acc * evaluate(arena, *c, ctx)?))?,
47        ExprNode::Neg(inner) => -evaluate(arena, *inner, ctx)?,
48        ExprNode::Pow(base, exp) => evaluate(arena, *base, ctx)?.powf(evaluate(arena, *exp, ctx)?),
49        ExprNode::Div(num, den) => evaluate(arena, *num, ctx)? / evaluate(arena, *den, ctx)?,
50        ExprNode::Sin(inner) => evaluate(arena, *inner, ctx)?.sin(),
51        ExprNode::Cos(inner) => evaluate(arena, *inner, ctx)?.cos(),
52        ExprNode::Exp(inner) => evaluate(arena, *inner, ctx)?.exp(),
53        ExprNode::Log(inner) => evaluate(arena, *inner, ctx)?.ln(),
54        ExprNode::Abs(inner) => evaluate(arena, *inner, ctx)?.abs(),
55        ExprNode::Linear { coeffs, constant } => {
56            let mut acc = *constant;
57            for (v, c) in coeffs {
58                acc += c * ctx.var(*v).ok_or(EvalError::UnboundVar(*v))?;
59            }
60            acc
61        }
62    })
63}