use echidna::{grad, Dual, Scalar};
#[cfg(feature = "bytecode")]
use echidna::{record, BReverse};
fn finite_diff_grad(f: impl Fn(&[f64]) -> f64, x: &[f64], h: f64) -> Vec<f64> {
let n = x.len();
let mut grad = vec![0.0; n];
for i in 0..n {
let mut xp = x.to_vec();
let mut xm = x.to_vec();
xp[i] += h;
xm[i] -= h;
grad[i] = (f(&xp) - f(&xm)) / (2.0 * h);
}
grad
}
fn forward_grad(f: impl Fn(&[Dual<f64>]) -> Dual<f64>, x: &[f64]) -> Vec<f64> {
let n = x.len();
let mut grad = vec![0.0; n];
for i in 0..n {
let inputs: Vec<Dual<f64>> = x
.iter()
.enumerate()
.map(|(k, &xi)| {
if k == i {
Dual::variable(xi)
} else {
Dual::constant(xi)
}
})
.collect();
grad[i] = f(&inputs).eps;
}
grad
}
fn rosenbrock<T: Scalar>(x: &[T]) -> T {
let one = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(1.0).unwrap());
let hundred = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(100.0).unwrap());
let mut sum = T::zero();
for i in 0..x.len() - 1 {
let t1 = one - x[i];
let t2 = x[i + 1] - x[i] * x[i];
sum = sum + t1 * t1 + hundred * t2 * t2;
}
sum
}
fn beale<T: Scalar>(x: &[T]) -> T {
let x0 = x[0];
let x1 = x[1];
let c1 = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(1.5).unwrap());
let c2 = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(2.25).unwrap());
let c3 = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(2.625).unwrap());
let t1 = c1 - x0 + x0 * x1;
let t2 = c2 - x0 + x0 * x1 * x1;
let t3 = c3 - x0 + x0 * x1 * x1 * x1;
t1 * t1 + t2 * t2 + t3 * t3
}
fn sphere<T: Scalar>(x: &[T]) -> T {
let mut sum = T::zero();
for &xi in x {
sum = sum + xi * xi;
}
sum
}
fn booth<T: Scalar>(x: &[T]) -> T {
let two = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(2.0).unwrap());
let five = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(5.0).unwrap());
let seven = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(7.0).unwrap());
let t1 = x[0] + two * x[1] - seven;
let t2 = two * x[0] + x[1] - five;
t1 * t1 + t2 * t2
}
fn cross_validate(
f_dual: impl Fn(&[Dual<f64>]) -> Dual<f64>,
f_rev: impl FnOnce(&[echidna::Reverse<f64>]) -> echidna::Reverse<f64>,
f_f64: impl Fn(&[f64]) -> f64,
x: &[f64],
label: &str,
) {
let fwd_grad = forward_grad(&f_dual, x);
let rev_grad = grad(f_rev, x);
let fd_grad = finite_diff_grad(&f_f64, x, 1e-7);
for i in 0..x.len() {
assert!(
(fwd_grad[i] - rev_grad[i]).abs() <= 1e-10 * fwd_grad[i].abs().max(1e-12),
"{} fwd vs rev, component {}: fwd={}, rev={}",
label,
i,
fwd_grad[i],
rev_grad[i]
);
}
for i in 0..x.len() {
let scale = fwd_grad[i].abs().max(1.0);
assert!(
(fwd_grad[i] - fd_grad[i]).abs() <= 1e-4 * scale,
"{} fwd vs fd, component {}: fwd={}, fd={}",
label,
i,
fwd_grad[i],
fd_grad[i]
);
}
}
#[test]
fn cross_validate_rosenbrock_2d() {
let x = [1.5, 2.0];
cross_validate(
|v| rosenbrock(v),
|v| rosenbrock(v),
|v| rosenbrock(v),
&x,
"rosenbrock-2d",
);
}
#[test]
fn cross_validate_rosenbrock_10d() {
let x: Vec<f64> = (0..10).map(|i| 0.5 + 0.1 * i as f64).collect();
cross_validate(
|v| rosenbrock(v),
|v| rosenbrock(v),
|v| rosenbrock(v),
&x,
"rosenbrock-10d",
);
}
#[test]
fn cross_validate_beale() {
let x = [1.0, 0.5];
cross_validate(|v| beale(v), |v| beale(v), |v| beale(v), &x, "beale");
}
#[test]
fn cross_validate_sphere_5d() {
let x = [1.0, -2.0, 3.0, -0.5, 0.7];
cross_validate(|v| sphere(v), |v| sphere(v), |v| sphere(v), &x, "sphere-5d");
}
#[test]
fn cross_validate_booth() {
let x = [2.0, 3.0];
cross_validate(|v| booth(v), |v| booth(v), |v| booth(v), &x, "booth");
}
fn trig_mix<T: Scalar>(x: &[T]) -> T {
x[0].sin() * x[1].exp() + x[0].cos() * x[1].ln()
}
#[test]
fn cross_validate_trig_mix() {
let x = [1.0, 2.0];
cross_validate(
|v| trig_mix(v),
|v| trig_mix(v),
|v| trig_mix(v),
&x,
"trig-mix",
);
}
fn ackley<T: Scalar>(x: &[T]) -> T {
let a = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(20.0).unwrap());
let b = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(0.2).unwrap());
let c = T::from_f(
<T::Float as num_traits::FromPrimitive>::from_f64(std::f64::consts::TAU).unwrap(),
);
let half = T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(0.5).unwrap());
let e_const =
T::from_f(<T::Float as num_traits::FromPrimitive>::from_f64(std::f64::consts::E).unwrap());
let sum_sq = x[0] * x[0] + x[1] * x[1];
let sum_cos = (c * x[0]).cos() + (c * x[1]).cos();
let term1 = (-(b) * (half * sum_sq).sqrt()).exp();
let term2 = (half * sum_cos).exp();
-a * term1 - term2 + a + e_const
}
#[test]
fn cross_validate_ackley() {
let x = [0.5, -0.3];
cross_validate(|v| ackley(v), |v| ackley(v), |v| ackley(v), &x, "ackley");
}
fn deep_nest<T: Scalar>(x: &[T]) -> T {
let mut y = x[0];
for _ in 0..10 {
y = y.sin().exp().ln_1p();
}
y
}
#[test]
fn cross_validate_deep_nest() {
let x = [0.5];
cross_validate(
|v| deep_nest(v),
|v| deep_nest(v),
|v| deep_nest(v),
&x,
"deep-nest",
);
}
fn pairwise_products<T: Scalar>(x: &[T]) -> T {
let mut sum = T::zero();
for i in 0..x.len() {
for j in i + 1..x.len() {
sum = sum + x[i] * x[j];
}
}
sum
}
#[test]
fn cross_validate_pairwise_products() {
let x = [1.0, 2.0, 3.0, 4.0, 5.0];
cross_validate(
|v| pairwise_products(v),
|v| pairwise_products(v),
|v| pairwise_products(v),
&x,
"pairwise-products",
);
}
fn logistic_chain<T: Scalar>(x: &[T]) -> T {
let one = T::one();
let mut y = x[0];
for _ in 0..5 {
y = one / (one + (-y).exp());
}
y
}
#[test]
fn cross_validate_logistic_chain() {
let x = [0.5];
cross_validate(
|v| logistic_chain(v),
|v| logistic_chain(v),
|v| logistic_chain(v),
&x,
"logistic-chain",
);
}
#[cfg(feature = "bytecode")]
fn bytecode_grad(f: impl FnOnce(&[BReverse<f64>]) -> BReverse<f64>, x: &[f64]) -> Vec<f64> {
let (mut tape, _) = record(f, x);
tape.gradient(x)
}
#[cfg(feature = "bytecode")]
fn cross_validate_all(
f_dual: impl Fn(&[Dual<f64>]) -> Dual<f64>,
f_rev: impl FnOnce(&[echidna::Reverse<f64>]) -> echidna::Reverse<f64>,
f_brev: impl FnOnce(&[BReverse<f64>]) -> BReverse<f64>,
f_f64: impl Fn(&[f64]) -> f64,
x: &[f64],
label: &str,
) {
let fwd_grad = forward_grad(&f_dual, x);
let rev_grad = grad(f_rev, x);
let btape_grad = bytecode_grad(f_brev, x);
let fd_grad = finite_diff_grad(&f_f64, x, 1e-7);
for i in 0..x.len() {
assert!(
(fwd_grad[i] - rev_grad[i]).abs() <= 1e-10 * fwd_grad[i].abs().max(1e-12),
"{} fwd vs rev, component {}: fwd={}, rev={}",
label,
i,
fwd_grad[i],
rev_grad[i]
);
assert!(
(rev_grad[i] - btape_grad[i]).abs() <= 1e-10 * rev_grad[i].abs().max(1e-12),
"{} rev vs btape, component {}: rev={}, btape={}",
label,
i,
rev_grad[i],
btape_grad[i]
);
let scale = fwd_grad[i].abs().max(1.0);
assert!(
(fwd_grad[i] - fd_grad[i]).abs() <= 1e-4 * scale,
"{} fwd vs fd, component {}: fwd={}, fd={}",
label,
i,
fwd_grad[i],
fd_grad[i]
);
}
}
#[cfg(feature = "bytecode")]
#[test]
fn cross_validate_all_rosenbrock_2d() {
let x = [1.5, 2.0];
cross_validate_all(
|v| rosenbrock(v),
|v| rosenbrock(v),
|v| rosenbrock(v),
|v| rosenbrock(v),
&x,
"all-rosenbrock-2d",
);
}
#[cfg(feature = "bytecode")]
#[test]
fn cross_validate_all_trig_mix() {
let x = [1.0, 2.0];
cross_validate_all(
|v| trig_mix(v),
|v| trig_mix(v),
|v| trig_mix(v),
|v| trig_mix(v),
&x,
"all-trig-mix",
);
}
#[cfg(feature = "bytecode")]
#[test]
fn cross_validate_all_deep_nest() {
let x = [0.5];
cross_validate_all(
|v| deep_nest(v),
|v| deep_nest(v),
|v| deep_nest(v),
|v| deep_nest(v),
&x,
"all-deep-nest",
);
}
#[cfg(feature = "bytecode")]
#[test]
fn cross_validate_all_logistic_chain() {
let x = [0.5];
cross_validate_all(
|v| logistic_chain(v),
|v| logistic_chain(v),
|v| logistic_chain(v),
|v| logistic_chain(v),
&x,
"all-logistic-chain",
);
}
#[cfg(feature = "bytecode")]
fn finite_diff_hvp(
tape: &mut echidna::BytecodeTape<f64>,
x: &[f64],
v: &[f64],
h: f64,
) -> Vec<f64> {
let n = x.len();
let mut xp = x.to_vec();
let mut xm = x.to_vec();
for i in 0..n {
xp[i] = x[i] + h * v[i];
xm[i] = x[i] - h * v[i];
}
let gp = tape.gradient(&xp);
let gm = tape.gradient(&xm);
(0..n).map(|i| (gp[i] - gm[i]) / (2.0 * h)).collect()
}
#[cfg(feature = "bytecode")]
fn cross_validate_hvp(
f: impl FnOnce(&[BReverse<f64>]) -> BReverse<f64>,
x: &[f64],
v: &[f64],
label: &str,
) {
let (mut tape, _) = record(f, x);
let (_, analytic_hv) = tape.hvp(x, v);
let fd_hv = finite_diff_hvp(&mut tape, x, v, 1e-5);
for i in 0..x.len() {
let scale = analytic_hv[i].abs().max(1.0);
assert!(
(analytic_hv[i] - fd_hv[i]).abs() <= 1e-4 * scale,
"{} hvp vs fd, component {}: analytic={}, fd={}",
label,
i,
analytic_hv[i],
fd_hv[i]
);
}
}
#[cfg(feature = "bytecode")]
#[test]
fn hvp_cross_validate_rosenbrock() {
let x = [1.5, 2.0];
let v = [0.7, -0.3];
cross_validate_hvp(|v| rosenbrock(v), &x, &v, "rosenbrock-hvp");
}
#[cfg(feature = "bytecode")]
#[test]
fn hvp_cross_validate_beale() {
let x = [1.0, 0.5];
let v = [1.0, -1.0];
cross_validate_hvp(|v| beale(v), &x, &v, "beale-hvp");
}
#[cfg(feature = "bytecode")]
#[test]
fn hvp_cross_validate_sphere() {
let x = [1.0, -2.0, 3.0, -0.5, 0.7];
let v = [0.1, 0.2, -0.3, 0.4, -0.5];
cross_validate_hvp(|v| sphere(v), &x, &v, "sphere-hvp");
}
#[cfg(feature = "bytecode")]
#[test]
fn hvp_cross_validate_booth() {
let x = [2.0, 3.0];
let v = [1.0, 1.0];
cross_validate_hvp(|v| booth(v), &x, &v, "booth-hvp");
}
#[cfg(feature = "bytecode")]
#[test]
fn hvp_cross_validate_trig_mix() {
let x = [1.0, 2.0];
let v = [0.5, -0.5];
cross_validate_hvp(|v| trig_mix(v), &x, &v, "trig-mix-hvp");
}
fn finite_diff_grad_f32(f: impl Fn(&[f32]) -> f32, x: &[f32], h: f32) -> Vec<f32> {
let n = x.len();
let mut grad = vec![0.0f32; n];
for i in 0..n {
let mut xp = x.to_vec();
let mut xm = x.to_vec();
xp[i] += h;
xm[i] -= h;
grad[i] = (f(&xp) - f(&xm)) / (2.0 * h);
}
grad
}
fn forward_grad_f32(f: impl Fn(&[Dual<f32>]) -> Dual<f32>, x: &[f32]) -> Vec<f32> {
let n = x.len();
let mut grad = vec![0.0f32; n];
for i in 0..n {
let inputs: Vec<Dual<f32>> = x
.iter()
.enumerate()
.map(|(k, &xi)| {
if k == i {
Dual::variable(xi)
} else {
Dual::constant(xi)
}
})
.collect();
grad[i] = f(&inputs).eps;
}
grad
}
fn cross_validate_f32(
f_dual: impl Fn(&[Dual<f32>]) -> Dual<f32>,
f_rev: impl FnOnce(&[echidna::Reverse<f32>]) -> echidna::Reverse<f32>,
f_f32: impl Fn(&[f32]) -> f32,
x: &[f32],
label: &str,
) {
let fwd_grad = forward_grad_f32(&f_dual, x);
let rev_grad = echidna::grad(f_rev, x);
let fd_grad = finite_diff_grad_f32(&f_f32, x, 1e-4);
for i in 0..x.len() {
let scale = fwd_grad[i].abs().max(1e-6);
assert!(
(fwd_grad[i] - rev_grad[i]).abs() <= 1e-5 * scale,
"f32 {label} fwd vs rev, component {i}: fwd={}, rev={}",
fwd_grad[i],
rev_grad[i]
);
}
for i in 0..x.len() {
let scale = fwd_grad[i].abs().max(1.0);
assert!(
(fwd_grad[i] - fd_grad[i]).abs() <= 1e-2 * scale,
"f32 {label} fwd vs fd, component {i}: fwd={}, fd={}",
fwd_grad[i],
fd_grad[i]
);
}
}
#[test]
fn cross_validate_f32_rosenbrock() {
let x: Vec<f32> = vec![1.5, 2.0];
cross_validate_f32(
|v| rosenbrock(v),
|v| rosenbrock(v),
|v| rosenbrock(v),
&x,
"rosenbrock-f32",
);
}
#[test]
fn cross_validate_f32_trig_mix() {
let x: Vec<f32> = vec![1.0, 2.0];
cross_validate_f32(
|v| trig_mix(v),
|v| trig_mix(v),
|v| trig_mix(v),
&x,
"trig-mix-f32",
);
}
#[test]
fn cross_validate_f32_sphere() {
let x: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
cross_validate_f32(
|v| sphere(v),
|v| sphere(v),
|v| sphere(v),
&x,
"sphere-5d-f32",
);
}
#[cfg(feature = "stde")]
#[test]
fn f32_diagonal_kth_order_boundary() {
use echidna::BReverse;
let f = |x: &[BReverse<f32>]| x[0] * x[0] + x[1] * x[1];
let x: Vec<f32> = vec![1.0, 2.0];
let (tape, _) = echidna::record(f, &x);
let (_, diag_k2) = echidna::stde::diagonal_kth_order(&tape, &x, 2);
assert!(
(diag_k2[0] - 2.0).abs() < 0.01,
"k=2 f32 diag[0] = {}",
diag_k2[0]
);
assert!(
(diag_k2[1] - 2.0).abs() < 0.01,
"k=2 f32 diag[1] = {}",
diag_k2[1]
);
}