use super::expr::{BinaryOp, ExprNode, NaryOp, UnaryOp};
#[derive(Debug, Clone)]
pub enum TapeOp {
Const(f64),
Var(usize),
Add(usize, usize),
Sub(usize, usize),
Mul(usize, usize),
Div(usize, usize),
Pow(usize, usize),
Mod(usize, usize),
Atan2(usize, usize),
Less(usize, usize),
IntDiv(usize, usize),
Neg(usize),
Abs(usize),
Floor(usize),
Ceil(usize),
Sqrt(usize),
Exp(usize),
Log(usize),
Log10(usize),
Sin(usize),
Cos(usize),
Tan(usize),
Asin(usize),
Acos(usize),
Atan(usize),
Sinh(usize),
Cosh(usize),
Tanh(usize),
Asinh(usize),
Acosh(usize),
Atanh(usize),
}
#[derive(Debug, Clone)]
pub struct Tape {
pub ops: Vec<TapeOp>,
pub n_vars: usize,
}
pub struct CommonExprCache {
entries: Vec<Option<(Vec<TapeOp>, usize)>>,
}
impl CommonExprCache {
pub fn build(common_exprs: &[ExprNode], n_vars: usize) -> Self {
let mut entries: Vec<Option<(Vec<TapeOp>, usize)>> = Vec::with_capacity(common_exprs.len());
for i in 0..common_exprs.len() {
let mut ops = Vec::new();
let result_idx = build_recursive_cached(&common_exprs[i], common_exprs, n_vars, &mut ops, &entries);
entries.push(Some((ops, result_idx)));
}
CommonExprCache { entries }
}
}
impl Tape {
pub fn build(expr: &ExprNode, common_exprs: &[ExprNode], n_vars: usize) -> Self {
let mut ops = Vec::new();
build_recursive(expr, common_exprs, n_vars, &mut ops);
Tape { ops, n_vars }
}
pub fn build_cached(expr: &ExprNode, common_exprs: &[ExprNode], n_vars: usize, cache: &CommonExprCache) -> Self {
let mut ops = Vec::new();
build_recursive_cached(expr, common_exprs, n_vars, &mut ops, &cache.entries);
Tape { ops, n_vars }
}
pub fn forward(&self, x: &[f64]) -> Vec<f64> {
let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
for op in &self.ops {
let v = match op {
TapeOp::Const(c) => *c,
TapeOp::Var(i) => x[*i],
TapeOp::Add(a, b) => vals[*a] + vals[*b],
TapeOp::Sub(a, b) => vals[*a] - vals[*b],
TapeOp::Mul(a, b) => vals[*a] * vals[*b],
TapeOp::Div(a, b) => vals[*a] / vals[*b],
TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
TapeOp::Mod(a, b) => vals[*a] % vals[*b],
TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
TapeOp::Less(a, b) => {
if vals[*a] < vals[*b] {
vals[*a]
} else {
vals[*b]
}
}
TapeOp::IntDiv(a, b) => (vals[*a] / vals[*b]).floor(),
TapeOp::Neg(a) => -vals[*a],
TapeOp::Abs(a) => vals[*a].abs(),
TapeOp::Floor(a) => vals[*a].floor(),
TapeOp::Ceil(a) => vals[*a].ceil(),
TapeOp::Sqrt(a) => vals[*a].sqrt(),
TapeOp::Exp(a) => vals[*a].exp(),
TapeOp::Log(a) => vals[*a].ln(),
TapeOp::Log10(a) => vals[*a].log10(),
TapeOp::Sin(a) => vals[*a].sin(),
TapeOp::Cos(a) => vals[*a].cos(),
TapeOp::Tan(a) => vals[*a].tan(),
TapeOp::Asin(a) => vals[*a].asin(),
TapeOp::Acos(a) => vals[*a].acos(),
TapeOp::Atan(a) => vals[*a].atan(),
TapeOp::Sinh(a) => vals[*a].sinh(),
TapeOp::Cosh(a) => vals[*a].cosh(),
TapeOp::Tanh(a) => vals[*a].tanh(),
TapeOp::Asinh(a) => vals[*a].asinh(),
TapeOp::Acosh(a) => vals[*a].acosh(),
TapeOp::Atanh(a) => vals[*a].atanh(),
};
vals.push(v);
}
vals
}
pub fn eval(&self, x: &[f64]) -> f64 {
let vals = self.forward(x);
*vals.last().unwrap_or(&0.0)
}
pub fn gradient(&self, x: &[f64], grad: &mut [f64]) {
let vals = self.forward(x);
self.reverse(&vals, grad);
}
pub fn variables(&self) -> Vec<usize> {
let mut vars: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
for op in &self.ops {
if let TapeOp::Var(j) = op {
vars.insert(*j);
}
}
vars.into_iter().collect()
}
pub fn reverse(&self, vals: &[f64], grad: &mut [f64]) {
let n = self.ops.len();
if n == 0 {
return;
}
let mut adj = vec![0.0f64; n];
adj[n - 1] = 1.0;
for i in (0..n).rev() {
let a = adj[i];
if a == 0.0 {
continue;
}
match &self.ops[i] {
TapeOp::Const(_) => {}
TapeOp::Var(j) => {
if *j < grad.len() {
grad[*j] += a;
}
}
TapeOp::Add(l, r) => {
adj[*l] += a;
adj[*r] += a;
}
TapeOp::Sub(l, r) => {
adj[*l] += a;
adj[*r] -= a;
}
TapeOp::Mul(l, r) => {
adj[*l] += a * vals[*r];
adj[*r] += a * vals[*l];
}
TapeOp::Div(l, r) => {
let rv = vals[*r];
adj[*l] += a / rv;
adj[*r] -= a * vals[*l] / (rv * rv);
}
TapeOp::Pow(l, r) => {
let lv = vals[*l];
let rv = vals[*r];
if rv != 0.0 {
adj[*l] += a * rv * lv.powf(rv - 1.0);
}
if lv > 0.0 {
adj[*r] += a * vals[i] * lv.ln();
}
}
TapeOp::Mod(l, _r) => {
adj[*l] += a;
}
TapeOp::Atan2(l, r) => {
let lv = vals[*l];
let rv = vals[*r];
let denom = lv * lv + rv * rv;
if denom > 0.0 {
adj[*l] += a * rv / denom;
adj[*r] -= a * lv / denom;
}
}
TapeOp::Less(l, r) => {
if vals[*l] < vals[*r] {
adj[*l] += a;
} else {
adj[*r] += a;
}
}
TapeOp::IntDiv(l, _r) => {
adj[*l] += a;
}
TapeOp::Neg(j) => {
adj[*j] -= a;
}
TapeOp::Abs(j) => {
if vals[*j] >= 0.0 {
adj[*j] += a;
} else {
adj[*j] -= a;
}
}
TapeOp::Floor(_) | TapeOp::Ceil(_) => {
}
TapeOp::Sqrt(j) => {
let sv = vals[i];
if sv > 0.0 {
adj[*j] += a * 0.5 / sv;
}
}
TapeOp::Exp(j) => {
adj[*j] += a * vals[i]; }
TapeOp::Log(j) => {
adj[*j] += a / vals[*j];
}
TapeOp::Log10(j) => {
adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
}
TapeOp::Sin(j) => {
adj[*j] += a * vals[*j].cos();
}
TapeOp::Cos(j) => {
adj[*j] -= a * vals[*j].sin();
}
TapeOp::Tan(j) => {
let c = vals[*j].cos();
adj[*j] += a / (c * c);
}
TapeOp::Asin(j) => {
adj[*j] += a / (1.0 - vals[*j] * vals[*j]).sqrt();
}
TapeOp::Acos(j) => {
adj[*j] -= a / (1.0 - vals[*j] * vals[*j]).sqrt();
}
TapeOp::Atan(j) => {
adj[*j] += a / (1.0 + vals[*j] * vals[*j]);
}
TapeOp::Sinh(j) => {
adj[*j] += a * vals[*j].cosh();
}
TapeOp::Cosh(j) => {
adj[*j] += a * vals[*j].sinh();
}
TapeOp::Tanh(j) => {
let tv = vals[i];
adj[*j] += a * (1.0 - tv * tv);
}
TapeOp::Asinh(j) => {
adj[*j] += a / (vals[*j] * vals[*j] + 1.0).sqrt();
}
TapeOp::Acosh(j) => {
adj[*j] += a / (vals[*j] * vals[*j] - 1.0).sqrt();
}
TapeOp::Atanh(j) => {
adj[*j] += a / (1.0 - vals[*j] * vals[*j]);
}
}
}
}
fn forward_tangent(&self, vals: &[f64], seed_var: usize) -> Vec<f64> {
let n = self.ops.len();
let mut dot = vec![0.0f64; n];
for i in 0..n {
dot[i] = match &self.ops[i] {
TapeOp::Const(_) => 0.0,
TapeOp::Var(k) => if *k == seed_var { 1.0 } else { 0.0 },
TapeOp::Add(a, b) => dot[*a] + dot[*b],
TapeOp::Sub(a, b) => dot[*a] - dot[*b],
TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
TapeOp::Div(a, b) => {
let vb = vals[*b];
(dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
}
TapeOp::Pow(a, b) => {
let u = vals[*a];
let r = vals[*b];
let du = dot[*a];
let dr = dot[*b];
let mut result = 0.0;
if r != 0.0 && u != 0.0 {
result += r * u.powf(r - 1.0) * du;
}
if u > 0.0 {
result += vals[i] * u.ln() * dr;
}
result
}
TapeOp::Mod(a, _) => dot[*a],
TapeOp::Atan2(a, b) => {
let y = vals[*a]; let xv = vals[*b];
let d = y * y + xv * xv;
if d > 0.0 { (dot[*a] * xv - y * dot[*b]) / d } else { 0.0 }
}
TapeOp::Less(a, b) => if vals[*a] < vals[*b] { dot[*a] } else { dot[*b] },
TapeOp::IntDiv(a, _) => dot[*a],
TapeOp::Neg(a) => -dot[*a],
TapeOp::Abs(a) => if vals[*a] >= 0.0 { dot[*a] } else { -dot[*a] },
TapeOp::Floor(_) | TapeOp::Ceil(_) => 0.0,
TapeOp::Sqrt(a) => {
let sv = vals[i];
if sv > 0.0 { dot[*a] * 0.5 / sv } else { 0.0 }
}
TapeOp::Exp(a) => dot[*a] * vals[i],
TapeOp::Log(a) => dot[*a] / vals[*a],
TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
TapeOp::Tan(a) => { let c = vals[*a].cos(); dot[*a] / (c * c) }
TapeOp::Asin(a) => dot[*a] / (1.0 - vals[*a] * vals[*a]).sqrt(),
TapeOp::Acos(a) => -dot[*a] / (1.0 - vals[*a] * vals[*a]).sqrt(),
TapeOp::Atan(a) => dot[*a] / (1.0 + vals[*a] * vals[*a]),
TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
TapeOp::Tanh(a) => { let tv = vals[i]; dot[*a] * (1.0 - tv * tv) }
TapeOp::Asinh(a) => dot[*a] / (vals[*a] * vals[*a] + 1.0).sqrt(),
TapeOp::Acosh(a) => dot[*a] / (vals[*a] * vals[*a] - 1.0).sqrt(),
TapeOp::Atanh(a) => dot[*a] / (1.0 - vals[*a] * vals[*a]),
};
}
dot
}
pub fn hessian_accumulate(
&self,
x: &[f64],
weight: f64,
hess_map: &std::collections::HashMap<(usize, usize), usize>,
vals: &mut [f64],
) {
let n = self.ops.len();
if n == 0 || weight == 0.0 {
return;
}
let v = self.forward(x);
let var_indices = self.variables();
for &j in &var_indices {
let dot = self.forward_tangent(&v, j);
let mut adj = vec![0.0f64; n];
let mut adj_dot = vec![0.0f64; n];
adj[n - 1] = 1.0;
for i in (0..n).rev() {
let w = adj[i];
let wd = adj_dot[i];
if w == 0.0 && wd == 0.0 {
continue;
}
match &self.ops[i] {
TapeOp::Const(_) => {}
TapeOp::Var(k) => {
if wd != 0.0 && *k >= j {
if let Some(&pos) = hess_map.get(&(*k, j)) {
vals[pos] += weight * wd;
}
}
}
TapeOp::Add(a, b) => {
adj[*a] += w; adj[*b] += w;
adj_dot[*a] += wd; adj_dot[*b] += wd;
}
TapeOp::Sub(a, b) => {
adj[*a] += w; adj[*b] -= w;
adj_dot[*a] += wd; adj_dot[*b] -= wd;
}
TapeOp::Mul(a, b) => {
adj[*a] += w * v[*b];
adj[*b] += w * v[*a];
adj_dot[*a] += wd * v[*b] + w * dot[*b];
adj_dot[*b] += wd * v[*a] + w * dot[*a];
}
TapeOp::Div(a, b) => {
let vb = v[*b];
let vb2 = vb * vb;
let vb3 = vb2 * vb;
adj[*a] += w / vb;
adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
adj[*b] += w * (-v[*a] / vb2);
adj_dot[*b] += wd * (-v[*a] / vb2)
+ w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
}
TapeOp::Pow(a, b) => {
let u = v[*a];
let r = v[*b];
let du = dot[*a];
let dr = dot[*b];
if r != 0.0 {
if u != 0.0 {
let p_a = r * u.powf(r - 1.0);
adj[*a] += w * p_a;
let mut dp_a = dr * u.powf(r - 1.0);
if u > 0.0 {
dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
} else {
dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
}
adj_dot[*a] += wd * p_a + w * dp_a;
} else if r >= 2.0 {
let p_a = 0.0; adj[*a] += w * p_a;
let dp_a = if r == 2.0 {
2.0 * du } else {
r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
};
adj_dot[*a] += wd * p_a + w * dp_a;
}
}
if u > 0.0 {
let ln_u = u.ln();
let p_b = v[i] * ln_u;
adj[*b] += w * p_b;
let dur = v[i] * (r * du / u + dr * ln_u);
let dp_b = dur * ln_u + v[i] * du / u;
adj_dot[*b] += wd * p_b + w * dp_b;
}
}
TapeOp::Atan2(a, b) => {
let y = v[*a]; let xv = v[*b];
let d = y * y + xv * xv;
if d > 0.0 {
let d2 = d * d;
let dy = dot[*a]; let dx = dot[*b];
let dd = 2.0 * y * dy + 2.0 * xv * dx;
adj[*a] += w * xv / d;
let dp_a = (dx * d - xv * dd) / d2;
adj_dot[*a] += wd * xv / d + w * dp_a;
adj[*b] += w * (-y / d);
let dp_b = (-dy * d + y * dd) / d2;
adj_dot[*b] += wd * (-y / d) + w * dp_b;
}
}
TapeOp::Mod(a, _) | TapeOp::IntDiv(a, _) => {
adj[*a] += w;
adj_dot[*a] += wd;
}
TapeOp::Less(a, b) => {
if v[*a] < v[*b] {
adj[*a] += w; adj_dot[*a] += wd;
} else {
adj[*b] += w; adj_dot[*b] += wd;
}
}
TapeOp::Neg(a) => {
adj[*a] -= w;
adj_dot[*a] -= wd;
}
TapeOp::Abs(a) => {
let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
adj[*a] += w * s;
adj_dot[*a] += wd * s; }
TapeOp::Floor(_) | TapeOp::Ceil(_) => {} TapeOp::Sqrt(a) => {
let sv = v[i];
if sv > 0.0 {
let fp = 0.5 / sv; let fpp = -0.25 / (v[*a] * sv); adj[*a] += w * fp;
adj_dot[*a] += wd * fp + w * fpp * dot[*a];
}
}
TapeOp::Exp(a) => {
let ev = v[i]; adj[*a] += w * ev;
adj_dot[*a] += wd * ev + w * ev * dot[*a]; }
TapeOp::Log(a) => {
let u = v[*a];
adj[*a] += w / u;
adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
}
TapeOp::Log10(a) => {
let u = v[*a];
let ln10 = std::f64::consts::LN_10;
adj[*a] += w / (u * ln10);
adj_dot[*a] += wd / (u * ln10) + w * (-1.0 / (u * u * ln10)) * dot[*a];
}
TapeOp::Sin(a) => {
let u = v[*a];
let cu = u.cos();
adj[*a] += w * cu;
adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
}
TapeOp::Cos(a) => {
let u = v[*a];
let su = u.sin();
adj[*a] -= w * su;
adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
}
TapeOp::Tan(a) => {
let u = v[*a];
let c = u.cos();
let sec2 = 1.0 / (c * c);
let t = u.tan();
adj[*a] += w * sec2;
adj_dot[*a] += wd * sec2 + w * 2.0 * t * sec2 * dot[*a];
}
TapeOp::Asin(a) => {
let u = v[*a];
let s = (1.0 - u * u).sqrt();
adj[*a] += w / s;
adj_dot[*a] += wd / s + w * (u / (s * s * s)) * dot[*a];
}
TapeOp::Acos(a) => {
let u = v[*a];
let s = (1.0 - u * u).sqrt();
adj[*a] -= w / s;
adj_dot[*a] += wd * (-1.0 / s) + w * (-u / (s * s * s)) * dot[*a];
}
TapeOp::Atan(a) => {
let u = v[*a];
let d = 1.0 + u * u;
adj[*a] += w / d;
adj_dot[*a] += wd / d + w * (-2.0 * u / (d * d)) * dot[*a];
}
TapeOp::Sinh(a) => {
let u = v[*a];
let ch = u.cosh();
adj[*a] += w * ch;
adj_dot[*a] += wd * ch + w * u.sinh() * dot[*a];
}
TapeOp::Cosh(a) => {
let u = v[*a];
let sh = u.sinh();
adj[*a] += w * sh;
adj_dot[*a] += wd * sh + w * u.cosh() * dot[*a];
}
TapeOp::Tanh(a) => {
let tv = v[i]; let sech2 = 1.0 - tv * tv;
adj[*a] += w * sech2;
adj_dot[*a] += wd * sech2 + w * (-2.0 * tv * sech2) * dot[*a];
}
TapeOp::Asinh(a) => {
let u = v[*a];
let s = (u * u + 1.0).sqrt();
adj[*a] += w / s;
adj_dot[*a] += wd / s + w * (-u / (s * s * s)) * dot[*a];
}
TapeOp::Acosh(a) => {
let u = v[*a];
let s = (u * u - 1.0).sqrt();
adj[*a] += w / s;
adj_dot[*a] += wd / s + w * (-u / (s * s * s)) * dot[*a];
}
TapeOp::Atanh(a) => {
let u = v[*a];
let d = 1.0 - u * u;
adj[*a] += w / d;
adj_dot[*a] += wd / d + w * (2.0 * u / (d * d)) * dot[*a];
}
}
}
}
}
pub fn hessian_sparsity(&self) -> std::collections::BTreeSet<(usize, usize)> {
use std::collections::BTreeSet;
let n = self.ops.len();
let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
let mut hess_pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
let emit_cross = |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
for &v1 in s1 {
for &v2 in s2 {
let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
pairs.insert((r, c));
}
}
};
let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
let vars: Vec<usize> = s.iter().copied().collect();
for (ai, &vi) in vars.iter().enumerate() {
for &vj in &vars[..=ai] {
let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
pairs.insert((r, c));
}
}
};
for op in &self.ops {
let vset = match op {
TapeOp::Const(_) => BTreeSet::new(),
TapeOp::Var(j) => {
let mut s = BTreeSet::new();
s.insert(*j);
s
}
TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
TapeOp::Floor(_) | TapeOp::Ceil(_) => BTreeSet::new(),
TapeOp::Mul(a, b) => {
emit_cross(&var_sets[*a], &var_sets[*b], &mut hess_pairs);
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Div(a, b) => {
emit_cross(&var_sets[*a], &var_sets[*b], &mut hess_pairs);
emit_self(&var_sets[*b], &mut hess_pairs);
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Pow(a, b) | TapeOp::Atan2(a, b) => {
let combined: BTreeSet<usize> = var_sets[*a].union(&var_sets[*b]).copied().collect();
emit_self(&combined, &mut hess_pairs);
combined
}
TapeOp::Mod(a, b) | TapeOp::IntDiv(a, b) | TapeOp::Less(a, b) => {
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Sqrt(a) | TapeOp::Exp(a) | TapeOp::Log(a) | TapeOp::Log10(a)
| TapeOp::Sin(a) | TapeOp::Cos(a) | TapeOp::Tan(a)
| TapeOp::Asin(a) | TapeOp::Acos(a) | TapeOp::Atan(a)
| TapeOp::Sinh(a) | TapeOp::Cosh(a) | TapeOp::Tanh(a)
| TapeOp::Asinh(a) | TapeOp::Acosh(a) | TapeOp::Atanh(a) => {
emit_self(&var_sets[*a], &mut hess_pairs);
var_sets[*a].clone()
}
};
var_sets.push(vset);
}
hess_pairs
}
}
fn build_recursive(
expr: &ExprNode,
common_exprs: &[ExprNode],
n_vars: usize,
ops: &mut Vec<TapeOp>,
) -> usize {
match expr {
ExprNode::Const(c) => {
let idx = ops.len();
ops.push(TapeOp::Const(*c));
idx
}
ExprNode::Var(i) => {
if *i < n_vars {
let idx = ops.len();
ops.push(TapeOp::Var(*i));
idx
} else {
let ce_idx = *i - n_vars;
if ce_idx < common_exprs.len() {
build_recursive(&common_exprs[ce_idx], common_exprs, n_vars, ops)
} else {
let idx = ops.len();
ops.push(TapeOp::Const(0.0));
idx
}
}
}
ExprNode::Binary(op, left, right) => {
let l = build_recursive(left, common_exprs, n_vars, ops);
let r = build_recursive(right, common_exprs, n_vars, ops);
let idx = ops.len();
ops.push(match op {
BinaryOp::Add => TapeOp::Add(l, r),
BinaryOp::Sub => TapeOp::Sub(l, r),
BinaryOp::Mul => TapeOp::Mul(l, r),
BinaryOp::Div => TapeOp::Div(l, r),
BinaryOp::Mod => TapeOp::Mod(l, r),
BinaryOp::Pow => TapeOp::Pow(l, r),
BinaryOp::Atan2 => TapeOp::Atan2(l, r),
BinaryOp::Less => TapeOp::Less(l, r),
BinaryOp::IntDiv => TapeOp::IntDiv(l, r),
});
idx
}
ExprNode::Unary(op, arg) => {
let a = build_recursive(arg, common_exprs, n_vars, ops);
let idx = ops.len();
ops.push(match op {
UnaryOp::Abs => TapeOp::Abs(a),
UnaryOp::Neg => TapeOp::Neg(a),
UnaryOp::Floor => TapeOp::Floor(a),
UnaryOp::Ceil => TapeOp::Ceil(a),
UnaryOp::Tanh => TapeOp::Tanh(a),
UnaryOp::Tan => TapeOp::Tan(a),
UnaryOp::Sqrt => TapeOp::Sqrt(a),
UnaryOp::Sinh => TapeOp::Sinh(a),
UnaryOp::Sin => TapeOp::Sin(a),
UnaryOp::Log10 => TapeOp::Log10(a),
UnaryOp::Log => TapeOp::Log(a),
UnaryOp::Exp => TapeOp::Exp(a),
UnaryOp::Cosh => TapeOp::Cosh(a),
UnaryOp::Cos => TapeOp::Cos(a),
UnaryOp::Atanh => TapeOp::Atanh(a),
UnaryOp::Atan => TapeOp::Atan(a),
UnaryOp::Asinh => TapeOp::Asinh(a),
UnaryOp::Asin => TapeOp::Asin(a),
UnaryOp::Acosh => TapeOp::Acosh(a),
UnaryOp::Acos => TapeOp::Acos(a),
});
idx
}
ExprNode::Nary(op, args) => {
if args.is_empty() {
let idx = ops.len();
ops.push(TapeOp::Const(match op {
NaryOp::Sum => 0.0,
NaryOp::Min => f64::INFINITY,
NaryOp::Max => f64::NEG_INFINITY,
}));
return idx;
}
let mut acc = build_recursive(&args[0], common_exprs, n_vars, ops);
for arg in &args[1..] {
let next = build_recursive(arg, common_exprs, n_vars, ops);
match op {
NaryOp::Sum => {
let idx = ops.len();
ops.push(TapeOp::Add(acc, next));
acc = idx;
}
NaryOp::Min => {
let idx = ops.len();
ops.push(TapeOp::Less(acc, next));
acc = idx;
}
NaryOp::Max => {
let neg_acc_idx = ops.len();
ops.push(TapeOp::Neg(acc));
let neg_next_idx = ops.len();
ops.push(TapeOp::Neg(next));
let min_idx = ops.len();
ops.push(TapeOp::Less(neg_acc_idx, neg_next_idx));
let result_idx = ops.len();
ops.push(TapeOp::Neg(min_idx));
acc = result_idx;
}
}
}
acc
}
ExprNode::If(cond, then_expr, else_expr) => {
let _c = build_recursive(cond, common_exprs, n_vars, ops);
let t = build_recursive(then_expr, common_exprs, n_vars, ops);
let _e = build_recursive(else_expr, common_exprs, n_vars, ops);
t
}
ExprNode::StringLiteral(_) => {
let _idx = ops.len();
ops.push(TapeOp::Const(0.0));
_idx
}
}
}
fn remap_op(op: &TapeOp, offset: usize) -> TapeOp {
match op {
TapeOp::Const(c) => TapeOp::Const(*c),
TapeOp::Var(i) => TapeOp::Var(*i), TapeOp::Add(a, b) => TapeOp::Add(a + offset, b + offset),
TapeOp::Sub(a, b) => TapeOp::Sub(a + offset, b + offset),
TapeOp::Mul(a, b) => TapeOp::Mul(a + offset, b + offset),
TapeOp::Div(a, b) => TapeOp::Div(a + offset, b + offset),
TapeOp::Pow(a, b) => TapeOp::Pow(a + offset, b + offset),
TapeOp::Mod(a, b) => TapeOp::Mod(a + offset, b + offset),
TapeOp::Atan2(a, b) => TapeOp::Atan2(a + offset, b + offset),
TapeOp::Less(a, b) => TapeOp::Less(a + offset, b + offset),
TapeOp::IntDiv(a, b) => TapeOp::IntDiv(a + offset, b + offset),
TapeOp::Neg(a) => TapeOp::Neg(a + offset),
TapeOp::Abs(a) => TapeOp::Abs(a + offset),
TapeOp::Floor(a) => TapeOp::Floor(a + offset),
TapeOp::Ceil(a) => TapeOp::Ceil(a + offset),
TapeOp::Sqrt(a) => TapeOp::Sqrt(a + offset),
TapeOp::Exp(a) => TapeOp::Exp(a + offset),
TapeOp::Log(a) => TapeOp::Log(a + offset),
TapeOp::Log10(a) => TapeOp::Log10(a + offset),
TapeOp::Sin(a) => TapeOp::Sin(a + offset),
TapeOp::Cos(a) => TapeOp::Cos(a + offset),
TapeOp::Tan(a) => TapeOp::Tan(a + offset),
TapeOp::Asin(a) => TapeOp::Asin(a + offset),
TapeOp::Acos(a) => TapeOp::Acos(a + offset),
TapeOp::Atan(a) => TapeOp::Atan(a + offset),
TapeOp::Sinh(a) => TapeOp::Sinh(a + offset),
TapeOp::Cosh(a) => TapeOp::Cosh(a + offset),
TapeOp::Tanh(a) => TapeOp::Tanh(a + offset),
TapeOp::Asinh(a) => TapeOp::Asinh(a + offset),
TapeOp::Acosh(a) => TapeOp::Acosh(a + offset),
TapeOp::Atanh(a) => TapeOp::Atanh(a + offset),
}
}
fn build_recursive_cached(
expr: &ExprNode,
common_exprs: &[ExprNode],
n_vars: usize,
ops: &mut Vec<TapeOp>,
cache: &[Option<(Vec<TapeOp>, usize)>],
) -> usize {
match expr {
ExprNode::Const(c) => {
let idx = ops.len();
ops.push(TapeOp::Const(*c));
idx
}
ExprNode::Var(i) => {
if *i < n_vars {
let idx = ops.len();
ops.push(TapeOp::Var(*i));
idx
} else {
let ce_idx = *i - n_vars;
if ce_idx < cache.len() {
if let Some((ce_ops, ce_result)) = &cache[ce_idx] {
let offset = ops.len();
for op in ce_ops {
ops.push(remap_op(op, offset));
}
offset + ce_result
} else {
let idx = ops.len();
ops.push(TapeOp::Const(0.0));
idx
}
} else {
let idx = ops.len();
ops.push(TapeOp::Const(0.0));
idx
}
}
}
ExprNode::Binary(op, left, right) => {
let l = build_recursive_cached(left, common_exprs, n_vars, ops, cache);
let r = build_recursive_cached(right, common_exprs, n_vars, ops, cache);
let idx = ops.len();
ops.push(match op {
BinaryOp::Add => TapeOp::Add(l, r),
BinaryOp::Sub => TapeOp::Sub(l, r),
BinaryOp::Mul => TapeOp::Mul(l, r),
BinaryOp::Div => TapeOp::Div(l, r),
BinaryOp::Mod => TapeOp::Mod(l, r),
BinaryOp::Pow => TapeOp::Pow(l, r),
BinaryOp::Atan2 => TapeOp::Atan2(l, r),
BinaryOp::Less => TapeOp::Less(l, r),
BinaryOp::IntDiv => TapeOp::IntDiv(l, r),
});
idx
}
ExprNode::Unary(op, arg) => {
let a = build_recursive_cached(arg, common_exprs, n_vars, ops, cache);
let idx = ops.len();
ops.push(match op {
UnaryOp::Abs => TapeOp::Abs(a),
UnaryOp::Neg => TapeOp::Neg(a),
UnaryOp::Floor => TapeOp::Floor(a),
UnaryOp::Ceil => TapeOp::Ceil(a),
UnaryOp::Tanh => TapeOp::Tanh(a),
UnaryOp::Tan => TapeOp::Tan(a),
UnaryOp::Sqrt => TapeOp::Sqrt(a),
UnaryOp::Sinh => TapeOp::Sinh(a),
UnaryOp::Sin => TapeOp::Sin(a),
UnaryOp::Log10 => TapeOp::Log10(a),
UnaryOp::Log => TapeOp::Log(a),
UnaryOp::Exp => TapeOp::Exp(a),
UnaryOp::Cosh => TapeOp::Cosh(a),
UnaryOp::Cos => TapeOp::Cos(a),
UnaryOp::Atanh => TapeOp::Atanh(a),
UnaryOp::Atan => TapeOp::Atan(a),
UnaryOp::Asinh => TapeOp::Asinh(a),
UnaryOp::Asin => TapeOp::Asin(a),
UnaryOp::Acosh => TapeOp::Acosh(a),
UnaryOp::Acos => TapeOp::Acos(a),
});
idx
}
ExprNode::Nary(op, args) => {
if args.is_empty() {
let idx = ops.len();
ops.push(TapeOp::Const(match op {
NaryOp::Sum => 0.0,
NaryOp::Min => f64::INFINITY,
NaryOp::Max => f64::NEG_INFINITY,
}));
return idx;
}
let mut acc = build_recursive_cached(&args[0], common_exprs, n_vars, ops, cache);
for arg in &args[1..] {
let next = build_recursive_cached(arg, common_exprs, n_vars, ops, cache);
match op {
NaryOp::Sum => {
let idx = ops.len();
ops.push(TapeOp::Add(acc, next));
acc = idx;
}
NaryOp::Min => {
let idx = ops.len();
ops.push(TapeOp::Less(acc, next));
acc = idx;
}
NaryOp::Max => {
let neg_acc_idx = ops.len();
ops.push(TapeOp::Neg(acc));
let neg_next_idx = ops.len();
ops.push(TapeOp::Neg(next));
let min_idx = ops.len();
ops.push(TapeOp::Less(neg_acc_idx, neg_next_idx));
let result_idx = ops.len();
ops.push(TapeOp::Neg(min_idx));
acc = result_idx;
}
}
}
acc
}
ExprNode::If(cond, then_expr, else_expr) => {
let _c = build_recursive_cached(cond, common_exprs, n_vars, ops, cache);
let t = build_recursive_cached(then_expr, common_exprs, n_vars, ops, cache);
let _e = build_recursive_cached(else_expr, common_exprs, n_vars, ops, cache);
t
}
ExprNode::StringLiteral(_) => {
let idx = ops.len();
ops.push(TapeOp::Const(0.0));
idx
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::expr::*;
#[test]
fn tape_build_and_eval_polynomial() {
let expr = ExprNode::Binary(
BinaryOp::Add,
Box::new(ExprNode::Binary(
BinaryOp::Mul,
Box::new(ExprNode::Const(3.0)),
Box::new(ExprNode::Binary(
BinaryOp::Pow,
Box::new(ExprNode::Var(0)),
Box::new(ExprNode::Const(2.0)),
)),
)),
Box::new(ExprNode::Binary(
BinaryOp::Mul,
Box::new(ExprNode::Const(2.0)),
Box::new(ExprNode::Var(1)),
)),
);
let tape = Tape::build(&expr, &[], 2);
let val = tape.eval(&[2.0, 3.0]);
assert!((val - 18.0).abs() < 1e-10);
}
#[test]
fn tape_gradient_polynomial() {
let expr = ExprNode::Binary(
BinaryOp::Add,
Box::new(ExprNode::Binary(
BinaryOp::Mul,
Box::new(ExprNode::Const(3.0)),
Box::new(ExprNode::Binary(
BinaryOp::Pow,
Box::new(ExprNode::Var(0)),
Box::new(ExprNode::Const(2.0)),
)),
)),
Box::new(ExprNode::Binary(
BinaryOp::Mul,
Box::new(ExprNode::Const(2.0)),
Box::new(ExprNode::Var(1)),
)),
);
let tape = Tape::build(&expr, &[], 2);
let mut grad = vec![0.0; 2];
tape.gradient(&[2.0, 3.0], &mut grad);
assert!((grad[0] - 12.0).abs() < 1e-10);
assert!((grad[1] - 2.0).abs() < 1e-10);
}
#[test]
fn tape_gradient_transcendental() {
let expr = ExprNode::Binary(
BinaryOp::Add,
Box::new(ExprNode::Binary(
BinaryOp::Add,
Box::new(ExprNode::Unary(UnaryOp::Exp, Box::new(ExprNode::Var(0)))),
Box::new(ExprNode::Unary(UnaryOp::Sin, Box::new(ExprNode::Var(1)))),
)),
Box::new(ExprNode::Binary(
BinaryOp::Add,
Box::new(ExprNode::Unary(UnaryOp::Log, Box::new(ExprNode::Var(0)))),
Box::new(ExprNode::Unary(UnaryOp::Sqrt, Box::new(ExprNode::Var(1)))),
)),
);
let tape = Tape::build(&expr, &[], 2);
let x = [1.0, 1.0];
let val = tape.eval(&x);
let expected_val = 1.0_f64.exp() + 1.0_f64.sin() + 0.0 + 1.0;
assert!((val - expected_val).abs() < 1e-10);
let mut grad = vec![0.0; 2];
tape.gradient(&x, &mut grad);
let expected_g0 = 1.0_f64.exp() + 1.0;
let expected_g1 = 1.0_f64.cos() + 0.5;
assert!((grad[0] - expected_g0).abs() < 1e-10);
assert!((grad[1] - expected_g1).abs() < 1e-10);
}
#[test]
fn tape_common_expr_inlining() {
let common_exprs = vec![ExprNode::Binary(
BinaryOp::Add,
Box::new(ExprNode::Var(0)),
Box::new(ExprNode::Var(1)),
)];
let expr = ExprNode::Binary(
BinaryOp::Pow,
Box::new(ExprNode::Var(2)),
Box::new(ExprNode::Const(2.0)),
);
let tape = Tape::build(&expr, &common_exprs, 2);
let val = tape.eval(&[3.0, 4.0]);
assert!((val - 49.0).abs() < 1e-10);
let mut grad = vec![0.0; 2];
tape.gradient(&[3.0, 4.0], &mut grad);
assert!((grad[0] - 14.0).abs() < 1e-10);
assert!((grad[1] - 14.0).abs() < 1e-10);
}
#[test]
fn tape_nary_max_min() {
let expr_max = ExprNode::Nary(
NaryOp::Max,
vec![ExprNode::Const(1.0), ExprNode::Var(0), ExprNode::Const(2.0)],
);
let tape_max = Tape::build(&expr_max, &[], 1);
let val_max = tape_max.eval(&[3.0]);
assert!((val_max - 3.0).abs() < 1e-10);
let expr_min = ExprNode::Nary(
NaryOp::Min,
vec![ExprNode::Const(5.0), ExprNode::Var(0), ExprNode::Const(2.0)],
);
let tape_min = Tape::build(&expr_min, &[], 1);
let val_min = tape_min.eval(&[3.0]);
assert!((val_min - 2.0).abs() < 1e-10);
}
fn check_hessian_vs_fd(tape: &Tape, x: &[f64], tol: f64) {
use std::collections::HashMap;
let vars = tape.variables();
let n = x.len();
let mut hess_map = HashMap::new();
let mut idx = 0;
for (ai, &vi) in vars.iter().enumerate() {
for &vj in &vars[..=ai] {
let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
hess_map.entry((r, c)).or_insert_with(|| { let i = idx; idx += 1; i });
}
}
let nnz = idx;
let mut vals_ad = vec![0.0; nnz];
tape.hessian_accumulate(x, 1.0, &hess_map, &mut vals_ad);
let mut vals_fd = vec![0.0; nnz];
let mut x_pert = x.to_vec();
let mut gp = vec![0.0; n];
let mut gm = vec![0.0; n];
for &j in &vars {
let h = (1e-7_f64).max(x[j].abs() * 1e-7);
x_pert[j] = x[j] + h;
gp.iter_mut().for_each(|v| *v = 0.0);
tape.gradient(&x_pert, &mut gp);
x_pert[j] = x[j] - h;
gm.iter_mut().for_each(|v| *v = 0.0);
tape.gradient(&x_pert, &mut gm);
x_pert[j] = x[j];
for &i in &vars {
if i >= j {
if let Some(&pos) = hess_map.get(&(i, j)) {
vals_fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
}
}
}
}
for (&(r, c), &pos) in &hess_map {
let ad = vals_ad[pos];
let fd = vals_fd[pos];
let err = (ad - fd).abs();
let scale = fd.abs().max(1.0);
assert!(
err / scale < tol,
"Hessian mismatch at ({},{}): AD={:.10e}, FD={:.10e}, err={:.2e}",
r, c, ad, fd, err
);
}
}
#[test]
fn hessian_quadratic() {
let expr = ExprNode::Nary(NaryOp::Sum, vec![
ExprNode::Binary(BinaryOp::Mul, Box::new(ExprNode::Const(3.0)),
Box::new(ExprNode::Binary(BinaryOp::Pow, Box::new(ExprNode::Var(0)), Box::new(ExprNode::Const(2.0))))),
ExprNode::Binary(BinaryOp::Mul, Box::new(ExprNode::Const(2.0)),
Box::new(ExprNode::Binary(BinaryOp::Mul, Box::new(ExprNode::Var(0)), Box::new(ExprNode::Var(1))))),
ExprNode::Binary(BinaryOp::Pow, Box::new(ExprNode::Var(1)), Box::new(ExprNode::Const(2.0))),
]);
let tape = Tape::build(&expr, &[], 2);
check_hessian_vs_fd(&tape, &[2.0, 3.0], 1e-5);
}
#[test]
fn hessian_transcendental() {
let expr = ExprNode::Nary(NaryOp::Sum, vec![
ExprNode::Unary(UnaryOp::Exp, Box::new(ExprNode::Var(0))),
ExprNode::Unary(UnaryOp::Sin, Box::new(ExprNode::Var(1))),
ExprNode::Unary(UnaryOp::Log, Box::new(ExprNode::Var(0))),
ExprNode::Unary(UnaryOp::Sqrt, Box::new(ExprNode::Var(1))),
ExprNode::Binary(BinaryOp::Mul, Box::new(ExprNode::Var(0)), Box::new(ExprNode::Var(1))),
]);
let tape = Tape::build(&expr, &[], 2);
check_hessian_vs_fd(&tape, &[1.5, 2.0], 1e-5);
}
#[test]
fn hessian_division_and_trig() {
let expr = ExprNode::Nary(NaryOp::Sum, vec![
ExprNode::Binary(BinaryOp::Div, Box::new(ExprNode::Var(0)), Box::new(ExprNode::Var(1))),
ExprNode::Unary(UnaryOp::Cos, Box::new(ExprNode::Var(0))),
ExprNode::Unary(UnaryOp::Tan, Box::new(ExprNode::Var(1))),
ExprNode::Unary(UnaryOp::Atan, Box::new(ExprNode::Var(0))),
]);
let tape = Tape::build(&expr, &[], 2);
check_hessian_vs_fd(&tape, &[0.5, 1.2], 1e-5);
}
#[test]
fn hessian_hyperbolic() {
let expr = ExprNode::Nary(NaryOp::Sum, vec![
ExprNode::Unary(UnaryOp::Sinh, Box::new(ExprNode::Var(0))),
ExprNode::Unary(UnaryOp::Cosh, Box::new(ExprNode::Var(0))),
ExprNode::Unary(UnaryOp::Tanh, Box::new(ExprNode::Var(0))),
ExprNode::Unary(UnaryOp::Asinh, Box::new(ExprNode::Var(0))),
ExprNode::Unary(UnaryOp::Acosh, Box::new(ExprNode::Binary(
BinaryOp::Add, Box::new(ExprNode::Var(0)), Box::new(ExprNode::Const(2.0))))),
ExprNode::Unary(UnaryOp::Atanh, Box::new(ExprNode::Binary(
BinaryOp::Div, Box::new(ExprNode::Var(0)), Box::new(ExprNode::Const(2.0))))),
]);
let tape = Tape::build(&expr, &[], 1);
check_hessian_vs_fd(&tape, &[0.5], 1e-5);
}
#[test]
fn hessian_rosenbrock() {
let expr = ExprNode::Nary(NaryOp::Sum, vec![
ExprNode::Binary(BinaryOp::Pow,
Box::new(ExprNode::Binary(BinaryOp::Sub,
Box::new(ExprNode::Const(1.0)), Box::new(ExprNode::Var(0)))),
Box::new(ExprNode::Const(2.0))),
ExprNode::Binary(BinaryOp::Mul,
Box::new(ExprNode::Const(100.0)),
Box::new(ExprNode::Binary(BinaryOp::Pow,
Box::new(ExprNode::Binary(BinaryOp::Sub,
Box::new(ExprNode::Var(1)),
Box::new(ExprNode::Binary(BinaryOp::Pow,
Box::new(ExprNode::Var(0)),
Box::new(ExprNode::Const(2.0)))))),
Box::new(ExprNode::Const(2.0))))),
]);
let tape = Tape::build(&expr, &[], 2);
check_hessian_vs_fd(&tape, &[1.0, 1.0], 1e-5);
check_hessian_vs_fd(&tape, &[-1.5, 2.3], 1e-5);
}
#[test]
fn hessian_sparsity_separable() {
let expr = ExprNode::Nary(NaryOp::Sum, vec![
ExprNode::Unary(UnaryOp::Sin, Box::new(ExprNode::Var(0))),
ExprNode::Binary(BinaryOp::Mul,
Box::new(ExprNode::Var(1)),
Box::new(ExprNode::Var(2))),
]);
let tape = Tape::build(&expr, &[], 3);
let sparsity = tape.hessian_sparsity();
assert!(sparsity.contains(&(0, 0)), "should have (0,0) from sin(x0)");
assert!(sparsity.contains(&(2, 1)), "should have (2,1) from x1*x2");
assert!(!sparsity.contains(&(1, 0)), "should NOT have (1,0) - separable");
assert!(!sparsity.contains(&(2, 0)), "should NOT have (2,0) - separable");
}
#[test]
fn hessian_sparsity_matches_numerical() {
let expr = ExprNode::Nary(NaryOp::Sum, vec![
ExprNode::Unary(UnaryOp::Exp,
Box::new(ExprNode::Binary(BinaryOp::Mul,
Box::new(ExprNode::Var(0)),
Box::new(ExprNode::Var(1))))),
ExprNode::Binary(BinaryOp::Pow,
Box::new(ExprNode::Var(2)),
Box::new(ExprNode::Const(2.0))),
]);
let tape = Tape::build(&expr, &[], 3);
let sparsity = tape.hessian_sparsity();
assert!(sparsity.contains(&(0, 0)));
assert!(sparsity.contains(&(1, 0)));
assert!(sparsity.contains(&(1, 1)));
assert!(sparsity.contains(&(2, 2)));
assert!(!sparsity.contains(&(2, 0)));
assert!(!sparsity.contains(&(2, 1)));
assert_eq!(sparsity.len(), 4);
}
}