Skip to main content

arael_sym/
eval.rs

1use std::collections::{BTreeSet, HashMap};
2use super::{Expr, E};
3
4impl Expr {
5    /// Evaluate the expression numerically given variable bindings.
6    ///
7    /// Returns `Err` if any symbol in the expression is not bound in `vars`.
8    pub fn eval(&self, vars: &HashMap<&str, f64>) -> Result<f64, String> {
9        match self {
10            Expr::Sym(name) => {
11                vars.get(name.as_str()).copied()
12                    .ok_or_else(|| format!("unbound symbol: {name}"))
13            }
14            Expr::Const(v) => Ok(*v),
15            Expr::NamedConst { value, .. } => Ok(*value),
16            Expr::Neg(a) => Ok(-a.eval(vars)?),
17            Expr::Add(a, b) => Ok(a.eval(vars)? + b.eval(vars)?),
18            Expr::Sub(a, b) => Ok(a.eval(vars)? - b.eval(vars)?),
19            Expr::Mul(a, b) => Ok(a.eval(vars)? * b.eval(vars)?),
20            Expr::Div(a, b) => Ok(a.eval(vars)? / b.eval(vars)?),
21            Expr::Pow(a, b) => Ok(a.eval(vars)?.powf(b.eval(vars)?)),
22            Expr::Sin(a) => Ok(a.eval(vars)?.sin()),
23            Expr::Cos(a) => Ok(a.eval(vars)?.cos()),
24            Expr::Tan(a) => Ok(a.eval(vars)?.tan()),
25            Expr::Asin(a) => Ok(a.eval(vars)?.asin()),
26            Expr::Acos(a) => Ok(a.eval(vars)?.acos()),
27            Expr::Atan(a) => Ok(a.eval(vars)?.atan()),
28            Expr::Atan2(y, x) => Ok(y.eval(vars)?.atan2(x.eval(vars)?)),
29            Expr::Sinh(a) => Ok(a.eval(vars)?.sinh()),
30            Expr::Cosh(a) => Ok(a.eval(vars)?.cosh()),
31            Expr::Tanh(a) => Ok(a.eval(vars)?.tanh()),
32            Expr::Exp(a) => Ok(a.eval(vars)?.exp()),
33            Expr::Ln(a) => Ok(a.eval(vars)?.ln()),
34            Expr::Log2(a) => Ok(a.eval(vars)?.log2()),
35            Expr::Log10(a) => Ok(a.eval(vars)?.log10()),
36            Expr::Sqrt(a) => Ok(a.eval(vars)?.sqrt()),
37            Expr::Abs(a) => Ok(a.eval(vars)?.abs()),
38            Expr::Heaviside(a) => {
39                let v = a.eval(vars)?;
40                Ok(if v < 0.0 { 0.0 } else { 1.0 })
41            }
42            Expr::Clamp(val, lo, hi) => {
43                let v = val.eval(vars)?;
44                let l = lo.eval(vars)?;
45                let h = hi.eval(vars)?;
46                Ok(v.clamp(l, h))
47            }
48            Expr::Func { params, kind, args, .. } => {
49                if let Some(f) = kind.eval_fn() {
50                    let vals: Result<Vec<f64>, _> = args.iter().map(|a| a.eval(vars)).collect();
51                    Ok(f(&vals?))
52                } else {
53                    let body = kind.body().expect("FuncKind must have body or eval_fn");
54                    super::expand_func(params, body, args).eval(vars)
55                }
56            }
57        }
58    }
59
60    /// Substitute all occurrences of the named variable with `replacement`.
61    ///
62    /// `var` can be any [`AsVarName`] -- a `&str`, a `String`, or an
63    /// [`E`] handle wrapping a `Sym` node. Returns a new expression
64    /// with the substitution applied throughout.
65    pub fn subs(&self, var: impl crate::AsVarName, replacement: &E) -> E {
66        self.subs_by_name(var.var_name(), replacement)
67    }
68
69    fn subs_by_name(&self, var: &str, replacement: &E) -> E {
70        match self {
71            Expr::Sym(name) if name == var => replacement.clone(),
72            Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => E::new(self.clone()),
73            Expr::Neg(a) => -a.subs_by_name(var, replacement),
74            Expr::Add(a, b) => a.subs_by_name(var, replacement) + b.subs_by_name(var, replacement),
75            Expr::Sub(a, b) => a.subs_by_name(var, replacement) - b.subs_by_name(var, replacement),
76            Expr::Mul(a, b) => a.subs_by_name(var, replacement) * b.subs_by_name(var, replacement),
77            Expr::Div(a, b) => a.subs_by_name(var, replacement) / b.subs_by_name(var, replacement),
78            Expr::Pow(a, b) => E::new(Expr::Pow(a.subs_by_name(var, replacement), b.subs_by_name(var, replacement))),
79            Expr::Sin(a) => E::new(Expr::Sin(a.subs_by_name(var, replacement))),
80            Expr::Cos(a) => E::new(Expr::Cos(a.subs_by_name(var, replacement))),
81            Expr::Tan(a) => E::new(Expr::Tan(a.subs_by_name(var, replacement))),
82            Expr::Asin(a) => E::new(Expr::Asin(a.subs_by_name(var, replacement))),
83            Expr::Acos(a) => E::new(Expr::Acos(a.subs_by_name(var, replacement))),
84            Expr::Atan(a) => E::new(Expr::Atan(a.subs_by_name(var, replacement))),
85            Expr::Atan2(y, x) => E::new(Expr::Atan2(y.subs_by_name(var, replacement), x.subs_by_name(var, replacement))),
86            Expr::Sinh(a) => E::new(Expr::Sinh(a.subs_by_name(var, replacement))),
87            Expr::Cosh(a) => E::new(Expr::Cosh(a.subs_by_name(var, replacement))),
88            Expr::Tanh(a) => E::new(Expr::Tanh(a.subs_by_name(var, replacement))),
89            Expr::Exp(a) => E::new(Expr::Exp(a.subs_by_name(var, replacement))),
90            Expr::Ln(a) => E::new(Expr::Ln(a.subs_by_name(var, replacement))),
91            Expr::Log2(a) => E::new(Expr::Log2(a.subs_by_name(var, replacement))),
92            Expr::Log10(a) => E::new(Expr::Log10(a.subs_by_name(var, replacement))),
93            Expr::Sqrt(a) => E::new(Expr::Sqrt(a.subs_by_name(var, replacement))),
94            Expr::Abs(a) => E::new(Expr::Abs(a.subs_by_name(var, replacement))),
95            Expr::Heaviside(a) => E::new(Expr::Heaviside(a.subs_by_name(var, replacement))),
96            Expr::Clamp(a, lo, hi) => E::new(Expr::Clamp(a.subs_by_name(var, replacement), lo.subs_by_name(var, replacement), hi.subs_by_name(var, replacement))),
97            Expr::Func { name, params, kind, args } => {
98                let new_args = args.iter().map(|a| a.subs_by_name(var, replacement)).collect();
99                E::new(Expr::Func { name: name.clone(), params: params.clone(), kind: kind.clone(), args: new_args })
100            }
101        }
102    }
103
104    /// Collect all free (unbound) variable names in the expression.
105    ///
106    /// Returns a sorted set of variable name strings.
107    pub fn free_vars(&self) -> BTreeSet<String> {
108        let mut set = BTreeSet::new();
109        self.collect_vars(&mut set);
110        set
111    }
112
113    fn collect_vars(&self, set: &mut BTreeSet<String>) {
114        match self {
115            Expr::Sym(name) => { set.insert(name.clone()); }
116            Expr::Const(_) | Expr::NamedConst { .. } => {}
117            Expr::Neg(a) => a.collect_vars(set),
118            Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
119            | Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
120                a.collect_vars(set);
121                b.collect_vars(set);
122            }
123            Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
124            | Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
125            | Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
126            | Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
127            | Expr::Sqrt(a) | Expr::Abs(a)
128            | Expr::Heaviside(a) => {
129                a.collect_vars(set);
130            }
131            Expr::Clamp(a, lo, hi) => {
132                a.collect_vars(set);
133                lo.collect_vars(set);
134                hi.collect_vars(set);
135            }
136            Expr::Func { args, .. } => {
137                for arg in args { arg.collect_vars(set); }
138            }
139        }
140    }
141
142    /// Differentiate with respect to multiple variables, returning a vector.
143    ///
144    /// Equivalent to calling [`Expr::diff`] for each variable in order.
145    pub fn diff_all(&self, vars: &[&str]) -> Vec<E> {
146        vars.iter().map(|v| self.diff(v)).collect()
147    }
148}