#![cfg(feature = "bytecode")]
use approx::assert_relative_eq;
use echidna::{record, record_multi, BReverse, Scalar};
use num_traits::Float;
#[test]
fn constant_folding_reduces_ops() {
let (tape, val) = record(
|x| {
let two = BReverse::constant(2.0);
let three = BReverse::constant(3.0);
let six = two * three;
x[0] * six
},
&[5.0_f64],
);
assert_relative_eq!(val, 30.0, max_relative = 1e-12);
let num_ops = tape.num_ops();
let (mut tape, _) = record(
|x| {
let two = BReverse::constant(2.0);
let three = BReverse::constant(3.0);
let six = two * three;
x[0] * six
},
&[5.0_f64],
);
let g = tape.gradient(&[5.0]);
assert_relative_eq!(g[0], 6.0, max_relative = 1e-12);
let _ = num_ops;
}
#[test]
fn constant_folding_powi() {
let (mut tape, val) = record(
|x| {
let three = BReverse::constant(3.0);
let nine = three.powi(2); x[0] + nine
},
&[1.0_f64],
);
assert_relative_eq!(val, 10.0, max_relative = 1e-12);
let g = tape.gradient(&[1.0]);
assert_relative_eq!(g[0], 1.0, max_relative = 1e-12);
}
#[test]
fn constant_folding_preserves_input_ops() {
let (mut tape, val) = record(|x| x[0] * x[0], &[3.0_f64]);
assert_relative_eq!(val, 9.0, max_relative = 1e-12);
let g = tape.gradient(&[3.0]);
assert_relative_eq!(g[0], 6.0, max_relative = 1e-12);
let g2 = tape.gradient(&[5.0]);
assert_relative_eq!(g2[0], 10.0, max_relative = 1e-12);
}
#[test]
fn dce_removes_unused_intermediates() {
let (mut tape, val) = record(
|x| {
let _unused = x[0].sin(); let _also_unused = x[0].exp(); x[0] * x[0] },
&[3.0_f64],
);
assert_relative_eq!(val, 9.0, max_relative = 1e-12);
let ops_before = tape.num_ops();
tape.dead_code_elimination();
let ops_after = tape.num_ops();
assert!(
ops_after < ops_before,
"DCE should reduce tape size: before={}, after={}",
ops_before,
ops_after
);
let g = tape.gradient(&[3.0]);
assert_relative_eq!(g[0], 6.0, max_relative = 1e-12);
}
#[test]
fn dce_preserves_all_inputs() {
let (mut tape, val) = record(|x| x[0] * x[0], &[3.0_f64, 4.0]);
assert_relative_eq!(val, 9.0, max_relative = 1e-12);
tape.dead_code_elimination();
assert_eq!(tape.num_inputs(), 2, "DCE must preserve all inputs");
}
#[test]
fn cse_deduplicates_common_subexpressions() {
let (mut tape, val) = record(
|x| {
let a = x[0] * x[0];
let b = x[0] * x[0]; a + b
},
&[3.0_f64],
);
assert_relative_eq!(val, 18.0, max_relative = 1e-12);
let ops_before = tape.num_ops();
tape.cse();
let ops_after = tape.num_ops();
assert!(
ops_after < ops_before,
"CSE should reduce tape size: before={}, after={}",
ops_before,
ops_after
);
let g = tape.gradient(&[3.0]);
assert_relative_eq!(g[0], 12.0, max_relative = 1e-12);
}
#[test]
fn cse_commutative_order() {
let (mut tape, val) = record(
|x| {
let a = x[0] * x[1];
let b = x[1] * x[0]; a + b
},
&[2.0_f64, 3.0],
);
assert_relative_eq!(val, 12.0, max_relative = 1e-12);
let ops_before = tape.num_ops();
tape.cse();
let ops_after = tape.num_ops();
assert!(
ops_after < ops_before,
"CSE should deduplicate commutative ops: before={}, after={}",
ops_before,
ops_after
);
let g = tape.gradient(&[2.0, 3.0]);
assert_relative_eq!(g[0], 6.0, max_relative = 1e-12);
assert_relative_eq!(g[1], 4.0, max_relative = 1e-12);
}
#[test]
fn cse_non_commutative_preserved() {
let (mut tape, val) = record(
|x| {
let a = x[0] - x[1]; let b = x[1] - x[0]; a * b
},
&[5.0_f64, 3.0],
);
assert_relative_eq!(val, -4.0, max_relative = 1e-12);
tape.cse();
let g = tape.gradient(&[5.0, 3.0]);
assert_relative_eq!(g[0], -4.0, max_relative = 1e-12);
assert_relative_eq!(g[1], 4.0, max_relative = 1e-12);
}
#[test]
fn gradient_correct_after_dce() {
let (mut tape, _) = record(
|x| {
let _dead = x[0].cos();
x[0].sin() * x[0]
},
&[1.5_f64],
);
tape.dead_code_elimination();
let g = tape.gradient(&[1.5]);
let expected = 1.5_f64.sin() + 1.5 * 1.5_f64.cos();
assert_relative_eq!(g[0], expected, max_relative = 1e-12);
}
#[test]
fn gradient_correct_after_cse() {
let (mut tape, _) = record(
|x| {
let s = x[0].sin();
let s2 = x[0].sin(); s * s2
},
&[1.0_f64],
);
tape.cse();
let g = tape.gradient(&[1.0]);
let expected = 2.0 * 1.0_f64.sin() * 1.0_f64.cos();
assert_relative_eq!(g[0], expected, max_relative = 1e-12);
}
#[test]
fn optimize_rosenbrock() {
let x = [1.5_f64, 2.0];
let (mut tape, _) = record(
|v| {
let one = BReverse::constant(1.0);
let hundred = BReverse::constant(100.0);
let t1 = one - v[0];
let t2 = v[1] - v[0] * v[0];
t1 * t1 + hundred * t2 * t2
},
&x,
);
let g_before = tape.gradient(&x);
let val_before = tape.output_value();
tape.optimize();
let g_after = tape.gradient(&x);
let val_after = tape.output_value();
assert_relative_eq!(val_before, val_after, max_relative = 1e-12);
for i in 0..x.len() {
assert_relative_eq!(g_before[i], g_after[i], max_relative = 1e-12);
}
let x2 = [0.5, 1.0];
let g2 = tape.gradient(&x2);
let val2 = tape.output_value();
let expected_val = (1.0 - x2[0]).powi(2) + 100.0 * (x2[1] - x2[0] * x2[0]).powi(2);
assert_relative_eq!(val2, expected_val, max_relative = 1e-12);
let h = 1e-7;
for i in 0..x2.len() {
let mut xp = x2;
let mut xm = x2;
xp[i] += h;
xm[i] -= h;
tape.forward(&xp);
let fp = tape.output_value();
tape.forward(&xm);
let fm = tape.output_value();
let fd = (fp - fm) / (2.0 * h);
assert_relative_eq!(g2[i], fd, max_relative = 1e-5);
}
}
#[test]
fn optimize_reduces_tape_size() {
let (mut tape, _) = record(
|x| {
let _dead1 = x[0].exp();
let _dead2 = x[0].cos();
let a = x[0].sin();
let b = x[0].sin(); a + b
},
&[1.0_f64],
);
let ops_before = tape.num_ops();
tape.optimize();
let ops_after = tape.num_ops();
assert!(
ops_after < ops_before,
"optimize should reduce tape size: before={}, after={}",
ops_before,
ops_after
);
}
#[test]
fn optimize_preserves_multi_output_correctness() {
let x = [2.0_f64, 3.0];
let (mut tape, values) = record_multi(
|v| {
let sum = v[0] + v[1];
let prod = v[0] * v[1];
let _dead = v[0].sin();
vec![sum, prod]
},
&x,
);
assert_relative_eq!(values[0], 5.0, max_relative = 1e-12);
assert_relative_eq!(values[1], 6.0, max_relative = 1e-12);
let jac_before = tape.jacobian(&x);
tape.optimize();
let jac_after = tape.jacobian(&x);
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(jac_before[i][j], jac_after[i][j], max_relative = 1e-12);
}
}
}
#[test]
fn algebraic_add_zero() {
let (mut tape, val) = record(
|x| {
let a = x[0] + 0.0_f64;
let b = 0.0_f64 + x[0];
a + b },
&[3.0_f64],
);
assert_relative_eq!(val, 6.0, max_relative = 1e-12);
let g = tape.gradient(&[3.0]);
assert_relative_eq!(g[0], 2.0, max_relative = 1e-12);
let g2 = tape.gradient(&[5.0]);
assert_relative_eq!(g2[0], 2.0, max_relative = 1e-12);
}
#[test]
fn algebraic_mul_one() {
let (mut tape, val) = record(
|x| {
let a = x[0] * 1.0_f64;
let b = 1.0_f64 * x[0];
a + b
},
&[4.0_f64],
);
assert_relative_eq!(val, 8.0, max_relative = 1e-12);
let g = tape.gradient(&[4.0]);
assert_relative_eq!(g[0], 2.0, max_relative = 1e-12);
}
#[test]
fn algebraic_sub_zero() {
let (mut tape, val) = record(|x| x[0] - 0.0_f64, &[7.0_f64]);
assert_relative_eq!(val, 7.0, max_relative = 1e-12);
let g = tape.gradient(&[7.0]);
assert_relative_eq!(g[0], 1.0, max_relative = 1e-12);
}
#[test]
fn algebraic_div_one() {
let (mut tape, val) = record(|x| x[0] / 1.0_f64, &[5.0_f64]);
assert_relative_eq!(val, 5.0, max_relative = 1e-12);
let g = tape.gradient(&[5.0]);
assert_relative_eq!(g[0], 1.0, max_relative = 1e-12);
}
#[test]
fn algebraic_mul_zero() {
let (mut tape, val) = record(
|x| {
let a = x[0] * 0.0_f64;
let b = 0.0_f64 * x[0];
a + b
},
&[42.0_f64],
);
assert_relative_eq!(val, 0.0, max_relative = 1e-12);
let g = tape.gradient(&[42.0]);
assert_relative_eq!(g[0], 0.0, max_relative = 1e-12);
}
#[test]
fn algebraic_sub_self() {
let (mut tape, val) = record(|x| x[0] - x[0], &[5.0_f64]);
assert_relative_eq!(val, 0.0, max_relative = 1e-12);
let g = tape.gradient(&[5.0]);
assert_relative_eq!(g[0], 0.0, max_relative = 1e-12);
}
#[test]
fn algebraic_div_self() {
let (mut tape, val) = record(|x| x[0] / x[0], &[3.0_f64]);
assert_relative_eq!(val, 1.0, max_relative = 1e-12);
let g = tape.gradient(&[3.0]);
assert_relative_eq!(g[0], 0.0, max_relative = 1e-12);
}
#[test]
fn algebraic_powi_zero() {
let (mut tape, val) = record(|x| x[0].powi(0), &[7.0_f64]);
assert_relative_eq!(val, 1.0, max_relative = 1e-12);
let g = tape.gradient(&[7.0]);
assert_relative_eq!(g[0], 0.0, max_relative = 1e-12);
}
#[test]
fn algebraic_powi_one() {
let (mut tape, val) = record(|x| x[0].powi(1), &[5.0_f64]);
assert_relative_eq!(val, 5.0, max_relative = 1e-12);
let g = tape.gradient(&[5.0]);
assert_relative_eq!(g[0], 1.0, max_relative = 1e-12);
let g2 = tape.gradient(&[3.0]);
assert_relative_eq!(g2[0], 1.0, max_relative = 1e-12);
}
#[test]
fn algebraic_powi_neg_one() {
let (mut tape, val) = record(|x| x[0].powi(-1), &[2.0_f64]);
assert_relative_eq!(val, 0.5, max_relative = 1e-12);
let g = tape.gradient(&[2.0]);
assert_relative_eq!(g[0], -0.25, max_relative = 1e-12);
let g2 = tape.gradient(&[4.0]);
assert_relative_eq!(g2[0], -1.0 / 16.0, max_relative = 1e-12);
}
#[test]
fn algebraic_nan_mul_zero_guard() {
let (tape, val) = record(
|x| {
let zero = x[0] - x[0]; let nan = zero / zero; nan * 0.0_f64
},
&[1.0_f64],
);
assert!(val.is_nan(), "NaN * 0.0 should produce NaN, not be folded");
assert!(tape.num_ops() > 2, "NaN * 0 should not be folded away");
}
#[test]
fn algebraic_inf_sub_self_guard() {
let (tape, val) = record(
|x| {
let big = x[0].exp().exp(); big - big
},
&[1000.0_f64],
);
assert!(
val.is_nan(),
"Inf - Inf should produce NaN, not be folded to 0"
);
assert!(tape.num_ops() > 2);
}
#[test]
fn algebraic_zero_div_self_guard() {
let (tape, val) = record(
|x| {
let zero = x[0] - x[0]; zero / zero },
&[5.0_f64],
);
assert!(val.is_nan(), "0/0 should produce NaN, not be folded to 1");
assert!(tape.num_ops() > 2);
}
#[test]
fn algebraic_tape_size_reduction() {
let (tape, val) = record(|x| x[0] + 0.0_f64 + 0.0_f64 + 0.0_f64, &[5.0_f64]);
assert_relative_eq!(val, 5.0, max_relative = 1e-12);
assert!(
tape.num_ops() < 7,
"algebraic simplification should reduce tape: got {} ops",
tape.num_ops()
);
}
#[test]
fn algebraic_reeval_after_simplify() {
let (mut tape, val) = record(|x| x[0] + 0.0_f64, &[3.0_f64]);
assert_relative_eq!(val, 3.0, max_relative = 1e-12);
tape.forward(&[5.0]);
assert_relative_eq!(tape.output_value(), 5.0, max_relative = 1e-12);
let g = tape.gradient(&[5.0]);
assert_relative_eq!(g[0], 1.0, max_relative = 1e-12);
}
#[test]
fn algebraic_enables_cse() {
let (mut tape, val) = record(
|x| {
let y = x[0] + 0.0_f64; let z = x[0];
y * z
},
&[3.0_f64],
);
assert_relative_eq!(val, 9.0, max_relative = 1e-12);
let ops_before = tape.num_ops();
tape.optimize();
let ops_after = tape.num_ops();
assert!(ops_after <= ops_before);
let g = tape.gradient(&[3.0]);
assert_relative_eq!(g[0], 6.0, max_relative = 1e-12);
}
#[test]
fn algebraic_rosenbrock() {
let f = |v: &[BReverse<f64>]| -> BReverse<f64> {
let t1 = 1.0_f64 - v[0]; let t2 = v[1] - v[0] * v[0];
t1 * t1 + 100.0_f64 * t2 * t2 };
let points = [[1.5, 2.0], [0.0, 0.0], [-1.0, 1.0], [1.0, 1.0]];
let h = 1e-7;
for x in &points {
let (mut tape, _) = record(|v| f(v), x);
let g = tape.gradient(x);
for i in 0..2 {
let mut xp = *x;
let mut xm = *x;
xp[i] += h;
xm[i] -= h;
tape.forward(&xp);
let fp = tape.output_value();
tape.forward(&xm);
let fm = tape.output_value();
let fd = (fp - fm) / (2.0 * h);
assert_relative_eq!(g[i], fd, max_relative = 1e-5, epsilon = 1e-10);
}
}
}
#[test]
fn cse_deep_chains() {
let (mut tape, val) = record(
|x| {
let a = x[0] * x[1];
let b = x[0] * x[1]; let c = a + x[2];
let d = b + x[2]; c + d
},
&[2.0_f64, 3.0, 1.0],
);
assert_relative_eq!(val, 14.0, max_relative = 1e-12); let ops_before = tape.num_ops();
tape.cse();
let ops_after = tape.num_ops();
assert!(
ops_after < ops_before,
"deep CSE should reduce tape: before={}, after={}",
ops_before,
ops_after
);
let g = tape.gradient(&[2.0, 3.0, 1.0]);
assert_relative_eq!(g[0], 6.0, max_relative = 1e-12);
assert_relative_eq!(g[1], 4.0, max_relative = 1e-12);
assert_relative_eq!(g[2], 2.0, max_relative = 1e-12);
}
#[test]
fn cse_powi_dedup_and_distinct() {
let (mut tape, val) = record(
|x| {
let a = x[0].powi(3);
let b = x[0].powi(3); let c = x[0].powi(2); a + b + c
},
&[2.0_f64],
);
assert_relative_eq!(val, 20.0, max_relative = 1e-12);
let ops_before = tape.num_ops();
tape.cse();
let ops_after = tape.num_ops();
assert!(
ops_after < ops_before,
"CSE should dedup identical powi: before={}, after={}",
ops_before,
ops_after
);
let g = tape.gradient(&[2.0]);
assert_relative_eq!(g[0], 28.0, max_relative = 1e-12);
}
#[test]
fn cse_preserves_multi_output() {
fn shared_sub<T: Scalar>(x: &[T]) -> Vec<T> {
let common = x[0] * x[1]; let common2 = x[0] * x[1]; vec![common + x[2], common2 * x[2]]
}
let x = [2.0_f64, 3.0, 4.0];
let (mut tape, values) = record_multi(|v| shared_sub(v), &x);
assert_relative_eq!(values[0], 10.0, max_relative = 1e-12); assert_relative_eq!(values[1], 24.0, max_relative = 1e-12);
let jac_before = tape.jacobian(&x);
tape.cse();
let jac_after = tape.jacobian(&x);
for i in 0..2 {
for j in 0..3 {
assert_relative_eq!(
jac_before[i][j],
jac_after[i][j],
max_relative = 1e-12,
epsilon = 1e-14
);
}
}
}
#[test]
fn targeted_dce_prunes_unused_output() {
let x = [2.0_f64, 3.0];
let (mut tape, values) = record_multi(
|v| {
let sum = v[0] + v[1];
let prod = v[0] * v[1];
let s = v[0].sin();
vec![sum, prod, s]
},
&x,
);
assert_eq!(values.len(), 3);
let ops_before = tape.num_ops();
let out_indices: Vec<u32> = tape.all_output_indices().to_vec();
assert_eq!(out_indices.len(), 3);
tape.dead_code_elimination_for_outputs(&[out_indices[0]]);
let ops_after = tape.num_ops();
assert!(
ops_after < ops_before,
"targeted DCE should reduce tape: before={}, after={}",
ops_before,
ops_after
);
let g = tape.gradient(&x);
assert_relative_eq!(g[0], 1.0, max_relative = 1e-12);
assert_relative_eq!(g[1], 1.0, max_relative = 1e-12);
}
#[test]
fn targeted_dce_preserves_active_output() {
let x = [3.0_f64, 4.0];
let (mut tape, _) = record_multi(
|v| {
let prod = v[0] * v[1];
let sum = v[0] + v[1];
vec![prod, sum]
},
&x,
);
let out_indices: Vec<u32> = tape.all_output_indices().to_vec();
tape.dead_code_elimination_for_outputs(&[out_indices[0]]);
let g = tape.gradient(&x);
assert_relative_eq!(g[0], 4.0, max_relative = 1e-12);
assert_relative_eq!(g[1], 3.0, max_relative = 1e-12);
let g2 = tape.gradient(&[5.0, 6.0]);
assert_relative_eq!(g2[0], 6.0, max_relative = 1e-12);
assert_relative_eq!(g2[1], 5.0, max_relative = 1e-12);
}
#[test]
fn optimize_with_algebraic_simplification() {
let f = |v: &[BReverse<f64>]| -> BReverse<f64> {
let x = v[0];
let y = v[1];
let x1 = x + 0.0_f64; let y1 = y * 1.0_f64; let z = x1 * y1; z - 0.0_f64 };
let (mut tape, val) = record(|v| f(v), &[3.0_f64, 4.0]);
assert_relative_eq!(val, 12.0, max_relative = 1e-12);
tape.optimize();
let points = [[3.0, 4.0], [1.0, 2.0], [0.0, 5.0], [-1.0, -3.0]];
let h = 1e-7;
for x in &points {
let g = tape.gradient(x);
for i in 0..2 {
let mut xp = *x;
let mut xm = *x;
xp[i] += h;
xm[i] -= h;
tape.forward(&xp);
let fp = tape.output_value();
tape.forward(&xm);
let fm = tape.output_value();
let fd = (fp - fm) / (2.0 * h);
assert_relative_eq!(g[i], fd, max_relative = 1e-5);
}
}
}
#[test]
fn all_optimizations_combined() {
let f = |v: &[BReverse<f64>]| -> BReverse<f64> {
let x = v[0];
let a = x + 0.0_f64;
let b = x * 1.0_f64;
let _dead = x.cos(); let c = a * b; let d = c + x; d / 1.0_f64 };
let (mut tape, val) = record(|v| f(v), &[3.0_f64]);
assert_relative_eq!(val, 12.0, max_relative = 1e-12);
let ops_before = tape.num_ops();
tape.optimize();
let ops_after = tape.num_ops();
assert!(ops_after < ops_before);
let h = 1e-7;
let points = [3.0, -2.0, 0.0, 10.0];
for &x in &points {
let g = tape.gradient(&[x]);
let expected = 2.0 * x + 1.0;
assert_relative_eq!(g[0], expected, max_relative = 1e-12);
tape.forward(&[x + h]);
let fp = tape.output_value();
tape.forward(&[x - h]);
let fm = tape.output_value();
let fd = (fp - fm) / (2.0 * h);
assert_relative_eq!(g[0], fd, max_relative = 1e-5);
}
}