use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::ops::cumulative::{
CumExtremeResult, cummax_forward, cummin_forward, cumprod_forward, cumsum_forward,
logcumsumexp_forward, reverse_cumsum,
};
use crate::shape::normalize_axis;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct CumsumBackward<T: Float> {
input: Tensor<T>,
dim: usize,
}
impl<T: Float> GradFn<T> for CumsumBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.ndim() == 0 {
return Ok(vec![Some(grad_output.clone())]);
}
if grad_output.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda {
op: "CumsumBackward",
});
}
let go_data = grad_output.data()?;
let shape = grad_output.shape();
let grad_data = reverse_cumsum(go_data, shape, self.dim);
let grad_input =
Tensor::from_storage(TensorStorage::cpu(grad_data), shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CumsumBackward"
}
}
pub fn cumsum<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<Tensor<T>> {
if input.ndim() == 0 {
return cumulative_scalar_identity(input, dim, "cumsum", ScalarBackwardKind::Cumsum);
}
let norm_dim = normalize_axis(dim as isize, input.ndim())?;
let result = cumsum_forward(input, dim)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(CumsumBackward {
input: input.clone(),
dim: norm_dim,
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
#[derive(Clone, Copy)]
enum ScalarBackwardKind {
Cumsum,
Cumprod,
Logcumsumexp,
}
fn cumulative_scalar_identity<T: Float>(
input: &Tensor<T>,
dim: i64,
op_name: &str,
kind: ScalarBackwardKind,
) -> FerrotorchResult<Tensor<T>> {
if dim != 0 && dim != -1 {
let _ = op_name; return Err(crate::error::FerrotorchError::InvalidArgument {
message: format!(
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
),
});
}
let scalar_val = input.item()?;
let result = Tensor::from_storage(TensorStorage::cpu(vec![scalar_val]), Vec::new(), false)?;
if !(is_grad_enabled() && input.requires_grad()) {
return Ok(result);
}
let (storage, shape) = result.into_storage_and_shape()?;
match kind {
ScalarBackwardKind::Cumsum => {
let grad_fn = Arc::new(CumsumBackward {
input: input.clone(),
dim: 0,
});
Tensor::from_operation(storage, shape, grad_fn)
}
ScalarBackwardKind::Cumprod => {
let saved_output =
Tensor::from_storage(TensorStorage::cpu(vec![scalar_val]), Vec::new(), false)?;
let grad_fn = Arc::new(CumprodBackward {
input: input.clone(),
output: saved_output,
dim: 0,
});
Tensor::from_operation(storage, shape, grad_fn)
}
ScalarBackwardKind::Logcumsumexp => {
let saved_output =
Tensor::from_storage(TensorStorage::cpu(vec![scalar_val]), Vec::new(), false)?;
let grad_fn = Arc::new(LogcumsumexpBackward {
input: input.clone(),
output: saved_output,
dim: 0,
});
Tensor::from_operation(storage, shape, grad_fn)
}
}
}
#[derive(Debug)]
pub struct CumprodBackward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
dim: usize,
}
impl<T: Float> GradFn<T> for CumprodBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.ndim() == 0 {
return Ok(vec![Some(grad_output.clone())]);
}
if grad_output.is_cuda() || self.input.is_cuda() || self.output.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda {
op: "CumprodBackward",
});
}
let go_data = grad_output.data()?;
let in_data = self.input.data()?;
let out_data = self.output.data()?;
let shape = self.input.shape();
let (outer, dim_size, inner) = dim_strides(shape, self.dim);
let numel = in_data.len();
let mut grad_input = vec![<T as num_traits::Zero>::zero(); numel];
for o in 0..outer {
for k in 0..inner {
let base = o * dim_size * inner + k;
let has_zero = (0..dim_size)
.any(|i| in_data[base + i * inner] == <T as num_traits::Zero>::zero());
if has_zero {
for i in 0..dim_size {
let mut acc = <T as num_traits::Zero>::zero();
for j in i..dim_size {
let mut partial = <T as num_traits::One>::one();
for kk in 0..=j {
if kk != i {
#[allow(clippy::assign_op_pattern)]
{
partial = partial * in_data[base + kk * inner];
}
}
}
acc += go_data[base + j * inner] * partial;
}
grad_input[base + i * inner] = acc;
}
} else {
let mut product = vec![<T as num_traits::Zero>::zero(); dim_size];
for (i, prod_elem) in product.iter_mut().enumerate().take(dim_size) {
let idx = base + i * inner;
*prod_elem = go_data[idx] * out_data[idx];
}
let mut rev_acc = <T as num_traits::Zero>::zero();
for i in (0..dim_size).rev() {
let idx = base + i * inner;
rev_acc += product[i];
grad_input[idx] = rev_acc / in_data[idx];
}
}
}
}
let result = Tensor::from_storage(TensorStorage::cpu(grad_input), shape.to_vec(), false)?;
Ok(vec![Some(result)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CumprodBackward"
}
}
pub fn cumprod<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<Tensor<T>> {
if input.ndim() == 0 {
return cumulative_scalar_identity(input, dim, "cumprod", ScalarBackwardKind::Cumprod);
}
let norm_dim = normalize_axis(dim as isize, input.ndim())?;
let result = cumprod_forward(input, dim)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(CumprodBackward {
input: input.clone(),
output: result.clone(),
dim: norm_dim,
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct CummaxBackward<T: Float> {
input: Tensor<T>,
indices: Vec<usize>,
input_shape: Vec<usize>,
dim: usize,
}
impl<T: Float> GradFn<T> for CummaxBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.ndim() == 0 {
return Ok(vec![Some(grad_output.clone())]);
}
cummaxmin_backward_impl(grad_output, &self.input_shape, &self.indices, self.dim)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CummaxBackward"
}
}
#[derive(Debug)]
pub struct CumminBackward<T: Float> {
input: Tensor<T>,
indices: Vec<usize>,
input_shape: Vec<usize>,
dim: usize,
}
impl<T: Float> GradFn<T> for CumminBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.ndim() == 0 {
return Ok(vec![Some(grad_output.clone())]);
}
cummaxmin_backward_impl(grad_output, &self.input_shape, &self.indices, self.dim)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CumminBackward"
}
}
fn cummaxmin_backward_impl<T: Float>(
grad_output: &Tensor<T>,
input_shape: &[usize],
indices: &[usize],
dim: usize,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda {
op: "CummaxBackward",
});
}
let numel: usize = input_shape.iter().product();
if numel == 0 {
let empty = Tensor::from_storage(
TensorStorage::cpu(Vec::<T>::new()),
input_shape.to_vec(),
false,
)?;
return Ok(vec![Some(empty)]);
}
let zeros = crate::creation::zeros::<T>(input_shape)?;
let grad_input =
crate::ops::indexing::scatter_add(&zeros, dim as isize, indices, input_shape, grad_output)?;
Ok(vec![Some(grad_input)])
}
pub fn cummax<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<CumExtremeResult<T>> {
if input.ndim() == 0 {
return cumextreme_scalar_identity(input, dim, "cummax");
}
let norm_dim = normalize_axis(dim as isize, input.ndim())?;
let result = cummax_forward(input, dim)?;
if !(is_grad_enabled() && input.requires_grad()) {
return Ok(result);
}
let CumExtremeResult { values, indices } = result;
let input_shape = values.shape().to_vec();
let grad_fn = Arc::new(CummaxBackward {
input: input.clone(),
indices: indices.clone(),
input_shape: input_shape.clone(),
dim: norm_dim,
});
let (storage, shape) = values.into_storage_and_shape()?;
let values = Tensor::from_operation(storage, shape, grad_fn)?;
Ok(CumExtremeResult { values, indices })
}
pub fn cummin<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<CumExtremeResult<T>> {
if input.ndim() == 0 {
return cumextreme_scalar_identity(input, dim, "cummin");
}
let norm_dim = normalize_axis(dim as isize, input.ndim())?;
let result = cummin_forward(input, dim)?;
if !(is_grad_enabled() && input.requires_grad()) {
return Ok(result);
}
let CumExtremeResult { values, indices } = result;
let input_shape = values.shape().to_vec();
let grad_fn = Arc::new(CumminBackward {
input: input.clone(),
indices: indices.clone(),
input_shape: input_shape.clone(),
dim: norm_dim,
});
let (storage, shape) = values.into_storage_and_shape()?;
let values = Tensor::from_operation(storage, shape, grad_fn)?;
Ok(CumExtremeResult { values, indices })
}
fn cumextreme_scalar_identity<T: Float>(
input: &Tensor<T>,
dim: i64,
op_name: &str,
) -> FerrotorchResult<CumExtremeResult<T>> {
if dim != 0 && dim != -1 {
return Err(crate::error::FerrotorchError::InvalidArgument {
message: format!(
"{op_name}(): Expected reduction dim -1 or 0 for scalar but got {dim}"
),
});
}
let scalar_val = input.item()?;
let values = Tensor::from_storage(TensorStorage::cpu(vec![scalar_val]), Vec::new(), false)?;
Ok(CumExtremeResult {
values,
indices: vec![0],
})
}
#[derive(Debug)]
pub struct LogcumsumexpBackward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
dim: usize,
}
impl<T: Float> GradFn<T> for LogcumsumexpBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.ndim() == 0 {
return Ok(vec![Some(grad_output.clone())]);
}
if grad_output.is_cuda() || self.input.is_cuda() || self.output.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda {
op: "LogcumsumexpBackward",
});
}
let go_data = grad_output.data()?;
let in_data = self.input.data()?;
let out_data = self.output.data()?;
let shape = self.input.shape();
let product: Vec<T> = go_data
.iter()
.zip(out_data.iter())
.map(|(&g, &o)| g * (-o).exp())
.collect();
let rev = reverse_cumsum(&product, shape, self.dim);
let grad_data: Vec<T> = in_data
.iter()
.zip(rev.iter())
.map(|(&x, &r)| x.exp() * r)
.collect();
let grad_input =
Tensor::from_storage(TensorStorage::cpu(grad_data), shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"LogcumsumexpBackward"
}
}
pub fn logcumsumexp<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<Tensor<T>> {
if input.ndim() == 0 {
return cumulative_scalar_identity(
input,
dim,
"logcumsumexp",
ScalarBackwardKind::Logcumsumexp,
);
}
let norm_dim = normalize_axis(dim as isize, input.ndim())?;
let result = logcumsumexp_forward(input, dim)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(LogcumsumexpBackward {
input: input.clone(),
output: result.clone(),
dim: norm_dim,
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
fn dim_strides(shape: &[usize], dim: usize) -> (usize, usize, usize) {
let outer: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let inner: usize = shape[dim + 1..].iter().product();
(outer, dim_size, inner)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::no_grad::no_grad;
use crate::grad_fns::reduction::sum;
use crate::storage::TensorStorage;
fn leaf(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
#[test]
fn test_cumsum_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let cs = cumsum(&x, 0).unwrap();
assert_eq!(cs.shape(), &[4]);
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
assert!((d[3] - 10.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_2d_dim0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cs = cumsum(&x, 0).unwrap();
assert_eq!(cs.shape(), &[2, 3]);
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 3.0).abs() < 1e-12);
assert!((d[3] - 5.0).abs() < 1e-12);
assert!((d[4] - 7.0).abs() < 1e-12);
assert!((d[5] - 9.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_2d_dim1() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cs = cumsum(&x, 1).unwrap();
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
assert!((d[3] - 4.0).abs() < 1e-12);
assert!((d[4] - 9.0).abs() < 1e-12);
assert!((d[5] - 15.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_negative_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cs = cumsum(&x, -1).unwrap();
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_3d() {
let data: Vec<f64> = (1..=12).map(|x| x as f64).collect();
let x = leaf(&data, &[2, 2, 3], false);
let cs = cumsum(&x, 1).unwrap();
assert_eq!(cs.shape(), &[2, 2, 3]);
let d = cs.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[3] - 5.0).abs() < 1e-12);
assert!((d[4] - 7.0).abs() < 1e-12);
assert!((d[5] - 9.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_backward_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cs = cumsum(&x, 0).unwrap();
let loss = sum(&cs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 3.0).abs() < 1e-12, "got {}", gd[0]);
assert!((gd[1] - 2.0).abs() < 1e-12, "got {}", gd[1]);
assert!((gd[2] - 1.0).abs() < 1e-12, "got {}", gd[2]);
}
#[test]
fn test_cumsum_backward_2d_dim0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], true);
let cs = cumsum(&x, 0).unwrap();
let loss = sum(&cs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 2.0).abs() < 1e-12, "got {}", gd[0]);
assert!((gd[1] - 2.0).abs() < 1e-12, "got {}", gd[1]);
assert!((gd[2] - 1.0).abs() < 1e-12, "got {}", gd[2]);
assert!((gd[3] - 1.0).abs() < 1e-12, "got {}", gd[3]);
}
#[test]
fn test_cumsum_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cs = cumsum(&x, 0).unwrap();
assert!(cs.grad_fn().is_some());
assert_eq!(cs.grad_fn().unwrap().name(), "CumsumBackward");
}
#[test]
fn test_cumsum_no_grad_fn_when_not_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let cs = cumsum(&x, 0).unwrap();
assert!(cs.grad_fn().is_none());
}
#[test]
fn test_cumsum_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cs = no_grad(|| cumsum(&x, 0)).unwrap();
assert!(cs.grad_fn().is_none());
}
#[test]
fn test_cumprod_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let cp = cumprod(&x, 0).unwrap();
let d = cp.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
assert!((d[3] - 24.0).abs() < 1e-12);
}
#[test]
fn test_cumprod_2d_dim0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cp = cumprod(&x, 0).unwrap();
let d = cp.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 3.0).abs() < 1e-12);
assert!((d[3] - 4.0).abs() < 1e-12);
assert!((d[4] - 10.0).abs() < 1e-12);
assert!((d[5] - 18.0).abs() < 1e-12);
}
#[test]
fn test_cumprod_2d_dim1() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let cp = cumprod(&x, 1).unwrap();
let d = cp.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 6.0).abs() < 1e-12);
assert!((d[3] - 4.0).abs() < 1e-12);
assert!((d[4] - 20.0).abs() < 1e-12);
assert!((d[5] - 120.0).abs() < 1e-12);
}
#[test]
fn test_cumprod_backward_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cp = cumprod(&x, 0).unwrap();
let loss = sum(&cp).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 9.0).abs() < 1e-10, "got {}", gd[0]);
assert!((gd[1] - 4.0).abs() < 1e-10, "got {}", gd[1]);
assert!((gd[2] - 2.0).abs() < 1e-10, "got {}", gd[2]);
}
#[test]
fn test_cumprod_backward_with_zero() {
let x = leaf(&[2.0, 0.0, 3.0], &[3], true);
let cp = cumprod(&x, 0).unwrap();
let loss = sum(&cp).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 1.0).abs() < 1e-10, "d/dx[0]: got {}", gd[0]);
assert!((gd[1] - 8.0).abs() < 1e-10, "d/dx[1]: got {}", gd[1]);
assert!((gd[2] - 0.0).abs() < 1e-10, "d/dx[2]: got {}", gd[2]);
}
#[test]
fn test_cumprod_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cp = cumprod(&x, 0).unwrap();
assert!(cp.grad_fn().is_some());
assert_eq!(cp.grad_fn().unwrap().name(), "CumprodBackward");
}
#[test]
fn test_cumprod_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let cp = no_grad(|| cumprod(&x, 0)).unwrap();
assert!(cp.grad_fn().is_none());
}
#[test]
fn test_cummax_1d() {
let x = leaf(&[3.0, 1.0, 4.0, 1.0, 5.0], &[5], false);
let r = cummax(&x, 0).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 3.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 4.0).abs() < 1e-12);
assert!((d[3] - 4.0).abs() < 1e-12);
assert!((d[4] - 5.0).abs() < 1e-12);
assert_eq!(r.indices, vec![0, 0, 2, 2, 4]);
}
#[test]
fn test_cummax_2d_dim1() {
let x = leaf(&[1.0, 3.0, 2.0, 5.0, 4.0, 6.0], &[2, 3], false);
let r = cummax(&x, 1).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12);
assert!((d[1] - 3.0).abs() < 1e-12);
assert!((d[2] - 3.0).abs() < 1e-12);
assert!((d[3] - 5.0).abs() < 1e-12);
assert!((d[4] - 5.0).abs() < 1e-12);
assert!((d[5] - 6.0).abs() < 1e-12);
assert_eq!(r.indices, vec![0, 1, 1, 0, 0, 2]);
}
#[test]
fn test_cummin_1d() {
let x = leaf(&[3.0, 1.0, 4.0, 1.0, 5.0], &[5], false);
let r = cummin(&x, 0).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 3.0).abs() < 1e-12);
assert!((d[1] - 1.0).abs() < 1e-12);
assert!((d[2] - 1.0).abs() < 1e-12);
assert!((d[3] - 1.0).abs() < 1e-12);
assert!((d[4] - 1.0).abs() < 1e-12);
assert_eq!(r.indices, vec![0, 1, 1, 3, 3]);
}
#[test]
fn test_cummin_2d_dim0() {
let x = leaf(&[5.0, 2.0, 3.0, 4.0], &[2, 2], false);
let r = cummin(&x, 0).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 5.0).abs() < 1e-12);
assert!((d[1] - 2.0).abs() < 1e-12);
assert!((d[2] - 3.0).abs() < 1e-12);
assert!((d[3] - 2.0).abs() < 1e-12);
assert_eq!(r.indices, vec![0, 0, 1, 0]);
}
#[test]
fn test_logcumsumexp_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let lcs = logcumsumexp(&x, 0).unwrap();
let d = lcs.data().unwrap();
let expected_0 = 1.0_f64;
let expected_1 = (1.0_f64.exp() + 2.0_f64.exp()).ln();
let expected_2 = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
assert!((d[0] - expected_0).abs() < 1e-10, "got {}", d[0]);
assert!((d[1] - expected_1).abs() < 1e-10, "got {}", d[1]);
assert!((d[2] - expected_2).abs() < 1e-10, "got {}", d[2]);
}
#[test]
fn test_logcumsumexp_2d_dim1() {
let x = leaf(&[0.0, 1.0, 2.0, 3.0], &[2, 2], false);
let lcs = logcumsumexp(&x, 1).unwrap();
let d = lcs.data().unwrap();
let e0 = 0.0_f64;
let e1 = (0.0_f64.exp() + 1.0_f64.exp()).ln();
let e2 = 2.0_f64;
let e3 = (2.0_f64.exp() + 3.0_f64.exp()).ln();
assert!((d[0] - e0).abs() < 1e-10, "got {}", d[0]);
assert!((d[1] - e1).abs() < 1e-10, "got {}", d[1]);
assert!((d[2] - e2).abs() < 1e-10, "got {}", d[2]);
assert!((d[3] - e3).abs() < 1e-10, "got {}", d[3]);
}
#[test]
fn test_logcumsumexp_numerical_stability() {
let x = leaf(&[1000.0, 1001.0, 1002.0], &[3], false);
let lcs = logcumsumexp(&x, 0).unwrap();
let d = lcs.data().unwrap();
for &v in d {
assert!(v.is_finite(), "got non-finite: {v}");
}
assert!((d[0] - 1000.0).abs() < 1e-10);
let expected_1 = 1001.0 + ((-1.0_f64).exp() + 1.0).ln();
assert!((d[1] - expected_1).abs() < 1e-8, "got {}", d[1]);
}
#[test]
fn test_logcumsumexp_backward_1d() {
let x_vals = [1.0_f64, 2.0, 3.0];
let eps = 1e-6;
let x = leaf(&x_vals, &[3], true);
let lcs = logcumsumexp(&x, 0).unwrap();
let loss = sum(&lcs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for idx in 0..3 {
let mut x_plus = x_vals.to_vec();
let mut x_minus = x_vals.to_vec();
x_plus[idx] += eps;
x_minus[idx] -= eps;
let tp = leaf(&x_plus, &[3], false);
let lp = logcumsumexp(&tp, 0).unwrap();
let sp = sum(&lp).unwrap().item().unwrap();
let tm = leaf(&x_minus, &[3], false);
let lm = logcumsumexp(&tm, 0).unwrap();
let sm = sum(&lm).unwrap().item().unwrap();
let numerical = (sp - sm) / (2.0 * eps);
assert!(
(gd[idx] - numerical).abs() < 1e-4,
"index {idx}: analytic={}, numerical={}",
gd[idx],
numerical,
);
}
}
#[test]
fn test_logcumsumexp_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let lcs = logcumsumexp(&x, 0).unwrap();
assert!(lcs.grad_fn().is_some());
assert_eq!(lcs.grad_fn().unwrap().name(), "LogcumsumexpBackward");
}
#[test]
fn test_logcumsumexp_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let lcs = no_grad(|| logcumsumexp(&x, 0)).unwrap();
assert!(lcs.grad_fn().is_none());
}
const PT_SCALAR_5_0: f64 = 5.0;
const PT_SCALAR_NEG_3_5: f64 = -3.5;
const PT_SCALAR_INDEX_0: usize = 0;
#[test]
fn test_cumsum_scalar_passthrough() {
let x =
Tensor::from_storage(TensorStorage::cpu(vec![PT_SCALAR_5_0]), vec![], false).unwrap();
let r = cumsum(&x, 0).unwrap();
assert_eq!(r.shape(), &[] as &[usize]);
assert!((r.item().unwrap() - PT_SCALAR_5_0).abs() < 1e-12);
let r_neg = cumsum(&x, -1).unwrap();
assert!((r_neg.item().unwrap() - PT_SCALAR_5_0).abs() < 1e-12);
}
#[test]
fn test_cumprod_scalar_passthrough() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![PT_SCALAR_NEG_3_5]), vec![], false)
.unwrap();
let r = cumprod(&x, 0).unwrap();
assert_eq!(r.shape(), &[] as &[usize]);
assert!((r.item().unwrap() - PT_SCALAR_NEG_3_5).abs() < 1e-12);
let r_neg = cumprod(&x, -1).unwrap();
assert!((r_neg.item().unwrap() - PT_SCALAR_NEG_3_5).abs() < 1e-12);
}
#[test]
fn test_cummax_scalar_passthrough() {
let x =
Tensor::from_storage(TensorStorage::cpu(vec![PT_SCALAR_5_0]), vec![], false).unwrap();
let r = cummax(&x, 0).unwrap();
assert_eq!(r.values.shape(), &[] as &[usize]);
assert!((r.values.item().unwrap() - PT_SCALAR_5_0).abs() < 1e-12);
assert_eq!(r.indices, vec![PT_SCALAR_INDEX_0]);
}
#[test]
fn test_cummin_scalar_passthrough() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![PT_SCALAR_NEG_3_5]), vec![], false)
.unwrap();
let r = cummin(&x, 0).unwrap();
assert!((r.values.item().unwrap() - PT_SCALAR_NEG_3_5).abs() < 1e-12);
assert_eq!(r.indices, vec![PT_SCALAR_INDEX_0]);
}
#[test]
fn test_logcumsumexp_scalar_passthrough() {
let x =
Tensor::from_storage(TensorStorage::cpu(vec![PT_SCALAR_5_0]), vec![], false).unwrap();
let r = logcumsumexp(&x, 0).unwrap();
assert!((r.item().unwrap() - PT_SCALAR_5_0).abs() < 1e-12);
}
#[test]
fn test_cumsum_scalar_dim_out_of_range() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f64]), vec![], false).unwrap();
assert!(cumsum(&x, 1).is_err());
assert!(cumsum(&x, -2).is_err());
}
#[test]
fn test_cummax_scalar_dim_out_of_range() {
let x = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f64]), vec![], false).unwrap();
assert!(cummax(&x, 1).is_err());
assert!(cummin(&x, -2).is_err());
}
#[test]
fn test_cumsum_scalar_backward_is_identity() {
let x = leaf(&[PT_SCALAR_5_0], &[], true);
let cs = cumsum(&x, 0).unwrap();
let loss = sum(&cs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_cumprod_scalar_backward_is_identity() {
let x = leaf(&[PT_SCALAR_NEG_3_5], &[], true);
let cp = cumprod(&x, 0).unwrap();
let loss = sum(&cp).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_logcumsumexp_scalar_backward_is_identity() {
let x = leaf(&[PT_SCALAR_5_0], &[], true);
let lcs = logcumsumexp(&x, 0).unwrap();
let loss = sum(&lcs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_cumsum_dim_out_of_bounds() {
let x = leaf(&[1.0, 2.0], &[2], false);
assert!(cumsum(&x, 1).is_err());
assert!(cumsum(&x, -2).is_err());
}
#[test]
fn test_cumprod_backward_numerical() {
let x_vals = [2.0_f64, 3.0, 0.5];
let eps = 1e-6;
let x = leaf(&x_vals, &[3], true);
let cp = cumprod(&x, 0).unwrap();
let loss = sum(&cp).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for idx in 0..3 {
let mut x_plus = x_vals.to_vec();
let mut x_minus = x_vals.to_vec();
x_plus[idx] += eps;
x_minus[idx] -= eps;
let tp = leaf(&x_plus, &[3], false);
let fp = sum(&cumprod(&tp, 0).unwrap()).unwrap().item().unwrap();
let tm = leaf(&x_minus, &[3], false);
let fm = sum(&cumprod(&tm, 0).unwrap()).unwrap().item().unwrap();
let numerical = (fp - fm) / (2.0 * eps);
assert!(
(gd[idx] - numerical).abs() < 1e-4,
"index {idx}: analytic={}, numerical={}",
gd[idx],
numerical,
);
}
}
#[test]
fn test_cumsum_backward_numerical() {
let x_vals = [1.0_f64, -2.0, 3.5, 0.7];
let eps = 1e-6;
let x = leaf(&x_vals, &[4], true);
let cs = cumsum(&x, 0).unwrap();
let loss = sum(&cs).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for idx in 0..4 {
let mut x_plus = x_vals.to_vec();
let mut x_minus = x_vals.to_vec();
x_plus[idx] += eps;
x_minus[idx] -= eps;
let tp = leaf(&x_plus, &[4], false);
let fp = sum(&cumsum(&tp, 0).unwrap()).unwrap().item().unwrap();
let tm = leaf(&x_minus, &[4], false);
let fm = sum(&cumsum(&tm, 0).unwrap()).unwrap().item().unwrap();
let numerical = (fp - fm) / (2.0 * eps);
assert!(
(gd[idx] - numerical).abs() < 1e-4,
"index {idx}: analytic={}, numerical={}",
gd[idx],
numerical,
);
}
}
#[test]
fn test_cummax_backward_monotonic() {
const EXPECTED: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], true);
let r = cummax(&x, 0).unwrap();
assert_eq!(r.indices, vec![0, 1, 2, 3]);
let loss = sum(&r.values).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for (i, expected) in EXPECTED.iter().enumerate() {
assert!(
(gd[i] - expected).abs() < 1e-12,
"cummax_backward_monotonic: idx={i} got={} expected={}",
gd[i],
expected,
);
}
assert_eq!(r.values.grad_fn().unwrap().name(), "CummaxBackward");
}
#[test]
fn test_cummax_backward_tie() {
const EXPECTED_INDICES: [usize; 4] = [0, 1, 2, 3];
const EXPECTED_GRAD: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
let x = leaf(&[1.0, 2.0, 2.0, 3.0], &[4], true);
let r = cummax(&x, 0).unwrap();
assert_eq!(r.indices, EXPECTED_INDICES.to_vec());
let loss = sum(&r.values).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for (i, expected) in EXPECTED_GRAD.iter().enumerate() {
assert!(
(gd[i] - expected).abs() < 1e-12,
"cummax_backward_tie: idx={i} got={} expected={}",
gd[i],
expected,
);
}
}
#[test]
fn test_cummin_backward_tie() {
const EXPECTED_INDICES: [usize; 4] = [0, 1, 2, 3];
const EXPECTED_GRAD: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
let x = leaf(&[5.0, 2.0, 2.0, 1.0], &[4], true);
let r = cummin(&x, 0).unwrap();
assert_eq!(r.indices, EXPECTED_INDICES.to_vec());
let loss = sum(&r.values).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for (i, expected) in EXPECTED_GRAD.iter().enumerate() {
assert!(
(gd[i] - expected).abs() < 1e-12,
"cummin_backward_tie: idx={i} got={} expected={}",
gd[i],
expected,
);
}
assert_eq!(r.values.grad_fn().unwrap().name(), "CumminBackward");
}
#[test]
fn test_cummax_forward_nan_propagates() {
const EXPECTED_INDICES: [usize; 4] = [0, 1, 1, 1];
let x = leaf(&[1.0, f64::NAN, 3.0, 4.0], &[4], false);
let r = cummax(&x, 0).unwrap();
let d = r.values.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-12, "values[0] should be 1.0");
assert!(d[1].is_nan(), "values[1] should be NaN (input is NaN)");
assert!(d[2].is_nan(), "values[2] should propagate NaN");
assert!(d[3].is_nan(), "values[3] should propagate NaN");
assert_eq!(r.indices, EXPECTED_INDICES.to_vec());
}
}