use crate::Scalar;
use super::graph::ExprGraph;
use super::node::{ExprId, Node};
impl ExprGraph {
pub fn eval<S: Scalar>(&self, expr: ExprId, inputs: &[S]) -> S {
let n = expr.0 as usize + 1;
let mut vals: Vec<S> = Vec::with_capacity(n);
for i in 0..n {
let v = match self.node(ExprId(i as u32)) {
Node::Var(idx) => inputs[idx as usize],
Node::Lit(bits) => S::from_f64(f64::from_bits(bits)),
Node::Add(a, b) => vals[a.0 as usize] + vals[b.0 as usize],
Node::Mul(a, b) => vals[a.0 as usize] * vals[b.0 as usize],
Node::Neg(a) => -vals[a.0 as usize],
Node::Recip(a) => vals[a.0 as usize].recip(),
Node::Sqrt(a) => vals[a.0 as usize].sqrt(),
Node::Sin(a) => vals[a.0 as usize].sin(),
Node::Atan2(y, x) => vals[y.0 as usize].atan2(vals[x.0 as usize]),
Node::Exp2(a) => {
let x = vals[a.0 as usize];
(x * S::from_f64(std::f64::consts::LN_2)).exp()
}
Node::Log2(a) => {
let x = vals[a.0 as usize];
x.ln() * S::from_f64(std::f64::consts::LOG2_E)
}
Node::Select(c, a, b) => {
S::select(vals[c.0 as usize], vals[a.0 as usize], vals[b.0 as usize])
}
};
vals.push(v);
}
vals[expr.0 as usize]
}
pub fn eval_many<S: Scalar>(&self, exprs: &[ExprId], inputs: &[S]) -> Vec<S> {
if exprs.is_empty() {
return Vec::new();
}
let max_id = exprs.iter().map(|e| e.0).max().unwrap() as usize;
let n = max_id + 1;
let mut vals: Vec<S> = Vec::with_capacity(n);
for i in 0..n {
let v = match self.node(ExprId(i as u32)) {
Node::Var(idx) => inputs[idx as usize],
Node::Lit(bits) => S::from_f64(f64::from_bits(bits)),
Node::Add(a, b) => vals[a.0 as usize] + vals[b.0 as usize],
Node::Mul(a, b) => vals[a.0 as usize] * vals[b.0 as usize],
Node::Neg(a) => -vals[a.0 as usize],
Node::Recip(a) => vals[a.0 as usize].recip(),
Node::Sqrt(a) => vals[a.0 as usize].sqrt(),
Node::Sin(a) => vals[a.0 as usize].sin(),
Node::Atan2(y, x) => vals[y.0 as usize].atan2(vals[x.0 as usize]),
Node::Exp2(a) => {
let x = vals[a.0 as usize];
(x * S::from_f64(std::f64::consts::LN_2)).exp()
}
Node::Log2(a) => {
let x = vals[a.0 as usize];
x.ln() * S::from_f64(std::f64::consts::LOG2_E)
}
Node::Select(c, a, b) => {
S::select(vals[c.0 as usize], vals[a.0 as usize], vals[b.0 as usize])
}
};
vals.push(v);
}
exprs.iter().map(|e| vals[e.0 as usize]).collect()
}
}
#[cfg(test)]
mod tests {
use super::graph::ExprGraph;
#[test]
fn eval_add_lits() {
let mut g = ExprGraph::new();
let a = g.lit(3.0);
let b = g.lit(4.0);
let sum = g.add(a, b);
let result: f64 = g.eval(sum, &[]);
assert!((result - 7.0).abs() < 1e-10);
}
#[test]
fn eval_with_vars() {
let mut g = ExprGraph::new();
let x = g.var(0);
let y = g.var(1);
let sum = g.add(x, y);
let prod = g.mul(sum, x);
let result: f64 = g.eval(prod, &[3.0, 4.0]);
assert!((result - 21.0).abs() < 1e-10);
}
#[test]
fn eval_sqrt() {
let mut g = ExprGraph::new();
let x = g.var(0);
let sq = g.sqrt(x);
let result: f64 = g.eval(sq, &[9.0]);
assert!((result - 3.0).abs() < 1e-10);
}
#[test]
fn eval_sin() {
let mut g = ExprGraph::new();
let x = g.var(0);
let s = g.sin(x);
let result: f64 = g.eval(s, &[std::f64::consts::FRAC_PI_2]);
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn eval_select_positive_cond() {
let mut g = ExprGraph::new();
let cond = g.lit(1.0);
let a = g.lit(3.0);
let b = g.lit(7.0);
let s = g.select(cond, a, b);
let result: f64 = g.eval(s, &[]);
assert!((result - 3.0).abs() < 1e-10);
}
#[test]
fn eval_select_negative_cond() {
let mut g = ExprGraph::new();
let cond = g.lit(-1.0);
let a = g.lit(3.0);
let b = g.lit(7.0);
let s = g.select(cond, a, b);
let result: f64 = g.eval(s, &[]);
assert!((result - 7.0).abs() < 1e-10);
}
#[test]
fn eval_select_zero_cond() {
let mut g = ExprGraph::new();
let cond = g.lit(0.0);
let a = g.lit(3.0);
let b = g.lit(7.0);
let s = g.select(cond, a, b);
let result: f64 = g.eval(s, &[]);
assert!((result - 7.0).abs() < 1e-10);
}
#[test]
fn eval_many_outputs() {
let mut g = ExprGraph::new();
let x = g.var(0);
let y = g.var(1);
let sum = g.add(x, y);
let prod = g.mul(x, y);
let results: Vec<f64> = g.eval_many(&[sum, prod], &[3.0, 4.0]);
assert!((results[0] - 7.0).abs() < 1e-10);
assert!((results[1] - 12.0).abs() < 1e-10);
}
}