use std::collections::{BTreeSet, HashMap};
use super::{Expr, E};
impl Expr {
pub fn eval(&self, vars: &HashMap<&str, f64>) -> Result<f64, String> {
match self {
Expr::Sym(name) => {
vars.get(name.as_str()).copied()
.ok_or_else(|| format!("unbound symbol: {name}"))
}
Expr::Const(v) => Ok(*v),
Expr::NamedConst { value, .. } => Ok(*value),
Expr::Neg(a) => Ok(-a.eval(vars)?),
Expr::Add(a, b) => Ok(a.eval(vars)? + b.eval(vars)?),
Expr::Sub(a, b) => Ok(a.eval(vars)? - b.eval(vars)?),
Expr::Mul(a, b) => Ok(a.eval(vars)? * b.eval(vars)?),
Expr::Div(a, b) => Ok(a.eval(vars)? / b.eval(vars)?),
Expr::Pow(a, b) => Ok(a.eval(vars)?.powf(b.eval(vars)?)),
Expr::Sin(a) => Ok(a.eval(vars)?.sin()),
Expr::Cos(a) => Ok(a.eval(vars)?.cos()),
Expr::Tan(a) => Ok(a.eval(vars)?.tan()),
Expr::Asin(a) => Ok(a.eval(vars)?.asin()),
Expr::Acos(a) => Ok(a.eval(vars)?.acos()),
Expr::Atan(a) => Ok(a.eval(vars)?.atan()),
Expr::Atan2(y, x) => Ok(y.eval(vars)?.atan2(x.eval(vars)?)),
Expr::Sinh(a) => Ok(a.eval(vars)?.sinh()),
Expr::Cosh(a) => Ok(a.eval(vars)?.cosh()),
Expr::Tanh(a) => Ok(a.eval(vars)?.tanh()),
Expr::Exp(a) => Ok(a.eval(vars)?.exp()),
Expr::Ln(a) => Ok(a.eval(vars)?.ln()),
Expr::Log2(a) => Ok(a.eval(vars)?.log2()),
Expr::Log10(a) => Ok(a.eval(vars)?.log10()),
Expr::Sqrt(a) => Ok(a.eval(vars)?.sqrt()),
Expr::Abs(a) => Ok(a.eval(vars)?.abs()),
Expr::Heaviside(a) => {
let v = a.eval(vars)?;
Ok(if v < 0.0 { 0.0 } else { 1.0 })
}
Expr::Clamp(val, lo, hi) => {
let v = val.eval(vars)?;
let l = lo.eval(vars)?;
let h = hi.eval(vars)?;
Ok(v.clamp(l, h))
}
Expr::Func { params, kind, args, .. } => {
if let Some(f) = kind.eval_fn() {
let vals: Result<Vec<f64>, _> = args.iter().map(|a| a.eval(vars)).collect();
Ok(f(&vals?))
} else {
let body = kind.body().expect("FuncKind must have body or eval_fn");
super::expand_func(params, body, args).eval(vars)
}
}
}
}
pub fn subs(&self, var: impl crate::AsVarName, replacement: &E) -> E {
self.subs_by_name(var.var_name(), replacement)
}
fn subs_by_name(&self, var: &str, replacement: &E) -> E {
match self {
Expr::Sym(name) if name == var => replacement.clone(),
Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => E::new(self.clone()),
Expr::Neg(a) => -a.subs_by_name(var, replacement),
Expr::Add(a, b) => a.subs_by_name(var, replacement) + b.subs_by_name(var, replacement),
Expr::Sub(a, b) => a.subs_by_name(var, replacement) - b.subs_by_name(var, replacement),
Expr::Mul(a, b) => a.subs_by_name(var, replacement) * b.subs_by_name(var, replacement),
Expr::Div(a, b) => a.subs_by_name(var, replacement) / b.subs_by_name(var, replacement),
Expr::Pow(a, b) => E::new(Expr::Pow(a.subs_by_name(var, replacement), b.subs_by_name(var, replacement))),
Expr::Sin(a) => E::new(Expr::Sin(a.subs_by_name(var, replacement))),
Expr::Cos(a) => E::new(Expr::Cos(a.subs_by_name(var, replacement))),
Expr::Tan(a) => E::new(Expr::Tan(a.subs_by_name(var, replacement))),
Expr::Asin(a) => E::new(Expr::Asin(a.subs_by_name(var, replacement))),
Expr::Acos(a) => E::new(Expr::Acos(a.subs_by_name(var, replacement))),
Expr::Atan(a) => E::new(Expr::Atan(a.subs_by_name(var, replacement))),
Expr::Atan2(y, x) => E::new(Expr::Atan2(y.subs_by_name(var, replacement), x.subs_by_name(var, replacement))),
Expr::Sinh(a) => E::new(Expr::Sinh(a.subs_by_name(var, replacement))),
Expr::Cosh(a) => E::new(Expr::Cosh(a.subs_by_name(var, replacement))),
Expr::Tanh(a) => E::new(Expr::Tanh(a.subs_by_name(var, replacement))),
Expr::Exp(a) => E::new(Expr::Exp(a.subs_by_name(var, replacement))),
Expr::Ln(a) => E::new(Expr::Ln(a.subs_by_name(var, replacement))),
Expr::Log2(a) => E::new(Expr::Log2(a.subs_by_name(var, replacement))),
Expr::Log10(a) => E::new(Expr::Log10(a.subs_by_name(var, replacement))),
Expr::Sqrt(a) => E::new(Expr::Sqrt(a.subs_by_name(var, replacement))),
Expr::Abs(a) => E::new(Expr::Abs(a.subs_by_name(var, replacement))),
Expr::Heaviside(a) => E::new(Expr::Heaviside(a.subs_by_name(var, replacement))),
Expr::Clamp(a, lo, hi) => E::new(Expr::Clamp(a.subs_by_name(var, replacement), lo.subs_by_name(var, replacement), hi.subs_by_name(var, replacement))),
Expr::Func { name, params, kind, args } => {
let new_args = args.iter().map(|a| a.subs_by_name(var, replacement)).collect();
E::new(Expr::Func { name: name.clone(), params: params.clone(), kind: kind.clone(), args: new_args })
}
}
}
pub fn free_vars(&self) -> BTreeSet<String> {
let mut set = BTreeSet::new();
self.collect_vars(&mut set);
set
}
fn collect_vars(&self, set: &mut BTreeSet<String>) {
match self {
Expr::Sym(name) => { set.insert(name.clone()); }
Expr::Const(_) | Expr::NamedConst { .. } => {}
Expr::Neg(a) => a.collect_vars(set),
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
a.collect_vars(set);
b.collect_vars(set);
}
Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a)
| Expr::Heaviside(a) => {
a.collect_vars(set);
}
Expr::Clamp(a, lo, hi) => {
a.collect_vars(set);
lo.collect_vars(set);
hi.collect_vars(set);
}
Expr::Func { args, .. } => {
for arg in args { arg.collect_vars(set); }
}
}
}
pub fn diff_all(&self, vars: &[&str]) -> Vec<E> {
vars.iter().map(|v| self.diff(v)).collect()
}
}