use std::collections::{BTreeSet, HashMap, HashSet};
use std::rc::Rc;
use std::sync::Arc;
use super::nl_external::{ExternalArg, ExternalLibrary, ExternalResolver};
use super::nl_reader::{BinOp, Expr, FuncallArg, UnaryOp};
#[derive(Debug, Clone)]
pub enum TapeOp {
Const(f64),
Var(usize),
Add(usize, usize),
Sub(usize, usize),
Mul(usize, usize),
Div(usize, usize),
Pow(usize, usize),
Neg(usize),
Abs(usize),
Sqrt(usize),
Exp(usize),
Log(usize),
Log10(usize),
Sin(usize),
Cos(usize),
Funcall {
lib: Arc<ExternalLibrary>,
name: String,
args: Vec<TapeFuncallArg>,
},
}
#[derive(Debug, Clone)]
pub enum TapeFuncallArg {
Tape(usize),
Str(String),
}
fn funcall_to_ext_args<'a>(args: &'a [TapeFuncallArg], vals: &[f64]) -> Vec<ExternalArg<'a>> {
args.iter()
.map(|a| match a {
TapeFuncallArg::Tape(idx) => ExternalArg::Real(vals[*idx]),
TapeFuncallArg::Str(s) => ExternalArg::Str(s.as_str()),
})
.collect()
}
#[derive(Debug, Clone)]
pub struct Tape {
pub ops: Vec<TapeOp>,
}
impl Tape {
pub fn build(expr: &Expr) -> Self {
Self::build_with_externals(expr, &ExternalResolver::default())
}
pub fn build_with_externals(expr: &Expr, resolver: &ExternalResolver) -> Self {
let mut ops = Vec::new();
let mut cache: HashMap<*const Expr, usize> = HashMap::new();
build_recursive(expr, &mut ops, &mut cache, resolver);
Tape { ops }
}
pub fn forward(&self, x: &[f64]) -> Vec<f64> {
let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
for op in &self.ops {
let v = match op {
TapeOp::Const(c) => *c,
TapeOp::Var(i) => x[*i],
TapeOp::Add(a, b) => vals[*a] + vals[*b],
TapeOp::Sub(a, b) => vals[*a] - vals[*b],
TapeOp::Mul(a, b) => vals[*a] * vals[*b],
TapeOp::Div(a, b) => vals[*a] / vals[*b],
TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
TapeOp::Neg(a) => -vals[*a],
TapeOp::Abs(a) => vals[*a].abs(),
TapeOp::Sqrt(a) => vals[*a].sqrt(),
TapeOp::Exp(a) => vals[*a].exp(),
TapeOp::Log(a) => vals[*a].ln(),
TapeOp::Log10(a) => vals[*a].log10(),
TapeOp::Sin(a) => vals[*a].sin(),
TapeOp::Cos(a) => vals[*a].cos(),
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, &vals);
let res = lib
.eval(name, &call_args, false, false)
.unwrap_or_else(|e| {
panic!("external function '{name}' forward eval failed: {e}")
});
res.value
}
};
vals.push(v);
}
vals
}
pub fn eval(&self, x: &[f64]) -> f64 {
let vals = self.forward(x);
*vals.last().unwrap_or(&0.0)
}
pub fn gradient_seed(&self, x: &[f64], seed: f64, grad: &mut [f64]) {
if seed == 0.0 || self.ops.is_empty() {
return;
}
let vals = self.forward(x);
self.reverse(&vals, seed, grad);
}
fn reverse(&self, vals: &[f64], seed: f64, grad: &mut [f64]) {
let n = self.ops.len();
let mut adj = vec![0.0f64; n];
adj[n - 1] = seed;
for i in (0..n).rev() {
let a = adj[i];
if a == 0.0 {
continue;
}
match &self.ops[i] {
TapeOp::Const(_) => {}
TapeOp::Var(j) => {
grad[*j] += a;
}
TapeOp::Add(l, r) => {
adj[*l] += a;
adj[*r] += a;
}
TapeOp::Sub(l, r) => {
adj[*l] += a;
adj[*r] -= a;
}
TapeOp::Mul(l, r) => {
adj[*l] += a * vals[*r];
adj[*r] += a * vals[*l];
}
TapeOp::Div(l, r) => {
let rv = vals[*r];
adj[*l] += a / rv;
adj[*r] -= a * vals[*l] / (rv * rv);
}
TapeOp::Pow(l, r) => {
let lv = vals[*l];
let rv = vals[*r];
if rv != 0.0 {
adj[*l] += a * rv * lv.powf(rv - 1.0);
}
if lv > 0.0 {
adj[*r] += a * vals[i] * lv.ln();
}
}
TapeOp::Neg(j) => {
adj[*j] -= a;
}
TapeOp::Abs(j) => {
if vals[*j] >= 0.0 {
adj[*j] += a;
} else {
adj[*j] -= a;
}
}
TapeOp::Sqrt(j) => {
let sv = vals[i];
if sv > 0.0 {
adj[*j] += a * 0.5 / sv;
}
}
TapeOp::Exp(j) => {
adj[*j] += a * vals[i];
}
TapeOp::Log(j) => {
adj[*j] += a / vals[*j];
}
TapeOp::Log10(j) => {
adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
}
TapeOp::Sin(j) => {
adj[*j] += a * vals[*j].cos();
}
TapeOp::Cos(j) => {
adj[*j] -= a * vals[*j].sin();
}
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, vals);
let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
panic!("external function '{name}' reverse eval failed: {e}")
});
let derivs = res.derivs.expect("want_derivs=true returns derivs");
let mut k = 0usize;
for arg in args {
if let TapeFuncallArg::Tape(idx) = arg {
adj[*idx] += a * derivs[k];
k += 1;
}
}
}
}
}
}
pub fn variables(&self) -> Vec<usize> {
let mut s: BTreeSet<usize> = BTreeSet::new();
for op in &self.ops {
if let TapeOp::Var(j) = op {
s.insert(*j);
}
}
s.into_iter().collect()
}
fn forward_tangent(&self, vals: &[f64], seed_var: usize, dot: &mut [f64]) {
let n = self.ops.len();
debug_assert_eq!(dot.len(), n);
for i in 0..n {
dot[i] = match &self.ops[i] {
TapeOp::Const(_) => 0.0,
TapeOp::Var(k) => {
if *k == seed_var {
1.0
} else {
0.0
}
}
TapeOp::Add(a, b) => dot[*a] + dot[*b],
TapeOp::Sub(a, b) => dot[*a] - dot[*b],
TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
TapeOp::Div(a, b) => {
let vb = vals[*b];
(dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
}
TapeOp::Pow(a, b) => {
let u = vals[*a];
let r = vals[*b];
let du = dot[*a];
let dr = dot[*b];
let mut result = 0.0;
if r != 0.0 && u != 0.0 {
result += r * u.powf(r - 1.0) * du;
}
if u > 0.0 {
result += vals[i] * u.ln() * dr;
}
result
}
TapeOp::Neg(a) => -dot[*a],
TapeOp::Abs(a) => {
if vals[*a] >= 0.0 {
dot[*a]
} else {
-dot[*a]
}
}
TapeOp::Sqrt(a) => {
let sv = vals[i];
if sv > 0.0 {
dot[*a] * 0.5 / sv
} else {
0.0
}
}
TapeOp::Exp(a) => dot[*a] * vals[i],
TapeOp::Log(a) => dot[*a] / vals[*a],
TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, vals);
let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
panic!("external function '{name}' tangent eval failed: {e}")
});
let derivs = res.derivs.expect("want_derivs=true returns derivs");
let mut acc = 0.0;
let mut k = 0usize;
for arg in args {
if let TapeFuncallArg::Tape(idx) = arg {
acc += derivs[k] * dot[*idx];
k += 1;
}
}
acc
}
};
}
}
pub fn forward_into(&self, x: &[f64], vals: &mut [f64]) {
let n = self.ops.len();
debug_assert!(vals.len() >= n);
for i in 0..n {
vals[i] = match &self.ops[i] {
TapeOp::Const(c) => *c,
TapeOp::Var(j) => x[*j],
TapeOp::Add(a, b) => vals[*a] + vals[*b],
TapeOp::Sub(a, b) => vals[*a] - vals[*b],
TapeOp::Mul(a, b) => vals[*a] * vals[*b],
TapeOp::Div(a, b) => vals[*a] / vals[*b],
TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
TapeOp::Neg(a) => -vals[*a],
TapeOp::Abs(a) => vals[*a].abs(),
TapeOp::Sqrt(a) => vals[*a].sqrt(),
TapeOp::Exp(a) => vals[*a].exp(),
TapeOp::Log(a) => vals[*a].ln(),
TapeOp::Log10(a) => vals[*a].log10(),
TapeOp::Sin(a) => vals[*a].sin(),
TapeOp::Cos(a) => vals[*a].cos(),
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, &*vals);
let res = lib
.eval(name, &call_args, false, false)
.unwrap_or_else(|e| {
panic!("external function '{name}' forward_into failed: {e}")
});
res.value
}
};
}
}
pub fn hessian_directional(
&self,
vals: &[f64],
seed: &[f64],
weight: f64,
out: &mut [f64],
dot: &mut [f64],
adj: &mut [f64],
adj_dot: &mut [f64],
) {
let n = self.ops.len();
if n == 0 || weight == 0.0 {
return;
}
debug_assert!(vals.len() >= n);
debug_assert!(dot.len() >= n);
debug_assert!(adj.len() >= n);
debug_assert!(adj_dot.len() >= n);
for i in 0..n {
dot[i] = match &self.ops[i] {
TapeOp::Const(_) => 0.0,
TapeOp::Var(k) => seed[*k],
TapeOp::Add(a, b) => dot[*a] + dot[*b],
TapeOp::Sub(a, b) => dot[*a] - dot[*b],
TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
TapeOp::Div(a, b) => {
let vb = vals[*b];
(dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
}
TapeOp::Pow(a, b) => {
let u = vals[*a];
let r = vals[*b];
let du = dot[*a];
let dr = dot[*b];
let mut result = 0.0;
if r != 0.0 && u != 0.0 {
result += r * u.powf(r - 1.0) * du;
}
if u > 0.0 {
result += vals[i] * u.ln() * dr;
}
result
}
TapeOp::Neg(a) => -dot[*a],
TapeOp::Abs(a) => {
if vals[*a] >= 0.0 {
dot[*a]
} else {
-dot[*a]
}
}
TapeOp::Sqrt(a) => {
let sv = vals[i];
if sv > 0.0 {
dot[*a] * 0.5 / sv
} else {
0.0
}
}
TapeOp::Exp(a) => vals[i] * dot[*a],
TapeOp::Log(a) => dot[*a] / vals[*a],
TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
TapeOp::Sin(a) => vals[*a].cos() * dot[*a],
TapeOp::Cos(a) => -vals[*a].sin() * dot[*a],
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, vals);
let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
panic!("external function '{name}' tangent eval failed: {e}")
});
let derivs = res.derivs.expect("want_derivs=true returns derivs");
let mut acc = 0.0;
let mut k = 0usize;
for arg in args {
if let TapeFuncallArg::Tape(idx) = arg {
acc += derivs[k] * dot[*idx];
k += 1;
}
}
acc
}
};
}
for slot in adj.iter_mut().take(n) {
*slot = 0.0;
}
for slot in adj_dot.iter_mut().take(n) {
*slot = 0.0;
}
adj[n - 1] = 1.0;
for i in (0..n).rev() {
let w = adj[i];
let wd = adj_dot[i];
if w == 0.0 && wd == 0.0 {
continue;
}
match &self.ops[i] {
TapeOp::Const(_) => {}
TapeOp::Var(k) => {
if wd != 0.0 {
out[*k] += weight * wd;
}
}
TapeOp::Add(a, b) => {
adj[*a] += w;
adj[*b] += w;
adj_dot[*a] += wd;
adj_dot[*b] += wd;
}
TapeOp::Sub(a, b) => {
adj[*a] += w;
adj[*b] -= w;
adj_dot[*a] += wd;
adj_dot[*b] -= wd;
}
TapeOp::Mul(a, b) => {
adj[*a] += w * vals[*b];
adj[*b] += w * vals[*a];
adj_dot[*a] += wd * vals[*b] + w * dot[*b];
adj_dot[*b] += wd * vals[*a] + w * dot[*a];
}
TapeOp::Div(a, b) => {
let vb = vals[*b];
let vb2 = vb * vb;
let vb3 = vb2 * vb;
adj[*a] += w / vb;
adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
adj[*b] += w * (-vals[*a] / vb2);
adj_dot[*b] += wd * (-vals[*a] / vb2)
+ w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
}
TapeOp::Pow(a, b) => {
let u = vals[*a];
let r = vals[*b];
let du = dot[*a];
let dr = dot[*b];
if r != 0.0 {
if u != 0.0 {
let p_a = r * u.powf(r - 1.0);
adj[*a] += w * p_a;
let mut dp_a = dr * u.powf(r - 1.0);
if u > 0.0 {
dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
} else {
dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
}
adj_dot[*a] += wd * p_a + w * dp_a;
} else if r >= 2.0 {
let p_a = 0.0;
adj[*a] += w * p_a;
let dp_a = if r == 2.0 {
2.0 * du
} else {
r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
};
adj_dot[*a] += wd * p_a + w * dp_a;
}
}
if u > 0.0 {
let ln_u = u.ln();
let p_b = vals[i] * ln_u;
adj[*b] += w * p_b;
let dur = vals[i] * (r * du / u + dr * ln_u);
let dp_b = dur * ln_u + vals[i] * du / u;
adj_dot[*b] += wd * p_b + w * dp_b;
}
}
TapeOp::Neg(a) => {
adj[*a] -= w;
adj_dot[*a] -= wd;
}
TapeOp::Abs(a) => {
let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
adj[*a] += w * s;
adj_dot[*a] += wd * s;
}
TapeOp::Sqrt(a) => {
let sv = vals[i];
if sv > 0.0 {
let fp = 0.5 / sv;
let fpp = -0.25 / (vals[*a] * sv);
adj[*a] += w * fp;
adj_dot[*a] += wd * fp + w * fpp * dot[*a];
}
}
TapeOp::Exp(a) => {
let ev = vals[i];
adj[*a] += w * ev;
adj_dot[*a] += wd * ev + w * ev * dot[*a];
}
TapeOp::Log(a) => {
let u = vals[*a];
adj[*a] += w / u;
adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
}
TapeOp::Log10(a) => {
let u = vals[*a];
let c = std::f64::consts::LN_10;
adj[*a] += w / (u * c);
adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
}
TapeOp::Sin(a) => {
let u = vals[*a];
let cu = u.cos();
adj[*a] += w * cu;
adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
}
TapeOp::Cos(a) => {
let u = vals[*a];
let su = u.sin();
adj[*a] -= w * su;
adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
}
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, vals);
let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
panic!("external function '{name}' 2nd-order eval failed: {e}")
});
let derivs = res.derivs.expect("want_derivs=true returns derivs");
let hes = res.hessian.expect("want_hes=true returns hessian");
let real_tape: Vec<usize> = args
.iter()
.filter_map(|a| match a {
TapeFuncallArg::Tape(t) => Some(*t),
TapeFuncallArg::Str(_) => None,
})
.collect();
for (k, &tk) in real_tape.iter().enumerate() {
adj[tk] += w * derivs[k];
let mut second_term = 0.0;
for (l, &tl) in real_tape.iter().enumerate() {
let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
let h_kl = hes[lo + hi * (hi + 1) / 2];
second_term += h_kl * dot[tl];
}
adj_dot[tk] += wd * derivs[k] + w * second_term;
}
}
}
}
}
pub fn hessian_accumulate(
&self,
x: &[f64],
weight: f64,
hess_map: &HashMap<(usize, usize), usize>,
values: &mut [f64],
) {
let n = self.ops.len();
if n == 0 || weight == 0.0 {
return;
}
let v = self.forward(x);
let var_indices = self.variables();
let mut dot = vec![0.0f64; n];
let mut adj = vec![0.0f64; n];
let mut adj_dot = vec![0.0f64; n];
for &j in &var_indices {
self.forward_tangent(&v, j, &mut dot);
adj.fill(0.0);
adj_dot.fill(0.0);
adj[n - 1] = 1.0;
for i in (0..n).rev() {
let w = adj[i];
let wd = adj_dot[i];
if w == 0.0 && wd == 0.0 {
continue;
}
match &self.ops[i] {
TapeOp::Const(_) => {}
TapeOp::Var(k) => {
if wd != 0.0 && *k >= j {
if let Some(&pos) = hess_map.get(&(*k, j)) {
values[pos] += weight * wd;
}
}
}
TapeOp::Add(a, b) => {
adj[*a] += w;
adj[*b] += w;
adj_dot[*a] += wd;
adj_dot[*b] += wd;
}
TapeOp::Sub(a, b) => {
adj[*a] += w;
adj[*b] -= w;
adj_dot[*a] += wd;
adj_dot[*b] -= wd;
}
TapeOp::Mul(a, b) => {
adj[*a] += w * v[*b];
adj[*b] += w * v[*a];
adj_dot[*a] += wd * v[*b] + w * dot[*b];
adj_dot[*b] += wd * v[*a] + w * dot[*a];
}
TapeOp::Div(a, b) => {
let vb = v[*b];
let vb2 = vb * vb;
let vb3 = vb2 * vb;
adj[*a] += w / vb;
adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
adj[*b] += w * (-v[*a] / vb2);
adj_dot[*b] += wd * (-v[*a] / vb2)
+ w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
}
TapeOp::Pow(a, b) => {
let u = v[*a];
let r = v[*b];
let du = dot[*a];
let dr = dot[*b];
if r != 0.0 {
if u != 0.0 {
let p_a = r * u.powf(r - 1.0);
adj[*a] += w * p_a;
let mut dp_a = dr * u.powf(r - 1.0);
if u > 0.0 {
dp_a +=
r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
} else {
dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
}
adj_dot[*a] += wd * p_a + w * dp_a;
} else if r >= 2.0 {
let p_a = 0.0;
adj[*a] += w * p_a;
let dp_a = if r == 2.0 {
2.0 * du
} else {
r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
};
adj_dot[*a] += wd * p_a + w * dp_a;
}
}
if u > 0.0 {
let ln_u = u.ln();
let p_b = v[i] * ln_u;
adj[*b] += w * p_b;
let dur = v[i] * (r * du / u + dr * ln_u);
let dp_b = dur * ln_u + v[i] * du / u;
adj_dot[*b] += wd * p_b + w * dp_b;
}
}
TapeOp::Neg(a) => {
adj[*a] -= w;
adj_dot[*a] -= wd;
}
TapeOp::Abs(a) => {
let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
adj[*a] += w * s;
adj_dot[*a] += wd * s;
}
TapeOp::Sqrt(a) => {
let sv = v[i];
if sv > 0.0 {
let fp = 0.5 / sv;
let fpp = -0.25 / (v[*a] * sv);
adj[*a] += w * fp;
adj_dot[*a] += wd * fp + w * fpp * dot[*a];
}
}
TapeOp::Exp(a) => {
let ev = v[i];
adj[*a] += w * ev;
adj_dot[*a] += wd * ev + w * ev * dot[*a];
}
TapeOp::Log(a) => {
let u = v[*a];
adj[*a] += w / u;
adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
}
TapeOp::Log10(a) => {
let u = v[*a];
let c = std::f64::consts::LN_10;
adj[*a] += w / (u * c);
adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
}
TapeOp::Sin(a) => {
let u = v[*a];
let cu = u.cos();
adj[*a] += w * cu;
adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
}
TapeOp::Cos(a) => {
let u = v[*a];
let su = u.sin();
adj[*a] -= w * su;
adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
}
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, &v);
let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
panic!("external function '{name}' 2nd-order eval failed: {e}")
});
let derivs = res.derivs.expect("want_derivs=true returns derivs");
let hes = res.hessian.expect("want_hes=true returns hessian");
let real_tape: Vec<usize> = args
.iter()
.filter_map(|a| match a {
TapeFuncallArg::Tape(t) => Some(*t),
TapeFuncallArg::Str(_) => None,
})
.collect();
for (k, &tk) in real_tape.iter().enumerate() {
adj[tk] += w * derivs[k];
let mut second_term = 0.0;
for (l, &tl) in real_tape.iter().enumerate() {
let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
let h_kl = hes[lo + hi * (hi + 1) / 2];
second_term += h_kl * dot[tl];
}
adj_dot[tk] += wd * derivs[k] + w * second_term;
}
}
}
}
}
}
pub fn hessian_sparsity(&self) -> BTreeSet<(usize, usize)> {
let n = self.ops.len();
let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
let emit_cross =
|s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
for &v1 in s1 {
for &v2 in s2 {
let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
pairs.insert((r, c));
}
}
};
let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
let vars: Vec<usize> = s.iter().copied().collect();
for (ai, &vi) in vars.iter().enumerate() {
for &vj in &vars[..=ai] {
let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
pairs.insert((r, c));
}
}
};
for op in &self.ops {
let vset = match op {
TapeOp::Const(_) => BTreeSet::new(),
TapeOp::Var(j) => {
let mut s = BTreeSet::new();
s.insert(*j);
s
}
TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
TapeOp::Mul(a, b) => {
emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Div(a, b) => {
emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
emit_self(&var_sets[*b], &mut pairs);
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Pow(a, b) => {
let combined: BTreeSet<usize> =
var_sets[*a].union(&var_sets[*b]).copied().collect();
emit_self(&combined, &mut pairs);
combined
}
TapeOp::Sqrt(a)
| TapeOp::Exp(a)
| TapeOp::Log(a)
| TapeOp::Log10(a)
| TapeOp::Sin(a)
| TapeOp::Cos(a) => {
emit_self(&var_sets[*a], &mut pairs);
var_sets[*a].clone()
}
TapeOp::Funcall { args, .. } => {
let mut combined: BTreeSet<usize> = BTreeSet::new();
for arg in args {
if let TapeFuncallArg::Tape(t) = arg {
for &vv in &var_sets[*t] {
combined.insert(vv);
}
}
}
emit_self(&combined, &mut pairs);
combined
}
};
var_sets.push(vset);
}
pairs
}
}
fn build_recursive(
expr: &Expr,
ops: &mut Vec<TapeOp>,
cache: &mut HashMap<*const Expr, usize>,
resolver: &ExternalResolver,
) -> usize {
match expr {
Expr::Const(c) => {
let idx = ops.len();
ops.push(TapeOp::Const(*c));
idx
}
Expr::Var(i) => {
let idx = ops.len();
ops.push(TapeOp::Var(*i));
idx
}
Expr::Binary(op, a, b) => {
if let BinOp::Pow = op {
if let Some(c) = peek_const(b) {
if let Some(idx) = try_emit_const_pow(a, c, ops, cache, resolver) {
return idx;
}
}
}
let l = build_recursive(a, ops, cache, resolver);
let r = build_recursive(b, ops, cache, resolver);
let idx = ops.len();
ops.push(match op {
BinOp::Add => TapeOp::Add(l, r),
BinOp::Sub => TapeOp::Sub(l, r),
BinOp::Mul => TapeOp::Mul(l, r),
BinOp::Div => TapeOp::Div(l, r),
BinOp::Pow => TapeOp::Pow(l, r),
});
idx
}
Expr::Unary(op, a) => {
let v = build_recursive(a, ops, cache, resolver);
let idx = ops.len();
ops.push(match op {
UnaryOp::Neg => TapeOp::Neg(v),
UnaryOp::Sqrt => TapeOp::Sqrt(v),
UnaryOp::Log => TapeOp::Log(v),
UnaryOp::Log10 => TapeOp::Log10(v),
UnaryOp::Exp => TapeOp::Exp(v),
UnaryOp::Abs => TapeOp::Abs(v),
UnaryOp::Sin => TapeOp::Sin(v),
UnaryOp::Cos => TapeOp::Cos(v),
});
idx
}
Expr::Sum(args) => {
if args.is_empty() {
let idx = ops.len();
ops.push(TapeOp::Const(0.0));
return idx;
}
let mut acc = build_recursive(&args[0], ops, cache, resolver);
for a in &args[1..] {
let next = build_recursive(a, ops, cache, resolver);
let idx = ops.len();
ops.push(TapeOp::Add(acc, next));
acc = idx;
}
acc
}
Expr::Cse(body) => {
let key = Rc::as_ptr(body) as *const Expr;
if let Some(&idx) = cache.get(&key) {
idx
} else {
let idx = build_recursive(body, ops, cache, resolver);
cache.insert(key, idx);
idx
}
}
Expr::Funcall { id, args } => {
let (lib, name) = resolver
.funcs_by_id
.get(id)
.unwrap_or_else(|| panic!("unresolved AMPL funcall id {id}"));
let tape_args: Vec<TapeFuncallArg> = args
.iter()
.map(|a| match a {
FuncallArg::Real(e) => {
TapeFuncallArg::Tape(build_recursive(e, ops, cache, resolver))
}
FuncallArg::Str(s) => TapeFuncallArg::Str(s.clone()),
})
.collect();
let idx = ops.len();
ops.push(TapeOp::Funcall {
lib: Arc::clone(lib),
name: name.clone(),
args: tape_args,
});
idx
}
}
}
fn peek_const(e: &Expr) -> Option<f64> {
match e {
Expr::Const(c) => Some(*c),
Expr::Cse(body) => peek_const(body),
_ => None,
}
}
fn try_emit_const_pow(
base_expr: &Expr,
c: f64,
ops: &mut Vec<TapeOp>,
cache: &mut HashMap<*const Expr, usize>,
resolver: &ExternalResolver,
) -> Option<usize> {
if c == 0.0 {
let idx = ops.len();
ops.push(TapeOp::Const(1.0));
return Some(idx);
}
if c == 1.0 {
return Some(build_recursive(base_expr, ops, cache, resolver));
}
if c == 0.5 {
let b = build_recursive(base_expr, ops, cache, resolver);
let idx = ops.len();
ops.push(TapeOp::Sqrt(b));
return Some(idx);
}
if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
let n = c.abs() as u32;
if n == 0 {
let idx = ops.len();
ops.push(TapeOp::Const(1.0));
return Some(idx);
}
let b = build_recursive(base_expr, ops, cache, resolver);
let pos = emit_int_pow(b, n, ops);
if c < 0.0 {
let one_idx = ops.len();
ops.push(TapeOp::Const(1.0));
let idx = ops.len();
ops.push(TapeOp::Div(one_idx, pos));
return Some(idx);
}
return Some(pos);
}
None
}
fn emit_int_pow(base: usize, n: u32, ops: &mut Vec<TapeOp>) -> usize {
debug_assert!(n >= 1);
if n == 1 {
return base;
}
let half = emit_int_pow(base, n / 2, ops);
let squared = ops.len();
ops.push(TapeOp::Mul(half, half));
if n % 2 == 1 {
let idx = ops.len();
ops.push(TapeOp::Mul(squared, base));
idx
} else {
squared
}
}
#[derive(Debug, Clone)]
pub enum SummandOp {
Local(TapeOp),
Shared(usize),
}
#[derive(Debug, Clone)]
pub struct Summand {
pub ops: Vec<SummandOp>,
pub root_slot: usize,
pub local_reach: Vec<usize>,
pub prelude_reach: Vec<usize>,
pub local_vars: Vec<usize>,
pub prelude_vars: Vec<usize>,
pub all_vars: Vec<usize>,
}
#[derive(Debug)]
pub struct HybridTape {
pub prelude: Vec<TapeOp>,
pub summands: Vec<Summand>,
}
impl HybridTape {
pub fn build_multi(exprs: &[Expr]) -> Self {
let mut cse_count: HashMap<*const Expr, usize> = HashMap::new();
for e in exprs {
let mut seen_in_root: HashSet<*const Expr> = HashSet::new();
count_cse_appearances(e, &mut seen_in_root, &mut cse_count);
}
let mut prelude: Vec<TapeOp> = Vec::new();
let mut prelude_map: HashMap<*const Expr, usize> = HashMap::new();
let mut summands: Vec<Summand> = Vec::with_capacity(exprs.len());
for e in exprs {
let mut local: Vec<SummandOp> = Vec::new();
let mut local_cache: HashMap<*const Expr, usize> = HashMap::new();
let root_slot = build_into_summand(
e,
&mut local,
&mut local_cache,
&mut prelude,
&mut prelude_map,
&cse_count,
);
summands.push(Summand {
ops: local,
root_slot,
local_reach: Vec::new(),
prelude_reach: Vec::new(),
local_vars: Vec::new(),
prelude_vars: Vec::new(),
all_vars: Vec::new(),
});
}
let mut p_visited: Vec<u32> = vec![0; prelude.len()];
let mut p_epoch: u32 = 0;
let mut p_stack: Vec<usize> = Vec::new();
for s in &mut summands {
let (local_reach, shared_refs) = compute_local_reach(&s.ops, s.root_slot);
s.local_reach = local_reach;
let mut lv: BTreeSet<usize> = BTreeSet::new();
for &i in &s.local_reach {
if let SummandOp::Local(TapeOp::Var(j)) = &s.ops[i] {
lv.insert(*j);
}
}
s.local_vars = lv.iter().copied().collect();
if !shared_refs.is_empty() {
p_epoch += 1;
let mut preach: Vec<usize> = Vec::new();
for &start in &shared_refs {
bfs_prelude(
&prelude,
start,
&mut p_visited,
p_epoch,
&mut p_stack,
&mut preach,
);
}
preach.sort_unstable();
s.prelude_vars = vars_in(&prelude, &preach);
s.prelude_reach = preach;
}
let mut av: BTreeSet<usize> = lv;
for &v in &s.prelude_vars {
av.insert(v);
}
s.all_vars = av.into_iter().collect();
}
HybridTape { prelude, summands }
}
pub fn n_prelude_ops(&self) -> usize {
self.prelude.len()
}
pub fn n_summands(&self) -> usize {
self.summands.len()
}
pub fn max_summand_ops(&self) -> usize {
self.summands.iter().map(|s| s.ops.len()).max().unwrap_or(0)
}
pub fn total_local_ops(&self) -> usize {
self.summands.iter().map(|s| s.ops.len()).sum()
}
pub fn forward_prelude(&self, x: &[f64], prelude_vals: &mut [f64]) {
debug_assert_eq!(prelude_vals.len(), self.prelude.len());
for i in 0..self.prelude.len() {
prelude_vals[i] = fwd_step(&self.prelude[i], x, prelude_vals);
}
}
pub fn forward_summand(
&self,
s: &Summand,
x: &[f64],
prelude_vals: &[f64],
local_vals: &mut [f64],
) {
debug_assert!(local_vals.len() >= s.ops.len());
for i in 0..s.ops.len() {
local_vals[i] = match &s.ops[i] {
SummandOp::Local(op) => fwd_step(op, x, local_vals),
SummandOp::Shared(k) => prelude_vals[*k],
};
}
}
#[inline]
pub fn root_value(&self, s: &Summand, local_vals: &[f64]) -> f64 {
local_vals[s.root_slot]
}
#[allow(clippy::too_many_arguments)]
pub fn gradient_summand(
&self,
s: &Summand,
prelude_vals: &[f64],
local_vals: &[f64],
seed: f64,
grad: &mut [f64],
local_adj: &mut [f64],
prelude_adj: &mut [f64],
) {
if seed == 0.0 || s.local_reach.is_empty() {
return;
}
for &i in &s.local_reach {
local_adj[i] = 0.0;
}
for &i in &s.prelude_reach {
prelude_adj[i] = 0.0;
}
local_adj[s.root_slot] = seed;
for &i in s.local_reach.iter().rev() {
let a = local_adj[i];
if a == 0.0 {
continue;
}
match &s.ops[i] {
SummandOp::Local(op) => rev_step(op, i, local_vals, local_adj, a, grad),
SummandOp::Shared(k) => {
prelude_adj[*k] += a;
}
}
}
for &i in s.prelude_reach.iter().rev() {
let a = prelude_adj[i];
if a == 0.0 {
continue;
}
rev_step(&self.prelude[i], i, prelude_vals, prelude_adj, a, grad);
}
}
#[allow(clippy::too_many_arguments)]
pub fn hessian_summand(
&self,
s: &Summand,
prelude_vals: &[f64],
local_vals: &[f64],
weight: f64,
hess_map: &HashMap<(usize, usize), usize>,
values: &mut [f64],
local_dot: &mut [f64],
local_adj: &mut [f64],
local_adj_dot: &mut [f64],
prelude_dot: &mut [f64],
prelude_adj: &mut [f64],
prelude_adj_dot: &mut [f64],
) {
if weight == 0.0 || s.local_reach.is_empty() {
return;
}
for &j in &s.all_vars {
for &i in &s.local_reach {
local_dot[i] = 0.0;
local_adj[i] = 0.0;
local_adj_dot[i] = 0.0;
}
for &i in &s.prelude_reach {
prelude_dot[i] = 0.0;
prelude_adj[i] = 0.0;
prelude_adj_dot[i] = 0.0;
}
for &i in &s.prelude_reach {
prelude_dot[i] = fwd_tan_step(&self.prelude[i], j, prelude_vals, prelude_dot, i);
}
for &i in &s.local_reach {
local_dot[i] = match &s.ops[i] {
SummandOp::Local(op) => fwd_tan_step(op, j, local_vals, local_dot, i),
SummandOp::Shared(k) => prelude_dot[*k],
};
}
local_adj[s.root_slot] = 1.0;
for &i in s.local_reach.iter().rev() {
let w = local_adj[i];
let wd = local_adj_dot[i];
if w == 0.0 && wd == 0.0 {
continue;
}
match &s.ops[i] {
SummandOp::Local(op) => {
ror_step(
op,
i,
j,
local_vals,
local_dot,
local_adj,
local_adj_dot,
w,
wd,
weight,
hess_map,
values,
);
}
SummandOp::Shared(k) => {
prelude_adj[*k] += w;
prelude_adj_dot[*k] += wd;
}
}
}
for &i in s.prelude_reach.iter().rev() {
let w = prelude_adj[i];
let wd = prelude_adj_dot[i];
if w == 0.0 && wd == 0.0 {
continue;
}
ror_step(
&self.prelude[i],
i,
j,
prelude_vals,
prelude_dot,
prelude_adj,
prelude_adj_dot,
w,
wd,
weight,
hess_map,
values,
);
}
}
}
pub fn hessian_sparsity_all(&self) -> BTreeSet<(usize, usize)> {
let mut pairs = hessian_sparsity_impl(&self.prelude);
let prelude_var_sets = compute_var_sets(&self.prelude);
for s in &self.summands {
summand_sparsity(&s.ops, &prelude_var_sets, &mut pairs);
}
pairs
}
}
fn count_cse_appearances(
e: &Expr,
seen_in_root: &mut HashSet<*const Expr>,
counts: &mut HashMap<*const Expr, usize>,
) {
match e {
Expr::Const(_) | Expr::Var(_) => {}
Expr::Binary(_, a, b) => {
count_cse_appearances(a, seen_in_root, counts);
count_cse_appearances(b, seen_in_root, counts);
}
Expr::Unary(_, a) => count_cse_appearances(a, seen_in_root, counts),
Expr::Sum(args) => {
for a in args {
count_cse_appearances(a, seen_in_root, counts);
}
}
Expr::Cse(body) => {
let key = Rc::as_ptr(body) as *const Expr;
if seen_in_root.insert(key) {
*counts.entry(key).or_insert(0) += 1;
count_cse_appearances(body, seen_in_root, counts);
}
}
Expr::Funcall { args, .. } => {
for arg in args {
if let FuncallArg::Real(e) = arg {
count_cse_appearances(e, seen_in_root, counts);
}
}
}
}
}
fn build_into_summand(
expr: &Expr,
local: &mut Vec<SummandOp>,
local_cache: &mut HashMap<*const Expr, usize>,
prelude: &mut Vec<TapeOp>,
prelude_map: &mut HashMap<*const Expr, usize>,
cse_count: &HashMap<*const Expr, usize>,
) -> usize {
match expr {
Expr::Const(c) => {
let i = local.len();
local.push(SummandOp::Local(TapeOp::Const(*c)));
i
}
Expr::Var(j) => {
let i = local.len();
local.push(SummandOp::Local(TapeOp::Var(*j)));
i
}
Expr::Binary(op, a, b) => {
if let BinOp::Pow = op {
if let Some(c) = peek_const(b) {
if let Some(i) = try_emit_const_pow_summand(
a,
c,
local,
local_cache,
prelude,
prelude_map,
cse_count,
) {
return i;
}
}
}
let l = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
let r = build_into_summand(b, local, local_cache, prelude, prelude_map, cse_count);
let i = local.len();
local.push(SummandOp::Local(match op {
BinOp::Add => TapeOp::Add(l, r),
BinOp::Sub => TapeOp::Sub(l, r),
BinOp::Mul => TapeOp::Mul(l, r),
BinOp::Div => TapeOp::Div(l, r),
BinOp::Pow => TapeOp::Pow(l, r),
}));
i
}
Expr::Unary(op, a) => {
let v = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
let i = local.len();
local.push(SummandOp::Local(match op {
UnaryOp::Neg => TapeOp::Neg(v),
UnaryOp::Sqrt => TapeOp::Sqrt(v),
UnaryOp::Log => TapeOp::Log(v),
UnaryOp::Log10 => TapeOp::Log10(v),
UnaryOp::Exp => TapeOp::Exp(v),
UnaryOp::Abs => TapeOp::Abs(v),
UnaryOp::Sin => TapeOp::Sin(v),
UnaryOp::Cos => TapeOp::Cos(v),
}));
i
}
Expr::Sum(args) => {
if args.is_empty() {
let i = local.len();
local.push(SummandOp::Local(TapeOp::Const(0.0)));
return i;
}
let mut acc = build_into_summand(
&args[0],
local,
local_cache,
prelude,
prelude_map,
cse_count,
);
for a in &args[1..] {
let nxt =
build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
let i = local.len();
local.push(SummandOp::Local(TapeOp::Add(acc, nxt)));
acc = i;
}
acc
}
Expr::Cse(body) => {
let key = Rc::as_ptr(body) as *const Expr;
if let Some(&li) = local_cache.get(&key) {
return li;
}
let promoted = cse_count.get(&key).copied().unwrap_or(0) >= 2;
if promoted {
let pslot =
build_recursive(expr, prelude, prelude_map, &ExternalResolver::default());
let li = local.len();
local.push(SummandOp::Shared(pslot));
local_cache.insert(key, li);
li
} else {
let li =
build_into_summand(body, local, local_cache, prelude, prelude_map, cse_count);
local_cache.insert(key, li);
li
}
}
Expr::Funcall { .. } => {
panic!(
"HybridTape: AMPL external function calls are not supported on the \
hybrid (partial-separability) tape path. Build with Tape::build_with_externals \
instead."
);
}
}
}
fn try_emit_const_pow_summand(
base_expr: &Expr,
c: f64,
local: &mut Vec<SummandOp>,
local_cache: &mut HashMap<*const Expr, usize>,
prelude: &mut Vec<TapeOp>,
prelude_map: &mut HashMap<*const Expr, usize>,
cse_count: &HashMap<*const Expr, usize>,
) -> Option<usize> {
if c == 0.0 {
let i = local.len();
local.push(SummandOp::Local(TapeOp::Const(1.0)));
return Some(i);
}
if c == 1.0 {
return Some(build_into_summand(
base_expr,
local,
local_cache,
prelude,
prelude_map,
cse_count,
));
}
if c == 0.5 {
let b = build_into_summand(
base_expr,
local,
local_cache,
prelude,
prelude_map,
cse_count,
);
let i = local.len();
local.push(SummandOp::Local(TapeOp::Sqrt(b)));
return Some(i);
}
if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
let n = c.abs() as u32;
if n == 0 {
let i = local.len();
local.push(SummandOp::Local(TapeOp::Const(1.0)));
return Some(i);
}
let b = build_into_summand(
base_expr,
local,
local_cache,
prelude,
prelude_map,
cse_count,
);
let pos = emit_int_pow_summand(b, n, local);
if c < 0.0 {
let one_idx = local.len();
local.push(SummandOp::Local(TapeOp::Const(1.0)));
let i = local.len();
local.push(SummandOp::Local(TapeOp::Div(one_idx, pos)));
return Some(i);
}
return Some(pos);
}
None
}
fn emit_int_pow_summand(base: usize, n: u32, local: &mut Vec<SummandOp>) -> usize {
debug_assert!(n >= 1);
if n == 1 {
return base;
}
let half = emit_int_pow_summand(base, n / 2, local);
let squared = local.len();
local.push(SummandOp::Local(TapeOp::Mul(half, half)));
if n % 2 == 1 {
let i = local.len();
local.push(SummandOp::Local(TapeOp::Mul(squared, base)));
i
} else {
squared
}
}
fn compute_local_reach(ops: &[SummandOp], root: usize) -> (Vec<usize>, Vec<usize>) {
let mut visited = vec![false; ops.len()];
let mut reach: Vec<usize> = Vec::new();
let mut shared: BTreeSet<usize> = BTreeSet::new();
let mut stack: Vec<usize> = Vec::with_capacity(16);
visited[root] = true;
reach.push(root);
stack.push(root);
while let Some(s) = stack.pop() {
match &ops[s] {
SummandOp::Local(op) => {
let (a, b) = op_operands(op);
if let Some(a) = a {
if !visited[a] {
visited[a] = true;
reach.push(a);
stack.push(a);
}
}
if let Some(b) = b {
if !visited[b] {
visited[b] = true;
reach.push(b);
stack.push(b);
}
}
}
SummandOp::Shared(k) => {
shared.insert(*k);
}
}
}
reach.sort_unstable();
(reach, shared.into_iter().collect())
}
fn bfs_prelude(
prelude: &[TapeOp],
start: usize,
visited: &mut [u32],
cur: u32,
stack: &mut Vec<usize>,
out: &mut Vec<usize>,
) {
if visited[start] == cur {
return;
}
visited[start] = cur;
out.push(start);
stack.push(start);
while let Some(s) = stack.pop() {
let (a, b) = op_operands(&prelude[s]);
if let Some(a) = a {
if visited[a] != cur {
visited[a] = cur;
out.push(a);
stack.push(a);
}
}
if let Some(b) = b {
if visited[b] != cur {
visited[b] = cur;
out.push(b);
stack.push(b);
}
}
}
}
fn compute_var_sets(ops: &[TapeOp]) -> Vec<BTreeSet<usize>> {
let mut out: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
for op in ops {
let vs: BTreeSet<usize> = match op {
TapeOp::Const(_) => BTreeSet::new(),
TapeOp::Var(j) => {
let mut s = BTreeSet::new();
s.insert(*j);
s
}
TapeOp::Add(a, b)
| TapeOp::Sub(a, b)
| TapeOp::Mul(a, b)
| TapeOp::Div(a, b)
| TapeOp::Pow(a, b) => out[*a].union(&out[*b]).copied().collect(),
TapeOp::Neg(a)
| TapeOp::Abs(a)
| TapeOp::Sqrt(a)
| TapeOp::Exp(a)
| TapeOp::Log(a)
| TapeOp::Log10(a)
| TapeOp::Sin(a)
| TapeOp::Cos(a) => out[*a].clone(),
TapeOp::Funcall { .. } => unreachable!(
"HybridTape prelude cannot contain TapeOp::Funcall; \
build_into_summand panics on Expr::Funcall."
),
};
out.push(vs);
}
out
}
fn summand_sparsity(
ops: &[SummandOp],
prelude_var_sets: &[BTreeSet<usize>],
pairs: &mut BTreeSet<(usize, usize)>,
) {
let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
let emit_cross =
|s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
for &v1 in s1 {
for &v2 in s2 {
let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
pairs.insert((r, c));
}
}
};
let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
let vars: Vec<usize> = s.iter().copied().collect();
for (ai, &vi) in vars.iter().enumerate() {
for &vj in &vars[..=ai] {
let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
pairs.insert((r, c));
}
}
};
for so in ops {
let vset: BTreeSet<usize> = match so {
SummandOp::Shared(k) => prelude_var_sets[*k].clone(),
SummandOp::Local(op) => match op {
TapeOp::Const(_) => BTreeSet::new(),
TapeOp::Var(j) => {
let mut s = BTreeSet::new();
s.insert(*j);
s
}
TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
TapeOp::Mul(a, b) => {
emit_cross(&var_sets[*a], &var_sets[*b], pairs);
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Div(a, b) => {
emit_cross(&var_sets[*a], &var_sets[*b], pairs);
emit_self(&var_sets[*b], pairs);
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Pow(a, b) => {
let combined: BTreeSet<usize> =
var_sets[*a].union(&var_sets[*b]).copied().collect();
emit_self(&combined, pairs);
combined
}
TapeOp::Sqrt(a)
| TapeOp::Exp(a)
| TapeOp::Log(a)
| TapeOp::Log10(a)
| TapeOp::Sin(a)
| TapeOp::Cos(a) => {
emit_self(&var_sets[*a], pairs);
var_sets[*a].clone()
}
TapeOp::Funcall { .. } => unreachable!(
"HybridTape summand cannot contain TapeOp::Funcall; \
build_into_summand panics on Expr::Funcall."
),
},
};
var_sets.push(vset);
}
}
#[inline]
fn op_operands(op: &TapeOp) -> (Option<usize>, Option<usize>) {
match op {
TapeOp::Const(_) | TapeOp::Var(_) => (None, None),
TapeOp::Add(a, b)
| TapeOp::Sub(a, b)
| TapeOp::Mul(a, b)
| TapeOp::Div(a, b)
| TapeOp::Pow(a, b) => (Some(*a), Some(*b)),
TapeOp::Neg(a)
| TapeOp::Abs(a)
| TapeOp::Sqrt(a)
| TapeOp::Exp(a)
| TapeOp::Log(a)
| TapeOp::Log10(a)
| TapeOp::Sin(a)
| TapeOp::Cos(a) => (Some(*a), None),
TapeOp::Funcall { .. } => (None, None),
}
}
fn vars_in(ops: &[TapeOp], reach: &[usize]) -> Vec<usize> {
let mut s: BTreeSet<usize> = BTreeSet::new();
for &i in reach {
if let TapeOp::Var(j) = &ops[i] {
s.insert(*j);
}
}
s.into_iter().collect()
}
#[inline]
fn fwd_step(op: &TapeOp, x: &[f64], vals: &[f64]) -> f64 {
match op {
TapeOp::Const(c) => *c,
TapeOp::Var(i) => x[*i],
TapeOp::Add(a, b) => vals[*a] + vals[*b],
TapeOp::Sub(a, b) => vals[*a] - vals[*b],
TapeOp::Mul(a, b) => vals[*a] * vals[*b],
TapeOp::Div(a, b) => vals[*a] / vals[*b],
TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
TapeOp::Neg(a) => -vals[*a],
TapeOp::Abs(a) => vals[*a].abs(),
TapeOp::Sqrt(a) => vals[*a].sqrt(),
TapeOp::Exp(a) => vals[*a].exp(),
TapeOp::Log(a) => vals[*a].ln(),
TapeOp::Log10(a) => vals[*a].log10(),
TapeOp::Sin(a) => vals[*a].sin(),
TapeOp::Cos(a) => vals[*a].cos(),
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, vals);
let res = lib
.eval(name, &call_args, false, false)
.unwrap_or_else(|e| panic!("external function '{name}' eval failed: {e}"));
res.value
}
}
}
#[inline]
fn rev_step(op: &TapeOp, i: usize, vals: &[f64], adj: &mut [f64], a: f64, grad: &mut [f64]) {
match op {
TapeOp::Const(_) => {}
TapeOp::Var(j) => {
grad[*j] += a;
}
TapeOp::Add(l, r) => {
adj[*l] += a;
adj[*r] += a;
}
TapeOp::Sub(l, r) => {
adj[*l] += a;
adj[*r] -= a;
}
TapeOp::Mul(l, r) => {
adj[*l] += a * vals[*r];
adj[*r] += a * vals[*l];
}
TapeOp::Div(l, r) => {
let rv = vals[*r];
adj[*l] += a / rv;
adj[*r] -= a * vals[*l] / (rv * rv);
}
TapeOp::Pow(l, r) => {
let lv = vals[*l];
let rv = vals[*r];
if rv != 0.0 {
adj[*l] += a * rv * lv.powf(rv - 1.0);
}
if lv > 0.0 {
adj[*r] += a * vals[i] * lv.ln();
}
}
TapeOp::Neg(j) => {
adj[*j] -= a;
}
TapeOp::Abs(j) => {
if vals[*j] >= 0.0 {
adj[*j] += a;
} else {
adj[*j] -= a;
}
}
TapeOp::Sqrt(j) => {
let sv = vals[i];
if sv > 0.0 {
adj[*j] += a * 0.5 / sv;
}
}
TapeOp::Exp(j) => {
adj[*j] += a * vals[i];
}
TapeOp::Log(j) => {
adj[*j] += a / vals[*j];
}
TapeOp::Log10(j) => {
adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
}
TapeOp::Sin(j) => {
adj[*j] += a * vals[*j].cos();
}
TapeOp::Cos(j) => {
adj[*j] -= a * vals[*j].sin();
}
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, vals);
let res = lib
.eval(name, &call_args, true, false)
.unwrap_or_else(|e| panic!("external function '{name}' reverse eval failed: {e}"));
let derivs = res.derivs.expect("want_derivs=true returns derivs");
let mut k = 0usize;
for arg in args {
if let TapeFuncallArg::Tape(idx) = arg {
adj[*idx] += a * derivs[k];
k += 1;
}
}
let _ = i;
let _ = grad;
}
}
}
#[inline]
fn fwd_tan_step(op: &TapeOp, seed_var: usize, vals: &[f64], dot: &[f64], i: usize) -> f64 {
match op {
TapeOp::Const(_) => 0.0,
TapeOp::Var(k) => {
if *k == seed_var {
1.0
} else {
0.0
}
}
TapeOp::Add(a, b) => dot[*a] + dot[*b],
TapeOp::Sub(a, b) => dot[*a] - dot[*b],
TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
TapeOp::Div(a, b) => {
let vb = vals[*b];
(dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
}
TapeOp::Pow(a, b) => {
let u = vals[*a];
let r = vals[*b];
let du = dot[*a];
let dr = dot[*b];
let mut result = 0.0;
if r != 0.0 && u != 0.0 {
result += r * u.powf(r - 1.0) * du;
}
if u > 0.0 {
result += vals[i] * u.ln() * dr;
}
result
}
TapeOp::Neg(a) => -dot[*a],
TapeOp::Abs(a) => {
if vals[*a] >= 0.0 {
dot[*a]
} else {
-dot[*a]
}
}
TapeOp::Sqrt(a) => {
let sv = vals[i];
if sv > 0.0 {
dot[*a] * 0.5 / sv
} else {
0.0
}
}
TapeOp::Exp(a) => dot[*a] * vals[i],
TapeOp::Log(a) => dot[*a] / vals[*a],
TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, vals);
let res = lib
.eval(name, &call_args, true, false)
.unwrap_or_else(|e| panic!("external function '{name}' tangent eval failed: {e}"));
let derivs = res.derivs.expect("want_derivs=true returns derivs");
let mut acc = 0.0;
let mut k = 0usize;
for arg in args {
if let TapeFuncallArg::Tape(idx) = arg {
acc += derivs[k] * dot[*idx];
k += 1;
}
}
let _ = seed_var;
acc
}
}
}
#[allow(clippy::too_many_arguments)]
#[inline]
fn ror_step(
op: &TapeOp,
i: usize,
seed_var: usize,
vals: &[f64],
dot: &[f64],
adj: &mut [f64],
adj_dot: &mut [f64],
w: f64,
wd: f64,
weight: f64,
hess_map: &HashMap<(usize, usize), usize>,
values: &mut [f64],
) {
match op {
TapeOp::Const(_) => {}
TapeOp::Var(k) => {
if wd != 0.0 && *k >= seed_var {
if let Some(&pos) = hess_map.get(&(*k, seed_var)) {
values[pos] += weight * wd;
}
}
}
TapeOp::Add(a, b) => {
adj[*a] += w;
adj[*b] += w;
adj_dot[*a] += wd;
adj_dot[*b] += wd;
}
TapeOp::Sub(a, b) => {
adj[*a] += w;
adj[*b] -= w;
adj_dot[*a] += wd;
adj_dot[*b] -= wd;
}
TapeOp::Mul(a, b) => {
adj[*a] += w * vals[*b];
adj[*b] += w * vals[*a];
adj_dot[*a] += wd * vals[*b] + w * dot[*b];
adj_dot[*b] += wd * vals[*a] + w * dot[*a];
}
TapeOp::Div(a, b) => {
let vb = vals[*b];
let vb2 = vb * vb;
let vb3 = vb2 * vb;
adj[*a] += w / vb;
adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
adj[*b] += w * (-vals[*a] / vb2);
adj_dot[*b] +=
wd * (-vals[*a] / vb2) + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
}
TapeOp::Pow(a, b) => {
let u = vals[*a];
let r = vals[*b];
let du = dot[*a];
let dr = dot[*b];
if r != 0.0 {
if u != 0.0 {
let p_a = r * u.powf(r - 1.0);
adj[*a] += w * p_a;
let mut dp_a = dr * u.powf(r - 1.0);
if u > 0.0 {
dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
} else {
dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
}
adj_dot[*a] += wd * p_a + w * dp_a;
} else if r >= 2.0 {
let p_a = 0.0;
adj[*a] += w * p_a;
let dp_a = if r == 2.0 {
2.0 * du
} else {
r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
};
adj_dot[*a] += wd * p_a + w * dp_a;
}
}
if u > 0.0 {
let ln_u = u.ln();
let p_b = vals[i] * ln_u;
adj[*b] += w * p_b;
let dur = vals[i] * (r * du / u + dr * ln_u);
let dp_b = dur * ln_u + vals[i] * du / u;
adj_dot[*b] += wd * p_b + w * dp_b;
}
}
TapeOp::Neg(a) => {
adj[*a] -= w;
adj_dot[*a] -= wd;
}
TapeOp::Abs(a) => {
let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
adj[*a] += w * s;
adj_dot[*a] += wd * s;
}
TapeOp::Sqrt(a) => {
let sv = vals[i];
if sv > 0.0 {
let fp = 0.5 / sv;
let fpp = -0.25 / (vals[*a] * sv);
adj[*a] += w * fp;
adj_dot[*a] += wd * fp + w * fpp * dot[*a];
}
}
TapeOp::Exp(a) => {
let ev = vals[i];
adj[*a] += w * ev;
adj_dot[*a] += wd * ev + w * ev * dot[*a];
}
TapeOp::Log(a) => {
let u = vals[*a];
adj[*a] += w / u;
adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
}
TapeOp::Log10(a) => {
let u = vals[*a];
let c = std::f64::consts::LN_10;
adj[*a] += w / (u * c);
adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
}
TapeOp::Sin(a) => {
let u = vals[*a];
let cu = u.cos();
adj[*a] += w * cu;
adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
}
TapeOp::Cos(a) => {
let u = vals[*a];
let su = u.sin();
adj[*a] -= w * su;
adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
}
TapeOp::Funcall { lib, name, args } => {
let call_args = funcall_to_ext_args(args, vals);
let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
panic!("external function '{name}' 2nd-order eval failed: {e}")
});
let derivs = res.derivs.expect("want_derivs=true returns derivs");
let hes = res.hessian.expect("want_hes=true returns hessian");
let real_tape: Vec<usize> = args
.iter()
.filter_map(|a| match a {
TapeFuncallArg::Tape(t) => Some(*t),
TapeFuncallArg::Str(_) => None,
})
.collect();
for (k, &tk) in real_tape.iter().enumerate() {
adj[tk] += w * derivs[k];
let mut second_term = 0.0;
for (l, &tl) in real_tape.iter().enumerate() {
let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
let h_kl = hes[lo + hi * (hi + 1) / 2];
second_term += h_kl * dot[tl];
}
adj_dot[tk] += wd * derivs[k] + w * second_term;
}
let _ = seed_var;
let _ = hess_map;
let _ = values;
let _ = weight;
let _ = i;
}
}
}
fn hessian_sparsity_impl(ops: &[TapeOp]) -> BTreeSet<(usize, usize)> {
let n = ops.len();
let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
let emit_cross =
|s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
for &v1 in s1 {
for &v2 in s2 {
let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
pairs.insert((r, c));
}
}
};
let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
let vars: Vec<usize> = s.iter().copied().collect();
for (ai, &vi) in vars.iter().enumerate() {
for &vj in &vars[..=ai] {
let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
pairs.insert((r, c));
}
}
};
for op in ops {
let vset = match op {
TapeOp::Const(_) => BTreeSet::new(),
TapeOp::Var(j) => {
let mut s = BTreeSet::new();
s.insert(*j);
s
}
TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
TapeOp::Mul(a, b) => {
emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Div(a, b) => {
emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
emit_self(&var_sets[*b], &mut pairs);
var_sets[*a].union(&var_sets[*b]).copied().collect()
}
TapeOp::Pow(a, b) => {
let combined: BTreeSet<usize> =
var_sets[*a].union(&var_sets[*b]).copied().collect();
emit_self(&combined, &mut pairs);
combined
}
TapeOp::Sqrt(a)
| TapeOp::Exp(a)
| TapeOp::Log(a)
| TapeOp::Log10(a)
| TapeOp::Sin(a)
| TapeOp::Cos(a) => {
emit_self(&var_sets[*a], &mut pairs);
var_sets[*a].clone()
}
TapeOp::Funcall { args, .. } => {
let mut combined: BTreeSet<usize> = BTreeSet::new();
for arg in args {
if let TapeFuncallArg::Tape(t) = arg {
for &vv in &var_sets[*t] {
combined.insert(vv);
}
}
}
emit_self(&combined, &mut pairs);
combined
}
};
var_sets.push(vset);
}
pairs
}
#[cfg(test)]
mod tests {
use super::*;
fn cnst(c: f64) -> Expr {
Expr::Const(c)
}
fn var(i: usize) -> Expr {
Expr::Var(i)
}
fn add(a: Expr, b: Expr) -> Expr {
Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
}
fn mul(a: Expr, b: Expr) -> Expr {
Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
}
fn pow(a: Expr, b: Expr) -> Expr {
Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
}
fn div(a: Expr, b: Expr) -> Expr {
Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
}
fn unary(op: UnaryOp, a: Expr) -> Expr {
Expr::Unary(op, Box::new(a))
}
#[test]
fn polynomial_eval_and_grad() {
let e = add(
mul(cnst(3.0), pow(var(0), cnst(2.0))),
mul(cnst(2.0), var(1)),
);
let t = Tape::build(&e);
assert!((t.eval(&[2.0, 3.0]) - 18.0).abs() < 1e-12);
let mut g = vec![0.0; 2];
t.gradient_seed(&[2.0, 3.0], 1.0, &mut g);
assert!((g[0] - 12.0).abs() < 1e-12);
assert!((g[1] - 2.0).abs() < 1e-12);
}
#[test]
fn cse_shared_body_evaluated_once() {
let body = Rc::new(add(var(0), var(1)));
let e = add(
pow(Expr::Cse(body.clone()), cnst(2.0)),
Expr::Cse(body.clone()),
);
let t = Tape::build(&e);
let n_body_adds = t
.ops
.iter()
.filter(|op| {
matches!(op, TapeOp::Add(a, b) if {
matches!(t.ops[*a], TapeOp::Var(0)) && matches!(t.ops[*b], TapeOp::Var(1))
})
})
.count();
assert_eq!(n_body_adds, 1, "CSE body should be emitted exactly once");
assert!((t.eval(&[1.0, 2.0]) - 12.0).abs() < 1e-12);
let mut g = vec![0.0; 2];
t.gradient_seed(&[1.0, 2.0], 1.0, &mut g);
assert!((g[0] - 7.0).abs() < 1e-12);
assert!((g[1] - 7.0).abs() < 1e-12);
}
fn fd_check(tape: &Tape, x: &[f64], n: usize, tol: f64) {
let vars = tape.variables();
let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
let mut pairs = Vec::new();
for (ai, &vi) in vars.iter().enumerate() {
for &vj in &vars[..=ai] {
let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
hess_map.entry((r, c)).or_insert_with(|| {
let p = pairs.len();
pairs.push((r, c));
p
});
}
}
let nnz = pairs.len();
let mut ad = vec![0.0; nnz];
tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
let mut fd = vec![0.0; nnz];
let mut xp = x.to_vec();
let mut gp = vec![0.0; n];
let mut gm = vec![0.0; n];
for &j in &vars {
let h = (1e-7_f64).max(x[j].abs() * 1e-7);
xp[j] = x[j] + h;
gp.iter_mut().for_each(|v| *v = 0.0);
tape.gradient_seed(&xp, 1.0, &mut gp);
xp[j] = x[j] - h;
gm.iter_mut().for_each(|v| *v = 0.0);
tape.gradient_seed(&xp, 1.0, &mut gm);
xp[j] = x[j];
for &i in &vars {
if i >= j {
if let Some(&pos) = hess_map.get(&(i, j)) {
fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
}
}
}
}
for (k, &(r, c)) in pairs.iter().enumerate() {
let scale = fd[k].abs().max(1.0);
assert!(
(ad[k] - fd[k]).abs() / scale < tol,
"H[{},{}]: AD={:.6e} FD={:.6e}",
r,
c,
ad[k],
fd[k]
);
}
}
#[test]
fn hessian_quadratic_matches_fd() {
let e = add(
add(
mul(cnst(3.0), pow(var(0), cnst(2.0))),
mul(cnst(2.0), mul(var(0), var(1))),
),
pow(var(1), cnst(2.0)),
);
let t = Tape::build(&e);
fd_check(&t, &[2.0, 3.0], 2, 1e-5);
}
#[test]
fn hessian_transcendental_matches_fd() {
let e = Expr::Sum(vec![
unary(UnaryOp::Exp, var(0)),
unary(UnaryOp::Sin, var(1)),
unary(UnaryOp::Log, var(0)),
unary(UnaryOp::Sqrt, var(1)),
mul(var(0), var(1)),
]);
let t = Tape::build(&e);
fd_check(&t, &[1.5, 2.0], 2, 1e-5);
}
#[test]
fn hessian_division_matches_fd() {
let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
let t = Tape::build(&e);
fd_check(&t, &[0.5, 1.2], 2, 1e-5);
}
fn directional_matches_accumulate(tape: &Tape, x: &[f64], n: usize) {
let vars = tape.variables();
let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
let mut pairs = Vec::new();
for (ai, &vi) in vars.iter().enumerate() {
for &vj in &vars[..=ai] {
let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
hess_map.entry((r, c)).or_insert_with(|| {
let p = pairs.len();
pairs.push((r, c));
p
});
}
}
let nnz = pairs.len();
let mut ad = vec![0.0; nnz];
tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
let nops = tape.ops.len();
let mut vals = vec![0.0; nops];
tape.forward_into(x, &mut vals);
let mut dot = vec![0.0; nops];
let mut adj = vec![0.0; nops];
let mut adj_dot = vec![0.0; nops];
for &j in &vars {
let mut seed = vec![0.0; n];
seed[j] = 1.0;
let mut col = vec![0.0; n];
tape.hessian_directional(
&vals,
&seed,
1.0,
&mut col,
&mut dot,
&mut adj,
&mut adj_dot,
);
for &i in &vars {
let (r, c) = if i >= j { (i, j) } else { (j, i) };
let expect = ad[hess_map[&(r, c)]];
assert!(
(col[i] - expect).abs() < 1e-10,
"directional H[{i},{j}] = {} vs accumulate {}",
col[i],
expect
);
}
}
}
#[test]
fn directional_quadratic_matches_accumulate() {
let e = add(
add(
mul(cnst(3.0), pow(var(0), cnst(2.0))),
mul(mul(cnst(2.0), var(0)), var(1)),
),
pow(var(1), cnst(2.0)),
);
let t = Tape::build(&e);
directional_matches_accumulate(&t, &[0.5, -0.3], 2);
}
#[test]
fn directional_transcendental_matches_accumulate() {
let e = Expr::Sum(vec![
unary(UnaryOp::Exp, var(0)),
unary(UnaryOp::Sin, var(1)),
unary(UnaryOp::Log, var(0)),
unary(UnaryOp::Sqrt, var(1)),
mul(var(0), var(1)),
]);
let t = Tape::build(&e);
directional_matches_accumulate(&t, &[1.5, 2.0], 2);
}
#[test]
fn directional_with_division_matches_accumulate() {
let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
let t = Tape::build(&e);
directional_matches_accumulate(&t, &[0.5, 1.2], 2);
}
#[test]
fn hessian_sparsity_separable() {
let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
let t = Tape::build(&e);
let s = t.hessian_sparsity();
assert!(s.contains(&(0, 0)));
assert!(s.contains(&(2, 1)));
assert!(!s.contains(&(1, 0)));
assert!(!s.contains(&(2, 0)));
}
fn count_op<F: Fn(&TapeOp) -> bool>(t: &Tape, pred: F) -> usize {
t.ops.iter().filter(|o| pred(o)).count()
}
#[test]
fn pow_zero_const_folds_to_one() {
let e = pow(var(0), cnst(0.0));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Var(_))), 0);
assert!((t.eval(&[7.0]) - 1.0).abs() < 1e-12);
}
#[test]
fn pow_one_passes_through() {
let e = pow(var(0), cnst(1.0));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Const(_))), 0);
assert!((t.eval(&[3.5]) - 3.5).abs() < 1e-12);
}
#[test]
fn pow_half_lowers_to_sqrt() {
let e = pow(var(0), cnst(0.5));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Sqrt(_))), 1);
assert!((t.eval(&[16.0]) - 4.0).abs() < 1e-12);
}
#[test]
fn pow_two_lowers_to_single_mul() {
let e = pow(var(0), cnst(2.0));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
assert!((t.eval(&[3.0]) - 9.0).abs() < 1e-12);
}
#[test]
fn pow_three_lowers_to_two_muls() {
let e = pow(var(0), cnst(3.0));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 2);
assert!((t.eval(&[2.0]) - 8.0).abs() < 1e-12);
}
#[test]
fn pow_eight_lowers_to_three_muls() {
let e = pow(var(0), cnst(8.0));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 3);
assert!((t.eval(&[2.0]) - 256.0).abs() < 1e-12);
}
#[test]
fn pow_negative_two_lowers_to_div() {
let e = pow(var(0), cnst(-2.0));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Div(..))), 1);
assert!((t.eval(&[4.0]) - (1.0 / 16.0)).abs() < 1e-12);
}
#[test]
fn pow_large_const_stays_generic() {
let e = pow(var(0), cnst(9.0));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
}
#[test]
fn pow_non_integer_const_stays_generic() {
let e = pow(var(0), cnst(1.5));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
}
#[test]
fn pow_const_through_cse_const() {
let two = Rc::new(cnst(2.0));
let e = Expr::Binary(BinOp::Pow, Box::new(var(0)), Box::new(Expr::Cse(two)));
let t = Tape::build(&e);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
}
#[test]
fn hessian_pow_three_matches_fd() {
let e = add(mul(cnst(5.0), pow(var(0), cnst(3.0))), mul(var(0), var(1)));
let t = Tape::build(&e);
fd_check(&t, &[1.7, 0.8], 2, 1e-5);
}
#[test]
fn hessian_pow_negative_matches_fd() {
let e = add(pow(var(0), cnst(-2.0)), pow(var(1), cnst(2.0)));
let t = Tape::build(&e);
fd_check(&t, &[1.3, 2.4], 2, 1e-5);
}
#[test]
fn hessian_pow_half_matches_fd() {
let e = add(pow(var(0), cnst(0.5)), mul(var(0), var(1)));
let t = Tape::build(&e);
fd_check(&t, &[2.5, 1.1], 2, 1e-5);
}
#[test]
fn hessian_sparsity_through_cse() {
let body = Rc::new(add(var(0), var(1)));
let e = add(
pow(Expr::Cse(body.clone()), cnst(2.0)),
Expr::Cse(body.clone()),
);
let t = Tape::build(&e);
let s = t.hessian_sparsity();
assert!(s.contains(&(0, 0)));
assert!(s.contains(&(1, 0)));
assert!(s.contains(&(1, 1)));
assert_eq!(s.len(), 3);
}
}