use crate::error::AutogradError;
use crate::Result;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum TapeOp {
Add {
a: usize,
b: usize,
},
Sub { a: usize, b: usize },
Mul { a: usize, b: usize },
Div { a: usize, b: usize },
Neg { a: usize },
Exp { a: usize },
Log { a: usize },
Sin { a: usize },
Cos { a: usize },
Sqrt { a: usize },
Scale { scalar: f64, a: usize },
Square { a: usize },
Constant { value: f64 },
Input { input_idx: usize },
}
#[derive(Debug, Default)]
pub struct ReverseTape {
ops: Vec<TapeOp>,
}
impl ReverseTape {
pub fn new() -> Self {
Self { ops: Vec::new() }
}
pub fn push_input(&mut self, input_idx: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Input { input_idx });
idx
}
pub fn push_constant(&mut self, value: f64) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Constant { value });
idx
}
pub fn push_add(&mut self, a: usize, b: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Add { a, b });
idx
}
pub fn push_sub(&mut self, a: usize, b: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Sub { a, b });
idx
}
pub fn push_mul(&mut self, a: usize, b: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Mul { a, b });
idx
}
pub fn push_div(&mut self, a: usize, b: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Div { a, b });
idx
}
pub fn push_neg(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Neg { a });
idx
}
pub fn push_exp(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Exp { a });
idx
}
pub fn push_log(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Log { a });
idx
}
pub fn push_sin(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Sin { a });
idx
}
pub fn push_cos(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Cos { a });
idx
}
pub fn push_sqrt(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Sqrt { a });
idx
}
pub fn push_scale(&mut self, scalar: f64, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Scale { scalar, a });
idx
}
pub fn push_square(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Square { a });
idx
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn forward(&self, inputs: &[f64]) -> Result<Vec<f64>> {
let n = self.ops.len();
let mut vals = vec![0.0f64; n];
for (i, op) in self.ops.iter().enumerate() {
vals[i] = match op {
TapeOp::Constant { value } => *value,
TapeOp::Input { input_idx } => {
inputs.get(*input_idx).copied().ok_or_else(|| {
AutogradError::invalid_argument(format!(
"ReverseTape::forward: input index {} out of range (len {})",
input_idx,
inputs.len()
))
})?
}
TapeOp::Add { a, b } => {
check_idx(*a, i, "Add.a")?;
check_idx(*b, i, "Add.b")?;
vals[*a] + vals[*b]
}
TapeOp::Sub { a, b } => {
check_idx(*a, i, "Sub.a")?;
check_idx(*b, i, "Sub.b")?;
vals[*a] - vals[*b]
}
TapeOp::Mul { a, b } => {
check_idx(*a, i, "Mul.a")?;
check_idx(*b, i, "Mul.b")?;
vals[*a] * vals[*b]
}
TapeOp::Div { a, b } => {
check_idx(*a, i, "Div.a")?;
check_idx(*b, i, "Div.b")?;
let denom = vals[*b];
if denom.abs() < f64::EPSILON * 1e6 {
return Err(AutogradError::compute_error(
"ReverseTape::forward: division by zero".to_string(),
));
}
vals[*a] / denom
}
TapeOp::Neg { a } => {
check_idx(*a, i, "Neg.a")?;
-vals[*a]
}
TapeOp::Exp { a } => {
check_idx(*a, i, "Exp.a")?;
vals[*a].exp()
}
TapeOp::Log { a } => {
check_idx(*a, i, "Log.a")?;
let v = vals[*a];
if v <= 0.0 {
return Err(AutogradError::compute_error(format!(
"ReverseTape::forward: log({v}) undefined"
)));
}
v.ln()
}
TapeOp::Sin { a } => {
check_idx(*a, i, "Sin.a")?;
vals[*a].sin()
}
TapeOp::Cos { a } => {
check_idx(*a, i, "Cos.a")?;
vals[*a].cos()
}
TapeOp::Sqrt { a } => {
check_idx(*a, i, "Sqrt.a")?;
let v = vals[*a];
if v < 0.0 {
return Err(AutogradError::compute_error(format!(
"ReverseTape::forward: sqrt({v}) undefined"
)));
}
v.sqrt()
}
TapeOp::Scale { scalar, a } => {
check_idx(*a, i, "Scale.a")?;
scalar * vals[*a]
}
TapeOp::Square { a } => {
check_idx(*a, i, "Square.a")?;
vals[*a] * vals[*a]
}
};
}
Ok(vals)
}
pub fn backward(&self, output_idx: usize, vals: &[f64]) -> Result<Vec<f64>> {
let n = self.ops.len();
if output_idx >= n {
return Err(AutogradError::invalid_argument(format!(
"ReverseTape::backward: output_idx {output_idx} >= tape len {n}"
)));
}
if vals.len() != n {
return Err(AutogradError::invalid_argument(format!(
"ReverseTape::backward: vals.len() {} != tape len {n}",
vals.len()
)));
}
let mut grads = vec![0.0f64; n];
grads[output_idx] = 1.0;
for i in (0..n).rev() {
let g = grads[i];
if g == 0.0 {
continue;
}
match &self.ops[i] {
TapeOp::Constant { .. } | TapeOp::Input { .. } => {}
TapeOp::Add { a, b } => {
grads[*a] += g;
grads[*b] += g;
}
TapeOp::Sub { a, b } => {
grads[*a] += g;
grads[*b] -= g;
}
TapeOp::Mul { a, b } => {
grads[*a] += g * vals[*b];
grads[*b] += g * vals[*a];
}
TapeOp::Div { a, b } => {
let va = vals[*a];
let vb = vals[*b];
grads[*a] += g / vb;
grads[*b] -= g * va / (vb * vb);
}
TapeOp::Neg { a } => {
grads[*a] -= g;
}
TapeOp::Exp { a } => {
grads[*a] += g * vals[i];
}
TapeOp::Log { a } => {
grads[*a] += g / vals[*a];
}
TapeOp::Sin { a } => {
grads[*a] += g * vals[*a].cos();
}
TapeOp::Cos { a } => {
grads[*a] -= g * vals[*a].sin();
}
TapeOp::Sqrt { a } => {
let sv = vals[i];
if sv.abs() > f64::EPSILON {
grads[*a] += g / (2.0 * sv);
}
}
TapeOp::Scale { scalar, a } => {
grads[*a] += g * scalar;
}
TapeOp::Square { a } => {
grads[*a] += g * 2.0 * vals[*a];
}
}
}
Ok(grads)
}
pub fn ops(&self) -> &[TapeOp] {
&self.ops
}
pub fn clear(&mut self) {
self.ops.clear();
}
}
#[derive(Debug, Default)]
pub struct ForwardTape {
ops: Vec<TapeOp>,
}
impl ForwardTape {
pub fn new() -> Self {
Self { ops: Vec::new() }
}
pub fn push_input(&mut self, input_idx: usize, _tangent: f64) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Input { input_idx });
idx
}
pub fn push_constant(&mut self, value: f64) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Constant { value });
idx
}
pub fn push_add(&mut self, a: usize, b: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Add { a, b });
idx
}
pub fn push_sub(&mut self, a: usize, b: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Sub { a, b });
idx
}
pub fn push_mul(&mut self, a: usize, b: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Mul { a, b });
idx
}
pub fn push_div(&mut self, a: usize, b: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Div { a, b });
idx
}
pub fn push_neg(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Neg { a });
idx
}
pub fn push_exp(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Exp { a });
idx
}
pub fn push_log(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Log { a });
idx
}
pub fn push_sin(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Sin { a });
idx
}
pub fn push_cos(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Cos { a });
idx
}
pub fn push_sqrt(&mut self, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Sqrt { a });
idx
}
pub fn push_scale(&mut self, scalar: f64, a: usize) -> usize {
let idx = self.ops.len();
self.ops.push(TapeOp::Scale { scalar, a });
idx
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn forward(
&self,
inputs: &[f64],
input_tangents: &[f64],
) -> Result<(Vec<f64>, Vec<f64>)> {
let n = self.ops.len();
let mut pvals = vec![0.0f64; n];
let mut tvals = vec![0.0f64; n];
for (i, op) in self.ops.iter().enumerate() {
match op {
TapeOp::Constant { value } => {
pvals[i] = *value;
tvals[i] = 0.0;
}
TapeOp::Input { input_idx } => {
pvals[i] = inputs.get(*input_idx).copied().ok_or_else(|| {
AutogradError::invalid_argument(format!(
"ForwardTape::forward: input index {input_idx} out of range"
))
})?;
tvals[i] = input_tangents.get(*input_idx).copied().unwrap_or(0.0);
}
TapeOp::Add { a, b } => {
check_idx(*a, i, "Add.a")?;
check_idx(*b, i, "Add.b")?;
pvals[i] = pvals[*a] + pvals[*b];
tvals[i] = tvals[*a] + tvals[*b];
}
TapeOp::Sub { a, b } => {
check_idx(*a, i, "Sub.a")?;
check_idx(*b, i, "Sub.b")?;
pvals[i] = pvals[*a] - pvals[*b];
tvals[i] = tvals[*a] - tvals[*b];
}
TapeOp::Mul { a, b } => {
check_idx(*a, i, "Mul.a")?;
check_idx(*b, i, "Mul.b")?;
pvals[i] = pvals[*a] * pvals[*b];
tvals[i] = tvals[*a] * pvals[*b] + pvals[*a] * tvals[*b];
}
TapeOp::Div { a, b } => {
check_idx(*a, i, "Div.a")?;
check_idx(*b, i, "Div.b")?;
let vb = pvals[*b];
if vb.abs() < f64::EPSILON * 1e6 {
return Err(AutogradError::compute_error(
"ForwardTape::forward: division by zero".to_string(),
));
}
pvals[i] = pvals[*a] / vb;
tvals[i] = (tvals[*a] * vb - pvals[*a] * tvals[*b]) / (vb * vb);
}
TapeOp::Neg { a } => {
check_idx(*a, i, "Neg.a")?;
pvals[i] = -pvals[*a];
tvals[i] = -tvals[*a];
}
TapeOp::Exp { a } => {
check_idx(*a, i, "Exp.a")?;
let ev = pvals[*a].exp();
pvals[i] = ev;
tvals[i] = ev * tvals[*a];
}
TapeOp::Log { a } => {
check_idx(*a, i, "Log.a")?;
let v = pvals[*a];
if v <= 0.0 {
return Err(AutogradError::compute_error(format!(
"ForwardTape::forward: log({v}) undefined"
)));
}
pvals[i] = v.ln();
tvals[i] = tvals[*a] / v;
}
TapeOp::Sin { a } => {
check_idx(*a, i, "Sin.a")?;
pvals[i] = pvals[*a].sin();
tvals[i] = pvals[*a].cos() * tvals[*a];
}
TapeOp::Cos { a } => {
check_idx(*a, i, "Cos.a")?;
pvals[i] = pvals[*a].cos();
tvals[i] = -pvals[*a].sin() * tvals[*a];
}
TapeOp::Sqrt { a } => {
check_idx(*a, i, "Sqrt.a")?;
let v = pvals[*a];
if v < 0.0 {
return Err(AutogradError::compute_error(format!(
"ForwardTape::forward: sqrt({v}) undefined"
)));
}
let sv = v.sqrt();
pvals[i] = sv;
tvals[i] = if sv.abs() > f64::EPSILON {
tvals[*a] / (2.0 * sv)
} else {
0.0
};
}
TapeOp::Scale { scalar, a } => {
check_idx(*a, i, "Scale.a")?;
pvals[i] = scalar * pvals[*a];
tvals[i] = scalar * tvals[*a];
}
TapeOp::Square { a } => {
check_idx(*a, i, "Square.a")?;
pvals[i] = pvals[*a] * pvals[*a];
tvals[i] = 2.0 * pvals[*a] * tvals[*a];
}
}
}
Ok((pvals, tvals))
}
pub fn ops(&self) -> &[TapeOp] {
&self.ops
}
pub fn clear(&mut self) {
self.ops.clear();
}
}
pub struct MixedMode {
tape: ReverseTape,
outputs: Vec<usize>,
n_inputs: usize,
}
impl MixedMode {
pub fn new() -> Self {
Self {
tape: ReverseTape::new(),
outputs: Vec::new(),
n_inputs: 0,
}
}
pub fn push_input(&mut self) -> usize {
let idx = self.tape.push_input(self.n_inputs);
self.n_inputs += 1;
idx
}
pub fn push_constant(&mut self, value: f64) -> usize {
self.tape.push_constant(value)
}
pub fn push_add(&mut self, a: usize, b: usize) -> usize {
self.tape.push_add(a, b)
}
pub fn push_sub(&mut self, a: usize, b: usize) -> usize {
self.tape.push_sub(a, b)
}
pub fn push_mul(&mut self, a: usize, b: usize) -> usize {
self.tape.push_mul(a, b)
}
pub fn push_div(&mut self, a: usize, b: usize) -> usize {
self.tape.push_div(a, b)
}
pub fn push_exp(&mut self, a: usize) -> usize {
self.tape.push_exp(a)
}
pub fn push_log(&mut self, a: usize) -> usize {
self.tape.push_log(a)
}
pub fn register_output(&mut self, node: usize) {
self.outputs.push(node);
}
pub fn jacobian(&self, inputs: &[f64]) -> Result<Vec<Vec<f64>>> {
let n = self.n_inputs;
let m = self.outputs.len();
if n == 0 || m == 0 {
return Err(AutogradError::invalid_argument(
"MixedMode::jacobian: no inputs or outputs registered".to_string(),
));
}
let vals = self.tape.forward(inputs)?;
let mut jac = vec![vec![0.0f64; n]; m];
for (row, &out_idx) in self.outputs.iter().enumerate() {
let grads = self.tape.backward(out_idx, &vals)?;
for col in 0..n {
let mut inp_count = 0usize;
for (k, op) in self.tape.ops().iter().enumerate() {
if let TapeOp::Input { input_idx } = op {
if *input_idx == col {
jac[row][col] = grads[k];
break;
}
inp_count += 1;
if inp_count > n {
break;
}
}
}
}
}
Ok(jac)
}
}
impl Default for MixedMode {
fn default() -> Self {
Self::new()
}
}
pub struct TapeCheckpoint {
tape: ReverseTape,
segments: Vec<(usize, usize)>,
current_segment_start: usize,
}
impl TapeCheckpoint {
pub fn new() -> Self {
Self {
tape: ReverseTape::new(),
segments: Vec::new(),
current_segment_start: 0,
}
}
pub fn tape_mut(&mut self) -> &mut ReverseTape {
&mut self.tape
}
pub fn checkpoint(&mut self) {
let current_end = self.tape.len();
if current_end > self.current_segment_start {
self.segments
.push((self.current_segment_start, current_end));
self.current_segment_start = current_end;
}
}
pub fn backward_checkpointed(
&mut self,
output_idx: usize,
inputs: &[f64],
) -> Result<Vec<f64>> {
self.checkpoint();
if self.segments.is_empty() {
return Err(AutogradError::invalid_argument(
"TapeCheckpoint::backward_checkpointed: no segments recorded".to_string(),
));
}
let full_vals = self.tape.forward(inputs)?;
let grads = self.tape.backward(output_idx, &full_vals)?;
Ok(grads)
}
pub fn segments(&self) -> &[(usize, usize)] {
&self.segments
}
pub fn tape(&self) -> &ReverseTape {
&self.tape
}
}
impl Default for TapeCheckpoint {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default, Clone)]
pub struct TapeOptimizationReport {
pub dead_nodes_eliminated: usize,
pub cse_eliminated: usize,
}
impl TapeOptimizationReport {
pub fn total(&self) -> usize {
self.dead_nodes_eliminated + self.cse_eliminated
}
}
pub struct TapeOptimizer;
impl TapeOptimizer {
pub fn optimize(
tape: &ReverseTape,
output_indices: &[usize],
) -> (ReverseTape, Vec<Option<usize>>, TapeOptimizationReport) {
let n = tape.len();
let mut report = TapeOptimizationReport::default();
let mut live = vec![false; n];
let mut stack: Vec<usize> = output_indices
.iter()
.filter(|&&i| i < n)
.copied()
.collect();
for &i in &stack {
live[i] = true;
}
while let Some(node) = stack.pop() {
let deps = deps_of(tape.ops(), node);
for d in deps {
if !live[d] {
live[d] = true;
stack.push(d);
}
}
}
let dead_count = live.iter().filter(|&&x| !x).count();
report.dead_nodes_eliminated = dead_count;
type CseKey = (u8, Vec<u64>);
let mut cse_map: HashMap<CseKey, usize> = HashMap::new();
let mut canonical: Vec<usize> = (0..n).collect();
for i in 0..n {
if !live[i] {
continue;
}
if let Some(key) = cse_key_of(tape.ops(), i, &canonical) {
match cse_map.get(&key) {
Some(&canonical_idx) => {
canonical[i] = canonical_idx;
live[i] = false; report.cse_eliminated += 1;
}
None => {
cse_map.insert(key, i);
}
}
}
}
let mut new_tape = ReverseTape::new();
let mut index_map: Vec<Option<usize>> = vec![None; n];
for i in 0..n {
if !live[i] {
continue;
}
let new_op = remap_op(tape.ops(), i, &canonical, &index_map);
index_map[i] = Some(new_tape.ops.len());
new_tape.ops.push(new_op);
}
(new_tape, index_map, report)
}
}
fn deps_of(tape_ops: &[TapeOp], idx: usize) -> Vec<usize> {
match &tape_ops[idx] {
TapeOp::Add { a, b }
| TapeOp::Sub { a, b }
| TapeOp::Mul { a, b }
| TapeOp::Div { a, b } => vec![*a, *b],
TapeOp::Neg { a }
| TapeOp::Exp { a }
| TapeOp::Log { a }
| TapeOp::Sin { a }
| TapeOp::Cos { a }
| TapeOp::Sqrt { a }
| TapeOp::Scale { a, .. }
| TapeOp::Square { a } => vec![*a],
TapeOp::Constant { .. } | TapeOp::Input { .. } => vec![],
}
}
fn cse_key_of(tape_ops: &[TapeOp], idx: usize, canonical: &[usize]) -> Option<(u8, Vec<u64>)> {
let c = |i: usize| canonical[i] as u64;
match &tape_ops[idx] {
TapeOp::Add { a, b } => {
let mut ops = vec![c(*a), c(*b)];
ops.sort_unstable(); Some((0, ops))
}
TapeOp::Sub { a, b } => Some((1, vec![c(*a), c(*b)])),
TapeOp::Mul { a, b } => {
let mut ops = vec![c(*a), c(*b)];
ops.sort_unstable(); Some((2, ops))
}
TapeOp::Div { a, b } => Some((3, vec![c(*a), c(*b)])),
TapeOp::Neg { a } => Some((4, vec![c(*a)])),
TapeOp::Exp { a } => Some((5, vec![c(*a)])),
TapeOp::Log { a } => Some((6, vec![c(*a)])),
TapeOp::Sin { a } => Some((7, vec![c(*a)])),
TapeOp::Cos { a } => Some((8, vec![c(*a)])),
TapeOp::Sqrt { a } => Some((9, vec![c(*a)])),
TapeOp::Square { a } => Some((10, vec![c(*a)])),
TapeOp::Scale { scalar, a } => {
let bits = scalar.to_bits();
Some((11, vec![bits, c(*a)]))
}
TapeOp::Constant { .. } | TapeOp::Input { .. } => None,
}
}
fn remap_op(
tape_ops: &[TapeOp],
idx: usize,
canonical: &[usize],
index_map: &[Option<usize>],
) -> TapeOp {
let remap = |i: usize| -> usize {
let c = canonical[i];
index_map[c].unwrap_or(c)
};
match &tape_ops[idx] {
TapeOp::Add { a, b } => TapeOp::Add {
a: remap(*a),
b: remap(*b),
},
TapeOp::Sub { a, b } => TapeOp::Sub {
a: remap(*a),
b: remap(*b),
},
TapeOp::Mul { a, b } => TapeOp::Mul {
a: remap(*a),
b: remap(*b),
},
TapeOp::Div { a, b } => TapeOp::Div {
a: remap(*a),
b: remap(*b),
},
TapeOp::Neg { a } => TapeOp::Neg { a: remap(*a) },
TapeOp::Exp { a } => TapeOp::Exp { a: remap(*a) },
TapeOp::Log { a } => TapeOp::Log { a: remap(*a) },
TapeOp::Sin { a } => TapeOp::Sin { a: remap(*a) },
TapeOp::Cos { a } => TapeOp::Cos { a: remap(*a) },
TapeOp::Sqrt { a } => TapeOp::Sqrt { a: remap(*a) },
TapeOp::Scale { scalar, a } => TapeOp::Scale {
scalar: *scalar,
a: remap(*a),
},
TapeOp::Square { a } => TapeOp::Square { a: remap(*a) },
TapeOp::Constant { value } => TapeOp::Constant { value: *value },
TapeOp::Input { input_idx } => TapeOp::Input {
input_idx: *input_idx,
},
}
}
fn check_idx(op_idx: usize, node_idx: usize, label: &str) -> Result<()> {
if op_idx >= node_idx {
Err(AutogradError::invalid_argument(format!(
"tape operand {label}: index {op_idx} >= current node {node_idx} (forward ordering violated)"
)))
} else {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reverse_tape_add() {
let mut tape = ReverseTape::new();
let ix = tape.push_input(0);
let iy = tape.push_input(1);
let out = tape.push_add(ix, iy);
let vals = tape.forward(&[3.0, 5.0]).expect("forward");
assert!((vals[out] - 8.0).abs() < 1e-12);
let grads = tape.backward(out, &vals).expect("backward");
assert!((grads[ix] - 1.0).abs() < 1e-12);
assert!((grads[iy] - 1.0).abs() < 1e-12);
}
#[test]
fn test_reverse_tape_mul() {
let mut tape = ReverseTape::new();
let ix = tape.push_input(0);
let iy = tape.push_input(1);
let out = tape.push_mul(ix, iy);
let vals = tape.forward(&[3.0, 5.0]).expect("forward");
assert!((vals[out] - 15.0).abs() < 1e-12);
let grads = tape.backward(out, &vals).expect("backward");
assert!((grads[ix] - 5.0).abs() < 1e-12); assert!((grads[iy] - 3.0).abs() < 1e-12); }
#[test]
fn test_reverse_tape_exp() {
let mut tape = ReverseTape::new();
let ix = tape.push_input(0);
let out = tape.push_exp(ix);
let vals = tape.forward(&[0.0]).expect("forward");
assert!((vals[out] - 1.0).abs() < 1e-12);
let grads = tape.backward(out, &vals).expect("backward");
assert!((grads[ix] - 1.0).abs() < 1e-12);
}
#[test]
fn test_reverse_tape_chain() {
let mut tape = ReverseTape::new();
let ix = tape.push_input(0);
let x2 = tape.push_square(ix);
let c3 = tape.push_constant(3.0);
let t = tape.push_mul(c3, ix);
let out = tape.push_add(x2, t);
let vals = tape.forward(&[4.0]).expect("forward");
assert!((vals[out] - 28.0).abs() < 1e-9);
let grads = tape.backward(out, &vals).expect("backward");
assert!((grads[ix] - 11.0).abs() < 1e-9);
}
#[test]
fn test_reverse_tape_div() {
let mut tape = ReverseTape::new();
let ix = tape.push_input(0);
let iy = tape.push_input(1);
let out = tape.push_div(ix, iy);
let vals = tape.forward(&[6.0, 2.0]).expect("forward");
assert!((vals[out] - 3.0).abs() < 1e-12);
let grads = tape.backward(out, &vals).expect("backward");
assert!((grads[ix] - 0.5).abs() < 1e-12); assert!((grads[iy] - (-1.5)).abs() < 1e-12); }
#[test]
fn test_reverse_tape_log() {
let mut tape = ReverseTape::new();
let ix = tape.push_input(0);
let out = tape.push_log(ix);
let vals = tape.forward(&[2.0]).expect("forward");
assert!((vals[out] - 2.0_f64.ln()).abs() < 1e-12);
let grads = tape.backward(out, &vals).expect("backward");
assert!((grads[ix] - 0.5).abs() < 1e-12);
}
#[test]
fn test_reverse_tape_invalid_output_idx() {
let tape = ReverseTape::new();
let vals = vec![];
assert!(tape.backward(0, &vals).is_err());
}
#[test]
fn test_forward_tape_mul() {
let mut tape = ForwardTape::new();
let ix = tape.push_input(0, 1.0);
let iy = tape.push_input(1, 0.0);
let out = tape.push_mul(ix, iy);
let (pvals, tvals) = tape.forward(&[3.0, 5.0], &[1.0, 0.0]).expect("forward");
assert!((pvals[out] - 15.0).abs() < 1e-12);
assert!((tvals[out] - 5.0).abs() < 1e-12); }
#[test]
fn test_forward_tape_exp() {
let mut tape = ForwardTape::new();
let ix = tape.push_input(0, 1.0);
let out = tape.push_exp(ix);
let (pvals, tvals) = tape.forward(&[0.0], &[1.0]).expect("forward");
assert!((pvals[out] - 1.0).abs() < 1e-12);
assert!((tvals[out] - 1.0).abs() < 1e-12);
}
#[test]
fn test_forward_tape_chain_rule() {
let mut tape = ForwardTape::new();
let ix = tape.push_input(0, 1.0);
let x2 = tape.push_mul(ix, ix); let out = tape.push_sin(x2);
let (pvals, tvals) = tape.forward(&[1.0], &[1.0]).expect("forward");
let expected_pval = 1.0_f64.sin(); let expected_tval = 1.0_f64.cos() * 2.0; assert!((pvals[out] - expected_pval).abs() < 1e-12);
assert!((tvals[out] - expected_tval).abs() < 1e-9);
}
#[test]
fn test_mixed_mode_jacobian() {
let mut mm = MixedMode::new();
let ix = mm.push_input(); let iy = mm.push_input(); let s = mm.push_add(ix, iy);
let p = mm.push_mul(ix, iy);
mm.register_output(s);
mm.register_output(p);
let jac = mm.jacobian(&[2.0, 3.0]).expect("jacobian");
assert_eq!(jac.len(), 2);
assert_eq!(jac[0].len(), 2);
assert!((jac[0][0] - 1.0).abs() < 1e-9, "J[0][0] = {}", jac[0][0]);
assert!((jac[0][1] - 1.0).abs() < 1e-9, "J[0][1] = {}", jac[0][1]);
assert!((jac[1][0] - 3.0).abs() < 1e-9, "J[1][0] = {}", jac[1][0]);
assert!((jac[1][1] - 2.0).abs() < 1e-9, "J[1][1] = {}", jac[1][1]);
}
#[test]
fn test_tape_checkpoint_basic() {
let mut cp = TapeCheckpoint::new();
let tape = cp.tape_mut();
let ix = tape.push_input(0);
let x2 = tape.push_square(ix);
cp.checkpoint();
let tape = cp.tape_mut();
let c = tape.push_constant(1.0);
let out = tape.push_add(x2, c);
let grads = cp
.backward_checkpointed(out, &[3.0])
.expect("backward checkpointed");
assert!((grads[ix] - 6.0).abs() < 1e-9);
}
#[test]
fn test_tape_checkpoint_segments() {
let mut cp = TapeCheckpoint::new();
let tape = cp.tape_mut();
tape.push_input(0);
cp.checkpoint();
assert_eq!(cp.segments().len(), 1);
}
#[test]
fn test_tape_optimizer_dce() {
let mut tape = ReverseTape::new();
let ix = tape.push_input(0);
let iy = tape.push_input(1); let _dead = tape.push_add(ix, iy); let out = tape.push_square(ix);
let (new_tape, _idx_map, report) = TapeOptimizer::optimize(&tape, &[out]);
assert!(
report.dead_nodes_eliminated >= 2,
"Expected >=2 dead, got {}",
report.dead_nodes_eliminated
);
assert!(new_tape.len() < tape.len());
}
#[test]
fn test_tape_optimizer_cse() {
let mut tape = ReverseTape::new();
let ix = tape.push_input(0);
let iy = tape.push_input(1);
let s1 = tape.push_add(ix, iy); let s2 = tape.push_add(ix, iy); let out = tape.push_add(s1, s2);
let (new_tape, _idx_map, report) = TapeOptimizer::optimize(&tape, &[out]);
assert!(
report.cse_eliminated >= 1,
"Expected >=1 CSE, got {}",
report.cse_eliminated
);
assert!(new_tape.len() < tape.len());
}
}