use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::ops::cumulative::{
CumExtremeResult, cummax_forward, cummin_forward, cumprod_forward, cumsum_forward,
logcumsumexp_forward, reverse_cumsum,
};
use crate::shape::normalize_axis;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct CumsumBackward<T: Float> {
input: Tensor<T>,
dim: usize,
}
impl<T: Float> GradFn<T> for CumsumBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_cpu = if grad_output.is_cuda() {
grad_output.cpu()?
} else {
grad_output.clone()
};
let go_data = go_cpu.data()?;
let shape = go_cpu.shape();
let grad_data = reverse_cumsum(go_data, shape, self.dim);
let grad_cpu = Tensor::from_storage(TensorStorage::cpu(grad_data), shape.to_vec(), false)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CumsumBackward"
}
}
pub fn cumsum<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<Tensor<T>> {
let norm_dim = normalize_axis(dim as isize, input.ndim())?;
let result = cumsum_forward(input, dim)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(CumsumBackward {
input: input.clone(),
dim: norm_dim,
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct CumprodBackward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
dim: usize,
}
impl<T: Float> GradFn<T> for CumprodBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_cpu = if grad_output.is_cuda() {
grad_output.cpu()?
} else {
grad_output.clone()
};
let input_cpu = if self.input.is_cuda() {
self.input.cpu()?
} else {
self.input.clone()
};
let output_cpu = if self.output.is_cuda() {
self.output.cpu()?
} else {
self.output.clone()
};
let go_data = go_cpu.data()?;
let in_data = input_cpu.data()?;
let out_data = output_cpu.data()?;
let shape = input_cpu.shape();
let (outer, dim_size, inner) = dim_strides(shape, self.dim);
let numel = in_data.len();
let mut grad_input = vec![<T as num_traits::Zero>::zero(); numel];
for o in 0..outer {
for k in 0..inner {
let base = o * dim_size * inner + k;
let has_zero = (0..dim_size)
.any(|i| in_data[base + i * inner] == <T as num_traits::Zero>::zero());
if !has_zero {
let mut product = vec![<T as num_traits::Zero>::zero(); dim_size];
for (i, prod_elem) in product.iter_mut().enumerate().take(dim_size) {
let idx = base + i * inner;
*prod_elem = go_data[idx] * out_data[idx];
}
let mut rev_acc = <T as num_traits::Zero>::zero();
for i in (0..dim_size).rev() {
let idx = base + i * inner;
rev_acc += product[i];
grad_input[idx] = rev_acc / in_data[idx];
}
} else {
for i in 0..dim_size {
let mut acc = <T as num_traits::Zero>::zero();
for j in i..dim_size {
let mut partial = <T as num_traits::One>::one();
for kk in 0..=j {
if kk != i {
#[allow(clippy::assign_op_pattern)]
{
partial = partial * in_data[base + kk * inner];
}
}
}
acc += go_data[base + j * inner] * partial;
}
grad_input[base + i * inner] = acc;
}
}
}
}
let grad_cpu = Tensor::from_storage(TensorStorage::cpu(grad_input), shape.to_vec(), false)?;
let result = grad_cpu.to(self.input.device())?;
Ok(vec![Some(result)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CumprodBackward"
}
}
pub fn cumprod<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<Tensor<T>> {
let norm_dim = normalize_axis(dim as isize, input.ndim())?;
let result = cumprod_forward(input, dim)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(CumprodBackward {
input: input.clone(),
output: result.clone(),
dim: norm_dim,
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
pub fn cummax<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<CumExtremeResult<T>> {
cummax_forward(input, dim)
}
pub fn cummin<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<CumExtremeResult<T>> {
cummin_forward(input, dim)
}
#[derive(Debug)]
pub struct LogcumsumexpBackward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
dim: usize,
}
impl<T: Float> GradFn<T> for LogcumsumexpBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_cpu = if grad_output.is_cuda() {
grad_output.cpu()?
} else {
grad_output.clone()
};
let input_cpu = if self.input.is_cuda() {
self.input.cpu()?
} else {
self.input.clone()
};
let output_cpu = if self.output.is_cuda() {
self.output.cpu()?
} else {
self.output.clone()
};
let go_data = go_cpu.data()?;
let in_data = input_cpu.data()?;
let out_data = output_cpu.data()?;
let shape = input_cpu.shape();
let product: Vec<T> = go_data
.iter()
.zip(out_data.iter())
.map(|(&g, &o)| g * (-o).exp())
.collect();
let rev = reverse_cumsum(&product, shape, self.dim);
let grad_data: Vec<T> = in_data
.iter()
.zip(rev.iter())
.map(|(&x, &r)| x.exp() * r)
.collect();
let grad_cpu = Tensor::from_storage(TensorStorage::cpu(grad_data), shape.to_vec(), false)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"LogcumsumexpBackward"
}
}
pub fn logcumsumexp<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<Tensor<T>> {
let norm_dim = normalize_axis(dim as isize, input.ndim())?;
let result = logcumsumexp_forward(input, dim)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(LogcumsumexpBackward {
input: input.clone(),
output: result.clone(),
dim: norm_dim,
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
fn dim_strides(shape: &[usize], dim: usize) -> (usize, usize, usize) {
let outer: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let inner: usize = shape[dim + 1..].iter().product();
(outer, dim_size, inner)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::no_grad::no_grad;
use crate::grad_fns::reduction::sum;
use crate::storage::TensorStorage;
fn leaf(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
#[test]
fn test_cumsum_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let cs = cumsum(&x, 0).unwrap();
assert_eq!(cs.shape(), &[4]);
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
assert!((d[3] - 10.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_2d_dim0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cs = cumsum(&x, 0).unwrap();
assert_eq!(cs.shape(), &[2, 3]);
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 3.0).abs() < 1e-12);
assert!((d[3] - 5.0).abs() < 1e-12);
assert!((d[4] - 7.0).abs() < 1e-12);
assert!((d[5] - 9.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_2d_dim1() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cs = cumsum(&x, 1).unwrap();
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
assert!((d[3] - 4.0).abs() < 1e-12);
assert!((d[4] - 9.0).abs() < 1e-12);
assert!((d[5] - 15.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_negative_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cs = cumsum(&x, -1).unwrap();
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_3d() {
let data: Vec<f64> = (1..=12).map(|x| x as f64).collect();
let x = leaf(&data, &[2, 2, 3], false);
let cs = cumsum(&x, 1).unwrap();
assert_eq!(cs.shape(), &[2, 2, 3]);
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[3] - 5.0).abs() < 1e-12);
assert!((d[4] - 7.0).abs() < 1e-12);
assert!((d[5] - 9.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_backward_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cs = cumsum(&x, 0).unwrap();
let loss = sum(&cs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 3.0).abs() < 1e-12, "got {}", gd[0]);
assert!((gd[1] - 2.0).abs() < 1e-12, "got {}", gd[1]);
assert!((gd[2] - 1.0).abs() < 1e-12, "got {}", gd[2]);
}
#[test]
fn test_cumsum_backward_2d_dim0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], true);
let cs = cumsum(&x, 0).unwrap();
let loss = sum(&cs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 2.0).abs() < 1e-12, "got {}", gd[0]);
assert!((gd[1] - 2.0).abs() < 1e-12, "got {}", gd[1]);
assert!((gd[2] - 1.0).abs() < 1e-12, "got {}", gd[2]);
assert!((gd[3] - 1.0).abs() < 1e-12, "got {}", gd[3]);
}
#[test]
fn test_cumsum_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cs = cumsum(&x, 0).unwrap();
assert!(cs.grad_fn().is_some());
assert_eq!(cs.grad_fn().unwrap().name(), "CumsumBackward");
}
#[test]
fn test_cumsum_no_grad_fn_when_not_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let cs = cumsum(&x, 0).unwrap();
assert!(cs.grad_fn().is_none());
}
#[test]
fn test_cumsum_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cs = no_grad(|| cumsum(&x, 0)).unwrap();
assert!(cs.grad_fn().is_none());
}
#[test]
fn test_cumprod_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let cp = cumprod(&x, 0).unwrap();
let d = cp.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
assert!((d[3] - 24.0).abs() < 1e-12);
}
#[test]
fn test_cumprod_2d_dim0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cp = cumprod(&x, 0).unwrap();
let d = cp.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 3.0).abs() < 1e-12);
assert!((d[3] - 4.0).abs() < 1e-12);
assert!((d[4] - 10.0).abs() < 1e-12);
assert!((d[5] - 18.0).abs() < 1e-12);
}
#[test]
fn test_cumprod_2d_dim1() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cp = cumprod(&x, 1).unwrap();
let d = cp.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
assert!((d[3] - 4.0).abs() < 1e-12);
assert!((d[4] - 20.0).abs() < 1e-12);
assert!((d[5] - 120.0).abs() < 1e-12);
}
#[test]
fn test_cumprod_backward_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cp = cumprod(&x, 0).unwrap();
let loss = sum(&cp).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 9.0).abs() < 1e-10, "got {}", gd[0]);
assert!((gd[1] - 4.0).abs() < 1e-10, "got {}", gd[1]);
assert!((gd[2] - 2.0).abs() < 1e-10, "got {}", gd[2]);
}
#[test]
fn test_cumprod_backward_with_zero() {
let x = leaf(&[2.0, 0.0, 3.0], &[3], true);
let cp = cumprod(&x, 0).unwrap();
let loss = sum(&cp).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 1.0).abs() < 1e-10, "d/dx[0]: got {}", gd[0]);
assert!((gd[1] - 8.0).abs() < 1e-10, "d/dx[1]: got {}", gd[1]);
assert!((gd[2] - 0.0).abs() < 1e-10, "d/dx[2]: got {}", gd[2]);
}
#[test]
fn test_cumprod_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cp = cumprod(&x, 0).unwrap();
assert!(cp.grad_fn().is_some());
assert_eq!(cp.grad_fn().unwrap().name(), "CumprodBackward");
}
#[test]
fn test_cumprod_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cp = no_grad(|| cumprod(&x, 0)).unwrap();
assert!(cp.grad_fn().is_none());
}
#[test]
fn test_cummax_1d() {
let x = leaf(&[3.0, 1.0, 4.0, 1.0, 5.0], &[5], false);
let r = cummax(&x, 0).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 3.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 4.0).abs() < 1e-12);
assert!((d[3] - 4.0).abs() < 1e-12);
assert!((d[4] - 5.0).abs() < 1e-12);
assert_eq!(r.indices, vec![0, 0, 2, 2, 4]);
}
#[test]
fn test_cummax_2d_dim1() {
let x = leaf(&[1.0, 3.0, 2.0, 5.0, 4.0, 6.0], &[2, 3], false);
let r = cummax(&x, 1).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 3.0).abs() < 1e-12);
assert!((d[3] - 5.0).abs() < 1e-12);
assert!((d[4] - 5.0).abs() < 1e-12);
assert!((d[5] - 6.0).abs() < 1e-12);
assert_eq!(r.indices, vec![0, 1, 1, 0, 0, 2]);
}
#[test]
fn test_cummin_1d() {
let x = leaf(&[3.0, 1.0, 4.0, 1.0, 5.0], &[5], false);
let r = cummin(&x, 0).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 3.0).abs() < 1e-12);
assert!((d[1] - 1.0).abs() < 1e-12);
assert!((d[2] - 1.0).abs() < 1e-12);
assert!((d[3] - 1.0).abs() < 1e-12);
assert!((d[4] - 1.0).abs() < 1e-12);
assert_eq!(r.indices, vec![0, 1, 1, 1, 1]);
}
#[test]
fn test_cummin_2d_dim0() {
let x = leaf(&[5.0, 2.0, 3.0, 4.0], &[2, 2], false);
let r = cummin(&x, 0).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 5.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 3.0).abs() < 1e-12);
assert!((d[3] - 2.0).abs() < 1e-12);
assert_eq!(r.indices, vec![0, 0, 1, 0]);
}
#[test]
fn test_logcumsumexp_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let lcs = logcumsumexp(&x, 0).unwrap();
let d = lcs.data().unwrap();
let expected_0 = 1.0_f64;
let expected_1 = (1.0_f64.exp() + 2.0_f64.exp()).ln();
let expected_2 = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
assert!((d[0] - expected_0).abs() < 1e-10, "got {}", d[0]);
assert!((d[1] - expected_1).abs() < 1e-10, "got {}", d[1]);
assert!((d[2] - expected_2).abs() < 1e-10, "got {}", d[2]);
}
#[test]
fn test_logcumsumexp_2d_dim1() {
let x = leaf(&[0.0, 1.0, 2.0, 3.0], &[2, 2], false);
let lcs = logcumsumexp(&x, 1).unwrap();
let d = lcs.data().unwrap();
let e0 = 0.0_f64;
let e1 = (0.0_f64.exp() + 1.0_f64.exp()).ln();
let e2 = 2.0_f64;
let e3 = (2.0_f64.exp() + 3.0_f64.exp()).ln();
assert!((d[0] - e0).abs() < 1e-10, "got {}", d[0]);
assert!((d[1] - e1).abs() < 1e-10, "got {}", d[1]);
assert!((d[2] - e2).abs() < 1e-10, "got {}", d[2]);
assert!((d[3] - e3).abs() < 1e-10, "got {}", d[3]);
}
#[test]
fn test_logcumsumexp_numerical_stability() {
let x = leaf(&[1000.0, 1001.0, 1002.0], &[3], false);
let lcs = logcumsumexp(&x, 0).unwrap();
let d = lcs.data().unwrap();
for &v in d {
assert!(v.is_finite(), "got non-finite: {v}");
}
assert!((d[0] - 1000.0).abs() < 1e-10);
let expected_1 = 1001.0 + ((-1.0_f64).exp() + 1.0).ln();
assert!((d[1] - expected_1).abs() < 1e-8, "got {}", d[1]);
}
#[test]
fn test_logcumsumexp_backward_1d() {
let x_vals = [1.0_f64, 2.0, 3.0];
let eps = 1e-6;
let x = leaf(&x_vals, &[3], true);
let lcs = logcumsumexp(&x, 0).unwrap();
let loss = sum(&lcs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for idx in 0..3 {
let mut x_plus = x_vals.to_vec();
let mut x_minus = x_vals.to_vec();
x_plus[idx] += eps;
x_minus[idx] -= eps;
let tp = leaf(&x_plus, &[3], false);
let lp = logcumsumexp(&tp, 0).unwrap();
let sp = sum(&lp).unwrap().item().unwrap();
let tm = leaf(&x_minus, &[3], false);
let lm = logcumsumexp(&tm, 0).unwrap();
let sm = sum(&lm).unwrap().item().unwrap();
let numerical = (sp - sm) / (2.0 * eps);
assert!(
(gd[idx] - numerical).abs() < 1e-4,
"index {idx}: analytic={}, numerical={}",
gd[idx],
numerical,
);
}
}
#[test]
fn test_logcumsumexp_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let lcs = logcumsumexp(&x, 0).unwrap();
assert!(lcs.grad_fn().is_some());
assert_eq!(lcs.grad_fn().unwrap().name(), "LogcumsumexpBackward");
}
#[test]
fn test_logcumsumexp_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let lcs = no_grad(|| logcumsumexp(&x, 0)).unwrap();
assert!(lcs.grad_fn().is_none());
}
#[test]
fn test_cumsum_scalar_error() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f64]), vec![], false).unwrap();
assert!(cumsum(&x, 0).is_err());
}
#[test]
fn test_cumprod_scalar_error() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f64]), vec![], false).unwrap();
assert!(cumprod(&x, 0).is_err());
}
#[test]
fn test_cummax_scalar_error() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f64]), vec![], false).unwrap();
assert!(cummax(&x, 0).is_err());
}
#[test]
fn test_cummin_scalar_error() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f64]), vec![], false).unwrap();
assert!(cummin(&x, 0).is_err());
}
#[test]
fn test_logcumsumexp_scalar_error() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f64]), vec![], false).unwrap();
assert!(logcumsumexp(&x, 0).is_err());
}
#[test]
fn test_cumsum_dim_out_of_bounds() {
let x = leaf(&[1.0, 2.0], &[2], false);
assert!(cumsum(&x, 1).is_err());
assert!(cumsum(&x, -2).is_err());
}
#[test]
fn test_cumprod_backward_numerical() {
let x_vals = [2.0_f64, 3.0, 0.5];
let eps = 1e-6;
let x = leaf(&x_vals, &[3], true);
let cp = cumprod(&x, 0).unwrap();
let loss = sum(&cp).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for idx in 0..3 {
let mut x_plus = x_vals.to_vec();
let mut x_minus = x_vals.to_vec();
x_plus[idx] += eps;
x_minus[idx] -= eps;
let tp = leaf(&x_plus, &[3], false);
let fp = sum(&cumprod(&tp, 0).unwrap()).unwrap().item().unwrap();
let tm = leaf(&x_minus, &[3], false);
let fm = sum(&cumprod(&tm, 0).unwrap()).unwrap().item().unwrap();
let numerical = (fp - fm) / (2.0 * eps);
assert!(
(gd[idx] - numerical).abs() < 1e-4,
"index {idx}: analytic={}, numerical={}",
gd[idx],
numerical,
);
}
}
#[test]
fn test_cumsum_backward_numerical() {
let x_vals = [1.0_f64, -2.0, 3.5, 0.7];
let eps = 1e-6;
let x = leaf(&x_vals, &[4], true);
let cs = cumsum(&x, 0).unwrap();
let loss = sum(&cs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for idx in 0..4 {
let mut x_plus = x_vals.to_vec();
let mut x_minus = x_vals.to_vec();
x_plus[idx] += eps;
x_minus[idx] -= eps;
let tp = leaf(&x_plus, &[4], false);
let fp = sum(&cumsum(&tp, 0).unwrap()).unwrap().item().unwrap();
let tm = leaf(&x_minus, &[4], false);
let fm = sum(&cumsum(&tm, 0).unwrap()).unwrap().item().unwrap();
let numerical = (fp - fm) / (2.0 * eps);
assert!(
(gd[idx] - numerical).abs() < 1e-4,
"index {idx}: analytic={}, numerical={}",
gd[idx],
numerical,
);
}
}
}