use crate::common::IntegrateFloat;
use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Debug, Clone)]
pub enum Operation<F: IntegrateFloat> {
Variable(usize),
Constant(F),
Add(usize, usize),
Sub(usize, usize),
Mul(usize, usize),
Div(usize, usize),
Neg(usize),
Pow(usize, F),
PowGeneral(usize, usize),
Sin(usize),
Cos(usize),
Tan(usize),
Exp(usize),
Ln(usize),
Sqrt(usize),
Tanh(usize),
Sinh(usize),
Cosh(usize),
Atan2(usize, usize),
Abs(usize),
Max(usize, usize),
Min(usize, usize),
}
pub struct TapeNode<F: IntegrateFloat> {
pub value: F,
pub operation: Operation<F>,
pub gradient: RefCell<F>,
}
impl<F: IntegrateFloat> TapeNode<F> {
pub fn new(value: F, operation: Operation<F>) -> Self {
TapeNode {
value,
operation,
gradient: RefCell::new(F::zero()),
}
}
}
pub struct Tape<F: IntegrateFloat> {
nodes: Vec<Rc<TapeNode<F>>>,
var_map: HashMap<usize, usize>,
}
impl<F: IntegrateFloat> Tape<F> {
pub fn new() -> Self {
Tape {
nodes: Vec::new(),
var_map: HashMap::new(),
}
}
pub fn variable(&mut self, idx: usize, value: F) -> usize {
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Variable(idx))));
self.var_map.insert(idx, nodeidx);
nodeidx
}
pub fn constant(&mut self, value: F) -> usize {
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Constant(value))));
nodeidx
}
pub fn add(&mut self, a: usize, b: usize) -> usize {
let value = self.nodes[a].value + self.nodes[b].value;
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Add(a, b))));
nodeidx
}
pub fn sub(&mut self, a: usize, b: usize) -> usize {
let value = self.nodes[a].value - self.nodes[b].value;
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Sub(a, b))));
nodeidx
}
pub fn mul(&mut self, a: usize, b: usize) -> usize {
let value = self.nodes[a].value * self.nodes[b].value;
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Mul(a, b))));
nodeidx
}
pub fn div(&mut self, a: usize, b: usize) -> usize {
let value = self.nodes[a].value / self.nodes[b].value;
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Div(a, b))));
nodeidx
}
pub fn neg(&mut self, a: usize) -> usize {
let value = -self.nodes[a].value;
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Neg(a))));
nodeidx
}
pub fn pow(&mut self, a: usize, n: F) -> usize {
let value = self.nodes[a].value.powf(n);
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Pow(a, n))));
nodeidx
}
pub fn sin(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.sin();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Sin(a))));
nodeidx
}
pub fn cos(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.cos();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Cos(a))));
nodeidx
}
pub fn exp(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.exp();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Exp(a))));
nodeidx
}
pub fn ln(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.ln();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Ln(a))));
nodeidx
}
pub fn sqrt(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.sqrt();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Sqrt(a))));
nodeidx
}
pub fn pow_general(&mut self, a: usize, b: usize) -> usize {
let value = self.nodes[a].value.powf(self.nodes[b].value);
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::PowGeneral(a, b))));
nodeidx
}
pub fn tan(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.tan();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Tan(a))));
nodeidx
}
pub fn tanh(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.tanh();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Tanh(a))));
nodeidx
}
pub fn sinh(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.sinh();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Sinh(a))));
nodeidx
}
pub fn cosh(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.cosh();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Cosh(a))));
nodeidx
}
pub fn atan2(&mut self, y: usize, x: usize) -> usize {
let value = self.nodes[y].value.atan2(self.nodes[x].value);
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Atan2(y, x))));
nodeidx
}
pub fn abs(&mut self, a: usize) -> usize {
let value = self.nodes[a].value.abs();
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Abs(a))));
nodeidx
}
pub fn max(&mut self, a: usize, b: usize) -> usize {
let value = self.nodes[a].value.max(self.nodes[b].value);
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Max(a, b))));
nodeidx
}
pub fn min(&mut self, a: usize, b: usize) -> usize {
let value = self.nodes[a].value.min(self.nodes[b].value);
let nodeidx = self.nodes.len();
self.nodes
.push(Rc::new(TapeNode::new(value, Operation::Min(a, b))));
nodeidx
}
pub fn value(&self, idx: usize) -> F {
self.nodes[idx].value
}
pub fn backward(&mut self, outputidx: usize, nvars: usize) -> Array1<F> {
for node in &self.nodes {
*node.gradient.borrow_mut() = F::zero();
}
*self.nodes[outputidx].gradient.borrow_mut() = F::one();
for i in (0..=outputidx).rev() {
let node = &self.nodes[i];
let grad = *node.gradient.borrow();
if grad.abs() < F::epsilon() {
continue;
}
match &node.operation {
Operation::Variable(_) | Operation::Constant(_) => {}
Operation::Add(a, b) => {
*self.nodes[*a].gradient.borrow_mut() += grad;
*self.nodes[*b].gradient.borrow_mut() += grad;
}
Operation::Sub(a, b) => {
*self.nodes[*a].gradient.borrow_mut() += grad;
*self.nodes[*b].gradient.borrow_mut() -= grad;
}
Operation::Mul(a, b) => {
*self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*b].value;
*self.nodes[*b].gradient.borrow_mut() += grad * self.nodes[*a].value;
}
Operation::Div(a, b) => {
let b_val = self.nodes[*b].value;
*self.nodes[*a].gradient.borrow_mut() += grad / b_val;
*self.nodes[*b].gradient.borrow_mut() -=
grad * self.nodes[*a].value / (b_val * b_val);
}
Operation::Neg(a) => {
*self.nodes[*a].gradient.borrow_mut() -= grad;
}
Operation::Pow(a, n) => {
*self.nodes[*a].gradient.borrow_mut() +=
grad * *n * self.nodes[*a].value.powf(*n - F::one());
}
Operation::Sin(a) => {
*self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.cos();
}
Operation::Cos(a) => {
*self.nodes[*a].gradient.borrow_mut() -= grad * self.nodes[*a].value.sin();
}
Operation::Exp(a) => {
*self.nodes[*a].gradient.borrow_mut() += grad * node.value;
}
Operation::Ln(a) => {
*self.nodes[*a].gradient.borrow_mut() += grad / self.nodes[*a].value;
}
Operation::Sqrt(a) => {
*self.nodes[*a].gradient.borrow_mut() += grad
/ (F::from(2.0).expect("Failed to convert constant to float") * node.value);
}
Operation::PowGeneral(a, b) => {
let a_val = self.nodes[*a].value;
let b_val = self.nodes[*b].value;
*self.nodes[*a].gradient.borrow_mut() +=
grad * b_val * a_val.powf(b_val - F::one());
*self.nodes[*b].gradient.borrow_mut() += grad * node.value * a_val.ln();
}
Operation::Tan(a) => {
let cos_val = self.nodes[*a].value.cos();
*self.nodes[*a].gradient.borrow_mut() += grad / (cos_val * cos_val);
}
Operation::Tanh(a) => {
let tanh_val = node.value;
*self.nodes[*a].gradient.borrow_mut() +=
grad * (F::one() - tanh_val * tanh_val);
}
Operation::Sinh(a) => {
*self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.cosh();
}
Operation::Cosh(a) => {
*self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.sinh();
}
Operation::Atan2(y, x) => {
let x_val = self.nodes[*x].value;
let y_val = self.nodes[*y].value;
let denom = x_val * x_val + y_val * y_val;
*self.nodes[*y].gradient.borrow_mut() += grad * x_val / denom;
*self.nodes[*x].gradient.borrow_mut() -= grad * y_val / denom;
}
Operation::Abs(a) => {
let sign = if self.nodes[*a].value >= F::zero() {
F::one()
} else {
-F::one()
};
*self.nodes[*a].gradient.borrow_mut() += grad * sign;
}
Operation::Max(a, b) => {
if self.nodes[*a].value >= self.nodes[*b].value {
*self.nodes[*a].gradient.borrow_mut() += grad;
} else {
*self.nodes[*b].gradient.borrow_mut() += grad;
}
}
Operation::Min(a, b) => {
if self.nodes[*a].value <= self.nodes[*b].value {
*self.nodes[*a].gradient.borrow_mut() += grad;
} else {
*self.nodes[*b].gradient.borrow_mut() += grad;
}
}
}
}
let mut gradients = Array1::zeros(nvars);
for (varidx, &nodeidx) in &self.var_map {
if *varidx < nvars {
gradients[*varidx] = *self.nodes[nodeidx].gradient.borrow();
}
}
gradients
}
}
impl<F: IntegrateFloat> Default for Tape<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub enum CheckpointStrategy {
None,
FixedInterval(usize),
Logarithmic,
MemoryBased { max_nodes: usize },
}
pub struct ReverseAD<F: IntegrateFloat> {
nvars: usize,
checkpoint_strategy: CheckpointStrategy,
_phantom: std::marker::PhantomData<F>,
}
impl<F: IntegrateFloat> ReverseAD<F> {
pub fn new(nvars: usize) -> Self {
ReverseAD {
nvars,
checkpoint_strategy: CheckpointStrategy::None,
_phantom: std::marker::PhantomData,
}
}
pub fn with_checkpoint_strategy(mut self, strategy: CheckpointStrategy) -> Self {
self.checkpoint_strategy = strategy;
self
}
pub fn gradient<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array1<F>>
where
Func: Fn(&mut Tape<F>, &[usize]) -> usize,
{
if x.len() != self.nvars {
return Err(IntegrateError::DimensionMismatch(format!(
"Expected {} variables, got {}",
self.nvars,
x.len()
)));
}
let mut tape = Tape::new();
let mut var_indices = Vec::new();
for (i, &val) in x.iter().enumerate() {
let idx = tape.variable(i, val);
var_indices.push(idx);
}
let outputidx = f(&mut tape, &var_indices);
Ok(tape.backward(outputidx, self.nvars))
}
pub fn jacobian<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
where
Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
{
if x.len() != self.nvars {
return Err(IntegrateError::DimensionMismatch(format!(
"Expected {} variables, got {}",
self.nvars,
x.len()
)));
}
let mut tape = Tape::new();
let mut var_indices = Vec::new();
for (i, &val) in x.iter().enumerate() {
let idx = tape.variable(i, val);
var_indices.push(idx);
}
let output_indices = f(&mut tape, &var_indices);
let m = output_indices.len();
let mut jacobian = Array2::zeros((m, self.nvars));
for (i, &outputidx) in output_indices.iter().enumerate() {
let grad = tape.backward(outputidx, self.nvars);
jacobian.row_mut(i).assign(&grad);
}
Ok(jacobian)
}
pub fn hessian<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
where
Func: Fn(&mut Tape<F>, &[usize]) -> usize + Clone,
{
if x.len() != self.nvars {
return Err(IntegrateError::DimensionMismatch(format!(
"Expected {} variables, got {}",
self.nvars,
x.len()
)));
}
let mut hessian = Array2::zeros((self.nvars, self.nvars));
let eps = F::from(1e-8).expect("Failed to convert constant to float");
for j in 0..self.nvars {
let mut x_plus = x.to_owned();
x_plus[j] += eps;
let grad_plus = self.gradient(f.clone(), x_plus.view())?;
let grad_base = self.gradient(f.clone(), x)?;
for i in 0..self.nvars {
hessian[[i, j]] = (grad_plus[i] - grad_base[i]) / eps;
}
}
for i in 0..self.nvars {
for j in (i + 1)..self.nvars {
let avg = (hessian[[i, j]] + hessian[[j, i]])
/ F::from(2.0).expect("Failed to convert constant to float");
hessian[[i, j]] = avg;
hessian[[j, i]] = avg;
}
}
Ok(hessian)
}
pub fn batch_gradient<Func>(
&mut self,
f: Func,
x_batch: &[Array1<F>],
) -> IntegrateResult<Vec<Array1<F>>>
where
Func: Fn(&mut Tape<F>, &[usize]) -> usize + Clone,
{
let mut gradients = Vec::with_capacity(x_batch.len());
for x in x_batch {
gradients.push(self.gradient(f.clone(), x.view())?);
}
Ok(gradients)
}
pub fn jvp<Func>(
&mut self,
f: Func,
x: ArrayView1<F>,
v: ArrayView1<F>,
) -> IntegrateResult<Array1<F>>
where
Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
{
if x.len() != self.nvars || v.len() != self.nvars {
return Err(IntegrateError::DimensionMismatch(format!(
"Expected {} variables for both x and v",
self.nvars
)));
}
let eps = F::from(1e-8).expect("Failed to convert constant to float");
let x_perturbed = &x + &(v.to_owned() * eps);
let mut tape = Tape::new();
let mut var_indices = Vec::new();
let mut var_indices_perturbed = Vec::new();
for (i, &val) in x.iter().enumerate() {
let idx = tape.variable(i, val);
var_indices.push(idx);
}
let output_base = f(&mut tape, &var_indices);
tape = Tape::new();
for (i, &val) in x_perturbed.iter().enumerate() {
let idx = tape.variable(i, val);
var_indices_perturbed.push(idx);
}
let output_perturbed = f(&mut tape, &var_indices_perturbed);
let mut jvp = Array1::zeros(output_base.len());
for (i, (&idx_base, &idx_pert)) in
output_base.iter().zip(output_perturbed.iter()).enumerate()
{
jvp[i] = (tape.value(idx_pert) - tape.value(idx_base)) / eps;
}
Ok(jvp)
}
pub fn vjp<Func>(
&mut self,
f: Func,
x: ArrayView1<F>,
v: ArrayView1<F>,
) -> IntegrateResult<Array1<F>>
where
Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
{
if x.len() != self.nvars {
return Err(IntegrateError::DimensionMismatch(format!(
"Expected {} variables",
self.nvars
)));
}
let mut tape = Tape::new();
let mut var_indices = Vec::new();
for (i, &val) in x.iter().enumerate() {
let idx = tape.variable(i, val);
var_indices.push(idx);
}
let output_indices = f(&mut tape, &var_indices);
if v.len() != output_indices.len() {
return Err(IntegrateError::DimensionMismatch(format!(
"Vector v length {} doesn't match output dimension {}",
v.len(),
output_indices.len()
)));
}
let mut weighted_sum = tape.constant(F::zero());
for (i, &outputidx) in output_indices.iter().enumerate() {
let v_i = tape.constant(v[i]);
let term = tape.mul(v_i, outputidx);
weighted_sum = tape.add(weighted_sum, term);
}
Ok(tape.backward(weighted_sum, self.nvars))
}
}
#[allow(dead_code)]
pub fn reverse_gradient<F, Func>(f: Func, x: ArrayView1<F>) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
Func: Fn(&mut Tape<F>, &[usize]) -> usize,
{
let mut ad = ReverseAD::new(x.len());
ad.gradient(f, x)
}
#[allow(dead_code)]
pub fn reverse_jacobian<F, Func>(f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
where
F: IntegrateFloat,
Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
{
let mut ad = ReverseAD::new(x.len());
ad.jacobian(f, x)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reverse_gradient() {
let f = |tape: &mut Tape<f64>, vars: &[usize]| {
let x_sq = tape.mul(vars[0], vars[0]);
let y_sq = tape.mul(vars[1], vars[1]);
tape.add(x_sq, y_sq)
};
let x = Array1::from_vec(vec![3.0, 4.0]);
let grad = reverse_gradient(f, x.view()).expect("Operation failed");
assert!((grad[0] - 6.0).abs() < 1e-10);
assert!((grad[1] - 8.0).abs() < 1e-10);
}
#[test]
fn test_reverse_jacobian() {
let f = |tape: &mut Tape<f64>, vars: &[usize]| {
let x_sq = tape.mul(vars[0], vars[0]);
let xy = tape.mul(vars[0], vars[1]);
let y_sq = tape.mul(vars[1], vars[1]);
vec![x_sq, xy, y_sq]
};
let x = Array1::from_vec(vec![2.0, 3.0]);
let jac = reverse_jacobian(f, x.view()).expect("Operation failed");
assert!((jac[[0, 0]] - 4.0).abs() < 1e-10); assert!((jac[[0, 1]] - 0.0).abs() < 1e-10);
assert!((jac[[1, 0]] - 3.0).abs() < 1e-10); assert!((jac[[1, 1]] - 2.0).abs() < 1e-10); assert!((jac[[2, 0]] - 0.0).abs() < 1e-10);
assert!((jac[[2, 1]] - 6.0).abs() < 1e-10); }
}