use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::bool_tensor::BoolTensor;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::int_tensor::IntTensor;
use crate::ops::elementwise;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go = if grad_output.is_cuda() {
let cpu = grad_output.cpu()?;
cpu.data()?[0]
} else {
grad_output.data()?[0]
};
let numel = self.input.numel();
if self.input.is_cuda() {
use crate::device::Device;
use crate::gpu_dispatch::gpu_backend;
use std::any::TypeId;
let ordinal = match self.input.device() {
Device::Cuda(o) => o,
_ => 0,
};
let is_t_f32 = TypeId::of::<T>() == TypeId::of::<f32>();
let is_t_f64 = TypeId::of::<T>() == TypeId::of::<f64>();
if let Some(backend) = gpu_backend() {
if is_t_f32 {
let scalar_f32: f32 = <T as num_traits::ToPrimitive>::to_f32(&go).unwrap();
let handle = backend.fill_f32(numel, scalar_f32, ordinal)?;
let grad_input = Tensor::from_storage(
TensorStorage::gpu(handle),
self.input.shape().to_vec(),
false,
)?;
return Ok(vec![Some(grad_input)]);
} else if is_t_f64 {
let scalar_f64: f64 = <T as num_traits::ToPrimitive>::to_f64(&go).unwrap();
let handle = backend.fill_f64(numel, scalar_f64, ordinal)?;
let grad_input = Tensor::from_storage(
TensorStorage::gpu(handle),
self.input.shape().to_vec(),
false,
)?;
return Ok(vec![Some(grad_input)]);
}
}
}
let data = vec![go; numel];
let grad_cpu =
Tensor::from_storage(TensorStorage::cpu(data), self.input.shape().to_vec(), false)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
pub fn sum<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = crate::meta_propagate::reduce_all(input)? {
return Ok(out);
}
crate::profiler_hook::profile_op_scope("sum", "reduction", &[input.shape()], || {
sum_inner(input)
})
}
fn sum_inner<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle: crate::gpu_dispatch::GpuBufferHandle = crate::dispatch_floating_dtype!(
T,
"sum",
f32 => backend.sum_f32(input.gpu_handle()?, input.numel()),
f64 => backend.sum_f64(input.gpu_handle()?, input.numel()),
bf16 => backend.sum_bf16_bf16(input.gpu_handle()?),
f16 => backend.sum_f16(input.gpu_handle()?),
)?;
let storage = TensorStorage::gpu(handle);
let shape = vec![];
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(SumBackward {
input: input.clone(),
});
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
let result = elementwise::sum(input)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(SumBackward {
input: input.clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
}
#[derive(Debug)]
pub struct MeanBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for MeanBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go = if grad_output.is_cuda() {
let cpu = grad_output.cpu()?;
cpu.data()?[0]
} else {
grad_output.data()?[0]
};
let numel = self.input.numel();
let n = T::from(numel).unwrap();
let val = go / n;
if self.input.is_cuda() {
use crate::device::Device;
use crate::gpu_dispatch::gpu_backend;
use std::any::TypeId;
let ordinal = match self.input.device() {
Device::Cuda(o) => o,
_ => 0,
};
let is_t_f32 = TypeId::of::<T>() == TypeId::of::<f32>();
let is_t_f64 = TypeId::of::<T>() == TypeId::of::<f64>();
if let Some(backend) = gpu_backend() {
if is_t_f32 {
let scalar_f32: f32 = <T as num_traits::ToPrimitive>::to_f32(&val).unwrap();
let handle = backend.fill_f32(numel, scalar_f32, ordinal)?;
let grad_input = Tensor::from_storage(
TensorStorage::gpu(handle),
self.input.shape().to_vec(),
false,
)?;
return Ok(vec![Some(grad_input)]);
} else if is_t_f64 {
let scalar_f64: f64 = <T as num_traits::ToPrimitive>::to_f64(&val).unwrap();
let handle = backend.fill_f64(numel, scalar_f64, ordinal)?;
let grad_input = Tensor::from_storage(
TensorStorage::gpu(handle),
self.input.shape().to_vec(),
false,
)?;
return Ok(vec![Some(grad_input)]);
}
}
}
let data = vec![val; numel];
let grad_cpu =
Tensor::from_storage(TensorStorage::cpu(data), self.input.shape().to_vec(), false)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"MeanBackward"
}
}
pub fn mean<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = crate::meta_propagate::reduce_all(input)? {
return Ok(out);
}
crate::profiler_hook::profile_op_scope("mean", "reduction", &[input.shape()], || {
mean_inner(input)
})
}
fn mean_inner<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let result = if input.is_cuda() {
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let input = input.contiguous()?;
let mean_handle: crate::gpu_dispatch::GpuBufferHandle = crate::dispatch_floating_dtype!(
T,
"mean",
f32 => {
let sum_handle = backend.sum_f32(input.gpu_handle()?, input.numel())?;
let inv_n = 1.0f32 / input.numel() as f32;
Ok::<_, crate::error::FerrotorchError>(backend.scale_f32(&sum_handle, inv_n)?)
},
f64 => {
let sum_handle = backend.sum_f64(input.gpu_handle()?, input.numel())?;
let inv_n = 1.0f64 / input.numel() as f64;
Ok::<_, crate::error::FerrotorchError>(backend.scale_f64(&sum_handle, inv_n)?)
},
bf16 => Ok::<_, crate::error::FerrotorchError>(
backend.mean_bf16_bf16(input.gpu_handle()?)?
),
f16 => Ok::<_, crate::error::FerrotorchError>(
backend.mean_f16(input.gpu_handle()?)?
),
)?;
Tensor::from_storage(TensorStorage::gpu(mean_handle), vec![], false)?
} else {
elementwise::mean(input)?
}
} else {
elementwise::mean(input)?
};
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MeanBackward {
input: input.clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct ProdBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for ProdBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let t_is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
let t_is_f64 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>();
if self.input.is_cuda()
&& (t_is_f32 || t_is_f64)
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let go_on_device = if grad_output.is_cuda() {
grad_output.clone()
} else {
grad_output.to(self.input.device())?
};
let grad_handle = if t_is_f32 {
backend.prod_backward_f32(self.input.gpu_handle()?, go_on_device.gpu_handle()?)?
} else {
backend.prod_backward_f64(self.input.gpu_handle()?, go_on_device.gpu_handle()?)?
};
let storage = TensorStorage::gpu(grad_handle);
let grad_input = Tensor::from_storage(storage, self.input.shape().to_vec(), false)?;
return Ok(vec![Some(grad_input)]);
}
if self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "prod backward",
});
}
let go = grad_output.data()?[0];
let input_data = self.input.data()?;
let n = input_data.len();
let mut prefix = vec![<T as num_traits::One>::one(); n];
for i in 1..n {
prefix[i] = prefix[i - 1] * input_data[i - 1];
}
let mut suffix = vec![<T as num_traits::One>::one(); n];
if n > 1 {
for i in (0..n - 1).rev() {
suffix[i] = suffix[i + 1] * input_data[i + 1];
}
}
let grad_data: Vec<T> = (0..n).map(|i| go * prefix[i] * suffix[i]).collect();
let grad_cpu = Tensor::from_storage(
TensorStorage::cpu(grad_data),
self.input.shape().to_vec(),
false,
)?;
let grad_input = grad_cpu.to(self.input.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ProdBackward"
}
}
pub fn prod<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = crate::meta_propagate::reduce_all(input)? {
return Ok(out);
}
crate::profiler_hook::profile_op_scope("prod", "reduction", &[input.shape()], || {
prod_inner(input)
})
}
fn prod_inner<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let t_is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
let t_is_f64 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>();
if input.is_cuda() && (t_is_f32 || t_is_f64) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle = if t_is_f32 {
backend.prod_f32(input.gpu_handle()?, input.numel())?
} else {
backend.prod_f64(input.gpu_handle()?, input.numel())?
};
let storage = TensorStorage::gpu(handle);
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(ProdBackward {
input: input.clone(),
});
return Tensor::from_operation(storage, vec![], grad_fn);
}
return Tensor::from_storage(storage, vec![], false);
}
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "prod" });
}
let data = input.data()?;
let total = data
.iter()
.copied()
.fold(<T as num_traits::One>::one(), |a, b| a * b);
let result = Tensor::from_storage(TensorStorage::cpu(vec![total]), vec![], false)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(ProdBackward {
input: input.clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct AminBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for AminBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go = if grad_output.is_cuda() {
grad_output.cpu()?.data()?[0]
} else {
grad_output.data()?[0]
};
let input_data = self.input.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let mn = input_data
.iter()
.copied()
.fold(
T::from(f64::INFINITY).unwrap(),
|a, b| if b < a { b } else { a },
);
let count = input_data.iter().filter(|&&v| v == mn).count() as f64;
let scale = T::from(go.to_f64().unwrap() / count.max(1.0)).unwrap();
let result: Vec<T> = input_data
.iter()
.map(|&v| if v == mn { scale } else { zero })
.collect();
let grad_input = Tensor::from_storage(
TensorStorage::cpu(result),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AminBackward"
}
}
#[derive(Debug)]
pub struct AmaxBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for AmaxBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go = if grad_output.is_cuda() {
grad_output.cpu()?.data()?[0]
} else {
grad_output.data()?[0]
};
let input_data = self.input.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let mx = input_data
.iter()
.copied()
.fold(T::from(f64::NEG_INFINITY).unwrap(), |a, b| {
if b > a { b } else { a }
});
let count = input_data.iter().filter(|&&v| v == mx).count() as f64;
let scale = T::from(go.to_f64().unwrap() / count.max(1.0)).unwrap();
let result: Vec<T> = input_data
.iter()
.map(|&v| if v == mx { scale } else { zero })
.collect();
let grad_input = Tensor::from_storage(
TensorStorage::cpu(result),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AmaxBackward"
}
}
pub fn amin<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
let is_f64 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>();
if input.is_cuda() && (is_f32 || is_f64) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle = if is_f32 {
backend.min_f32(input.gpu_handle()?, input.numel())?
} else {
backend.min_f64(input.gpu_handle()?, input.numel())?
};
let storage = TensorStorage::gpu(handle);
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(AminBackward {
input: input.clone(),
});
return Tensor::from_operation(storage, vec![], grad_fn);
}
return Tensor::from_storage(storage, vec![], false);
}
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "amin" });
}
let data = input.data_vec()?;
let mn = data.iter().copied().fold(
T::from(f64::INFINITY).unwrap(),
|a, b| if b < a { b } else { a },
);
let storage = TensorStorage::cpu(vec![mn]);
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(AminBackward {
input: input.clone(),
});
Tensor::from_operation(storage, vec![], grad_fn)
} else {
Tensor::from_storage(storage, vec![], false)
}
}
pub fn amax<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
let is_f64 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>();
if input.is_cuda() && (is_f32 || is_f64) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle = if is_f32 {
backend.max_f32(input.gpu_handle()?, input.numel())?
} else {
backend.max_f64(input.gpu_handle()?, input.numel())?
};
let storage = TensorStorage::gpu(handle);
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(AmaxBackward {
input: input.clone(),
});
return Tensor::from_operation(storage, vec![], grad_fn);
}
return Tensor::from_storage(storage, vec![], false);
}
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "amax" });
}
let data = input.data_vec()?;
let mx = data
.iter()
.copied()
.fold(T::from(f64::NEG_INFINITY).unwrap(), |a, b| {
if b > a { b } else { a }
});
let storage = TensorStorage::cpu(vec![mx]);
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(AmaxBackward {
input: input.clone(),
});
Tensor::from_operation(storage, vec![], grad_fn)
} else {
Tensor::from_storage(storage, vec![], false)
}
}
#[derive(Debug)]
pub struct SumDimBackward<T: Float> {
input: Tensor<T>,
dim: usize,
keepdim: bool,
}
impl<T: Float> GradFn<T> for SumDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let input_shape = self.input.shape();
let outer: usize = input_shape[..self.dim].iter().product::<usize>().max(1);
let inner: usize = input_shape[(self.dim + 1)..]
.iter()
.product::<usize>()
.max(1);
let repeat_count = input_shape[self.dim];
let t_is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
let t_is_f64 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>();
if grad_output.is_cuda() && (t_is_f32 || t_is_f64) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let result_h = if t_is_f32 {
backend.repeat_along_dim_f32(
grad_output.gpu_handle()?,
outer,
repeat_count,
inner,
)?
} else {
backend.repeat_along_dim_f64(
grad_output.gpu_handle()?,
outer,
repeat_count,
inner,
)?
};
let grad_input =
Tensor::from_storage(TensorStorage::gpu(result_h), input_shape.to_vec(), false)?;
return Ok(vec![Some(grad_input)]);
}
if grad_output.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "sum_dim backward",
});
}
let grad = if self.keepdim {
grad_output.clone()
} else {
let mut unsqueezed_shape = grad_output.shape().to_vec();
unsqueezed_shape.insert(self.dim, 1);
let data = grad_output.data()?.to_vec();
Tensor::from_storage(TensorStorage::cpu(data), unsqueezed_shape, false)?
};
let grad_data = grad.data()?;
let grad_shape = grad.shape();
let out_numel: usize = input_shape.iter().product();
let mut result = Vec::with_capacity(out_numel);
for flat in 0..out_numel {
let mut rem = flat;
let mut coords = vec![0usize; input_shape.len()];
for d in (0..input_shape.len()).rev() {
coords[d] = rem % input_shape[d];
rem /= input_shape[d];
}
let mut grad_flat = 0usize;
let mut stride = 1usize;
for d in (0..grad_shape.len()).rev() {
let c = if d == self.dim { 0 } else { coords[d] };
grad_flat += c * stride;
stride *= grad_shape[d];
}
result.push(grad_data[grad_flat]);
}
let grad_input =
Tensor::from_storage(TensorStorage::cpu(result), input_shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumDimBackward"
}
}
pub fn sum_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = crate::meta_propagate::reduce_dim(input, dim, keepdim)? {
return Ok(out);
}
crate::profiler_hook::profile_op_scope("sum_dim", "reduction", &[input.shape()], || {
sum_dim_inner(input, dim, keepdim)
})
}
fn reduce_axis_sum_contiguous<T: Float>(
in_data: &[T],
outer: usize,
axis: usize,
inner: usize,
) -> Vec<T> {
let accum_numel = outer * inner;
let mut accum = vec![<T as num_traits::Zero>::zero(); accum_numel];
if inner == 1 {
let t_is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
let t_is_f64 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>();
if t_is_f32 {
let mut buf: Vec<f32> = vec![0.0; axis];
for (o, slot) in accum.iter_mut().enumerate() {
let row = &in_data[o * axis..o * axis + axis];
for (b, &v) in buf.iter_mut().zip(row.iter()) {
*b = num_traits::ToPrimitive::to_f32(&v).unwrap_or(0.0);
}
let s = crate::simd_reduce::sum_f32(&buf);
*slot = <T as num_traits::NumCast>::from(s).unwrap_or(*slot);
}
} else if t_is_f64 {
let mut buf: Vec<f64> = vec![0.0; axis];
for (o, slot) in accum.iter_mut().enumerate() {
let row = &in_data[o * axis..o * axis + axis];
for (b, &v) in buf.iter_mut().zip(row.iter()) {
*b = num_traits::ToPrimitive::to_f64(&v).unwrap_or(0.0);
}
let s = crate::simd_reduce::sum_f64(&buf);
*slot = <T as num_traits::NumCast>::from(s).unwrap_or(*slot);
}
} else {
for (o, slot) in accum.iter_mut().enumerate() {
let row = &in_data[o * axis..o * axis + axis];
let mut acc = <T as num_traits::Zero>::zero();
for &v in row {
acc += v;
}
*slot = acc;
}
}
} else {
for o in 0..outer {
let ab = o * inner;
for a in 0..axis {
let ib = (o * axis + a) * inner;
let src = &in_data[ib..ib + inner];
let dst = &mut accum[ab..ab + inner];
for (acc_i, &v) in dst.iter_mut().zip(src.iter()) {
*acc_i += v;
}
}
}
}
accum
}
fn outer_axis_inner(in_shape: &[usize], norm_dim: usize) -> (usize, usize, usize) {
let outer: usize = in_shape[..norm_dim].iter().product();
let axis: usize = in_shape[norm_dim];
let inner: usize = in_shape[norm_dim + 1..].iter().product();
(outer, axis, inner)
}
fn sum_dim_inner<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "sum_dim: cannot reduce a scalar (0-D) tensor along a dimension".into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"sum_dim: dim {dim} is out of bounds for tensor with {ndim} dimensions"
),
});
}
let in_shape = input.shape();
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
if input.is_cuda() {
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let input = input.contiguous()?;
let handle: crate::gpu_dispatch::GpuBufferHandle = crate::dispatch_floating_dtype!(
T,
"sum_dim",
f32 => backend.sum_axis_f32(input.gpu_handle()?, in_shape, norm_dim),
f64 => backend.sum_axis_f64(input.gpu_handle()?, in_shape, norm_dim),
bf16 => backend.sum_axis_bf16_bf16(input.gpu_handle()?, in_shape, norm_dim),
f16 => backend.sum_axis_f16(input.gpu_handle()?, in_shape, norm_dim),
)?;
let storage = TensorStorage::gpu(handle);
return if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(SumDimBackward {
input: input.clone(),
dim: norm_dim,
keepdim,
});
Tensor::from_operation(storage, out_shape, grad_fn)
} else {
Tensor::from_storage(storage, out_shape, false)
};
}
return Err(FerrotorchError::DeviceUnavailable);
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let (outer, axis, inner) = outer_axis_inner(in_shape, norm_dim);
let accum = reduce_axis_sum_contiguous(&in_data[..input.numel()], outer, axis, inner);
let device = input.device();
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(SumDimBackward {
input: input.clone(),
dim: norm_dim,
keepdim,
});
let storage = TensorStorage::on_device(accum, device)?;
Tensor::from_operation(storage, out_shape, grad_fn)
} else {
let storage = TensorStorage::on_device(accum, device)?;
Tensor::from_storage(storage, out_shape, false)
}
}
#[derive(Debug)]
pub struct MeanDimBackward<T: Float> {
input: Tensor<T>,
dim: usize,
keepdim: bool,
}
impl<T: Float> GradFn<T> for MeanDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let input_shape = self.input.shape();
let dim_size = input_shape[self.dim];
let is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
if grad_output.is_cuda()
&& is_f32
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let grad_shape_keepdim: Vec<usize> = if self.keepdim {
grad_output.shape().to_vec()
} else {
let mut s = grad_output.shape().to_vec();
s.insert(self.dim, 1);
s
};
let input_numel: usize = input_shape.iter().product();
let inv_n = 1.0f32 / (dim_size as f32);
let ones_handle = backend.fill_f32(input_numel, inv_n, 0)?;
let grad_handle = grad_output.gpu_handle()?;
let grad_input_handle = backend.broadcast_mul_f32(
&ones_handle,
grad_handle,
input_shape,
&grad_shape_keepdim,
input_shape,
)?;
let storage = TensorStorage::gpu(grad_input_handle);
let grad_input = Tensor::from_storage(storage, input_shape.to_vec(), false)?;
return Ok(vec![Some(grad_input)]);
}
let is_f64 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>();
if grad_output.is_cuda()
&& is_f64
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let outer: usize = input_shape[..self.dim].iter().product::<usize>().max(1);
let inner: usize = input_shape[(self.dim + 1)..]
.iter()
.product::<usize>()
.max(1);
let repeat_count = dim_size;
let expanded = backend.repeat_along_dim_f64(
grad_output.gpu_handle()?,
outer,
repeat_count,
inner,
)?;
let scaled = backend.scale_f64(&expanded, 1.0 / repeat_count as f64)?;
let grad_input =
Tensor::from_storage(TensorStorage::gpu(scaled), input_shape.to_vec(), false)?;
return Ok(vec![Some(grad_input)]);
}
if grad_output.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "mean_dim backward",
});
}
let n = T::from(dim_size).unwrap();
let grad = if self.keepdim {
grad_output.clone()
} else {
let mut unsqueezed_shape = grad_output.shape().to_vec();
unsqueezed_shape.insert(self.dim, 1);
let data = grad_output.data()?.to_vec();
Tensor::from_storage(TensorStorage::cpu(data), unsqueezed_shape, false)?
};
let grad_data = grad.data()?;
let grad_shape = grad.shape();
let out_numel: usize = input_shape.iter().product();
let mut result = Vec::with_capacity(out_numel);
for flat in 0..out_numel {
let mut rem = flat;
let mut coords = vec![0usize; input_shape.len()];
for d in (0..input_shape.len()).rev() {
coords[d] = rem % input_shape[d];
rem /= input_shape[d];
}
let mut grad_flat = 0usize;
let mut stride = 1usize;
for d in (0..grad_shape.len()).rev() {
let c = if d == self.dim { 0 } else { coords[d] };
grad_flat += c * stride;
stride *= grad_shape[d];
}
result.push(grad_data[grad_flat] / n);
}
let grad_input =
Tensor::from_storage(TensorStorage::cpu(result), input_shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"MeanDimBackward"
}
}
pub fn mean_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = crate::meta_propagate::reduce_dim(input, dim, keepdim)? {
return Ok(out);
}
crate::profiler_hook::profile_op_scope("mean_dim", "reduction", &[input.shape()], || {
mean_dim_inner(input, dim, keepdim)
})
}
fn mean_dim_inner<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "mean_dim: cannot reduce a scalar (0-D) tensor along a dimension".into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"mean_dim: dim {dim} is out of bounds for tensor with {ndim} dimensions"
),
});
}
let in_shape = input.shape();
let dim_size = in_shape[norm_dim];
let n = T::from(dim_size).unwrap();
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
if input.is_cuda() {
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let input = input.contiguous()?;
let mean_handle: crate::gpu_dispatch::GpuBufferHandle = crate::dispatch_floating_dtype!(
T,
"mean_dim",
f32 => {
let s = backend.sum_axis_f32(input.gpu_handle()?, in_shape, norm_dim)?;
Ok::<_, crate::error::FerrotorchError>(backend.scale_f32(&s, 1.0 / dim_size as f32)?)
},
f64 => {
let s = backend.sum_axis_f64(input.gpu_handle()?, in_shape, norm_dim)?;
Ok::<_, crate::error::FerrotorchError>(backend.scale_f64(&s, 1.0 / dim_size as f64)?)
},
bf16 => Ok::<_, crate::error::FerrotorchError>(
backend.mean_axis_bf16_bf16(input.gpu_handle()?, in_shape, norm_dim)?
),
f16 => Ok::<_, crate::error::FerrotorchError>(
backend.mean_axis_f16(input.gpu_handle()?, in_shape, norm_dim)?
),
)?;
let storage = TensorStorage::gpu(mean_handle);
return if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MeanDimBackward {
input: input.clone(),
dim: norm_dim,
keepdim,
});
Tensor::from_operation(storage, out_shape, grad_fn)
} else {
Tensor::from_storage(storage, out_shape, false)
};
}
return Err(FerrotorchError::DeviceUnavailable);
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let (outer, axis, inner) = outer_axis_inner(in_shape, norm_dim);
let mut accum = reduce_axis_sum_contiguous(&in_data[..input.numel()], outer, axis, inner);
for v in &mut accum {
*v = *v / n;
}
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MeanDimBackward {
input: input.clone(),
dim: norm_dim,
keepdim,
});
let result = Tensor::from_operation(TensorStorage::cpu(accum), out_shape, grad_fn)?;
result.to(input.device())
} else {
let result = Tensor::from_storage(TensorStorage::cpu(accum), out_shape, false)?;
result.to(input.device())
}
}
#[inline]
fn float_from_f64<T: Float>(v: f64) -> FerrotorchResult<T> {
<T as num_traits::NumCast>::from(v).ok_or(FerrotorchError::InvalidArgument {
message: format!("reduction: value {v} not representable in target Float dtype"),
})
}
#[inline]
fn to_f64<T: Float>(v: T) -> FerrotorchResult<f64> {
<T as num_traits::ToPrimitive>::to_f64(&v).ok_or(FerrotorchError::InvalidArgument {
message: "reduction: cannot convert Float to f64".into(),
})
}
#[derive(Debug)]
pub struct LogsumexpBackward<T: Float> {
input: Tensor<T>,
result: Tensor<T>,
}
impl<T: Float> GradFn<T> for LogsumexpBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "logsumexp backward",
});
}
let go = grad_output.data()?[0];
let r = self.result.data()?[0];
let input_data = self.input.data()?;
let grad_data: Vec<T> = input_data.iter().map(|&v| go * (v - r).exp()).collect();
let grad_input = Tensor::from_storage(
TensorStorage::cpu(grad_data),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"LogsumexpBackward"
}
}
pub fn logsumexp<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = crate::meta_propagate::reduce_all(input)? {
return Ok(out);
}
let result = elementwise::logsumexp(input)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(LogsumexpBackward {
input: input.clone(),
result: result.clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct LogsumexpDimBackward<T: Float> {
input: Tensor<T>,
result_keepdim: Tensor<T>,
dim: usize,
}
impl<T: Float> GradFn<T> for LogsumexpDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "logsumexp_dim backward",
});
}
let input_shape = self.input.shape();
let input_data = self.input.data()?;
let result_data = self.result_keepdim.data()?;
let result_shape = self.result_keepdim.shape();
let grad_keepdim_data = grad_output.data()?.to_vec();
let in_numel: usize = input_shape.iter().product();
let mut out = Vec::with_capacity(in_numel);
for flat in 0..in_numel {
let mut rem = flat;
let mut coords = vec![0usize; input_shape.len()];
for d in (0..input_shape.len()).rev() {
coords[d] = rem % input_shape[d];
rem /= input_shape[d];
}
let mut ki = 0usize;
let mut ks = 1usize;
for d in (0..result_shape.len()).rev() {
let c = if d == self.dim { 0 } else { coords[d] };
ki += c * ks;
ks *= result_shape[d];
}
let r = result_data[ki];
let g = grad_keepdim_data[ki];
out.push(g * (input_data[flat] - r).exp());
}
let grad_input =
Tensor::from_storage(TensorStorage::cpu(out), input_shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"LogsumexpDimBackward"
}
}
pub fn logsumexp_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = crate::meta_propagate::reduce_dim(input, dim, keepdim)? {
return Ok(out);
}
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "logsumexp_dim: cannot reduce a 0-D tensor along a dimension".into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"logsumexp_dim: dim {dim} is out of bounds for tensor with {ndim} dimensions"
),
});
}
let result_keepdim = elementwise::logsumexp_dim(input, norm_dim, true)?;
let in_shape = input.shape();
let mut keepdim_shape: Vec<usize> = in_shape.to_vec();
keepdim_shape[norm_dim] = 1;
let final_result = if keepdim {
result_keepdim.clone()
} else {
let data = result_keepdim.data()?.to_vec();
let mut s = keepdim_shape.clone();
s.remove(norm_dim);
Tensor::from_storage(TensorStorage::cpu(data), s, false)?
};
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(LogsumexpDimBackward {
input: input.clone(),
result_keepdim,
dim: norm_dim,
});
let (storage, shape) = final_result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(final_result)
}
}
fn argmax_argmin_full<T: Float>(
input: &Tensor<T>,
find_max: bool,
) -> FerrotorchResult<IntTensor<i64>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: if find_max { "argmax" } else { "argmin" },
});
}
let data = input.data()?;
if data.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "argmax/argmin: cannot reduce an empty tensor".into(),
});
}
let mut best_idx = 0i64;
let mut best_val = data[0];
for (i, &v) in data.iter().enumerate().skip(1) {
let take = if find_max {
v.is_nan() || (!best_val.is_nan() && v > best_val)
} else {
v.is_nan() || (!best_val.is_nan() && v < best_val)
};
if take {
best_idx = i as i64;
best_val = v;
}
}
Ok(IntTensor::<i64>::scalar(best_idx))
}
fn argmax_argmin_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
find_max: bool,
) -> FerrotorchResult<IntTensor<i64>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: if find_max { "argmax_dim" } else { "argmin_dim" },
});
}
let ndim = input.ndim();
if ndim == 0 {
return Ok(IntTensor::<i64>::scalar(0));
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"argmax/argmin: dim {dim} is out of bounds for tensor with {ndim} dimensions"
),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let mut out = Vec::with_capacity(outer * inner);
for o in 0..outer {
for i in 0..inner {
let base = o * dim_size * inner + i;
let mut best_idx = 0i64;
let mut best_val = in_data[base];
for d in 1..dim_size {
let v = in_data[base + d * inner];
let take = if find_max {
v.is_nan() || (!best_val.is_nan() && v > best_val)
} else {
v.is_nan() || (!best_val.is_nan() && v < best_val)
};
if take {
best_idx = d as i64;
best_val = v;
}
}
out.push(best_idx);
}
}
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
IntTensor::<i64>::from_vec(out, out_shape)
}
pub fn argmax<T: Float>(input: &Tensor<T>) -> FerrotorchResult<IntTensor<i64>> {
argmax_argmin_full(input, true)
}
pub fn argmax_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<IntTensor<i64>> {
argmax_argmin_dim(input, dim, keepdim, true)
}
pub fn argmin<T: Float>(input: &Tensor<T>) -> FerrotorchResult<IntTensor<i64>> {
argmax_argmin_full(input, false)
}
pub fn argmin_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<IntTensor<i64>> {
argmax_argmin_dim(input, dim, keepdim, false)
}
#[derive(Debug)]
pub struct VarBackward<T: Float> {
input: Tensor<T>,
mean: T,
denom: f64,
}
impl<T: Float> GradFn<T> for VarBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "var backward" });
}
let go = grad_output.data()?[0];
let scale = float_from_f64::<T>(2.0 / self.denom)?;
let data = self.input.data()?;
let grad_data: Vec<T> = data.iter().map(|&v| go * scale * (v - self.mean)).collect();
let grad_input = Tensor::from_storage(
TensorStorage::cpu(grad_data),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"VarBackward"
}
}
#[derive(Debug)]
pub struct StdBackward<T: Float> {
input: Tensor<T>,
mean: T,
denom: f64,
result: T,
}
impl<T: Float> GradFn<T> for StdBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "std backward" });
}
let go = grad_output.data()?[0];
let data = self.input.data()?;
let zero = <T as num_traits::Zero>::zero();
if self.result == zero {
let grad_data = vec![zero; data.len()];
let grad_input = Tensor::from_storage(
TensorStorage::cpu(grad_data),
self.input.shape().to_vec(),
false,
)?;
return Ok(vec![Some(grad_input)]);
}
let scale = float_from_f64::<T>(1.0 / self.denom)? / self.result;
let grad_data: Vec<T> = data.iter().map(|&v| go * scale * (v - self.mean)).collect();
let grad_input = Tensor::from_storage(
TensorStorage::cpu(grad_data),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"StdBackward"
}
}
fn var_inner<T: Float>(
input: &Tensor<T>,
correction: f64,
take_sqrt: bool,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = crate::meta_propagate::reduce_all(input)? {
return Ok(out);
}
let op_name = if take_sqrt { "std" } else { "var" };
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: op_name });
}
let data = input.data()?;
let n = data.len();
if n == 0 {
let nan_val = <T as num_traits::Float>::nan();
return Tensor::from_storage(TensorStorage::cpu(vec![nan_val]), vec![], false);
}
let denom_f = (n as f64 - correction).max(0.0);
let mut sum_f: f64 = 0.0;
for &v in data {
sum_f += to_f64::<T>(v)?;
}
let mean_f = sum_f / n as f64;
let mut sum_sq: f64 = 0.0;
for &v in data {
let d = to_f64::<T>(v)? - mean_f;
sum_sq += d * d;
}
let var_f = sum_sq / denom_f;
let final_f = if take_sqrt { var_f.sqrt() } else { var_f };
let result_val = float_from_f64::<T>(final_f)?;
let result = Tensor::from_storage(TensorStorage::cpu(vec![result_val]), vec![], false)?;
if is_grad_enabled() && input.requires_grad() {
let mean_t = float_from_f64::<T>(mean_f)?;
if take_sqrt {
let grad_fn = Arc::new(StdBackward {
input: input.clone(),
mean: mean_t,
denom: denom_f,
result: result_val,
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
let grad_fn = Arc::new(VarBackward {
input: input.clone(),
mean: mean_t,
denom: denom_f,
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
}
} else {
Ok(result)
}
}
pub fn var<T: Float>(input: &Tensor<T>, unbiased: bool) -> FerrotorchResult<Tensor<T>> {
var_inner(input, if unbiased { 1.0 } else { 0.0 }, false)
}
pub fn var_with_correction<T: Float>(
input: &Tensor<T>,
correction: f64,
) -> FerrotorchResult<Tensor<T>> {
var_inner(input, correction, false)
}
pub fn std<T: Float>(input: &Tensor<T>, unbiased: bool) -> FerrotorchResult<Tensor<T>> {
var_inner(input, if unbiased { 1.0 } else { 0.0 }, true)
}
pub fn std_with_correction<T: Float>(
input: &Tensor<T>,
correction: f64,
) -> FerrotorchResult<Tensor<T>> {
var_inner(input, correction, true)
}
#[allow(
clippy::type_complexity,
reason = "single-use forward helper returning (result, in_shape, norm_dim, \
out_shape, dim_size); a struct adds boilerplate without aiding \
the two callers (var_dim/std_dim)."
)]
fn std_var_dim_forward<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
correction: f64,
take_sqrt: bool,
) -> FerrotorchResult<(Vec<T>, Vec<usize>, usize, Vec<usize>, usize)> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "std_dim/var_dim: cannot reduce a 0-D tensor".into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!("std_dim/var_dim: dim {dim} out of bounds"),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape().to_vec();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let denom = (dim_size as f64 - correction).max(0.0);
let mut result = Vec::with_capacity(outer * inner);
for o in 0..outer {
for i in 0..inner {
let mut s = 0.0_f64;
for d in 0..dim_size {
s += to_f64::<T>(in_data[o * dim_size * inner + d * inner + i])?;
}
let mean_f = s / dim_size as f64;
let mut ss = 0.0_f64;
for d in 0..dim_size {
let v = to_f64::<T>(in_data[o * dim_size * inner + d * inner + i])?;
let dv = v - mean_f;
ss += dv * dv;
}
let var_f = ss / denom;
let final_f = if take_sqrt { var_f.sqrt() } else { var_f };
result.push(float_from_f64::<T>(final_f)?);
}
}
let mut out_shape: Vec<usize> = in_shape.clone();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
Ok((result, in_shape, norm_dim, out_shape, dim_size))
}
pub fn var_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
correction: f64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "var_dim" });
}
let (data, _, _, out_shape, _) = std_var_dim_forward(input, dim, keepdim, correction, false)?;
Tensor::from_storage(TensorStorage::cpu(data), out_shape, false)
}
pub fn std_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
correction: f64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "std_dim" });
}
let (data, _, _, out_shape, _) = std_var_dim_forward(input, dim, keepdim, correction, true)?;
Tensor::from_storage(TensorStorage::cpu(data), out_shape, false)
}
fn is_nonzero_float<T: Float>(v: T) -> bool {
v != <T as num_traits::Zero>::zero()
}
pub fn any<T: Float>(input: &Tensor<T>) -> FerrotorchResult<BoolTensor> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "any" });
}
let data = input.data()?;
let result = data.iter().copied().any(is_nonzero_float::<T>);
BoolTensor::from_vec(vec![result], vec![])
}
pub fn all<T: Float>(input: &Tensor<T>) -> FerrotorchResult<BoolTensor> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "all" });
}
let data = input.data()?;
let result = data.iter().copied().all(is_nonzero_float::<T>);
BoolTensor::from_vec(vec![result], vec![])
}
pub fn count_nonzero<T: Float>(input: &Tensor<T>) -> FerrotorchResult<IntTensor<i64>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "count_nonzero",
});
}
let data = input.data()?;
let n = data
.iter()
.copied()
.filter(|&v| is_nonzero_float(v))
.count();
Ok(IntTensor::<i64>::scalar(n as i64))
}
fn reduce_dim_loop_bool<T, F>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
init: bool,
op_name: &'static str,
fold: F,
) -> FerrotorchResult<BoolTensor>
where
T: Float,
F: Fn(bool, T) -> bool,
{
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: op_name });
}
let ndim = input.ndim();
if ndim == 0 {
let v = input.data()?[0];
return BoolTensor::from_vec(vec![fold(init, v)], vec![]);
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!("{op_name}_dim: dim {dim} out of bounds for {ndim}-D tensor"),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let mut out = Vec::with_capacity(outer * inner);
for o in 0..outer {
for i in 0..inner {
let mut acc = init;
for d in 0..dim_size {
acc = fold(acc, in_data[o * dim_size * inner + d * inner + i]);
}
out.push(acc);
}
}
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
BoolTensor::from_vec(out, out_shape)
}
pub fn any_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<BoolTensor> {
reduce_dim_loop_bool(input, dim, keepdim, false, "any", |acc, v| {
acc || is_nonzero_float(v)
})
}
pub fn all_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<BoolTensor> {
reduce_dim_loop_bool(input, dim, keepdim, true, "all", |acc, v| {
acc && is_nonzero_float(v)
})
}
pub fn count_nonzero_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
) -> FerrotorchResult<IntTensor<i64>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "count_nonzero_dim",
});
}
let ndim = input.ndim();
if ndim == 0 {
let v = input.data()?[0];
return Ok(IntTensor::<i64>::scalar(i64::from(is_nonzero_float(v))));
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!("count_nonzero_dim: dim {dim} out of bounds for {ndim}-D tensor"),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let mut out = Vec::with_capacity(outer * inner);
for o in 0..outer {
for i in 0..inner {
let mut count: i64 = 0;
for d in 0..dim_size {
if is_nonzero_float(in_data[o * dim_size * inner + d * inner + i]) {
count += 1;
}
}
out.push(count);
}
}
let mut out_shape: Vec<usize> = in_shape.to_vec();
out_shape.remove(norm_dim);
IntTensor::<i64>::from_vec(out, out_shape)
}
#[derive(Debug)]
pub struct AminDimBackward<T: Float> {
input: Tensor<T>,
expanded_result: Tensor<T>,
dim: usize,
keepdim: bool,
}
#[derive(Debug)]
pub struct AmaxDimBackward<T: Float> {
input: Tensor<T>,
expanded_result: Tensor<T>,
dim: usize,
keepdim: bool,
}
fn amin_amax_dim_backward<T: Float>(
input: &Tensor<T>,
expanded: &Tensor<T>,
grad_output: &Tensor<T>,
dim: usize,
keepdim: bool,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "amin/amax_dim backward",
});
}
let input_data = input.data()?;
let expanded_data = expanded.data()?;
let in_shape = input.shape();
let dim_size = in_shape[dim];
let outer: usize = in_shape[..dim].iter().product();
let inner: usize = in_shape[dim + 1..].iter().product();
let mut counts = vec![0i64; outer * inner];
for o in 0..outer {
for i in 0..inner {
let target = expanded_data[o * dim_size * inner + i];
let mut c = 0i64;
for d in 0..dim_size {
if input_data[o * dim_size * inner + d * inner + i] == target {
c += 1;
}
}
counts[o * inner + i] = c;
}
}
let grad_data = grad_output.data()?;
let _ = keepdim;
let in_numel: usize = in_shape.iter().product();
let mut out = Vec::with_capacity(in_numel);
for o in 0..outer {
for d in 0..dim_size {
for i in 0..inner {
let target = expanded_data[o * dim_size * inner + i];
let val = input_data[o * dim_size * inner + d * inner + i];
if val == target {
let c = counts[o * inner + i].max(1) as f64;
let g = grad_data[o * inner + i];
let scale = float_from_f64::<T>(1.0 / c)?;
out.push(g * scale);
} else {
out.push(<T as num_traits::Zero>::zero());
}
}
}
}
let grad_input = Tensor::from_storage(TensorStorage::cpu(out), in_shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
impl<T: Float> GradFn<T> for AminDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
amin_amax_dim_backward(
&self.input,
&self.expanded_result,
grad_output,
self.dim,
self.keepdim,
)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AminDimBackward"
}
}
impl<T: Float> GradFn<T> for AmaxDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
amin_amax_dim_backward(
&self.input,
&self.expanded_result,
grad_output,
self.dim,
self.keepdim,
)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AmaxDimBackward"
}
}
#[allow(
clippy::type_complexity,
reason = "Single-use helper returning (result, expanded, norm_dim, out_shape); \
a named tuple struct adds boilerplate without clarifying the local \
flow at the two callers (amin_dim/amax_dim)."
)]
fn amin_amax_dim_forward<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
find_max: bool,
) -> FerrotorchResult<(Vec<T>, Vec<T>, usize, Vec<usize>)> {
let ndim = input.ndim();
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!("amin/amax_dim: dim {dim} out of bounds for {ndim}-D tensor"),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let mut result = Vec::with_capacity(outer * inner);
let mut expanded = Vec::with_capacity(in_shape.iter().product());
for o in 0..outer {
for i in 0..inner {
let mut best = in_data[o * dim_size * inner + i];
for d in 1..dim_size {
let v = in_data[o * dim_size * inner + d * inner + i];
let take = if find_max {
v.is_nan() || (!best.is_nan() && v > best)
} else {
v.is_nan() || (!best.is_nan() && v < best)
};
if take {
best = v;
}
}
result.push(best);
}
}
for o in 0..outer {
for _d in 0..dim_size {
for i in 0..inner {
expanded.push(result[o * inner + i]);
}
}
}
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
Ok((result, expanded, norm_dim, out_shape))
}
pub fn amin_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "amin_dim" });
}
let ndim = input.ndim();
if ndim == 0 {
return grad_clone(input);
}
let (result, expanded, norm_dim, out_shape) =
amin_amax_dim_forward(input, dim, keepdim, false)?;
let storage = TensorStorage::cpu(result);
let result_t = Tensor::from_storage(storage, out_shape, false)?;
if is_grad_enabled() && input.requires_grad() {
let expanded_t =
Tensor::from_storage(TensorStorage::cpu(expanded), input.shape().to_vec(), false)?;
let grad_fn = Arc::new(AminDimBackward {
input: input.clone(),
expanded_result: expanded_t,
dim: norm_dim,
keepdim,
});
let (s, sh) = result_t.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(result_t)
}
}
pub fn amax_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "amax_dim" });
}
let ndim = input.ndim();
if ndim == 0 {
return grad_clone(input);
}
let (result, expanded, norm_dim, out_shape) = amin_amax_dim_forward(input, dim, keepdim, true)?;
let storage = TensorStorage::cpu(result);
let result_t = Tensor::from_storage(storage, out_shape, false)?;
if is_grad_enabled() && input.requires_grad() {
let expanded_t =
Tensor::from_storage(TensorStorage::cpu(expanded), input.shape().to_vec(), false)?;
let grad_fn = Arc::new(AmaxDimBackward {
input: input.clone(),
expanded_result: expanded_t,
dim: norm_dim,
keepdim,
});
let (s, sh) = result_t.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(result_t)
}
}
fn grad_clone<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let data = input.data()?.to_vec();
Tensor::from_storage(TensorStorage::cpu(data), input.shape().to_vec(), false)
}
#[derive(Debug)]
pub struct ProdDimBackward<T: Float> {
input: Tensor<T>,
dim: usize,
keepdim: bool,
}
impl<T: Float> GradFn<T> for ProdDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "prod_dim backward",
});
}
let input_data = self.input.data()?;
let in_shape = self.input.shape();
let dim_size = in_shape[self.dim];
let outer: usize = in_shape[..self.dim].iter().product();
let inner: usize = in_shape[self.dim + 1..].iter().product();
let _ = self.keepdim;
let go_data = grad_output.data()?;
let one = <T as num_traits::One>::one();
let mut out = vec![<T as num_traits::Zero>::zero(); in_shape.iter().product()];
for o in 0..outer {
for i in 0..inner {
let mut prefix = vec![one; dim_size];
let mut suffix = vec![one; dim_size];
for d in 1..dim_size {
prefix[d] =
prefix[d - 1] * input_data[o * dim_size * inner + (d - 1) * inner + i];
}
if dim_size > 1 {
for d in (0..dim_size - 1).rev() {
suffix[d] =
suffix[d + 1] * input_data[o * dim_size * inner + (d + 1) * inner + i];
}
}
let g = go_data[o * inner + i];
for d in 0..dim_size {
out[o * dim_size * inner + d * inner + i] = g * prefix[d] * suffix[d];
}
}
}
let grad_input = Tensor::from_storage(TensorStorage::cpu(out), in_shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ProdDimBackward"
}
}
pub fn prod_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "prod_dim" });
}
let ndim = input.ndim();
if ndim == 0 {
return grad_clone(input);
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!("prod_dim: dim {dim} out of bounds for {ndim}-D tensor"),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let one = <T as num_traits::One>::one();
let mut result = Vec::with_capacity(outer * inner);
for o in 0..outer {
for i in 0..inner {
let mut acc = one;
for d in 0..dim_size {
acc = acc * in_data[o * dim_size * inner + d * inner + i];
}
result.push(acc);
}
}
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
let result_t = Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(ProdDimBackward {
input: input.clone(),
dim: norm_dim,
keepdim,
});
let (s, sh) = result_t.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(result_t)
}
}
#[derive(Debug)]
pub struct MaxMinDimBackward<T: Float> {
input: Tensor<T>,
indices_flat: Vec<i64>,
dim: usize,
keepdim: bool,
name: &'static str,
}
impl<T: Float> GradFn<T> for MaxMinDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "max_with_dim/min_with_dim backward",
});
}
let in_shape = self.input.shape();
let dim_size = in_shape[self.dim];
let outer: usize = in_shape[..self.dim].iter().product();
let inner: usize = in_shape[self.dim + 1..].iter().product();
let go = grad_output.data()?;
let _ = self.keepdim;
let zero = <T as num_traits::Zero>::zero();
let in_numel: usize = in_shape.iter().product();
let mut out = vec![zero; in_numel];
for o in 0..outer {
for i in 0..inner {
let slot = o * inner + i;
let d = self.indices_flat[slot] as usize;
debug_assert!(d < dim_size);
let flat_in = o * dim_size * inner + d * inner + i;
out[flat_in] = go[slot];
}
}
let grad_input = Tensor::from_storage(TensorStorage::cpu(out), in_shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
self.name
}
}
#[allow(
clippy::type_complexity,
reason = "single-use helper returning (values, indices_flat, indices_int, out_shape, dim); \
wrapping in a struct adds boilerplate without aiding the two callers."
)]
fn max_min_with_dim_forward<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
find_max: bool,
) -> FerrotorchResult<(Vec<T>, Vec<i64>, IntTensor<i64>, Vec<usize>, usize)> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "max/min_with_dim: cannot reduce a 0-D tensor along a dimension".into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!("max/min_with_dim: dim {dim} out of bounds for {ndim}-D tensor"),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let mut values = Vec::with_capacity(outer * inner);
let mut indices = Vec::with_capacity(outer * inner);
for o in 0..outer {
for i in 0..inner {
let base = o * dim_size * inner + i;
let mut best = in_data[base];
let mut best_idx: i64 = 0;
for d in 1..dim_size {
let v = in_data[base + d * inner];
let take = if find_max {
v.is_nan() || (!best.is_nan() && v > best)
} else {
v.is_nan() || (!best.is_nan() && v < best)
};
if take {
best = v;
best_idx = d as i64;
}
}
values.push(best);
indices.push(best_idx);
}
}
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
let indices_int = IntTensor::<i64>::from_vec(indices.clone(), out_shape.clone())?;
Ok((values, indices, indices_int, out_shape, norm_dim))
}
pub fn max_with_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<(Tensor<T>, IntTensor<i64>)> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "max_with_dim" });
}
let (values, indices_flat, indices_int, out_shape, norm_dim) =
max_min_with_dim_forward(input, dim, keepdim, true)?;
let storage = TensorStorage::cpu(values);
let values_t = Tensor::from_storage(storage, out_shape, false)?;
let values_t = if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MaxMinDimBackward {
input: input.clone(),
indices_flat,
dim: norm_dim,
keepdim,
name: "MaxDimBackward",
});
let (s, sh) = values_t.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)?
} else {
values_t
};
Ok((values_t, indices_int))
}
pub fn min_with_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<(Tensor<T>, IntTensor<i64>)> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "min_with_dim" });
}
let (values, indices_flat, indices_int, out_shape, norm_dim) =
max_min_with_dim_forward(input, dim, keepdim, false)?;
let storage = TensorStorage::cpu(values);
let values_t = Tensor::from_storage(storage, out_shape, false)?;
let values_t = if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MaxMinDimBackward {
input: input.clone(),
indices_flat,
dim: norm_dim,
keepdim,
name: "MinDimBackward",
});
let (s, sh) = values_t.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)?
} else {
values_t
};
Ok((values_t, indices_int))
}
#[allow(
clippy::type_complexity,
reason = "single-use helper returning (values, indices_flat, indices_int, out_shape, dim); \
matches max_min_with_dim_forward's tuple shape for the two callers."
)]
fn median_with_dim_forward<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
ignore_nan: bool,
) -> FerrotorchResult<(Vec<T>, Vec<i64>, IntTensor<i64>, Vec<usize>, usize)> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "median/nanmedian_with_dim: cannot reduce a 0-D tensor along a dimension"
.into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"median/nanmedian_with_dim: dim {dim} out of bounds for {ndim}-D tensor"
),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let mut values = Vec::with_capacity(outer * inner);
let mut indices = Vec::with_capacity(outer * inner);
for o in 0..outer {
for i in 0..inner {
let base = o * dim_size * inner + i;
let slice: Vec<(usize, T)> = (0..dim_size)
.map(|d| (d, in_data[base + d * inner]))
.collect();
if !ignore_nan && let Some(&(nan_idx, nan_val)) = slice.iter().find(|(_, v)| v.is_nan())
{
values.push(nan_val);
indices.push(nan_idx as i64);
continue;
}
let mut order: Vec<usize> = (0..dim_size).collect();
order.sort_by(|&i, &j| {
let vi = slice[i].1;
let vj = slice[j].1;
let i_nan = vi.is_nan();
let j_nan = vj.is_nan();
match (i_nan, j_nan) {
(true, true) => i.cmp(&j),
(true, false) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Less,
(false, false) => match vi.partial_cmp(&vj) {
Some(std::cmp::Ordering::Equal) | None => i.cmp(&j),
Some(ord) => ord,
},
}
});
let num_nan = if ignore_nan {
slice.iter().filter(|(_, v)| v.is_nan()).count()
} else {
0
};
let effective = dim_size - num_nan;
let rank = if effective == 0 {
(dim_size - 1) / 2
} else {
(effective - 1) / 2
};
let median_local = order[rank];
values.push(slice[median_local].1);
indices.push(median_local as i64);
}
}
let mut out_shape: Vec<usize> = in_shape.to_vec();
if keepdim {
out_shape[norm_dim] = 1;
} else {
out_shape.remove(norm_dim);
}
let indices_int = IntTensor::<i64>::from_vec(indices.clone(), out_shape.clone())?;
Ok((values, indices, indices_int, out_shape, norm_dim))
}
pub fn median_with_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<(Tensor<T>, IntTensor<i64>)> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "median_with_dim",
});
}
let (values, indices_flat, indices_int, out_shape, norm_dim) =
median_with_dim_forward(input, dim, keepdim, false)?;
let storage = TensorStorage::cpu(values);
let values_t = Tensor::from_storage(storage, out_shape, false)?;
let values_t = if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MaxMinDimBackward {
input: input.clone(),
indices_flat,
dim: norm_dim,
keepdim,
name: "MedianDimBackward",
});
let (s, sh) = values_t.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)?
} else {
values_t
};
Ok((values_t, indices_int))
}
pub fn nanmedian_with_dim<T: Float>(
input: &Tensor<T>,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<(Tensor<T>, IntTensor<i64>)> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "nanmedian_with_dim",
});
}
let (values, indices_flat, indices_int, out_shape, norm_dim) =
median_with_dim_forward(input, dim, keepdim, true)?;
let storage = TensorStorage::cpu(values);
let values_t = Tensor::from_storage(storage, out_shape, false)?;
let values_t = if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(MaxMinDimBackward {
input: input.clone(),
indices_flat,
dim: norm_dim,
keepdim,
name: "NanmedianDimBackward",
});
let (s, sh) = values_t.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)?
} else {
values_t
};
Ok((values_t, indices_int))
}
#[derive(Debug)]
pub struct NormDimBackward<T: Float> {
input: Tensor<T>,
p: f64,
result_keepdim: Tensor<T>,
dim: usize,
}
impl<T: Float> GradFn<T> for NormDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "norm_with_dim backward",
});
}
let in_shape = self.input.shape();
let dim_size = in_shape[self.dim];
let outer: usize = in_shape[..self.dim].iter().product();
let inner: usize = in_shape[self.dim + 1..].iter().product();
let in_data = self.input.data()?;
let go = grad_output.data()?;
let res = self.result_keepdim.data()?;
let p = self.p;
let zero = <T as num_traits::Zero>::zero();
let one_f64 = 1.0_f64;
let in_numel: usize = in_shape.iter().product();
let mut out = vec![zero; in_numel];
for o in 0..outer {
for i in 0..inner {
let slot = o * inner + i;
let r_f = to_f64::<T>(res[slot])?;
if r_f == 0.0 {
continue;
}
let g_f = to_f64::<T>(go[slot])?;
let scale_pow = r_f.powf(one_f64 - p);
for d in 0..dim_size {
let xf = to_f64::<T>(in_data[o * dim_size * inner + d * inner + i])?;
let abs_x = xf.abs();
let s = if xf > 0.0 {
1.0
} else if xf < 0.0 {
-1.0
} else {
0.0
};
let grad_xf = g_f * abs_x.powf(p - one_f64) * s * scale_pow;
out[o * dim_size * inner + d * inner + i] = float_from_f64::<T>(grad_xf)?;
}
}
}
let grad_input = Tensor::from_storage(TensorStorage::cpu(out), in_shape.to_vec(), false)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"NormDimBackward"
}
}
pub fn norm_with_dim<T: Float>(
input: &Tensor<T>,
p: f64,
dim: i64,
keepdim: bool,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "norm_with_dim",
});
}
if !(p.is_finite() && p > 0.0) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"norm_with_dim: p must be finite and > 0; got {p}. Other norms (inf, 0) \
are non-differentiable / piecewise and tracked separately."
),
});
}
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "norm_with_dim: cannot reduce a 0-D tensor along a dimension".into(),
});
}
let norm_dim = if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
};
if norm_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!("norm_with_dim: dim {dim} out of bounds for {ndim}-D tensor"),
});
}
let input_ref = if input.is_contiguous() {
input.clone()
} else {
input.contiguous()?
};
let in_data = input_ref.data()?;
let in_shape = input_ref.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = in_shape[norm_dim + 1..].iter().product();
let mut result_keepdim_data = Vec::with_capacity(outer * inner);
let t_is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
#[allow(
clippy::float_cmp,
reason = "exact `p == 2.0` mirrors torch's norm-kernel dispatch `if (val == 2.0)` at ReduceOpsKernel.cpp:195 — the L2 vectorized path is selected by exact equality, not an epsilon band; a margin compare would mis-route p values near 2.0 that torch routes to the generic NormOps path"
)]
let is_l2_lastdim_f32 = p == 2.0 && t_is_f32 && inner == 1;
if is_l2_lastdim_f32 {
for o in 0..outer {
let slice_start = o * dim_size;
let mut row: Vec<f32> = Vec::with_capacity(dim_size);
for d in 0..dim_size {
let v = in_data[slice_start + d];
row.push(num_traits::ToPrimitive::to_f32(&v).ok_or(
FerrotorchError::InvalidArgument {
message: "norm_with_dim: f32 element not representable".into(),
},
)?);
}
let norm_f32 = crate::simd_reduce::l2_norm_f32_torch(&row);
result_keepdim_data.push(float_from_f64::<T>(f64::from(norm_f32))?);
}
} else {
for o in 0..outer {
for i in 0..inner {
let mut acc = 0.0_f64;
for d in 0..dim_size {
let v = to_f64::<T>(in_data[o * dim_size * inner + d * inner + i])?;
acc += v.abs().powf(p);
}
let r = acc.powf(1.0 / p);
result_keepdim_data.push(float_from_f64::<T>(r)?);
}
}
}
let mut keepdim_shape: Vec<usize> = in_shape.to_vec();
keepdim_shape[norm_dim] = 1;
let result_keepdim = Tensor::from_storage(
TensorStorage::cpu(result_keepdim_data.clone()),
keepdim_shape.clone(),
false,
)?;
let mut out_shape = keepdim_shape.clone();
if !keepdim {
out_shape.remove(norm_dim);
}
let final_result =
Tensor::from_storage(TensorStorage::cpu(result_keepdim_data), out_shape, false)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(NormDimBackward {
input: input.clone(),
p,
result_keepdim,
dim: norm_dim,
});
let (s, sh) = final_result.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(final_result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::no_grad::no_grad;
use crate::storage::TensorStorage;
fn leaf(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
fn leaf_scalar(val: f64, requires_grad: bool) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], requires_grad).unwrap()
}
#[test]
fn test_sum_forward_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let s = sum(&x).unwrap();
assert!(s.is_scalar());
assert!((s.item().unwrap() - 10.0).abs() < 1e-12);
}
#[test]
fn test_sum_forward_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum(&x).unwrap();
assert!((s.item().unwrap() - 21.0).abs() < 1e-12);
}
#[test]
fn test_mean_forward() {
let x = leaf(&[2.0, 4.0, 6.0, 8.0], &[4], false);
let m = mean(&x).unwrap();
assert!((m.item().unwrap() - 5.0).abs() < 1e-12);
}
#[test]
fn test_prod_forward() {
let x = leaf(&[2.0, 3.0, 4.0], &[3], false);
let p = prod(&x).unwrap();
assert!((p.item().unwrap() - 24.0).abs() < 1e-12);
}
#[test]
fn test_prod_forward_scalar() {
let x = leaf_scalar(7.0, false);
let p = prod(&x).unwrap();
assert!((p.item().unwrap() - 7.0).abs() < 1e-12);
}
#[test]
fn test_prod_forward_with_zero() {
let x = leaf(&[3.0, 0.0, 5.0], &[3], false);
let p = prod(&x).unwrap();
assert!((p.item().unwrap()).abs() < 1e-12);
}
#[test]
fn test_sum_backward_scalar_input() {
let x = leaf_scalar(5.0, true);
let s = sum(&x).unwrap();
s.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_sum_backward_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = sum(&x).unwrap();
s.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert_eq!(gd.len(), 3);
for &v in gd {
assert!((v - 1.0).abs() < 1e-12);
}
}
#[test]
fn test_sum_backward_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let s = sum(&x).unwrap();
s.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-12);
}
}
#[test]
fn test_mean_backward_scalar_input() {
let x = leaf_scalar(5.0, true);
let m = mean(&x).unwrap();
m.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_mean_backward_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let m = mean(&x).unwrap();
m.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
let expected = 1.0 / 3.0;
for &v in gd {
assert!((v - expected).abs() < 1e-12);
}
}
#[test]
fn test_mean_backward_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let m = mean(&x).unwrap();
m.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
let expected = 1.0 / 6.0;
for &v in g.data().unwrap() {
assert!((v - expected).abs() < 1e-12);
}
}
#[test]
fn test_prod_backward_scalar_input() {
let x = leaf_scalar(5.0, true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert!((g.item().unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn test_prod_backward_1d() {
let x = leaf(&[2.0, 3.0, 4.0], &[3], true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!(
(gd[0] - 12.0).abs() < 1e-12,
"d/da = 3*4 = 12, got {}",
gd[0]
);
assert!((gd[1] - 8.0).abs() < 1e-12, "d/db = 2*4 = 8, got {}", gd[1]);
assert!((gd[2] - 6.0).abs() < 1e-12, "d/dc = 2*3 = 6, got {}", gd[2]);
}
#[test]
fn test_prod_backward_with_zero() {
let x = leaf(&[3.0, 0.0, 5.0], &[3], true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
assert!((gd[0] - 0.0).abs() < 1e-12, "got {}", gd[0]);
assert!((gd[1] - 15.0).abs() < 1e-12, "got {}", gd[1]);
assert!((gd[2] - 0.0).abs() < 1e-12, "got {}", gd[2]);
}
#[test]
fn test_prod_backward_two_zeros() {
let x = leaf(&[0.0, 0.0, 5.0], &[3], true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let g = x.grad().unwrap().unwrap();
let gd = g.data().unwrap();
for &v in gd {
assert!((v).abs() < 1e-12, "expected 0, got {v}");
}
}
#[test]
fn test_sum_no_grad_fn_when_input_not_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let s = sum(&x).unwrap();
assert!(s.grad_fn().is_none());
assert!(!s.requires_grad());
}
#[test]
fn test_sum_has_grad_fn_when_input_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = sum(&x).unwrap();
assert!(s.grad_fn().is_some());
assert_eq!(s.grad_fn().unwrap().name(), "SumBackward");
assert!(s.requires_grad());
}
#[test]
fn test_mean_has_grad_fn_when_input_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let m = mean(&x).unwrap();
assert!(m.grad_fn().is_some());
assert_eq!(m.grad_fn().unwrap().name(), "MeanBackward");
}
#[test]
fn test_prod_has_grad_fn_when_input_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let p = prod(&x).unwrap();
assert!(p.grad_fn().is_some());
assert_eq!(p.grad_fn().unwrap().name(), "ProdBackward");
}
#[test]
fn test_sum_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = no_grad(|| sum(&x)).unwrap();
assert!(s.grad_fn().is_none());
assert!(!s.requires_grad());
}
#[test]
fn test_mean_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let m = no_grad(|| mean(&x)).unwrap();
assert!(m.grad_fn().is_none());
}
#[test]
fn test_prod_no_grad_fn_in_no_grad_context() {
let x = leaf(&[2.0, 3.0], &[2], true);
let p = no_grad(|| prod(&x)).unwrap();
assert!(p.grad_fn().is_none());
}
fn numerical_grad_check(
f: impl Fn(&Tensor<f64>) -> FerrotorchResult<Tensor<f64>>,
x_val: f64,
expected_analytic: f64,
tol: f64,
) {
let eps = 1e-7;
let x_plus = leaf_scalar(x_val + eps, false);
let x_minus = leaf_scalar(x_val - eps, false);
let f_plus = f(&x_plus).unwrap().item().unwrap();
let f_minus = f(&x_minus).unwrap().item().unwrap();
let numerical = (f_plus - f_minus) / (2.0 * eps);
assert!(
(numerical - expected_analytic).abs() < tol,
"numerical gradient {numerical} differs from analytic {expected_analytic} by more than {tol}"
);
}
#[test]
fn test_sum_numerical_gradient() {
let x = leaf_scalar(3.0, true);
let s = sum(&x).unwrap();
s.backward().unwrap();
let analytic = x.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(sum, 3.0, analytic, 1e-5);
}
#[test]
fn test_mean_numerical_gradient() {
let x = leaf_scalar(3.0, true);
let m = mean(&x).unwrap();
m.backward().unwrap();
let analytic = x.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(mean, 3.0, analytic, 1e-5);
}
#[test]
fn test_prod_numerical_gradient() {
let x = leaf_scalar(3.0, true);
let p = prod(&x).unwrap();
p.backward().unwrap();
let analytic = x.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(prod, 3.0, analytic, 1e-5);
}
#[test]
fn test_sum_dim_axis0_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum_dim(&x, 0, false).unwrap();
assert_eq!(s.shape(), &[3]);
let d = s.data().unwrap();
assert!((d[0] - 5.0).abs() < 1e-12);
assert!((d[1] - 7.0).abs() < 1e-12);
assert!((d[2] - 9.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_axis1_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum_dim(&x, 1, false).unwrap();
assert_eq!(s.shape(), &[2]);
let d = s.data().unwrap();
assert!((d[0] - 6.0).abs() < 1e-12);
assert!((d[1] - 15.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_keepdim_true() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum_dim(&x, 0, true).unwrap();
assert_eq!(s.shape(), &[1, 3]);
let d = s.data().unwrap();
assert!((d[0] - 5.0).abs() < 1e-12);
assert!((d[1] - 7.0).abs() < 1e-12);
assert!((d[2] - 9.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_negative_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let s = sum_dim(&x, -1, false).unwrap();
assert_eq!(s.shape(), &[2]);
let d = s.data().unwrap();
assert!((d[0] - 6.0).abs() < 1e-12);
assert!((d[1] - 15.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let s = sum_dim(&x, 0, false).unwrap();
assert!(s.is_scalar());
assert!((s.item().unwrap() - 10.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_1d_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let s = sum_dim(&x, 0, true).unwrap();
assert_eq!(s.shape(), &[1]);
assert!((s.data().unwrap()[0] - 10.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_3d() {
let data: Vec<f64> = (1..=12).map(|x| x as f64).collect();
let x = leaf(&data, &[2, 2, 3], false);
let s = sum_dim(&x, 1, false).unwrap();
assert_eq!(s.shape(), &[2, 3]);
let d = s.data().unwrap();
assert!((d[0] - 5.0).abs() < 1e-12);
assert!((d[1] - 7.0).abs() < 1e-12);
assert!((d[2] - 9.0).abs() < 1e-12);
assert!((d[3] - 17.0).abs() < 1e-12);
assert!((d[4] - 19.0).abs() < 1e-12);
assert!((d[5] - 21.0).abs() < 1e-12);
}
#[test]
fn test_sum_dim_backward_axis0_no_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let s = sum_dim(&x, 0, false).unwrap();
let loss = sum(&s).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-12, "expected 1.0, got {v}");
}
}
#[test]
fn test_sum_dim_backward_axis1_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let s = sum_dim(&x, 1, true).unwrap();
assert_eq!(s.shape(), &[2, 1]);
let loss = sum(&s).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-12, "expected 1.0, got {v}");
}
}
#[test]
fn test_sum_dim_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = sum_dim(&x, 0, false).unwrap();
assert!(s.grad_fn().is_some());
assert_eq!(s.grad_fn().unwrap().name(), "SumDimBackward");
}
#[test]
fn test_sum_dim_no_grad_fn_when_not_requires_grad() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let s = sum_dim(&x, 0, false).unwrap();
assert!(s.grad_fn().is_none());
}
#[test]
fn test_sum_dim_no_grad_fn_in_no_grad_context() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let s = no_grad(|| sum_dim(&x, 0, false)).unwrap();
assert!(s.grad_fn().is_none());
}
#[test]
fn test_mean_dim_axis0_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let m = mean_dim(&x, 0, false).unwrap();
assert_eq!(m.shape(), &[3]);
let d = m.data().unwrap();
assert!((d[0] - 2.5).abs() < 1e-12);
assert!((d[1] - 3.5).abs() < 1e-12);
assert!((d[2] - 4.5).abs() < 1e-12);
}
#[test]
fn test_mean_dim_axis1_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let m = mean_dim(&x, 1, false).unwrap();
assert_eq!(m.shape(), &[2]);
let d = m.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-12);
assert!((d[1] - 5.0).abs() < 1e-12);
}
#[test]
fn test_mean_dim_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let m = mean_dim(&x, 0, true).unwrap();
assert_eq!(m.shape(), &[1, 3]);
let d = m.data().unwrap();
assert!((d[0] - 2.5).abs() < 1e-12);
assert!((d[1] - 3.5).abs() < 1e-12);
assert!((d[2] - 4.5).abs() < 1e-12);
}
#[test]
fn test_mean_dim_negative_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let m = mean_dim(&x, -1, false).unwrap();
assert_eq!(m.shape(), &[2]);
let d = m.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-12);
assert!((d[1] - 5.0).abs() < 1e-12);
}
#[test]
fn test_mean_dim_backward_axis0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let m = mean_dim(&x, 0, false).unwrap();
let loss = sum(&m).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
let expected = 1.0 / 2.0;
for &v in g.data().unwrap() {
assert!((v - expected).abs() < 1e-12, "expected {expected}, got {v}");
}
}
#[test]
fn test_mean_dim_backward_axis1_keepdim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let m = mean_dim(&x, 1, true).unwrap();
assert_eq!(m.shape(), &[2, 1]);
let loss = sum(&m).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.shape(), &[2, 3]);
let expected = 1.0 / 3.0;
for &v in g.data().unwrap() {
assert!((v - expected).abs() < 1e-12, "expected {expected}, got {v}");
}
}
#[test]
fn test_mean_dim_has_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let m = mean_dim(&x, 0, false).unwrap();
assert!(m.grad_fn().is_some());
assert_eq!(m.grad_fn().unwrap().name(), "MeanDimBackward");
}
#[test]
fn test_median_with_dim_odd_lower_median() {
let x = leaf(&[3.0, 1.0, 2.0], &[3], false);
let (vals, inds) = median_with_dim(&x, 0, false).unwrap();
assert_eq!(vals.data().unwrap(), &[2.0]);
assert_eq!(inds.data().unwrap(), &[2]);
}
#[test]
fn test_median_with_dim_even_takes_lower() {
let x = leaf(&[4.0, 2.0, 1.0, 3.0], &[4], false);
let (vals, inds) = median_with_dim(&x, 0, false).unwrap();
assert_eq!(vals.data().unwrap(), &[2.0]);
assert_eq!(inds.data().unwrap(), &[1]);
}
#[test]
fn test_median_with_dim_2d_axis1() {
let x = leaf(&[5.0, 3.0, 4.0, 1.0, 9.0, 2.0], &[2, 3], false);
let (vals, inds) = median_with_dim(&x, 1, false).unwrap();
assert_eq!(vals.shape(), &[2]);
assert_eq!(vals.data().unwrap(), &[4.0, 2.0]);
assert_eq!(inds.data().unwrap(), &[2, 2]);
}
#[test]
fn test_median_nan_poisons_slice() {
let x = leaf(&[1.0, f64::NAN, 3.0], &[3], false);
let (vals, inds) = median_with_dim(&x, 0, false).unwrap();
assert!(vals.data().unwrap()[0].is_nan());
assert_eq!(inds.data().unwrap(), &[1]);
}
#[test]
fn test_nanmedian_skips_nan() {
let x = leaf(&[1.0, f64::NAN, 3.0, 2.0], &[4], false);
let (vals, inds) = nanmedian_with_dim(&x, 0, false).unwrap();
assert_eq!(vals.data().unwrap(), &[2.0]);
assert_eq!(inds.data().unwrap(), &[3]);
}
#[test]
fn test_median_with_dim_keepdim_shape() {
let x = leaf(&[5.0, 3.0, 4.0, 1.0, 9.0, 2.0], &[2, 3], false);
let (vals, inds) = median_with_dim(&x, 1, true).unwrap();
assert_eq!(vals.shape(), &[2, 1]);
assert_eq!(inds.shape(), &[2, 1]);
}
#[test]
fn test_median_backward_scatters_to_selected_index() {
let x = leaf(&[3.0, 1.0, 2.0], &[3], true);
let (vals, _inds) = median_with_dim(&x, 0, false).unwrap();
assert_eq!(vals.grad_fn().unwrap().name(), "MedianDimBackward");
vals.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.data().unwrap(), &[0.0, 0.0, 1.0]);
}
#[test]
fn test_median_backward_2d_finite_difference() {
let x = leaf(&[5.0, 3.0, 4.0, 1.0, 9.0, 2.0], &[2, 3], true);
let (vals, _inds) = median_with_dim(&x, 1, false).unwrap();
let loss = sum(&vals).unwrap();
loss.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.data().unwrap(), &[0.0, 0.0, 1.0, 0.0, 0.0, 1.0]);
}
#[test]
fn test_nanmedian_backward_scatters_to_nonnan_median() {
let x = leaf(&[1.0, f64::NAN, 3.0, 2.0], &[4], true);
let (vals, _inds) = nanmedian_with_dim(&x, 0, false).unwrap();
assert_eq!(vals.grad_fn().unwrap().name(), "NanmedianDimBackward");
vals.backward().unwrap();
let g = x.grad().unwrap().unwrap();
assert_eq!(g.data().unwrap(), &[0.0, 0.0, 0.0, 1.0]);
}
}