use std::collections::HashMap;
use super::graph::ExprGraph;
use super::node::{ExprId, Node};
impl ExprGraph {
pub fn diff(&mut self, expr: ExprId, var: u16) -> ExprId {
let mut memo = HashMap::new();
self.diff_inner(expr, var, &mut memo)
}
fn diff_inner(
&mut self,
expr: ExprId,
var: u16,
memo: &mut HashMap<(ExprId, u16), ExprId>,
) -> ExprId {
if let Some(&cached) = memo.get(&(expr, var)) {
return cached;
}
let result = match self.node(expr) {
Node::Var(n) => {
if n == var {
ExprId::ONE
} else {
ExprId::ZERO
}
}
Node::Lit(_) => ExprId::ZERO,
Node::Add(a, b) => {
let da = self.diff_inner(a, var, memo);
let db = self.diff_inner(b, var, memo);
self.add(da, db)
}
Node::Mul(a, b) => {
let da = self.diff_inner(a, var, memo);
let db = self.diff_inner(b, var, memo);
let t1 = self.mul(da, b);
let t2 = self.mul(a, db);
self.add(t1, t2)
}
Node::Neg(a) => {
let da = self.diff_inner(a, var, memo);
self.neg(da)
}
Node::Recip(a) => {
let da = self.diff_inner(a, var, memo);
let a_sq = self.mul(a, a);
let r = self.recip(a_sq);
let t = self.mul(da, r);
self.neg(t)
}
Node::Sqrt(a) => {
let da = self.diff_inner(a, var, memo);
let sq = self.sqrt(a);
let two_sq = self.mul(ExprId::TWO, sq);
let r = self.recip(two_sq);
self.mul(da, r)
}
Node::Sin(a) => {
let da = self.diff_inner(a, var, memo);
let half_pi = self.lit(std::f64::consts::FRAC_PI_2);
let shifted = self.add(a, half_pi);
let cos_a = self.sin(shifted);
self.mul(cos_a, da)
}
Node::Atan2(y, x) => {
let dy = self.diff_inner(y, var, memo);
let dx = self.diff_inner(x, var, memo);
let x_dy = self.mul(x, dy);
let y_dx = self.mul(y, dx);
let neg_y_dx = self.neg(y_dx);
let numer = self.add(x_dy, neg_y_dx);
let xx = self.mul(x, x);
let yy = self.mul(y, y);
let denom = self.add(xx, yy);
let r = self.recip(denom);
self.mul(numer, r)
}
Node::Exp2(a) => {
let da = self.diff_inner(a, var, memo);
let ln2 = self.lit(std::f64::consts::LN_2);
let exp2_a = self.exp2(a);
let t = self.mul(ln2, exp2_a);
self.mul(t, da)
}
Node::Log2(a) => {
let da = self.diff_inner(a, var, memo);
let ln2 = self.lit(std::f64::consts::LN_2);
let ln2_a = self.mul(ln2, a);
let r = self.recip(ln2_a);
self.mul(da, r)
}
Node::Select(c, a, b) => {
let da = self.diff_inner(a, var, memo);
let db = self.diff_inner(b, var, memo);
self.select(c, da, db)
}
};
memo.insert((expr, var), result);
result
}
}
#[cfg(test)]
mod tests {
use super::graph::ExprGraph;
use super::node::ExprId;
#[test]
fn diff_constant() {
let mut g = ExprGraph::new();
let c = g.lit(5.0);
let dc = g.diff(c, 0);
assert_eq!(dc, ExprId::ZERO);
}
#[test]
fn diff_var_self() {
let mut g = ExprGraph::new();
let x = g.var(0);
let dx = g.diff(x, 0);
assert_eq!(dx, ExprId::ONE);
}
#[test]
fn diff_var_other() {
let mut g = ExprGraph::new();
let x = g.var(0);
let dx = g.diff(x, 1);
assert_eq!(dx, ExprId::ZERO);
}
#[test]
fn diff_add() {
let mut g = ExprGraph::new();
let x = g.var(0);
let c = g.lit(3.0);
let sum = g.add(x, c);
let d = g.diff(sum, 0);
let result: f64 = g.eval(d, &[99.0]); assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn diff_mul_product_rule() {
let mut g = ExprGraph::new();
let x = g.var(0);
let xx = g.mul(x, x);
let d = g.diff(xx, 0);
let result: f64 = g.eval(d, &[3.0]);
assert!((result - 6.0).abs() < 1e-10);
}
#[test]
fn diff_sin() {
let mut g = ExprGraph::new();
let x = g.var(0);
let s = g.sin(x);
let ds = g.diff(s, 0);
let result: f64 = g.eval(ds, &[0.0]);
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn diff_chain_rule() {
let mut g = ExprGraph::new();
let x = g.var(0);
let xx = g.mul(x, x);
let s = g.sin(xx);
let ds = g.diff(s, 0);
let expected = 2.0 * 1.0_f64.cos();
let result: f64 = g.eval(ds, &[1.0]);
assert!((result - expected).abs() < 1e-10);
}
#[test]
fn diff_sqrt() {
let mut g = ExprGraph::new();
let x = g.var(0);
let sq = g.sqrt(x);
let d = g.diff(sq, 0);
let result: f64 = g.eval(d, &[4.0]);
assert!((result - 0.25).abs() < 1e-10);
}
#[test]
fn diff_recip() {
let mut g = ExprGraph::new();
let x = g.var(0);
let r = g.recip(x);
let d = g.diff(r, 0);
let result: f64 = g.eval(d, &[2.0]);
assert!((result - (-0.25)).abs() < 1e-10);
}
#[test]
fn diff_memoization() {
let mut g = ExprGraph::new();
let x = g.var(0);
let xx = g.mul(x, x);
let sum = g.add(xx, xx);
let d = g.diff(sum, 0);
let result: f64 = g.eval(d, &[3.0]);
assert!((result - 12.0).abs() < 1e-10);
}
#[test]
fn diff_select() {
let mut g = ExprGraph::new();
let x = g.var(0);
let xx = g.mul(x, x);
let xp1 = g.add(x, ExprId::ONE);
let s = g.select(x, xx, xp1);
let ds = g.diff(s, 0);
let result: f64 = g.eval(ds, &[2.0]);
assert!((result - 4.0).abs() < 1e-10);
let result2: f64 = g.eval(ds, &[-1.0]);
assert!((result2 - 1.0).abs() < 1e-10);
}
#[test]
fn diff_dot_product() {
let mut g = ExprGraph::new();
let x0 = g.var(0);
let x1 = g.var(1);
let x2 = g.var(2);
let x3 = g.var(3);
let x4 = g.var(4);
let x5 = g.var(5);
let t0 = g.mul(x0, x3);
let t1 = g.mul(x1, x4);
let t2 = g.mul(x2, x5);
let s01 = g.add(t0, t1);
let dot = g.add(s01, t2);
let d0 = g.diff(dot, 0);
let result: f64 = g.eval(d0, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert!((result - 4.0).abs() < 1e-10);
}
}