use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::ops::elementwise;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go = if grad_output.is_cuda() {
let cpu = grad_output.cpu()?;
cpu.data()?[0]
} else {
grad_output.data()?[0]
};
let numel = self.input.numel();
let data = vec![go; numel];
let grad_cpu =
Tensor::from_storage(TensorStorage::cpu(data), self.input.shape().to_vec(), false)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
pub fn sum<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = backend.sum_f32(input.gpu_handle()?, input.numel())?;
let storage = TensorStorage::gpu(handle);
let shape = vec![];
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(SumBackward {
input: input.clone(),
});
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
let result = elementwise::sum(input)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(SumBackward {
input: input.clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
}
#[derive(Debug)]
pub struct MeanBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for MeanBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go = if grad_output.is_cuda() {
let cpu = grad_output.cpu()?;
cpu.data()?[0]
} else {
grad_output.data()?[0]
};
let numel = self.input.numel();
let n = T::from(numel).unwrap();
let val = go / n;
let data = vec![val; numel];
let grad_cpu =
Tensor::from_storage(TensorStorage::cpu(data), self.input.shape().to_vec(), false)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"MeanBackward"
}
}
pub fn mean<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let result = elementwise::mean(input)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MeanBackward {
input: input.clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct ProdBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for ProdBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go = if grad_output.is_cuda() {
let cpu = grad_output.cpu()?;
cpu.data()?[0]
} else {
grad_output.data()?[0]
};
let input_cpu = if self.input.is_cuda() {
self.input.cpu()?
} else {
self.input.clone()
};
let input_data = input_cpu.data()?;
let n = input_data.len();
let mut prefix = vec![<T as num_traits::One>::one(); n];
for i in 1..n {
prefix[i] = prefix[i - 1] * input_data[i - 1];
}
let mut suffix = vec![<T as num_traits::One>::one(); n];
if n > 1 {
for i in (0..n - 1).rev() {
suffix[i] = suffix[i + 1] * input_data[i + 1];
}
}
let grad_data: Vec<T> = (0..n).map(|i| go * prefix[i] * suffix[i]).collect();
let grad_cpu = Tensor::from_storage(
TensorStorage::cpu(grad_data),
self.input.shape().to_vec(),
false,
)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ProdBackward"
}
}
pub fn prod<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let cpu_input = if input.is_cuda() {
input.cpu()?
} else {
input.clone()
};
let data = cpu_input.data()?;
let total = data
.iter()
.copied()
.fold(<T as num_traits::One>::one(), |a, b| a * b);
let result = Tensor::from_storage(TensorStorage::cpu(vec![total]), vec![], false)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(ProdBackward {
input: input.clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct SumDimBackward<T: Float> {
input: Tensor<T>,
dim: usize,
keepdim: bool,
}
impl<T: Float> GradFn<T> for SumDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let input_shape = self.input.shape();
let grad = if self.keepdim {
grad_output.clone()
} else {
let go_cpu = if grad_output.is_cuda() {
grad_output.cpu()?
} else {
grad_output.clone()
};
let mut unsqueezed_shape = go_cpu.shape().to_vec();
unsqueezed_shape.insert(self.dim, 1);
let data = go_cpu.data()?.to_vec();
Tensor::from_storage(TensorStorage::cpu(data), unsqueezed_shape, false)?
};
let grad_cpu = if grad.is_cuda() { grad.cpu()? } else { grad };
let grad_data = grad_cpu.data()?;
let grad_shape = grad_cpu.shape();
let out_numel: usize = input_shape.iter().product();
let mut result = Vec::with_capacity(out_numel);
for flat in 0..out_numel {
let mut rem = flat;
let mut coords = vec![0usize; input_shape.len()];
for d in (0..input_shape.len()).rev() {
coords[d] = rem % input_shape[d];
rem /= input_shape[d];
}
let mut grad_flat = 0usize;
let mut stride = 1usize;
for d in (0..grad_shape.len()).rev() {
let c = if d == self.dim { 0 } else { coords[d] };
grad_flat += c * stride;
stride *= grad_shape[d];
}
result.push(grad_data[grad_flat]);
}
let grad_cpu =
Tensor::from_storage(TensorStorage::cpu(result), input_shape.to_vec(), false)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumDimBackward"
}
}
pub fn sum_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "sum_dim: cannot reduce a scalar (0-D) tensor along a dimension".into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"sum_dim: dim {} is out of bounds for tensor with {} dimensions",
dim, ndim
),
});
}
let input_cpu = if input.is_cuda() {
input.cpu()?
} else {
input.clone()
};
let in_data = input_cpu.data()?;
let in_shape = input_cpu.shape();
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
let mut accum_shape: Vec<usize> = in_shape.to_vec();
accum_shape[norm_dim] = 1;
let accum_numel: usize = accum_shape.iter().product();
let mut accum = vec![<T as num_traits::Zero>::zero(); accum_numel];
for (flat, &val) in in_data.iter().enumerate().take(input.numel()) {
let mut rem = flat;
let mut coords = vec![0usize; in_shape.len()];
for d in (0..in_shape.len()).rev() {
coords[d] = rem % in_shape[d];
rem /= in_shape[d];
}
let mut oi = 0usize;
let mut os = 1usize;
for d in (0..accum_shape.len()).rev() {
let c = if d == norm_dim { 0 } else { coords[d] };
oi += c * os;
os *= accum_shape[d];
}
accum[oi] += val;
}
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(SumDimBackward {
input: input.clone(),
dim: norm_dim,
keepdim,
});
let result = Tensor::from_operation(TensorStorage::cpu(accum), out_shape, grad_fn)?;
result.to(input.device())
} else {
let result = Tensor::from_storage(TensorStorage::cpu(accum), out_shape, false)?;
result.to(input.device())
}
}
#[derive(Debug)]
pub struct MeanDimBackward<T: Float> {
input: Tensor<T>,
dim: usize,
keepdim: bool,
}
impl<T: Float> GradFn<T> for MeanDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let input_shape = self.input.shape();
let dim_size = input_shape[self.dim];
let n = T::from(dim_size).unwrap();
let grad = if self.keepdim {
grad_output.clone()
} else {
let go_cpu = if grad_output.is_cuda() {
grad_output.cpu()?
} else {
grad_output.clone()
};
let mut unsqueezed_shape = go_cpu.shape().to_vec();
unsqueezed_shape.insert(self.dim, 1);
let data = go_cpu.data()?.to_vec();
Tensor::from_storage(TensorStorage::cpu(data), unsqueezed_shape, false)?
};
let grad_cpu = if grad.is_cuda() { grad.cpu()? } else { grad };
let grad_data = grad_cpu.data()?;
let grad_shape = grad_cpu.shape();
let out_numel: usize = input_shape.iter().product();
let mut result = Vec::with_capacity(out_numel);
for flat in 0..out_numel {
let mut rem = flat;
let mut coords = vec![0usize; input_shape.len()];
for d in (0..input_shape.len()).rev() {
coords[d] = rem % input_shape[d];
rem /= input_shape[d];
}
let mut grad_flat = 0usize;
let mut stride = 1usize;
for d in (0..grad_shape.len()).rev() {
let c = if d == self.dim { 0 } else { coords[d] };
grad_flat += c * stride;
stride *= grad_shape[d];
}
result.push(grad_data[grad_flat] / n);
}
let grad_cpu =
Tensor::from_storage(TensorStorage::cpu(result), input_shape.to_vec(), false)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"MeanDimBackward"
}
}
pub fn mean_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "mean_dim: cannot reduce a scalar (0-D) tensor along a dimension".into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"mean_dim: dim {} is out of bounds for tensor with {} dimensions",
dim, ndim
),
});
}
let input_cpu = if input.is_cuda() {
input.cpu()?
} else {
input.clone()
};
let in_data = input_cpu.data()?;
let in_shape = input_cpu.shape();
let dim_size = in_shape[norm_dim];
let n = T::from(dim_size).unwrap();
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
let mut accum_shape: Vec<usize> = in_shape.to_vec();
accum_shape[norm_dim] = 1;
let accum_numel: usize = accum_shape.iter().product();
let mut accum = vec![<T as num_traits::Zero>::zero(); accum_numel];
for (flat, &val) in in_data.iter().enumerate().take(input.numel()) {
let mut rem = flat;
let mut coords = vec![0usize; in_shape.len()];
for d in (0..in_shape.len()).rev() {
coords[d] = rem % in_shape[d];
rem /= in_shape[d];
}
let mut oi = 0usize;
let mut os = 1usize;
for d in (0..accum_shape.len()).rev() {
let c = if d == norm_dim { 0 } else { coords[d] };
oi += c * os;
os *= accum_shape[d];
}
accum[oi] += val;
}
for v in &mut accum {
*v = *v / n;
}
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MeanDimBackward {
input: input.clone(),
dim: norm_dim,
keepdim,
});
let result = Tensor::from_operation(TensorStorage::cpu(accum), out_shape, grad_fn)?;
result.to(input.device())
} else {
let result = Tensor::from_storage(TensorStorage::cpu(accum), out_shape, false)?;
result.to(input.device())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::no_grad::no_grad;
use crate::storage::TensorStorage;
fn leaf(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
fn leaf_scalar(val: f64, requires_grad: bool) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], requires_grad).unwrap()
}
#[test]
fn test_sum_forward_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let s = sum(&x).unwrap();
assert!(s.is_scalar());
assert!((s.item().unwrap() - 10.0).abs() < 1e-12);
}
#[test]
fn test_sum_forward_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum(&x).unwrap();
assert!((s.item().unwrap() - 21.0).abs() < 1e-12);
}
#[test]
fn test_mean_forward() {
let x = leaf(&[2.0, 4.0, 6.0, 8.0], &[4], false);
let m = mean(&x).unwrap();
assert!((m.item().unwrap() - 5.0).abs() < 1e-12);
}
#[test]
fn test_prod_forward() {
let x = leaf(&[2.0, 3.0, 4.0], &[3], false);
let p = prod(&x).unwrap();
assert!((p.item().unwrap() - 24.0).abs() < 1e-12);
}
#[test]
fn test_prod_forward_scalar() {
let x = leaf_scalar(7.0, false);
let p = prod(&x).unwrap();
assert!((p.item().unwrap() - 7.0).abs() < 1e-12);
}
#[test]
fn test_prod_forward_with_zero() {
let x = leaf(&[3.0, 0.0, 5.0], &[3], false);
let p = prod(&x).unwrap();
assert!((p.item().unwrap()).abs() < 1e-12);
}
#[test]
fn test_sum_backward_scalar_input() {
let x = leaf_scalar(5.0, true);
let s = sum(&x).unwrap();
s.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_sum_backward_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = sum(&x).unwrap();
s.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert_eq!(gd.len(), 3);
for &v in gd {
assert!((v - 1.0).abs() < 1e-12);
}
}
#[test]
fn test_sum_backward_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let s = sum(&x).unwrap();
s.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-12);
}
}
#[test]
fn test_mean_backward_scalar_input() {
let x = leaf_scalar(5.0, true);
let m = mean(&x).unwrap();
m.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_mean_backward_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let m = mean(&x).unwrap();
m.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
let expected = 1.0 / 3.0;
for &v in gd {
assert!((v - expected).abs() < 1e-12);
}
}
#[test]
fn test_mean_backward_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let m = mean(&x).unwrap();
m.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
let expected = 1.0 / 6.0;
for &v in g.data().unwrap() {
assert!((v - expected).abs() < 1e-12);
}
}
#[test]
fn test_prod_backward_scalar_input() {
let x = leaf_scalar(5.0, true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_prod_backward_1d() {
let x = leaf(&[2.0, 3.0, 4.0], &[3], true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!(
(gd[0] - 12.0).abs() < 1e-12,
"d/da = 3*4 = 12, got {}",
gd[0]
);
assert!((gd[1] - 8.0).abs() < 1e-12, "d/db = 2*4 = 8, got {}", gd[1]);
assert!((gd[2] - 6.0).abs() < 1e-12, "d/dc = 2*3 = 6, got {}", gd[2]);
}
#[test]
fn test_prod_backward_with_zero() {
let x = leaf(&[3.0, 0.0, 5.0], &[3], true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 0.0).abs() < 1e-12, "got {}", gd[0]);
assert!((gd[1] - 15.0).abs() < 1e-12, "got {}", gd[1]);
assert!((gd[2] - 0.0).abs() < 1e-12, "got {}", gd[2]);
}
#[test]
fn test_prod_backward_two_zeros() {
let x = leaf(&[0.0, 0.0, 5.0], &[3], true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for &v in gd {
assert!((v).abs() < 1e-12, "expected 0, got {v}");
}
}
#[test]
fn test_sum_no_grad_fn_when_input_not_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let s = sum(&x).unwrap();
assert!(s.grad_fn().is_none());
assert!(!s.requires_grad());
}
#[test]
fn test_sum_has_grad_fn_when_input_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = sum(&x).unwrap();
assert!(s.grad_fn().is_some());
assert_eq!(s.grad_fn().unwrap().name(), "SumBackward");
assert!(s.requires_grad());
}
#[test]
fn test_mean_has_grad_fn_when_input_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let m = mean(&x).unwrap();
assert!(m.grad_fn().is_some());
assert_eq!(m.grad_fn().unwrap().name(), "MeanBackward");
}
#[test]
fn test_prod_has_grad_fn_when_input_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let p = prod(&x).unwrap();
assert!(p.grad_fn().is_some());
assert_eq!(p.grad_fn().unwrap().name(), "ProdBackward");
}
#[test]
fn test_sum_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = no_grad(|| sum(&x)).unwrap();
assert!(s.grad_fn().is_none());
assert!(!s.requires_grad());
}
#[test]
fn test_mean_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let m = no_grad(|| mean(&x)).unwrap();
assert!(m.grad_fn().is_none());
}
#[test]
fn test_prod_no_grad_fn_in_no_grad_context() {
let x = leaf(&[2.0, 3.0], &[2], true);
let p = no_grad(|| prod(&x)).unwrap();
assert!(p.grad_fn().is_none());
}
fn numerical_grad_check(
f: impl Fn(&Tensor<f64>) -> FerrotorchResult<Tensor<f64>>,
x_val: f64,
expected_analytic: f64,
tol: f64,
) {
let eps = 1e-7;
let x_plus = leaf_scalar(x_val + eps, false);
let x_minus = leaf_scalar(x_val - eps, false);
let f_plus = f(&x_plus).unwrap().item().unwrap();
let f_minus = f(&x_minus).unwrap().item().unwrap();
let numerical = (f_plus - f_minus) / (2.0 * eps);
assert!(
(numerical - expected_analytic).abs() < tol,
"numerical gradient {numerical} differs from analytic {expected_analytic} by more than {tol}"
);
}
#[test]
fn test_sum_numerical_gradient() {
let x = leaf_scalar(3.0, true);
let s = sum(&x).unwrap();
s.backward().unwrap();
let analytic = x.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(sum, 3.0, analytic, 1e-5);
}
#[test]
fn test_mean_numerical_gradient() {
let x = leaf_scalar(3.0, true);
let m = mean(&x).unwrap();
m.backward().unwrap();
let analytic = x.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(mean, 3.0, analytic, 1e-5);
}
#[test]
fn test_prod_numerical_gradient() {
let x = leaf_scalar(3.0, true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let analytic = x.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(prod, 3.0, analytic, 1e-5);
}
#[test]
fn test_sum_dim_axis0_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum_dim(&x, 0, false).unwrap();
assert_eq!(s.shape(), &[3]);
let d = s.data().unwrap();
assert!((d[0] - 5.0).abs() < 1e-12);
assert!((d[1] - 7.0).abs() < 1e-12);
assert!((d[2] - 9.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_axis1_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum_dim(&x, 1, false).unwrap();
assert_eq!(s.shape(), &[2]);
let d = s.data().unwrap();
assert!((d[0] - 6.0).abs() < 1e-12);
assert!((d[1] - 15.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_keepdim_true() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum_dim(&x, 0, true).unwrap();
assert_eq!(s.shape(), &[1, 3]);
let d = s.data().unwrap();
assert!((d[0] - 5.0).abs() < 1e-12);
assert!((d[1] - 7.0).abs() < 1e-12);
assert!((d[2] - 9.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_negative_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum_dim(&x, -1, false).unwrap();
assert_eq!(s.shape(), &[2]);
let d = s.data().unwrap();
assert!((d[0] - 6.0).abs() < 1e-12);
assert!((d[1] - 15.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let s = sum_dim(&x, 0, false).unwrap();
assert!(s.is_scalar());
assert!((s.item().unwrap() - 10.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_1d_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let s = sum_dim(&x, 0, true).unwrap();
assert_eq!(s.shape(), &[1]);
assert!((s.data().unwrap()[0] - 10.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_3d() {
let data: Vec<f64> = (1..=12).map(|x| x as f64).collect();
let x = leaf(&data, &[2, 2, 3], false);
let s = sum_dim(&x, 1, false).unwrap();
assert_eq!(s.shape(), &[2, 3]);
let d = s.data().unwrap();
assert!((d[0] - 5.0).abs() < 1e-12);
assert!((d[1] - 7.0).abs() < 1e-12);
assert!((d[2] - 9.0).abs() < 1e-12);
assert!((d[3] - 17.0).abs() < 1e-12);
assert!((d[4] - 19.0).abs() < 1e-12);
assert!((d[5] - 21.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_backward_axis0_no_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let s = sum_dim(&x, 0, false).unwrap();
let loss = sum(&s).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-12, "expected 1.0, got {v}");
}
}
#[test]
fn test_sum_dim_backward_axis1_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let s = sum_dim(&x, 1, true).unwrap();
assert_eq!(s.shape(), &[2, 1]);
let loss = sum(&s).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-12, "expected 1.0, got {v}");
}
}
#[test]
fn test_sum_dim_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = sum_dim(&x, 0, false).unwrap();
assert!(s.grad_fn().is_some());
assert_eq!(s.grad_fn().unwrap().name(), "SumDimBackward");
}
#[test]
fn test_sum_dim_no_grad_fn_when_not_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let s = sum_dim(&x, 0, false).unwrap();
assert!(s.grad_fn().is_none());
}
#[test]
fn test_sum_dim_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = no_grad(|| sum_dim(&x, 0, false)).unwrap();
assert!(s.grad_fn().is_none());
}
#[test]
fn test_mean_dim_axis0_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let m = mean_dim(&x, 0, false).unwrap();
assert_eq!(m.shape(), &[3]);
let d = m.data().unwrap();
assert!((d[0] - 2.5).abs() < 1e-12);
assert!((d[1] - 3.5).abs() < 1e-12);
assert!((d[2] - 4.5).abs() < 1e-12);
}
#[test]
fn test_mean_dim_axis1_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let m = mean_dim(&x, 1, false).unwrap();
assert_eq!(m.shape(), &[2]);
let d = m.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-12);
assert!((d[1] - 5.0).abs() < 1e-12);
}
#[test]
fn test_mean_dim_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let m = mean_dim(&x, 0, true).unwrap();
assert_eq!(m.shape(), &[1, 3]);
let d = m.data().unwrap();
assert!((d[0] - 2.5).abs() < 1e-12);
assert!((d[1] - 3.5).abs() < 1e-12);
assert!((d[2] - 4.5).abs() < 1e-12);
}
#[test]
fn test_mean_dim_negative_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let m = mean_dim(&x, -1, false).unwrap();
assert_eq!(m.shape(), &[2]);
let d = m.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-12);
assert!((d[1] - 5.0).abs() < 1e-12);
}
#[test]
fn test_mean_dim_backward_axis0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let m = mean_dim(&x, 0, false).unwrap();
let loss = sum(&m).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
let expected = 1.0 / 2.0;
for &v in g.data().unwrap() {
assert!((v - expected).abs() < 1e-12, "expected {expected}, got {v}");
}
}
#[test]
fn test_mean_dim_backward_axis1_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let m = mean_dim(&x, 1, true).unwrap();
assert_eq!(m.shape(), &[2, 1]);
let loss = sum(&m).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
let expected = 1.0 / 3.0;
for &v in g.data().unwrap() {
assert!((v - expected).abs() < 1e-12, "expected {expected}, got {v}");
}
}
#[test]
fn test_mean_dim_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let m = mean_dim(&x, 0, false).unwrap();
assert!(m.grad_fn().is_some());
assert_eq!(m.grad_fn().unwrap().name(), "MeanDimBackward");
}
}