use std::collections::{HashMap, VecDeque};
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::{Tensor, TensorId};
pub fn grad<T: Float>(
outputs: &Tensor<T>,
inputs: &[&Tensor<T>],
retain_graph: bool,
create_graph: bool,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !outputs.is_scalar() && outputs.numel() != 1 {
return Err(FerrotorchError::BackwardNonScalar {
shape: outputs.shape().to_vec(),
});
}
let input_ids: HashMap<TensorId, usize> = inputs
.iter()
.enumerate()
.map(|(i, t)| (t.id(), i))
.collect();
let seed = if create_graph {
Tensor::from_storage(
TensorStorage::cpu(vec![<T as num_traits::One>::one()]),
vec![],
true,
)?
} else {
Tensor::from_storage(
TensorStorage::cpu(vec![<T as num_traits::One>::one()]),
vec![],
false,
)?
};
let mut in_degree: HashMap<TensorId, usize> = HashMap::new();
let mut node_map: HashMap<TensorId, &Tensor<T>> = HashMap::new();
let mut queue: VecDeque<&Tensor<T>> = VecDeque::new();
queue.push_back(outputs);
in_degree.entry(outputs.id()).or_insert(0);
node_map.insert(outputs.id(), outputs);
while let Some(node) = queue.pop_front() {
if let Some(grad_fn) = node.grad_fn() {
for input in grad_fn.inputs() {
let input_id = input.id();
let count = in_degree.entry(input_id).or_insert(0);
*count += 1;
if let std::collections::hash_map::Entry::Vacant(e) = node_map.entry(input_id) {
e.insert(input);
queue.push_back(input);
}
}
}
}
let mut topo_order: Vec<TensorId> = Vec::new();
let mut bfs_queue: VecDeque<TensorId> = VecDeque::new();
for (&id, °) in &in_degree {
if deg == 0 {
bfs_queue.push_back(id);
}
}
while let Some(id) = bfs_queue.pop_front() {
topo_order.push(id);
if let Some(node) = node_map.get(&id) {
if let Some(grad_fn) = node.grad_fn() {
for input in grad_fn.inputs() {
if let Some(deg) = in_degree.get_mut(&input.id()) {
*deg -= 1;
if *deg == 0 {
bfs_queue.push_back(input.id());
}
}
}
}
}
}
let mut grads: HashMap<TensorId, Tensor<T>> = HashMap::new();
grads.insert(outputs.id(), seed);
let mut result: Vec<Option<Tensor<T>>> = vec![None; inputs.len()];
for &id in &topo_order {
let node = match node_map.get(&id) {
Some(n) => *n,
None => continue,
};
let grad_output = match grads.remove(&id) {
Some(g) => g,
None => continue,
};
if let Some(&idx) = input_ids.get(&id) {
result[idx] = Some(grad_output.clone());
}
if let Some(grad_fn) = node.grad_fn() {
let input_grads = grad_fn.backward(&grad_output)?;
let fn_inputs = grad_fn.inputs();
for (input, maybe_grad) in fn_inputs.iter().zip(input_grads.into_iter()) {
if let Some(ig) = maybe_grad {
if input.requires_grad() {
let grad_tensor = if create_graph && !ig.requires_grad() {
ig.requires_grad_(true)
} else {
ig
};
if let Some(existing) = grads.remove(&input.id()) {
if create_graph {
let summed =
differentiable_add(&existing, &grad_tensor, create_graph)?;
grads.insert(input.id(), summed);
} else {
let a = existing.data()?;
let b = grad_tensor.data()?;
let summed: Vec<T> =
a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect();
let storage = TensorStorage::cpu(summed);
let combined = Tensor::from_storage(
storage,
existing.shape().to_vec(),
false,
)?;
grads.insert(input.id(), combined);
}
} else {
grads.insert(input.id(), grad_tensor);
}
}
}
}
}
}
for (id, g) in grads {
if let Some(&idx) = input_ids.get(&id) {
if result[idx].is_none() {
result[idx] = Some(g);
}
}
}
let _ = retain_graph;
Ok(result)
}
fn differentiable_add<T: Float>(
a: &Tensor<T>,
b: &Tensor<T>,
_create_graph: bool,
) -> FerrotorchResult<Tensor<T>> {
crate::grad_fns::arithmetic::add(a, b)
}
pub fn jacobian<T: Float, F>(f: F, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
let n = input.numel();
let x = Tensor::from_storage(
TensorStorage::cpu(input.data()?.to_vec()),
input.shape().to_vec(),
true,
)?;
let y = f(&x)?;
let m = y.numel();
let y_data = y.data()?.to_vec();
let mut jac_data: Vec<T> = Vec::with_capacity(m * n);
for i in 0..m {
let x_fresh = Tensor::from_storage(
TensorStorage::cpu(input.data()?.to_vec()),
input.shape().to_vec(),
true,
)?;
let y_fresh = f(&x_fresh)?;
let y_i = extract_element(&y_fresh, i)?;
let grads = grad(&y_i, &[&x_fresh], false, false)?;
match &grads[0] {
Some(g) => {
let g_data = g.data()?;
jac_data.extend_from_slice(g_data);
}
None => {
jac_data.extend(std::iter::repeat_n(<T as num_traits::Zero>::zero(), n));
}
}
}
let _ = y_data;
Tensor::from_storage(TensorStorage::cpu(jac_data), vec![m, n], false)
}
pub fn hessian<T: Float, F>(f: F, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
F: Fn(&Tensor<T>) -> FerrotorchResult<Tensor<T>>,
{
let n = input.numel();
let mut hess_data: Vec<T> = Vec::with_capacity(n * n);
for i in 0..n {
let x = Tensor::from_storage(
TensorStorage::cpu(input.data()?.to_vec()),
input.shape().to_vec(),
true,
)?;
let y = f(&x)?;
let grads = grad(&y, &[&x], true, true)?;
let grad_vec = match &grads[0] {
Some(g) => g.clone(),
None => {
hess_data.extend(std::iter::repeat_n(<T as num_traits::Zero>::zero(), n));
continue;
}
};
let grad_i = extract_element(&grad_vec, i)?;
let grads2 = grad(&grad_i, &[&x], false, false)?;
match &grads2[0] {
Some(g2) => {
let g2_data = g2.data()?;
hess_data.extend_from_slice(g2_data);
}
None => {
hess_data.extend(std::iter::repeat_n(<T as num_traits::Zero>::zero(), n));
}
}
}
Tensor::from_storage(TensorStorage::cpu(hess_data), vec![n, n], false)
}
fn extract_element<T: Float>(tensor: &Tensor<T>, index: usize) -> FerrotorchResult<Tensor<T>> {
let data = tensor.data()?;
if index >= data.len() {
return Err(FerrotorchError::IndexOutOfBounds {
index,
axis: 0,
size: data.len(),
});
}
let val = data[index];
let scalar = Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], false)?;
if tensor.requires_grad() || tensor.grad_fn().is_some() {
let grad_fn = std::sync::Arc::new(IndexSelectBackward {
input: tensor.clone(),
index,
});
Tensor::from_operation(TensorStorage::cpu(vec![val]), vec![], grad_fn)
} else {
Ok(scalar)
}
}
#[derive(Debug)]
struct IndexSelectBackward<T: Float> {
input: Tensor<T>,
index: usize,
}
impl<T: Float> crate::tensor::GradFn<T> for IndexSelectBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let numel = self.input.numel();
if grad_output.requires_grad() || grad_output.grad_fn().is_some() {
let one = <T as num_traits::One>::one();
let zero = <T as num_traits::Zero>::zero();
let mut basis = vec![zero; numel];
basis[self.index] = one;
let basis_tensor = Tensor::from_storage(
TensorStorage::cpu(basis),
self.input.shape().to_vec(),
false,
)?;
let go_val = grad_output.data()?[0];
let broadcast_tracked = Tensor::from_operation(
TensorStorage::cpu(vec![go_val; numel]),
self.input.shape().to_vec(),
std::sync::Arc::new(BroadcastScalarBackward {
scalar_input: grad_output.clone(),
numel,
}),
)?;
let grad_input = crate::grad_fns::arithmetic::mul(&broadcast_tracked, &basis_tensor)?;
return Ok(vec![Some(grad_input)]);
}
let go = grad_output.data()?[0];
let mut grad_data = vec![<T as num_traits::Zero>::zero(); numel];
grad_data[self.index] = go;
let grad_input = Tensor::from_storage(
TensorStorage::cpu(grad_data),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IndexSelectBackward"
}
}
#[derive(Debug)]
struct BroadcastScalarBackward<T: Float> {
scalar_input: Tensor<T>,
#[allow(dead_code)]
numel: usize,
}
impl<T: Float> crate::tensor::GradFn<T> for BroadcastScalarBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data()?;
let total: T = go_data.iter().copied().fold(<T as num_traits::Zero>::zero(), |a, b| a + b);
let grad_scalar = Tensor::from_storage(TensorStorage::cpu(vec![total]), vec![], false)?;
Ok(vec![Some(grad_scalar)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.scalar_input]
}
fn name(&self) -> &'static str {
"BroadcastScalarBackward"
}
}
impl<T: Float> Tensor<T> {
pub fn grad_wrt(
&self,
inputs: &[&Tensor<T>],
retain_graph: bool,
create_graph: bool,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
grad(self, inputs, retain_graph, create_graph)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grad_fns::arithmetic::{add, mul, pow};
use crate::grad_fns::reduction::sum;
use crate::storage::TensorStorage;
fn leaf_scalar(val: f32, requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], requires_grad).unwrap()
}
fn leaf_vec(data: &[f32], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], requires_grad)
.unwrap()
}
fn assert_approx(actual: f32, expected: f32, tol: f32, msg: &str) {
assert!(
(actual - expected).abs() < tol,
"{msg}: expected {expected}, got {actual}"
);
}
#[test]
fn test_grad_simple_pow() {
let x = leaf_scalar(2.0, true);
let y = pow(&x, 3.0).unwrap();
let grads = grad(&y, &[&x], false, false).unwrap();
let dy_dx = grads[0].as_ref().unwrap();
assert_approx(dy_dx.item().unwrap(), 12.0, 1e-4, "f'(2) for x^3");
}
#[test]
fn test_grad_add() {
let x = leaf_scalar(3.0, true);
let y = leaf_scalar(5.0, true);
let z = add(&x, &y).unwrap();
let grads = grad(&z, &[&x, &y], false, false).unwrap();
assert_approx(
grads[0].as_ref().unwrap().item().unwrap(),
1.0,
1e-6,
"dz/dx",
);
assert_approx(
grads[1].as_ref().unwrap().item().unwrap(),
1.0,
1e-6,
"dz/dy",
);
}
#[test]
fn test_grad_mul() {
let x = leaf_scalar(3.0, true);
let y = leaf_scalar(5.0, true);
let z = mul(&x, &y).unwrap();
let grads = grad(&z, &[&x, &y], false, false).unwrap();
assert_approx(
grads[0].as_ref().unwrap().item().unwrap(),
5.0,
1e-6,
"dz/dx = y",
);
assert_approx(
grads[1].as_ref().unwrap().item().unwrap(),
3.0,
1e-6,
"dz/dy = x",
);
}
#[test]
fn test_grad_x_squared_plus_y_squared() {
let x = leaf_scalar(3.0, true);
let y = leaf_scalar(4.0, true);
let x2 = pow(&x, 2.0).unwrap();
let y2 = pow(&y, 2.0).unwrap();
let z = add(&x2, &y2).unwrap();
let grads = grad(&z, &[&x, &y], false, false).unwrap();
assert_approx(
grads[0].as_ref().unwrap().item().unwrap(),
6.0,
1e-4,
"dz/dx = 2x",
);
assert_approx(
grads[1].as_ref().unwrap().item().unwrap(),
8.0,
1e-4,
"dz/dy = 2y",
);
}
#[test]
fn test_grad_does_not_accumulate_on_leaves() {
let x = leaf_scalar(2.0, true);
let y = pow(&x, 2.0).unwrap();
let _grads = grad(&y, &[&x], false, false).unwrap();
assert!(
x.grad().unwrap().is_none(),
"grad() should not accumulate on leaf tensors"
);
}
#[test]
fn test_grad_no_create_graph_returns_detached() {
let x = leaf_scalar(2.0, true);
let y = pow(&x, 3.0).unwrap();
let grads = grad(&y, &[&x], false, false).unwrap();
let dy_dx = grads[0].as_ref().unwrap();
assert!(
dy_dx.grad_fn().is_none(),
"create_graph=false: gradient should not have grad_fn"
);
}
#[test]
fn test_grad_create_graph_returns_differentiable() {
let x = leaf_scalar(2.0, true);
let y = pow(&x, 3.0).unwrap();
let grads = grad(&y, &[&x], true, true).unwrap();
let dy_dx = grads[0].as_ref().unwrap();
assert!(
dy_dx.requires_grad(),
"create_graph=true: gradient should require grad"
);
}
#[test]
fn test_higher_order_x_cubed() {
let x = leaf_scalar(2.0, true);
let y = pow(&x, 3.0).unwrap();
let grads1 = grad(&y, &[&x], true, true).unwrap();
let dy_dx = grads1[0].as_ref().unwrap();
assert_approx(dy_dx.item().unwrap(), 12.0, 1e-4, "f'(2) = 3*4 = 12");
let grads2 = grad(dy_dx, &[&x], false, false).unwrap();
let d2y_dx2 = grads2[0].as_ref().unwrap();
assert_approx(d2y_dx2.item().unwrap(), 12.0, 1e-3, "f''(2) = 6*2 = 12");
}
#[test]
fn test_higher_order_x_squared() {
let x = leaf_scalar(5.0, true);
let y = pow(&x, 2.0).unwrap();
let grads1 = grad(&y, &[&x], true, true).unwrap();
let dy_dx = grads1[0].as_ref().unwrap();
assert_approx(dy_dx.item().unwrap(), 10.0, 1e-4, "f'(5) = 2*5 = 10");
let grads2 = grad(dy_dx, &[&x], false, false).unwrap();
let d2y_dx2 = grads2[0].as_ref().unwrap();
assert_approx(d2y_dx2.item().unwrap(), 2.0, 1e-3, "f''(5) = 2");
}
#[test]
fn test_higher_order_product() {
let x = leaf_scalar(3.0, true);
let y = leaf_scalar(4.0, true);
let z = mul(&x, &y).unwrap();
let grads1 = grad(&z, &[&x, &y], true, true).unwrap();
let dz_dx = grads1[0].as_ref().unwrap();
let dz_dy = grads1[1].as_ref().unwrap();
assert_approx(dz_dx.item().unwrap(), 4.0, 1e-6, "dz/dx = y = 4");
assert_approx(dz_dy.item().unwrap(), 3.0, 1e-6, "dz/dy = x = 3");
let grads2 = grad(dz_dx, &[&y], false, false).unwrap();
let d2z_dxdy = grads2[0].as_ref().unwrap();
assert_approx(d2z_dxdy.item().unwrap(), 1.0, 1e-4, "d2z/dxdy = 1");
}
#[test]
fn test_jacobian_quadratic() {
let input = leaf_vec(&[2.0, 3.0], false);
let jac = jacobian(
|x| {
let e0 = extract_element(x, 0).unwrap();
let e1 = extract_element(x, 1).unwrap();
let f0 = pow(&e0, 2.0).unwrap(); let f1 = mul(&e0, &e1).unwrap();
let f0_val = f0.item().unwrap();
let f1_val = f1.item().unwrap();
let out = Tensor::from_operation(
TensorStorage::cpu(vec![f0_val, f1_val]),
vec![2],
std::sync::Arc::new(ConcatBackward2 {
input0: f0,
input1: f1,
}),
)
.unwrap();
Ok(out)
},
&input,
)
.unwrap();
assert_eq!(jac.shape(), &[2, 2]);
let j = jac.data().unwrap();
assert_approx(j[0], 4.0, 1e-4, "J[0,0] = 2x = 4");
assert_approx(j[1], 0.0, 1e-4, "J[0,1] = 0");
assert_approx(j[2], 3.0, 1e-4, "J[1,0] = y = 3");
assert_approx(j[3], 2.0, 1e-4, "J[1,1] = x = 2");
}
#[test]
fn test_jacobian_identity() {
let input = leaf_vec(&[3.0], false);
let jac = jacobian(
|x| {
let s = sum(x).unwrap();
Ok(s)
},
&input,
)
.unwrap();
assert_eq!(jac.shape(), &[1, 1]);
assert_approx(jac.data().unwrap()[0], 1.0, 1e-6, "J[0,0]");
}
#[test]
fn test_jacobian_scaled() {
let input = leaf_vec(&[5.0], false);
let jac = jacobian(
|x| {
let doubled = add(x, x).unwrap();
let s = sum(&doubled).unwrap();
Ok(s)
},
&input,
)
.unwrap();
assert_eq!(jac.shape(), &[1, 1]);
assert_approx(jac.data().unwrap()[0], 2.0, 1e-5, "J[0,0] = 2");
}
#[test]
fn test_jacobian_vector_to_vector() {
let input = leaf_vec(&[1.0, 1.0], false);
let jac = jacobian(
|x| {
let e0 = extract_element(x, 0).unwrap();
let e1 = extract_element(x, 1).unwrap();
let f0 = pow(&e0, 2.0).unwrap(); let f1 = mul(&e0, &e1).unwrap();
let f0_val = f0.item().unwrap();
let f1_val = f1.item().unwrap();
let out = Tensor::from_operation(
TensorStorage::cpu(vec![f0_val, f1_val]),
vec![2],
std::sync::Arc::new(ConcatBackward2 {
input0: f0,
input1: f1,
}),
)
.unwrap();
Ok(out)
},
&input,
)
.unwrap();
assert_eq!(jac.shape(), &[2, 2]);
let j = jac.data().unwrap();
assert_approx(j[0], 2.0, 1e-4, "J[0,0] = 2x = 2");
assert_approx(j[1], 0.0, 1e-4, "J[0,1] = 0");
assert_approx(j[2], 1.0, 1e-4, "J[1,0] = y = 1");
assert_approx(j[3], 1.0, 1e-4, "J[1,1] = x = 1");
}
#[derive(Debug)]
struct ConcatBackward2<T: Float> {
input0: Tensor<T>,
input1: Tensor<T>,
}
impl<T: Float> crate::tensor::GradFn<T> for ConcatBackward2<T> {
fn backward(
&self,
grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go = grad_output.data()?;
let g0 = Tensor::from_storage(TensorStorage::cpu(vec![go[0]]), vec![], false)?;
let g1 = Tensor::from_storage(TensorStorage::cpu(vec![go[1]]), vec![], false)?;
Ok(vec![Some(g0), Some(g1)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input0, &self.input1]
}
fn name(&self) -> &'static str {
"ConcatBackward2"
}
}
#[test]
fn test_hessian_x_squared_plus_y_squared() {
let input = leaf_vec(&[3.0, 4.0], false);
let hess = hessian(
|x| {
let e0 = extract_element(x, 0).unwrap();
let e1 = extract_element(x, 1).unwrap();
let f0 = pow(&e0, 2.0).unwrap();
let f1 = pow(&e1, 2.0).unwrap();
let result = add(&f0, &f1).unwrap();
Ok(result)
},
&input,
)
.unwrap();
assert_eq!(hess.shape(), &[2, 2]);
let h = hess.data().unwrap();
assert_approx(h[0], 2.0, 1e-3, "H[0,0]");
assert_approx(h[1], 0.0, 1e-3, "H[0,1]");
assert_approx(h[2], 0.0, 1e-3, "H[1,0]");
assert_approx(h[3], 2.0, 1e-3, "H[1,1]");
}
#[test]
fn test_hessian_x_cubed() {
let input = leaf_vec(&[2.0], false);
let hess = hessian(
|x| {
let e = extract_element(x, 0).unwrap();
pow(&e, 3.0)
},
&input,
)
.unwrap();
assert_eq!(hess.shape(), &[1, 1]);
assert_approx(hess.data().unwrap()[0], 12.0, 1e-2, "H[0,0] = 6*2 = 12");
}
#[test]
fn test_hessian_xy() {
let input = leaf_vec(&[2.0, 3.0], false);
let hess = hessian(
|x| {
let e0 = extract_element(x, 0).unwrap();
let e1 = extract_element(x, 1).unwrap();
mul(&e0, &e1)
},
&input,
)
.unwrap();
assert_eq!(hess.shape(), &[2, 2]);
let h = hess.data().unwrap();
assert_approx(h[0], 0.0, 1e-3, "H[0,0] = 0");
assert_approx(h[1], 1.0, 1e-3, "H[0,1] = 1");
assert_approx(h[2], 1.0, 1e-3, "H[1,0] = 1");
assert_approx(h[3], 0.0, 1e-3, "H[1,1] = 0");
}
#[test]
fn test_grad_non_scalar_error() {
let x = leaf_vec(&[1.0, 2.0, 3.0], true);
let result = grad(&x, &[&x], false, false);
assert!(result.is_err());
}
#[test]
fn test_grad_no_dependency() {
let x = leaf_scalar(1.0, true);
let y = leaf_scalar(2.0, true);
let z = pow(&y, 2.0).unwrap();
let grads = grad(&z, &[&x], false, false).unwrap();
assert!(grads[0].is_none(), "x is not in the graph of z");
}
#[test]
fn test_grad_wrt_convenience() {
let x = leaf_scalar(3.0, true);
let y = pow(&x, 2.0).unwrap();
let grads = y.grad_wrt(&[&x], false, false).unwrap();
assert_approx(
grads[0].as_ref().unwrap().item().unwrap(),
6.0,
1e-4,
"dy/dx = 2x = 6",
);
}
}