use std::cell::RefCell;
use std::rc::Rc;
#[derive(Clone, Debug)]
pub(crate) struct Node {
#[allow(dead_code)]
pub value: f64,
pub parent1: Option<(usize, f64)>,
pub parent2: Option<(usize, f64)>,
}
#[derive(Clone, Debug, Default)]
pub struct Tape {
pub(crate) nodes: Vec<Node>,
pub(crate) n_inputs: usize,
}
pub type TapeRef = Rc<RefCell<Tape>>;
impl Tape {
pub fn new() -> TapeRef {
Rc::new(RefCell::new(Tape {
nodes: Vec::new(),
n_inputs: 0,
}))
}
pub fn var(tape: &TapeRef, value: f64) -> super::reverse::Var {
let mut t = tape.borrow_mut();
let index = t.nodes.len();
t.nodes.push(Node {
value,
parent1: None,
parent2: None,
});
t.n_inputs += 1;
super::reverse::Var {
index,
value,
tape: Rc::clone(tape),
}
}
pub(crate) fn push_unary(
tape: &TapeRef,
value: f64,
parent: usize,
deriv: f64,
) -> (usize, f64) {
let mut t = tape.borrow_mut();
let index = t.nodes.len();
t.nodes.push(Node {
value,
parent1: Some((parent, deriv)),
parent2: None,
});
(index, value)
}
pub(crate) fn push_binary(
tape: &TapeRef,
value: f64,
p1: usize,
d1: f64,
p2: usize,
d2: f64,
) -> (usize, f64) {
let mut t = tape.borrow_mut();
let index = t.nodes.len();
t.nodes.push(Node {
value,
parent1: Some((p1, d1)),
parent2: Some((p2, d2)),
});
(index, value)
}
pub fn gradient(tape: &TapeRef, output: &super::reverse::Var) -> Vec<f64> {
let t = tape.borrow();
let n = t.nodes.len();
let mut adjoints = vec![0.0; n];
adjoints[output.index] = 1.0;
for i in (0..n).rev() {
let adj = adjoints[i];
if adj == 0.0 {
continue;
}
let node = &t.nodes[i];
if let Some((p, d)) = node.parent1 {
adjoints[p] += adj * d;
}
if let Some((p, d)) = node.parent2 {
adjoints[p] += adj * d;
}
}
adjoints[..t.n_inputs].to_vec()
}
pub fn jacobian(tape: &TapeRef, outputs: &[super::reverse::Var]) -> Vec<Vec<f64>> {
outputs.iter().map(|o| Self::gradient(tape, o)).collect()
}
pub fn clear(tape: &TapeRef) {
let mut t = tape.borrow_mut();
t.nodes.clear();
t.n_inputs = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tape_basic() {
let tape = Tape::new();
let x = Tape::var(&tape, 2.0);
let y = Tape::var(&tape, 3.0);
let z = x * y; assert!((z.value - 6.0).abs() < 1e-14);
let grad = Tape::gradient(&tape, &z);
assert_eq!(grad.len(), 2);
assert!((grad[0] - 3.0).abs() < 1e-14); assert!((grad[1] - 2.0).abs() < 1e-14); }
#[test]
fn test_tape_clear() {
let tape = Tape::new();
let x = Tape::var(&tape, 1.0);
let _ = x.clone() + x;
assert!(tape.borrow().nodes.len() > 1);
Tape::clear(&tape);
assert_eq!(tape.borrow().nodes.len(), 0);
assert_eq!(tape.borrow().n_inputs, 0);
}
#[test]
fn test_tape_jacobian() {
let tape = Tape::new();
let x = Tape::var(&tape, 1.0);
let y = Tape::var(&tape, 2.0);
let f1 = x.clone() + y.clone(); let f2 = x * y; let jac = Tape::jacobian(&tape, &[f1, f2]);
assert_eq!(jac.len(), 2);
assert!((jac[0][0] - 1.0).abs() < 1e-14);
assert!((jac[0][1] - 1.0).abs() < 1e-14);
assert!((jac[1][0] - 2.0).abs() < 1e-14);
assert!((jac[1][1] - 1.0).abs() < 1e-14);
}
}