Skip to main content

oximo_expr/
eval.rs

1use thiserror::Error;
2
3use crate::arena::{ExprArena, ExprId, ExprNode, ParamId, VarId};
4
5#[derive(Debug, Error)]
6pub enum EvalError {
7    #[error("variable {0:?} has no value bound in the evaluation context")]
8    UnboundVar(VarId),
9    #[error("parameter {0:?} has no value bound in the evaluation context")]
10    UnboundParam(ParamId),
11}
12
13/// Source of variable and parameter values during expression evaluation.
14pub trait EvalContext {
15    fn var(&self, v: VarId) -> Option<f64>;
16    fn param(&self, p: ParamId) -> Option<f64>;
17}
18
19impl EvalContext for &[f64] {
20    fn var(&self, v: VarId) -> Option<f64> {
21        self.get(v.index()).copied()
22    }
23    fn param(&self, _p: ParamId) -> Option<f64> {
24        None
25    }
26}
27
28/// Evaluate `id` to an `f64`, pulling variable / parameter values from `ctx`.
29///
30/// # Errors
31///
32/// Returns an [`EvalError`] if a needed variable or parameter is missing from the context.
33pub fn evaluate<C: EvalContext>(arena: &ExprArena, id: ExprId, ctx: &C) -> Result<f64, EvalError> {
34    Ok(match arena.get(id) {
35        ExprNode::Const(c) => *c,
36        ExprNode::Var(v) => ctx.var(*v).ok_or(EvalError::UnboundVar(*v))?,
37        ExprNode::Param(p) => ctx.param(*p).ok_or(EvalError::UnboundParam(*p))?,
38        ExprNode::Add(children) => children
39            .iter()
40            .try_fold(0.0, |acc, c| Ok::<_, EvalError>(acc + evaluate(arena, *c, ctx)?))?,
41        ExprNode::Mul(children) => children
42            .iter()
43            .try_fold(1.0, |acc, c| Ok::<_, EvalError>(acc * evaluate(arena, *c, ctx)?))?,
44        ExprNode::Neg(inner) => -evaluate(arena, *inner, ctx)?,
45        ExprNode::Pow(base, exp) => evaluate(arena, *base, ctx)?.powf(evaluate(arena, *exp, ctx)?),
46        ExprNode::Sin(inner) => evaluate(arena, *inner, ctx)?.sin(),
47        ExprNode::Cos(inner) => evaluate(arena, *inner, ctx)?.cos(),
48        ExprNode::Exp(inner) => evaluate(arena, *inner, ctx)?.exp(),
49        ExprNode::Log(inner) => evaluate(arena, *inner, ctx)?.ln(),
50        ExprNode::Linear { coeffs, constant } => {
51            let mut acc = *constant;
52            for (v, c) in coeffs {
53                acc += c * ctx.var(*v).ok_or(EvalError::UnboundVar(*v))?;
54            }
55            acc
56        }
57    })
58}