use std::any::TypeId;
use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::device::Device;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[inline]
fn is_f32<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f32>()
}
#[inline]
fn is_f64<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f64>()
}
#[inline]
fn ensure_cpu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<(Tensor<T>, Device)> {
let device = input.device();
if input.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda {
op: "shape backward",
});
}
Ok((input.clone(), device))
}
#[inline]
fn restore_device<T: Float>(tensor: Tensor<T>, device: Device) -> FerrotorchResult<Tensor<T>> {
if device.is_cuda() {
tensor.to(device)
} else {
Ok(tensor)
}
}
#[derive(Debug)]
pub struct ReshapeBackward<T: Float> {
input: Tensor<T>,
input_shape: Vec<usize>,
}
impl<T: Float> ReshapeBackward<T> {
pub fn new(input: Tensor<T>, input_shape: Vec<usize>) -> Self {
Self { input, input_shape }
}
}
impl<T: Float> GradFn<T> for ReshapeBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let grad_input = grad_output.view_reshape(self.input_shape.clone())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ReshapeBackward"
}
}
pub fn reshape<T: Float>(input: &Tensor<T>, new_shape: &[isize]) -> FerrotorchResult<Tensor<T>> {
let numel = input.numel();
let resolved = resolve_shape(new_shape, numel)?;
if !is_grad_enabled() || !input.requires_grad() {
return input.view_reshape(resolved);
}
let grad_fn = Arc::new(ReshapeBackward::new(input.clone(), input.shape().to_vec()));
input.view_operation(resolved, grad_fn)
}
#[derive(Debug)]
pub struct FlattenBackward<T: Float> {
input: Tensor<T>,
input_shape: Vec<usize>,
}
impl<T: Float> FlattenBackward<T> {
pub fn new(input: Tensor<T>, input_shape: Vec<usize>) -> Self {
Self { input, input_shape }
}
}
impl<T: Float> GradFn<T> for FlattenBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let grad_input = grad_output.view_reshape(self.input_shape.clone())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FlattenBackward"
}
}
pub fn flatten<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let numel = input.numel();
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(FlattenBackward::new(input.clone(), input.shape().to_vec()));
input.view_operation(vec![numel], grad_fn)
} else {
input.view_reshape(vec![numel])
}
}
#[derive(Debug)]
pub struct SqueezeBackward<T: Float> {
input: Tensor<T>,
axis: usize,
}
impl<T: Float> SqueezeBackward<T> {
pub fn new(input: Tensor<T>, axis: usize) -> Self {
Self { input, axis }
}
}
impl<T: Float> GradFn<T> for SqueezeBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let mut new_shape = grad_output.shape().to_vec();
new_shape.insert(self.axis, 1);
let grad_input = grad_output.view_reshape(new_shape)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SqueezeBackward"
}
}
pub fn squeeze<T: Float>(input: &Tensor<T>, axis: isize) -> FerrotorchResult<Tensor<T>> {
let norm_axis = crate::shape::normalize_axis(axis, input.ndim())?;
if input.shape()[norm_axis] != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"squeeze: dimension {} has size {}, expected 1",
norm_axis,
input.shape()[norm_axis]
),
});
}
let mut new_shape = input.shape().to_vec();
new_shape.remove(norm_axis);
if !is_grad_enabled() || !input.requires_grad() {
return input.view_reshape(new_shape);
}
let grad_fn = Arc::new(SqueezeBackward::new(input.clone(), norm_axis));
input.view_operation(new_shape, grad_fn)
}
#[derive(Debug)]
pub struct UnsqueezeBackward<T: Float> {
input: Tensor<T>,
axis: usize,
}
impl<T: Float> UnsqueezeBackward<T> {
pub fn new(input: Tensor<T>, axis: usize) -> Self {
Self { input, axis }
}
}
impl<T: Float> GradFn<T> for UnsqueezeBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let mut new_shape = grad_output.shape().to_vec();
new_shape.remove(self.axis);
let grad_input = grad_output.view_reshape(new_shape)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"UnsqueezeBackward"
}
}
pub fn unsqueeze<T: Float>(input: &Tensor<T>, axis: isize) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
let new_ndim = ndim + 1;
let ndim_i = new_ndim as isize;
if axis >= ndim_i || axis < -ndim_i {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"unsqueeze: axis {axis} is out of bounds for tensor with {ndim} dimensions (new ndim = {new_ndim})"
),
});
}
let norm_axis = if axis < 0 {
(ndim_i + axis) as usize
} else {
axis as usize
};
let mut new_shape = input.shape().to_vec();
new_shape.insert(norm_axis, 1);
if !is_grad_enabled() || !input.requires_grad() {
return input.view_reshape(new_shape);
}
let grad_fn = Arc::new(UnsqueezeBackward::new(input.clone(), norm_axis));
input.view_operation(new_shape, grad_fn)
}
#[derive(Debug)]
pub struct TransposeBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> TransposeBackward<T> {
pub fn new(input: Tensor<T>) -> Self {
Self { input }
}
}
impl<T: Float> GradFn<T> for TransposeBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let grad_input = crate::methods::permute_t(grad_output, &[1, 0])?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"TransposeBackward"
}
}
pub fn transpose_2d<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("transpose_2d requires 2-D tensor, got {:?}", input.shape()),
});
}
crate::methods::permute_t(input, &[1, 0])
}
#[derive(Debug)]
pub struct ExpandBackward<T: Float> {
input: Tensor<T>,
input_shape: Vec<usize>,
}
impl<T: Float> ExpandBackward<T> {
pub fn new(input: Tensor<T>, input_shape: Vec<usize>) -> Self {
Self { input, input_shape }
}
}
impl<T: Float> GradFn<T> for ExpandBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let grad_input = super::arithmetic::reduce_grad_to_shape(grad_output, &self.input_shape)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ExpandBackward"
}
}
pub fn expand<T: Float>(input: &Tensor<T>, new_shape: &[usize]) -> FerrotorchResult<Tensor<T>> {
let in_shape = input.shape();
let out_ndim = new_shape.len();
let in_ndim = in_shape.len();
if out_ndim < in_ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"expand: target shape {new_shape:?} has fewer dimensions than input {in_shape:?}"
),
});
}
for i in 0..in_ndim {
let in_dim = in_shape[in_ndim - 1 - i];
let out_dim = new_shape[out_ndim - 1 - i];
if in_dim != 1 && in_dim != out_dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"expand: cannot expand dimension {} from {} to {}",
in_ndim - 1 - i,
in_dim,
out_dim
),
});
}
}
if input.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let device_ord = input.gpu_handle()?.device_ordinal();
let zeros_dtype = if is_f64::<T>() {
crate::dtype::DType::F64
} else {
crate::dtype::DType::F32
};
let zeros = backend.alloc_zeros(1, zeros_dtype, device_ord)?;
let expanded = if is_f64::<T>() {
backend.broadcast_add_f64(input.gpu_handle()?, &zeros, in_shape, &[1], new_shape)?
} else {
backend.broadcast_add_f32(input.gpu_handle()?, &zeros, in_shape, &[1], new_shape)?
};
let storage = TensorStorage::gpu(expanded);
return if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(ExpandBackward::new(input.clone(), in_shape.to_vec()));
Tensor::from_operation(storage, new_shape.to_vec(), grad_fn)
} else {
Tensor::from_storage(storage, new_shape.to_vec(), false)
};
}
if input.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda { op: "expand" });
}
let in_data = input.data()?;
let out_numel: usize = new_shape.iter().product();
let mut out_data = Vec::with_capacity(out_numel);
for flat in 0..out_numel {
let idx = broadcast_flat_index(flat, new_shape, in_shape);
out_data.push(in_data[idx]);
}
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(ExpandBackward::new(input.clone(), in_shape.to_vec()));
Tensor::from_operation(TensorStorage::cpu(out_data), new_shape.to_vec(), grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out_data), new_shape.to_vec(), false)
}
}
pub fn expand_as<T: Float>(input: &Tensor<T>, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
expand(input, other.shape())
}
pub fn unflatten<T: Float>(
input: &Tensor<T>,
dim: isize,
sizes: &[isize],
) -> FerrotorchResult<Tensor<T>> {
if sizes.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "unflatten: sizes must be non-empty".into(),
});
}
let norm_dim = crate::shape::normalize_axis(dim, input.ndim())?;
let old_shape = input.shape();
let dim_size = old_shape[norm_dim];
let resolved_sizes = resolve_unflatten_sizes(sizes, dim_size)?;
let mut new_shape: Vec<isize> = Vec::with_capacity(old_shape.len() + sizes.len() - 1);
new_shape.extend(old_shape[..norm_dim].iter().map(|&d| d as isize));
new_shape.extend(resolved_sizes.iter().map(|&d| d as isize));
new_shape.extend(old_shape[norm_dim + 1..].iter().map(|&d| d as isize));
reshape(input, &new_shape)
}
fn resolve_unflatten_sizes(sizes: &[isize], dim_size: usize) -> FerrotorchResult<Vec<usize>> {
let mut inferred_idx: Option<usize> = None;
let mut product: usize = 1;
for (i, &s) in sizes.iter().enumerate() {
if s == -1 {
if inferred_idx.is_some() {
return Err(FerrotorchError::InvalidArgument {
message: "unflatten: only one dimension can be -1".into(),
});
}
inferred_idx = Some(i);
} else if s < 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("unflatten: invalid size {s}"),
});
} else {
product *= s as usize;
}
}
let mut out: Vec<usize> = sizes.iter().map(|&s| s.max(0) as usize).collect();
if let Some(idx) = inferred_idx {
if product == 0 || !dim_size.is_multiple_of(product) {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"unflatten: cannot infer -1 slot for dim of size {dim_size} from {sizes:?}"
),
});
}
out[idx] = dim_size / product;
} else if product != dim_size {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"unflatten: provided sizes {sizes:?} (product {product}) do not match dim size {dim_size}"
),
});
}
Ok(out)
}
pub fn swapaxes<T: Float>(
input: &Tensor<T>,
axis0: usize,
axis1: usize,
) -> FerrotorchResult<Tensor<T>> {
input.transpose(axis0, axis1)
}
pub fn swapdims<T: Float>(
input: &Tensor<T>,
dim0: usize,
dim1: usize,
) -> FerrotorchResult<Tensor<T>> {
input.transpose(dim0, dim1)
}
#[derive(Debug)]
pub struct FlipBackward<T: Float> {
input: Tensor<T>,
dims: Vec<usize>,
}
impl<T: Float> FlipBackward<T> {
pub fn new(input: Tensor<T>, dims: Vec<usize>) -> Self {
Self { input, dims }
}
}
impl<T: Float> GradFn<T> for FlipBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let (cpu_go, device) = ensure_cpu(grad_output)?;
let go_data = cpu_go.data_vec()?;
let shape = cpu_go.shape();
let flipped = flip_cpu_inner(&go_data, shape, &self.dims);
let grad_tensor = Tensor::from_storage(TensorStorage::cpu(flipped), shape.to_vec(), false)?;
Ok(vec![Some(restore_device(grad_tensor, device)?)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FlipBackward"
}
}
pub fn flip<T: Float>(input: &Tensor<T>, dims: &[isize]) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
let mut norm: Vec<usize> = Vec::with_capacity(dims.len());
for &d in dims {
let nd = crate::shape::normalize_axis(d, ndim)?;
if norm.contains(&nd) {
return Err(FerrotorchError::InvalidArgument {
message: format!("flip: dim {nd} appears multiple times in the list of dims"),
});
}
norm.push(nd);
}
if input.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda { op: "flip" });
}
let in_data = input.data_vec()?;
let shape = input.shape();
let out_data = flip_cpu_inner(&in_data, shape, &norm);
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(FlipBackward::new(input.clone(), norm));
Tensor::from_operation(TensorStorage::cpu(out_data), shape.to_vec(), grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out_data), shape.to_vec(), false)
}
}
fn flip_cpu_inner<T: Float>(data: &[T], shape: &[usize], dims: &[usize]) -> Vec<T> {
let numel = data.len();
let strides = crate::shape::c_contiguous_strides(shape);
let ndim = shape.len();
let mut out = vec![<T as num_traits::Zero>::zero(); numel];
for out_flat in 0..numel {
let mut rem = out_flat;
let mut src_flat = 0usize;
for d in 0..ndim {
let stride = strides[d] as usize;
let coord = rem / stride;
rem %= stride;
let src_coord = if dims.contains(&d) {
shape[d] - 1 - coord
} else {
coord
};
src_flat += src_coord * stride;
}
out[out_flat] = data[src_flat];
}
out
}
pub fn fliplr<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() < 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("fliplr: input must be >= 2-D, got {}-D", input.ndim()),
});
}
flip(input, &[1])
}
pub fn flipud<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() < 1 {
return Err(FerrotorchError::InvalidArgument {
message: "flipud: input must be >= 1-D, got 0-D".into(),
});
}
flip(input, &[0])
}
pub fn rot90<T: Float>(input: &Tensor<T>, k: i64, dims: &[isize]) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if dims.len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"rot90: expected exactly 2 rotation dims, got {}",
dims.len()
),
});
}
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("rot90: expected total dims >= 2, got {ndim}"),
});
}
let d0 = crate::shape::normalize_axis(dims[0], ndim)?;
let d1 = crate::shape::normalize_axis(dims[1], ndim)?;
if d0 == d1 {
return Err(FerrotorchError::InvalidArgument {
message: format!("rot90: rotation dims must differ, got dim0 = {d0}, dim1 = {d1}"),
});
}
let kk = k.rem_euclid(4) as u8;
match kk {
1 => flip(input, &[d1 as isize])?.transpose(d0, d1),
2 => flip(input, &[d0 as isize, d1 as isize]),
3 => flip(input, &[d0 as isize])?.transpose(d0, d1),
_ => Ok(input.clone()),
}
}
pub fn movedim<T: Float>(
input: &Tensor<T>,
source: &[isize],
destination: &[isize],
) -> FerrotorchResult<Tensor<T>> {
if source.len() != destination.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"movedim: source ({} dims) and destination ({} dims) must match in length",
source.len(),
destination.len()
),
});
}
let ndim = input.ndim();
let norm_src: Vec<usize> = source
.iter()
.map(|&d| crate::shape::normalize_axis(d, ndim))
.collect::<FerrotorchResult<_>>()?;
let norm_dst: Vec<usize> = destination
.iter()
.map(|&d| crate::shape::normalize_axis(d, ndim))
.collect::<FerrotorchResult<_>>()?;
let has_dup = |v: &[usize]| {
let mut s = v.to_vec();
s.sort_unstable();
s.windows(2).any(|w| w[0] == w[1])
};
if has_dup(&norm_src) {
return Err(FerrotorchError::InvalidArgument {
message: "movedim: repeated dim in `source`".into(),
});
}
if has_dup(&norm_dst) {
return Err(FerrotorchError::InvalidArgument {
message: "movedim: repeated dim in `destination`".into(),
});
}
if ndim == 0 {
return Ok(input.clone());
}
let sentinel = usize::MAX;
let mut order = vec![sentinel; ndim];
let mut src_used = vec![false; ndim];
for i in 0..norm_src.len() {
order[norm_dst[i]] = norm_src[i];
src_used[norm_src[i]] = true;
}
let mut leftover_src = (0..ndim).filter(|d| !src_used[*d]);
for slot in &mut order {
if *slot == sentinel {
*slot = leftover_src
.next()
.expect("movedim: leftover dim accounting");
}
}
crate::methods::permute_t(input, &order)
}
pub fn moveaxis<T: Float>(
input: &Tensor<T>,
source: &[isize],
destination: &[isize],
) -> FerrotorchResult<Tensor<T>> {
movedim(input, source, destination)
}
pub fn broadcast_to<T: Float>(input: &Tensor<T>, shape: &[usize]) -> FerrotorchResult<Tensor<T>> {
expand(input, shape)
}
pub fn broadcast_tensors<T: Float>(tensors: &[Tensor<T>]) -> FerrotorchResult<Vec<Tensor<T>>> {
if tensors.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "broadcast_tensors: empty tensor list".into(),
});
}
let mut common: Vec<usize> = tensors[0].shape().to_vec();
for t in &tensors[1..] {
common = crate::shape::broadcast_shapes(&common, t.shape())?;
}
tensors.iter().map(|t| expand(t, &common)).collect()
}
pub fn repeat<T: Float>(input: &Tensor<T>, repeats: &[isize]) -> FerrotorchResult<Tensor<T>> {
let in_ndim = input.ndim();
if repeats.len() < in_ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"repeat: number of repeat dims ({}) cannot be smaller than tensor dims ({})",
repeats.len(),
in_ndim
),
});
}
for &r in repeats {
if r < 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("repeat: repeat value {r} must be non-negative"),
});
}
}
let num_new = repeats.len() - in_ndim;
let mut cur = if num_new > 0 {
let mut padded: Vec<isize> = vec![1; num_new];
padded.extend(input.shape().iter().map(|&d| d as isize));
reshape(input, &padded)?
} else {
input.clone()
};
for (ax, &r) in repeats.iter().enumerate() {
let r = r as usize;
if r == 1 {
continue;
}
if r == 0 {
let mut zero_shape: Vec<isize> = cur.shape().iter().map(|&d| d as isize).collect();
zero_shape[ax] = 0;
cur = reshape(&cur, &zero_shape)?;
continue;
}
let copies = vec![cur.clone(); r];
cur = cat(&copies, ax as isize)?;
}
Ok(cur)
}
pub fn tile<T: Float>(input: &Tensor<T>, reps: &[isize]) -> FerrotorchResult<Tensor<T>> {
let in_ndim = input.ndim();
if reps.len() < in_ndim {
let pad = in_ndim - reps.len();
let mut padded: Vec<isize> = vec![1; pad];
padded.extend_from_slice(reps);
repeat(input, &padded)
} else {
repeat(input, reps)
}
}
pub fn unbind<T: Float>(input: &Tensor<T>, dim: isize) -> FerrotorchResult<Vec<Tensor<T>>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "unbind: cannot unbind a 0-D tensor".into(),
});
}
let norm_dim = crate::shape::normalize_axis(dim, ndim)?;
let size = input.shape()[norm_dim];
let mut out = Vec::with_capacity(size);
for i in 0..size {
let slice = crate::methods::narrow_t(input, norm_dim, i, 1)?;
out.push(squeeze(&slice, norm_dim as isize)?);
}
Ok(out)
}
pub fn tensor_split<T: Float>(
input: &Tensor<T>,
indices: &[usize],
dim: isize,
) -> FerrotorchResult<Vec<Tensor<T>>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "tensor_split: expected at least a 1-dimensional tensor".into(),
});
}
let norm_dim = crate::shape::normalize_axis(dim, ndim)?;
let dim_size = input.shape()[norm_dim];
let mut out = Vec::with_capacity(indices.len() + 1);
let mut start = 0usize;
for &idx in indices {
let end = idx.clamp(start, dim_size);
out.push(crate::methods::narrow_t(
input,
norm_dim,
start,
end - start,
)?);
start = end;
}
out.push(crate::methods::narrow_t(
input,
norm_dim,
start,
dim_size - start,
)?);
Ok(out)
}
#[derive(Debug)]
pub struct RepeatInterleaveBackward<T: Float> {
input: Tensor<T>,
dim: usize,
repeats: usize,
}
impl<T: Float> RepeatInterleaveBackward<T> {
pub fn new(input: Tensor<T>, dim: usize, repeats: usize) -> Self {
Self {
input,
dim,
repeats,
}
}
}
impl<T: Float> GradFn<T> for RepeatInterleaveBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let (cpu_go, device) = ensure_cpu(grad_output)?;
let go_data = cpu_go.data()?;
let in_shape = self.input.shape();
let dim = self.dim;
let dim_size = in_shape[dim];
let outer: usize = in_shape[..dim].iter().product();
let inner: usize = if dim + 1 < in_shape.len() {
in_shape[dim + 1..].iter().product()
} else {
1
};
let out_dim_size = dim_size * self.repeats;
let in_numel: usize = in_shape.iter().product();
let mut grad = vec![<T as num_traits::Zero>::zero(); in_numel];
for o in 0..outer {
for d in 0..dim_size {
for r in 0..self.repeats {
let od = d * self.repeats + r;
let src_base = o * out_dim_size * inner + od * inner;
let dst_base = o * dim_size * inner + d * inner;
for i in 0..inner {
grad[dst_base + i] += go_data[src_base + i];
}
}
}
}
let grad_tensor = Tensor::from_storage(TensorStorage::cpu(grad), in_shape.to_vec(), false)?;
Ok(vec![Some(restore_device(grad_tensor, device)?)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RepeatInterleaveBackward"
}
}
pub fn repeat_interleave<T: Float>(
input: &Tensor<T>,
repeats: usize,
dim: isize,
) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "repeat_interleave: cannot repeat a 0-D tensor along a dim".into(),
});
}
let norm_dim = crate::shape::normalize_axis(dim, ndim)?;
if input.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda {
op: "repeat_interleave",
});
}
let in_data = input.data_vec()?;
let in_shape = input.shape();
let dim_size = in_shape[norm_dim];
let outer: usize = in_shape[..norm_dim].iter().product();
let inner: usize = if norm_dim + 1 < ndim {
in_shape[norm_dim + 1..].iter().product()
} else {
1
};
let out_dim_size = dim_size * repeats;
let out_numel = outer * out_dim_size * inner;
let mut out_data = Vec::with_capacity(out_numel);
for o in 0..outer {
for d in 0..dim_size {
let src_base = o * dim_size * inner + d * inner;
for _ in 0..repeats {
out_data.extend_from_slice(&in_data[src_base..src_base + inner]);
}
}
}
let mut out_shape = in_shape.to_vec();
out_shape[norm_dim] = out_dim_size;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(RepeatInterleaveBackward::new(
input.clone(),
norm_dim,
repeats,
));
Tensor::from_operation(TensorStorage::cpu(out_data), out_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape, false)
}
}
pub fn vstack<T: Float>(tensors: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "vstack: empty tensor list".into(),
});
}
let promoted: Vec<Tensor<T>> = tensors
.iter()
.map(atleast_2d)
.collect::<FerrotorchResult<_>>()?;
cat(&promoted, 0)
}
pub fn hstack<T: Float>(tensors: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "hstack: empty tensor list".into(),
});
}
let promoted: Vec<Tensor<T>> = tensors
.iter()
.map(atleast_1d)
.collect::<FerrotorchResult<_>>()?;
let axis: isize = isize::from(promoted[0].ndim() != 1);
cat(&promoted, axis)
}
pub fn dstack<T: Float>(tensors: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "dstack: empty tensor list".into(),
});
}
let promoted: Vec<Tensor<T>> = tensors
.iter()
.map(atleast_3d)
.collect::<FerrotorchResult<_>>()?;
cat(&promoted, 2)
}
pub fn column_stack<T: Float>(tensors: &[Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "column_stack: empty tensor list".into(),
});
}
let reshaped: Vec<Tensor<T>> = tensors
.iter()
.map(|t| {
if t.ndim() <= 1 {
reshape(t, &[t.numel() as isize, 1])
} else {
Ok(t.clone())
}
})
.collect::<FerrotorchResult<_>>()?;
hstack(&reshaped)
}
fn atleast_1d<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() == 0 {
reshape(input, &[1])
} else {
Ok(input.clone())
}
}
fn atleast_2d<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
match input.ndim() {
0 => reshape(input, &[1, 1]),
1 => unsqueeze(input, 0),
_ => Ok(input.clone()),
}
}
fn atleast_3d<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
match input.ndim() {
0 => reshape(input, &[1, 1, 1]),
1 => unsqueeze(&unsqueeze(input, 0)?, -1),
2 => unsqueeze(input, -1),
_ => Ok(input.clone()),
}
}
#[derive(Debug)]
pub struct CatBackward<T: Float> {
inputs: Vec<Tensor<T>>,
axis: usize,
split_sizes: Vec<usize>,
}
impl<T: Float> CatBackward<T> {
pub fn new(inputs: Vec<Tensor<T>>, axis: usize, split_sizes: Vec<usize>) -> Self {
Self {
inputs,
axis,
split_sizes,
}
}
}
impl<T: Float> GradFn<T> for CatBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let go_shape = grad_output.shape();
let ndim = go_shape.len();
let axis = self.axis;
let total_along_axis = go_shape[axis];
let inner: usize = if axis + 1 < ndim {
go_shape[axis + 1..].iter().product()
} else {
1
};
let go_handle = grad_output.gpu_handle()?;
let f64_path = is_f64::<T>();
let mut result = Vec::with_capacity(self.inputs.len());
let mut offset = 0usize;
for (i, &split_size) in self.split_sizes.iter().enumerate() {
if !self.inputs[i].requires_grad() {
result.push(None);
offset += split_size;
continue;
}
let chunk_numel = self.inputs[i].numel();
let chunk_handle = if f64_path {
backend.strided_split_f64(
go_handle,
total_along_axis,
offset,
split_size,
inner,
chunk_numel,
)?
} else {
backend.strided_split_f32(
go_handle,
total_along_axis,
offset,
split_size,
inner,
chunk_numel,
)?
};
let grad_tensor = Tensor::from_storage(
TensorStorage::gpu(chunk_handle),
self.inputs[i].shape().to_vec(),
false,
)?;
result.push(Some(grad_tensor));
offset += split_size;
}
return Ok(result);
}
let (cpu_go, device) = ensure_cpu(grad_output)?;
let grad_data = cpu_go.data()?;
let out_shape = cpu_go.shape();
let ndim = out_shape.len();
let axis = self.axis;
let outer: usize = out_shape[..axis].iter().product();
let inner: usize = if axis + 1 < ndim {
out_shape[axis + 1..].iter().product()
} else {
1
};
let mut result = Vec::with_capacity(self.inputs.len());
let mut offset = 0usize;
for (i, split_size) in self.split_sizes.iter().enumerate() {
if !self.inputs[i].requires_grad() {
result.push(None);
offset += split_size * inner;
continue;
}
let chunk_numel: usize = self.inputs[i].numel();
let mut grad_chunk = vec![<T as num_traits::Zero>::zero(); chunk_numel];
for o in 0..outer {
let src_row_start = o * out_shape[axis] * inner + offset;
let dst_row_start = o * split_size * inner;
let row_len = split_size * inner;
grad_chunk[dst_row_start..dst_row_start + row_len]
.copy_from_slice(&grad_data[src_row_start..src_row_start + row_len]);
}
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_chunk),
self.inputs[i].shape().to_vec(),
false,
)?;
result.push(Some(restore_device(grad_tensor, device)?));
offset += split_size * inner;
}
Ok(result)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
self.inputs.iter().collect()
}
fn name(&self) -> &'static str {
"CatBackward"
}
}
#[derive(Debug)]
pub struct SplitBackward<T: Float> {
input: Tensor<T>,
dim: usize,
offset: usize,
chunk_size: usize,
}
impl<T: Float> SplitBackward<T> {
pub fn new(input: Tensor<T>, dim: usize, offset: usize, chunk_size: usize) -> Self {
Self {
input,
dim,
offset,
chunk_size,
}
}
}
impl<T: Float> GradFn<T> for SplitBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let elem_size = std::mem::size_of::<T>();
if grad_output.is_cuda()
&& matches!(elem_size, 2 | 4 | 8)
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let orig_shape = self.input.shape();
let ndim = orig_shape.len();
let inner: usize = if self.dim + 1 < ndim {
orig_shape[self.dim + 1..].iter().product()
} else {
1
};
let total_along_dim = orig_shape[self.dim];
let orig_numel: usize = orig_shape.iter().product();
let device_ord = grad_output.gpu_handle()?.device_ordinal();
let mut zeros_handle = backend.alloc_zeros(orig_numel, T::dtype(), device_ord)?;
let go_handle = grad_output.gpu_handle()?;
let chunk_numel = grad_output.numel();
backend.strided_cat(
go_handle,
&mut zeros_handle,
total_along_dim,
self.offset,
self.chunk_size,
inner,
chunk_numel,
elem_size,
)?;
let grad_tensor =
Tensor::from_storage(TensorStorage::gpu(zeros_handle), orig_shape.to_vec(), false)?;
return Ok(vec![Some(grad_tensor)]);
}
let (cpu_go, device) = ensure_cpu(grad_output)?;
let grad_data = cpu_go.data()?;
let orig_shape = self.input.shape();
let ndim = orig_shape.len();
let outer: usize = orig_shape[..self.dim].iter().product();
let inner: usize = if self.dim + 1 < ndim {
orig_shape[self.dim + 1..].iter().product()
} else {
1
};
let total_along_dim = orig_shape[self.dim];
let orig_numel: usize = orig_shape.iter().product();
let mut result = vec![<T as num_traits::Zero>::zero(); orig_numel];
for o in 0..outer {
let dst_start = o * total_along_dim * inner + self.offset * inner;
let src_start = o * self.chunk_size * inner;
let row_len = self.chunk_size * inner;
result[dst_start..dst_start + row_len]
.copy_from_slice(&grad_data[src_start..src_start + row_len]);
}
let grad_tensor =
Tensor::from_storage(TensorStorage::cpu(result), orig_shape.to_vec(), false)?;
Ok(vec![Some(restore_device(grad_tensor, device)?)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SplitBackward"
}
}
pub fn cat<T: Float>(tensors: &[Tensor<T>], axis: isize) -> FerrotorchResult<Tensor<T>> {
if tensors.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "cat: empty tensor list".into(),
});
}
let ndim = tensors[0].ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "cat: cannot concatenate scalar (0-D) tensors".into(),
});
}
let norm_axis = crate::shape::normalize_axis(axis, ndim)?;
for (i, t) in tensors.iter().enumerate().skip(1) {
if t.ndim() != ndim {
return Err(FerrotorchError::ShapeMismatch {
message: format!("cat: tensor {} has {} dims, expected {}", i, t.ndim(), ndim),
});
}
for d in 0..ndim {
if d != norm_axis && t.shape()[d] != tensors[0].shape()[d] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cat: tensor {} has shape {:?}, incompatible with {:?} on axis {}",
i,
t.shape(),
tensors[0].shape(),
d
),
});
}
}
}
let mut out_shape = tensors[0].shape().to_vec();
let split_sizes: Vec<usize> = tensors.iter().map(|t| t.shape()[norm_axis]).collect();
let total_along_axis: usize = split_sizes.iter().sum();
out_shape[norm_axis] = total_along_axis;
let device = tensors[0].device();
let elem_size = std::mem::size_of::<T>();
if device.is_cuda()
&& matches!(elem_size, 2 | 4 | 8)
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let inner: usize = if norm_axis + 1 < ndim {
out_shape[norm_axis + 1..].iter().product()
} else {
1
};
let out_numel: usize = out_shape.iter().product();
let device_ord = tensors[0].gpu_handle()?.device_ordinal();
let mut out_handle = backend.alloc_zeros(out_numel, T::dtype(), device_ord)?;
let mut offset = 0usize;
for t in tensors {
let t_axis_size = t.shape()[norm_axis];
let t_numel = t.numel();
let t_handle = t.gpu_handle()?;
backend.strided_cat(
t_handle,
&mut out_handle,
total_along_axis,
offset,
t_axis_size,
inner,
t_numel,
elem_size,
)?;
offset += t_axis_size;
}
let any_requires_grad = tensors.iter().any(|t| t.requires_grad());
let storage = TensorStorage::gpu(out_handle);
return if is_grad_enabled() && any_requires_grad {
let grad_fn = Arc::new(CatBackward::new(tensors.to_vec(), norm_axis, split_sizes));
Tensor::from_operation(storage, out_shape, grad_fn)
} else {
Tensor::from_storage(storage, out_shape, false)
};
}
if device.is_cuda() {
return Err(crate::error::FerrotorchError::NotImplementedOnCuda { op: "cat" });
}
let cpu_tensors: Vec<Tensor<T>> = tensors.to_vec();
let outer: usize = out_shape[..norm_axis].iter().product();
let inner: usize = if norm_axis + 1 < ndim {
out_shape[norm_axis + 1..].iter().product()
} else {
1
};
let out_numel: usize = out_shape.iter().product();
let mut out_data = vec![<T as num_traits::Zero>::zero(); out_numel];
let mut offset = 0usize;
for t in &cpu_tensors {
let t_data = t.data()?;
let t_axis_size = t.shape()[norm_axis];
for o in 0..outer {
let src_start = o * t_axis_size * inner;
let dst_start = o * total_along_axis * inner + offset;
let row_len = t_axis_size * inner;
out_data[dst_start..dst_start + row_len]
.copy_from_slice(&t_data[src_start..src_start + row_len]);
}
offset += t_axis_size * inner;
}
let any_requires_grad = tensors.iter().any(|t| t.requires_grad());
if is_grad_enabled() && any_requires_grad {
let storage = if device.is_cuda() {
let tmp = Tensor::from_storage(TensorStorage::cpu(out_data), out_shape.clone(), false)?;
let gpu_tmp = tmp.to(device)?;
gpu_tmp.into_storage_and_shape()?.0
} else {
TensorStorage::cpu(out_data)
};
let grad_fn = Arc::new(CatBackward::new(tensors.to_vec(), norm_axis, split_sizes));
Tensor::from_operation(storage, out_shape, grad_fn)
} else {
let result = Tensor::from_storage(TensorStorage::cpu(out_data), out_shape, false)?;
restore_device(result, device)
}
}
#[derive(Debug)]
pub struct RollBackward<T: Float> {
input: Tensor<T>,
shifts: i64,
dim: usize,
}
impl<T: Float> RollBackward<T> {
pub fn new(input: Tensor<T>, shifts: i64, dim: usize) -> Self {
Self { input, shifts, dim }
}
}
impl<T: Float> GradFn<T> for RollBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let shape = self.input.shape();
let dim_size = shape[self.dim] as i64;
let shift_norm = if dim_size == 0 {
0
} else {
(((-self.shifts) % dim_size) + dim_size) % dim_size
};
if grad_output.is_cuda() {
if is_f32::<T>()
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
if shift_norm == 0 {
let grad_handle = backend.clone_buffer(grad_output.gpu_handle()?)?;
let grad_tensor = Tensor::from_storage(
TensorStorage::gpu(grad_handle),
shape.to_vec(),
false,
)?;
return Ok(vec![Some(grad_tensor)]);
}
let outer: usize = shape[..self.dim].iter().product();
let inner: usize = shape[self.dim + 1..].iter().product();
let handle = backend.roll_f32(
grad_output.gpu_handle()?,
outer,
shape[self.dim],
inner,
shift_norm as usize,
)?;
let grad_tensor =
Tensor::from_storage(TensorStorage::gpu(handle), shape.to_vec(), false)?;
return Ok(vec![Some(grad_tensor)]);
}
return Err(FerrotorchError::NotImplementedOnCuda {
op: "roll backward",
});
}
let go_data = grad_output.data_vec()?;
let grad = if shift_norm == 0 {
go_data
} else {
crate::ops::tensor_ops::roll_cpu_inner(&go_data, shape, shift_norm as usize, self.dim)
};
let grad_tensor = Tensor::from_storage(TensorStorage::cpu(grad), shape.to_vec(), false)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RollBackward"
}
}
fn resolve_shape(shape: &[isize], numel: usize) -> FerrotorchResult<Vec<usize>> {
let mut inferred_idx: Option<usize> = None;
let mut product: usize = 1;
for (i, &dim) in shape.iter().enumerate() {
if dim == -1 {
if inferred_idx.is_some() {
return Err(FerrotorchError::InvalidArgument {
message: "reshape: only one dimension can be -1".into(),
});
}
inferred_idx = Some(i);
} else if dim < 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("reshape: invalid dimension {dim}"),
});
} else {
product *= dim as usize;
}
}
let mut result: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
if let Some(idx) = inferred_idx {
if product == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "reshape: cannot infer dimension with zero-size dimensions".into(),
});
}
if !numel.is_multiple_of(product) {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"reshape: cannot reshape tensor of {numel} elements into shape {shape:?}"
),
});
}
result[idx] = numel / product;
} else if product != numel {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"reshape: cannot reshape tensor of {numel} elements into shape {shape:?}"
),
});
}
Ok(result)
}
fn broadcast_flat_index(flat: usize, out_shape: &[usize], in_shape: &[usize]) -> usize {
let out_ndim = out_shape.len();
let in_ndim = in_shape.len();
let mut in_flat = 0usize;
let mut in_stride = 1usize;
let mut out_stride = 1usize;
for i in 0..in_ndim {
let out_axis = out_ndim - 1 - i;
let in_axis = in_ndim - 1 - i;
let out_dim = out_shape[out_axis];
let in_dim = in_shape[in_axis];
let coord = (flat / out_stride) % out_dim;
let in_coord = if in_dim == 1 { 0 } else { coord };
in_flat += in_coord * in_stride;
in_stride *= in_dim;
out_stride *= out_dim;
}
in_flat
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::backward;
fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let n = self.input.numel();
let ones = vec![<T as num_traits::One>::one(); n];
let g =
Tensor::from_storage(TensorStorage::cpu(ones), self.input.shape().to_vec(), false)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
fn sum_to_scalar(t: &Tensor<f32>) -> Tensor<f32> {
let data = t.data().unwrap();
let total: f32 = data.iter().sum();
Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(SumBackward { input: t.clone() }),
)
.unwrap()
}
#[test]
fn test_reshape_forward() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let y = reshape(&x, &[3, 2]).unwrap();
assert_eq!(y.shape(), &[3, 2]);
assert_eq!(y.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_reshape_infer_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], false);
let y = reshape(&x, &[2, -1]).unwrap();
assert_eq!(y.shape(), &[2, 3]);
}
#[test]
fn test_reshape_backward() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let y = reshape(&x, &[3, 2]).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let grad = x.grad().unwrap().expect("x should have a gradient");
assert_eq!(grad.shape(), &[2, 3]);
for &v in grad.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6, "expected 1.0, got {v}");
}
}
#[test]
fn test_reshape_shape_mismatch() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
assert!(reshape(&x, &[2, 2]).is_err());
}
#[test]
fn test_flatten_forward() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let y = flatten(&x).unwrap();
assert_eq!(y.shape(), &[4]);
assert_eq!(y.data().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_flatten_backward() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let y = flatten(&x).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let grad = x.grad().unwrap().expect("x should have a gradient");
assert_eq!(grad.shape(), &[2, 3]);
}
#[test]
fn test_squeeze_forward() {
let x = leaf(&[1.0, 2.0, 3.0], &[1, 3], false);
let y = squeeze(&x, 0).unwrap();
assert_eq!(y.shape(), &[3]);
}
#[test]
fn test_squeeze_non_one_error() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
assert!(squeeze(&x, 0).is_err());
}
#[test]
fn test_unsqueeze_forward() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let y = unsqueeze(&x, 0).unwrap();
assert_eq!(y.shape(), &[1, 3]);
let z = unsqueeze(&x, -1).unwrap();
assert_eq!(z.shape(), &[3, 1]);
}
#[test]
fn test_squeeze_unsqueeze_roundtrip() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let y = unsqueeze(&x, 1).unwrap();
assert_eq!(y.shape(), &[3, 1]);
let z = squeeze(&y, 1).unwrap();
assert_eq!(z.shape(), &[3]);
assert_eq!(z.data().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_transpose_2d_forward() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let y = transpose_2d(&x).unwrap();
assert_eq!(y.shape(), &[3, 2]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_cat_forward_axis0() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let b = leaf(&[5.0, 6.0], &[1, 2], false);
let c = cat(&[a, b], 0).unwrap();
assert_eq!(c.shape(), &[3, 2]);
assert_eq!(c.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_cat_forward_axis1() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let b = leaf(&[5.0, 6.0], &[2, 1], false);
let c = cat(&[a, b], 1).unwrap();
assert_eq!(c.shape(), &[2, 3]);
assert_eq!(c.data().unwrap(), &[1.0, 2.0, 5.0, 3.0, 4.0, 6.0]);
}
#[test]
fn test_cat_backward_axis0() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], true);
let b = leaf(&[5.0, 6.0], &[1, 2], true);
let c = cat(&[a.clone(), b.clone()], 0).unwrap();
let loss = sum_to_scalar(&c);
backward(&loss).unwrap();
let a_grad = a.grad().unwrap().expect("a should have gradient");
assert_eq!(a_grad.shape(), &[2, 2]);
for &v in a_grad.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
let b_grad = b.grad().unwrap().expect("b should have gradient");
assert_eq!(b_grad.shape(), &[1, 2]);
for &v in b_grad.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_cat_backward_axis1() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], true);
let b = leaf(&[5.0, 6.0], &[2, 1], true);
let c = cat(&[a.clone(), b.clone()], 1).unwrap();
let loss = sum_to_scalar(&c);
backward(&loss).unwrap();
let a_grad = a.grad().unwrap().expect("a should have gradient");
assert_eq!(a_grad.shape(), &[2, 2]);
for &v in a_grad.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
let b_grad = b.grad().unwrap().expect("b should have gradient");
assert_eq!(b_grad.shape(), &[2, 1]);
for &v in b_grad.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_cat_backward_mixed_requires_grad() {
let a = leaf(&[1.0, 2.0], &[2], true);
let b = leaf(&[3.0, 4.0], &[2], false);
let c = cat(&[a.clone(), b.clone()], 0).unwrap();
let loss = sum_to_scalar(&c);
backward(&loss).unwrap();
let a_grad = a.grad().unwrap().expect("a should have gradient");
assert_eq!(a_grad.shape(), &[2]);
for &v in a_grad.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
assert!(b.grad().unwrap().is_none());
}
#[test]
fn test_cat_empty_error() {
let result: FerrotorchResult<Tensor<f32>> = cat(&[], 0);
assert!(result.is_err());
}
#[test]
fn test_cat_1d() {
let a = leaf(&[1.0, 2.0], &[2], false);
let b = leaf(&[3.0, 4.0, 5.0], &[3], false);
let c = cat(&[a, b], 0).unwrap();
assert_eq!(c.shape(), &[5]);
assert_eq!(c.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_reshape_no_grad() {
crate::autograd::no_grad(|| {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], true);
let y = reshape(&x, &[2, 2]).unwrap();
assert!(y.grad_fn().is_none());
});
}
#[test]
fn test_resolve_shape_basic() {
assert_eq!(resolve_shape(&[2, 3], 6).unwrap(), vec![2, 3]);
}
#[test]
fn test_resolve_shape_infer() {
assert_eq!(resolve_shape(&[2, -1], 6).unwrap(), vec![2, 3]);
assert_eq!(resolve_shape(&[-1, 2], 6).unwrap(), vec![3, 2]);
assert_eq!(resolve_shape(&[-1], 6).unwrap(), vec![6]);
}
#[test]
fn test_resolve_shape_multiple_infer_error() {
assert!(resolve_shape(&[-1, -1], 6).is_err());
}
#[test]
fn test_resolve_shape_mismatch() {
assert!(resolve_shape(&[2, 2], 6).is_err());
}
#[test]
fn test_squeeze_preserves_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[1, 3], true);
let y = squeeze(&x, 0).unwrap();
assert!(y.grad_fn().is_some(), "squeeze must attach a grad_fn");
assert!(!y.is_leaf(), "squeeze output must be non-leaf");
assert!(y.requires_grad(), "squeeze output must require grad");
}
#[test]
fn test_unsqueeze_preserves_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let y = unsqueeze(&x, 0).unwrap();
assert!(y.grad_fn().is_some(), "unsqueeze must attach a grad_fn");
assert!(!y.is_leaf(), "unsqueeze output must be non-leaf");
assert!(y.requires_grad(), "unsqueeze output must require grad");
}
#[test]
fn test_flatten_preserves_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], true);
let y = flatten(&x).unwrap();
assert!(y.grad_fn().is_some(), "flatten must attach a grad_fn");
assert!(!y.is_leaf(), "flatten output must be non-leaf");
assert!(y.requires_grad(), "flatten output must require grad");
}
#[test]
fn test_squeeze_backward_reaches_leaf() {
let x = leaf(&[1.0, 2.0, 3.0], &[3, 1], true);
let squeezed = squeeze(&x, 1).unwrap();
let loss = sum_to_scalar(&squeezed);
backward(&loss).unwrap();
let grad = x
.grad()
.unwrap()
.expect("squeeze must propagate gradients to leaf input");
assert_eq!(grad.shape(), &[3, 1]);
for &v in grad.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6, "expected gradient 1.0, got {v}");
}
}
#[test]
fn test_unsqueeze_backward_reaches_leaf() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let unsqueezed = unsqueeze(&x, 1).unwrap();
let loss = sum_to_scalar(&unsqueezed);
backward(&loss).unwrap();
let grad = x
.grad()
.unwrap()
.expect("unsqueeze must propagate gradients to leaf input");
assert_eq!(grad.shape(), &[3]);
for &v in grad.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6, "expected gradient 1.0, got {v}");
}
}
#[test]
fn test_squeeze_in_longer_chain() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], true);
let two = leaf(&[2.0; 6], &[3, 2], false);
let scaled = crate::grad_fns::arithmetic::mul(&x, &two).unwrap();
let ones = leaf(&[1.0, 1.0], &[2, 1], false);
let row_sums = crate::grad_fns::linalg::mm_differentiable(&scaled, &ones).unwrap();
let squeezed = squeeze(&row_sums, 1).unwrap();
assert!(squeezed.grad_fn().is_some(), "squeeze must preserve graph");
let loss = sum_to_scalar(&squeezed);
backward(&loss).unwrap();
let grad = x
.grad()
.unwrap()
.expect("backward through squeeze in a longer chain must reach leaf parameters");
assert_eq!(grad.shape(), &[3, 2]);
for &v in grad.data().unwrap() {
assert!((v - 2.0).abs() < 1e-6, "expected gradient 2.0, got {v}");
}
}
#[test]
fn test_shape_ops_share_storage_with_input() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let flat = flatten(&x).unwrap();
assert_eq!(flat.data().unwrap(), x.data().unwrap());
assert_eq!(flat.shape(), &[6]);
assert!(
flat.shares_storage(&x),
"flatten should share storage with input (zero-copy)"
);
let orig = leaf(&[1.0, 2.0, 3.0], &[1, 3], true);
let sq2 = squeeze(&orig, 0).unwrap();
assert!(
sq2.shares_storage(&orig),
"squeeze should share storage with input (zero-copy)"
);
let orig3 = leaf(&[1.0, 2.0, 3.0], &[3], true);
let us = unsqueeze(&orig3, 0).unwrap();
assert!(
us.shares_storage(&orig3),
"unsqueeze should share storage with input (zero-copy)"
);
}
#[test]
fn test_squeeze_no_grad_is_view() {
let x = leaf(&[1.0, 2.0, 3.0], &[1, 3], false);
let y = squeeze(&x, 0).unwrap();
assert!(y.grad_fn().is_none());
assert_eq!(y.shape(), &[3]);
assert_eq!(y.data().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_roll_forward_registers_grad_fn() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5], true);
let y = crate::ops::tensor_ops::roll(&x, 2, 0).unwrap();
assert_eq!(y.data().unwrap(), &[4.0, 5.0, 1.0, 2.0, 3.0]);
assert!(y.requires_grad());
assert!(!y.is_leaf());
assert_eq!(y.grad_fn().unwrap().name(), "RollBackward");
}
#[test]
fn test_roll_zero_shift_early_return() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let y = crate::ops::tensor_ops::roll(&x, 0, 0).unwrap();
assert_eq!(y.data().unwrap(), &[1.0, 2.0, 3.0]);
let y2 = crate::ops::tensor_ops::roll(&x, 3, 0).unwrap();
assert_eq!(y2.data().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_roll_backward_simple_1d_hand_computed() {
let x = leaf(&[10.0, 20.0, 30.0, 40.0, 50.0], &[5], true);
let y = crate::ops::tensor_ops::roll(&x, 2, 0).unwrap();
#[derive(Debug)]
struct WeightedSumBackward<T: Float> {
input: Tensor<T>,
weights: Vec<T>,
}
impl<T: Float> GradFn<T> for WeightedSumBackward<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let g = Tensor::from_storage(
TensorStorage::cpu(self.weights.clone()),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"WeightedSumBackward"
}
}
let w = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let total: f32 = y
.data()
.unwrap()
.iter()
.zip(w.iter())
.map(|(yi, wi)| yi * wi)
.sum();
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(WeightedSumBackward {
input: y.clone(),
weights: w,
}),
)
.unwrap();
backward(&loss).unwrap();
let grad = x.grad().unwrap().expect("x should have a gradient");
let gd = grad.data().unwrap();
let expected = [3.0, 4.0, 5.0, 1.0, 2.0];
for (i, (&g, &e)) in gd.iter().zip(expected.iter()).enumerate() {
assert!((g - e).abs() < 1e-6, "grad[{i}] = {g}, expected {e}");
}
}
#[test]
fn test_roll_backward_negative_shift_2d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let y = crate::ops::tensor_ops::roll(&x, -1, 1).unwrap();
assert_eq!(y.data().unwrap(), &[2.0, 3.0, 1.0, 5.0, 6.0, 4.0]);
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![1.0_f32, 10.0, 100.0, 1000.0, 10000.0, 100000.0]),
vec![2, 3],
false,
)
.unwrap();
let grad_fn = y.grad_fn().expect("y must carry RollBackward");
let grads = grad_fn.backward(&grad_output).unwrap();
let g = grads[0].as_ref().expect("grad must be Some");
let gd = g.data().unwrap();
let expected = [100.0_f32, 1.0, 10.0, 100000.0, 1000.0, 10000.0];
for (i, (&got, &exp)) in gd.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-6,
"grad[{i}] = {got}, expected {exp}"
);
}
}
#[test]
fn test_unflatten_forward() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let y = unflatten(&x, 0, &[2, 1]).unwrap();
assert_eq!(y.shape(), &[2, 1, 3]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_unflatten_infer_slot() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], false);
let y = unflatten(&x, 0, &[2, -1]).unwrap();
assert_eq!(y.shape(), &[2, 3]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_unflatten_negative_dim_and_middle_splice() {
let x = leaf(
&(0..24).map(|v| v as f32).collect::<Vec<_>>(),
&[2, 12, 1],
false,
);
let y = unflatten(&x, -2, &[3, 4]).unwrap();
assert_eq!(y.shape(), &[2, 3, 4, 1]);
}
#[test]
fn test_unflatten_empty_sizes_errors() {
let x = leaf(&[1.0, 2.0], &[2], false);
assert!(unflatten(&x, 0, &[]).is_err());
}
#[test]
fn test_unflatten_product_mismatch_errors() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], false);
assert!(unflatten(&x, 0, &[2, 4]).is_err());
}
#[test]
fn test_unflatten_backward_reaches_leaf() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], true);
let y = unflatten(&x, 0, &[2, 3]).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.shape(), &[6]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_swapaxes_equals_transpose() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let via_swap = swapaxes(&x, 0, 1).unwrap();
let via_transpose = x.transpose(0, 1).unwrap();
assert_eq!(via_swap.shape(), via_transpose.shape());
assert_eq!(
via_swap.data_vec().unwrap(),
via_transpose.data_vec().unwrap()
);
assert_eq!(via_swap.shape(), &[3, 2]);
assert_eq!(
via_swap.data_vec().unwrap(),
&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]
);
}
#[test]
fn test_swapdims_equals_transpose() {
let x = leaf(
&(0..24).map(|v| v as f32).collect::<Vec<_>>(),
&[2, 3, 4],
false,
);
let via_swap = swapdims(&x, 0, 2).unwrap();
let via_transpose = x.transpose(0, 2).unwrap();
assert_eq!(via_swap.shape(), &[4, 3, 2]);
assert_eq!(
via_swap.data_vec().unwrap(),
via_transpose.data_vec().unwrap()
);
}
#[test]
fn test_swapaxes_backward_reaches_leaf() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let y = crate::methods::contiguous_t(&swapaxes(&x, 0, 1).unwrap()).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.shape(), &[2, 3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_expand_as_equals_expand() {
let x = leaf(&[1.0, 2.0, 3.0], &[1, 3], false);
let other = leaf(&[0.0; 12], &[4, 3], false);
let via_expand_as = expand_as(&x, &other).unwrap();
let via_expand = expand(&x, &[4, 3]).unwrap();
assert_eq!(via_expand_as.shape(), &[4, 3]);
assert_eq!(
via_expand_as.data_vec().unwrap(),
via_expand.data_vec().unwrap()
);
assert_eq!(
via_expand_as.data_vec().unwrap(),
&[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
);
}
#[test]
fn test_expand_as_backward_sums_broadcast_axes() {
let x = leaf(&[1.0, 2.0, 3.0], &[1, 3], true);
let other = leaf(&[0.0; 12], &[4, 3], false);
let y = expand_as(&x, &other).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.shape(), &[1, 3]);
for &v in g.data().unwrap() {
assert!((v - 4.0).abs() < 1e-6, "expected 4.0, got {v}");
}
}
#[test]
fn test_flip_forward_1d() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let y = flip(&x, &[0]).unwrap();
assert_eq!(y.shape(), &[4]);
assert_eq!(y.data_vec().unwrap(), &[4.0, 3.0, 2.0, 1.0]);
}
#[test]
fn test_flip_forward_2d_both_dims() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let y = flip(&x, &[0, 1]).unwrap();
assert_eq!(y.shape(), &[2, 3]);
assert_eq!(y.data_vec().unwrap(), &[6.0, 5.0, 4.0, 3.0, 2.0, 1.0]);
}
#[test]
fn test_flip_forward_2d_single_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let y = flip(&x, &[1]).unwrap();
assert_eq!(y.data_vec().unwrap(), &[3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
}
#[test]
fn test_flip_rejects_duplicate_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
assert!(flip(&x, &[0, 0]).is_err());
}
#[test]
fn test_flip_backward_is_self_inverse() {
let x = leaf(&[10.0, 20.0, 30.0], &[3], true);
let y = flip(&x, &[0]).unwrap();
assert_eq!(y.data_vec().unwrap(), &[30.0, 20.0, 10.0]);
#[derive(Debug)]
struct WSum {
input: Tensor<f32>,
w: Vec<f32>,
}
impl GradFn<f32> for WSum {
fn backward(&self, _g: &Tensor<f32>) -> FerrotorchResult<Vec<Option<Tensor<f32>>>> {
Ok(vec![Some(
Tensor::from_storage(
TensorStorage::cpu(self.w.clone()),
self.input.shape().to_vec(),
false,
)
.unwrap(),
)])
}
fn inputs(&self) -> Vec<&Tensor<f32>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"WSum"
}
}
let w = vec![1.0_f32, 2.0, 3.0];
let total: f32 = y.data().unwrap().iter().zip(&w).map(|(a, b)| a * b).sum();
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(WSum {
input: y.clone(),
w,
}),
)
.unwrap();
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.data().unwrap(), &[3.0, 2.0, 1.0]);
}
#[test]
fn test_fliplr_equals_flip_dim1() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
assert_eq!(
fliplr(&x).unwrap().data_vec().unwrap(),
flip(&x, &[1]).unwrap().data_vec().unwrap()
);
assert!(fliplr(&leaf(&[1.0, 2.0], &[2], false)).is_err());
}
#[test]
fn test_flipud_equals_flip_dim0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
assert_eq!(
flipud(&x).unwrap().data_vec().unwrap(),
flip(&x, &[0]).unwrap().data_vec().unwrap()
);
}
#[test]
fn test_rot90_k1() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let y = crate::methods::contiguous_t(&rot90(&x, 1, &[0, 1]).unwrap()).unwrap();
assert_eq!(y.shape(), &[2, 2]);
assert_eq!(y.data_vec().unwrap(), &[2.0, 4.0, 1.0, 3.0]);
}
#[test]
fn test_rot90_k2_is_flip_both() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let y = rot90(&x, 2, &[0, 1]).unwrap();
assert_eq!(y.data_vec().unwrap(), &[4.0, 3.0, 2.0, 1.0]);
}
#[test]
fn test_rot90_k0_and_k4_identity() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
assert_eq!(
rot90(&x, 0, &[0, 1]).unwrap().data_vec().unwrap(),
&[1.0, 2.0, 3.0, 4.0]
);
assert_eq!(
rot90(&x, 4, &[0, 1]).unwrap().data_vec().unwrap(),
&[1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn test_rot90_negative_k() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let y = crate::methods::contiguous_t(&rot90(&x, -1, &[0, 1]).unwrap()).unwrap();
assert_eq!(y.data_vec().unwrap(), &[3.0, 1.0, 4.0, 2.0]);
}
#[test]
fn test_rot90_backward_reaches_leaf() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], true);
let y = crate::methods::contiguous_t(&rot90(&x, 1, &[0, 1]).unwrap()).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.shape(), &[2, 2]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_movedim_single() {
let x = leaf(
&(0..24).map(|v| v as f32).collect::<Vec<_>>(),
&[2, 3, 4],
false,
);
let y = movedim(&x, &[0], &[2]).unwrap();
assert_eq!(y.shape(), &[3, 4, 2]);
let viap = crate::methods::permute_t(&x, &[1, 2, 0]).unwrap();
assert_eq!(
crate::methods::contiguous_t(&y)
.unwrap()
.data_vec()
.unwrap(),
crate::methods::contiguous_t(&viap)
.unwrap()
.data_vec()
.unwrap()
);
}
#[test]
fn test_movedim_multi() {
let x = leaf(&vec![0.0; 2 * 3 * 4 * 5 * 6], &[2, 3, 4, 5, 6], false);
let y = movedim(&x, &[0, 1], &[2, 4]).unwrap();
let viap = crate::methods::permute_t(&x, &[2, 3, 0, 4, 1]).unwrap();
assert_eq!(y.shape(), viap.shape());
assert_eq!(y.shape(), &[4, 5, 2, 6, 3]);
}
#[test]
fn test_moveaxis_equals_movedim() {
let x = leaf(
&(0..24).map(|v| v as f32).collect::<Vec<_>>(),
&[2, 3, 4],
false,
);
assert_eq!(
moveaxis(&x, &[2], &[0]).unwrap().shape(),
movedim(&x, &[2], &[0]).unwrap().shape()
);
}
#[test]
fn test_movedim_backward_reaches_leaf() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let y = crate::methods::contiguous_t(&movedim(&x, &[0], &[1]).unwrap()).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.shape(), &[2, 3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_movedim_rejects_repeated_dim() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
assert!(movedim(&x, &[0, 0], &[0, 1]).is_err());
assert!(movedim(&x, &[0, 1], &[1, 1]).is_err());
assert!(movedim(&x, &[0], &[0, 1]).is_err());
}
#[test]
fn test_broadcast_to_equals_expand() {
let x = leaf(&[1.0, 2.0, 3.0], &[1, 3], false);
let y = broadcast_to(&x, &[2, 3]).unwrap();
let e = expand(&x, &[2, 3]).unwrap();
assert_eq!(y.shape(), &[2, 3]);
assert_eq!(y.data_vec().unwrap(), e.data_vec().unwrap());
assert_eq!(y.data_vec().unwrap(), &[1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
}
#[test]
fn test_broadcast_tensors_common_shape() {
let a = leaf(&[1.0, 2.0, 3.0], &[3, 1], false);
let b = leaf(&[10.0, 20.0, 30.0, 40.0], &[1, 4], false);
let out = broadcast_tensors(&[a, b]).unwrap();
assert_eq!(out.len(), 2);
assert_eq!(out[0].shape(), &[3, 4]);
assert_eq!(out[1].shape(), &[3, 4]);
assert_eq!(
out[0].data_vec().unwrap(),
&[1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]
);
}
#[test]
fn test_repeat_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let y = repeat(&x, &[2]).unwrap();
assert_eq!(y.shape(), &[6]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
}
#[test]
fn test_repeat_2d_with_new_leading_dim() {
let x = leaf(&[1.0, 2.0], &[2], false);
let y = repeat(&x, &[2, 2]).unwrap();
assert_eq!(y.shape(), &[2, 4]);
assert_eq!(
y.data_vec().unwrap(),
&[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]
);
}
#[test]
fn test_repeat_rejects_too_few_dims() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
assert!(repeat(&x, &[2]).is_err());
}
#[test]
fn test_tile_pads_reps() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let y = tile(&x, &[2]).unwrap();
assert_eq!(y.shape(), &[2, 4]);
assert_eq!(
y.data_vec().unwrap(),
&[1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]
);
}
#[test]
fn test_repeat_backward_accumulates() {
let x = leaf(&[1.0, 2.0], &[2], true);
let y = repeat(&x, &[3]).unwrap();
assert_eq!(y.shape(), &[6]);
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.shape(), &[2]);
for &v in g.data().unwrap() {
assert!((v - 3.0).abs() < 1e-6, "expected 3.0, got {v}");
}
}
#[test]
fn test_repeat_interleave_1d() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let y = repeat_interleave(&x, 2, 0).unwrap();
assert_eq!(y.shape(), &[6]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
}
#[test]
fn test_repeat_interleave_2d_dim1() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let y = repeat_interleave(&x, 2, 1).unwrap();
assert_eq!(y.shape(), &[2, 4]);
assert_eq!(
y.data_vec().unwrap(),
&[1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
);
}
#[test]
fn test_repeat_interleave_differs_from_repeat() {
let x = leaf(&[1.0, 2.0], &[2], false);
assert_eq!(
repeat_interleave(&x, 2, 0).unwrap().data_vec().unwrap(),
&[1.0, 1.0, 2.0, 2.0]
);
assert_eq!(
repeat(&x, &[2]).unwrap().data_vec().unwrap(),
&[1.0, 2.0, 1.0, 2.0]
);
}
#[test]
fn test_repeat_interleave_backward_sums_segments() {
let x = leaf(&[1.0, 2.0], &[2], true);
let y = repeat_interleave(&x, 3, 0).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.shape(), &[2]);
for &v in g.data().unwrap() {
assert!((v - 3.0).abs() < 1e-6, "expected 3.0, got {v}");
}
}
#[test]
fn test_unbind_dim0() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let parts = unbind(&x, 0).unwrap();
assert_eq!(parts.len(), 2);
assert_eq!(parts[0].shape(), &[3]);
assert_eq!(
crate::methods::contiguous_t(&parts[0])
.unwrap()
.data_vec()
.unwrap(),
&[1.0, 2.0, 3.0]
);
assert_eq!(
crate::methods::contiguous_t(&parts[1])
.unwrap()
.data_vec()
.unwrap(),
&[4.0, 5.0, 6.0]
);
}
#[test]
fn test_unbind_dim1() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let parts = unbind(&x, 1).unwrap();
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape(), &[2]);
assert_eq!(
crate::methods::contiguous_t(&parts[1])
.unwrap()
.data_vec()
.unwrap(),
&[2.0, 5.0]
);
}
#[test]
fn test_unbind_backward_scatters() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], true);
let parts = unbind(&x, 0).unwrap();
let loss = sum_to_scalar(&crate::methods::contiguous_t(&parts[1]).unwrap());
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.data().unwrap(), &[0.0, 0.0, 1.0, 1.0]);
}
#[test]
fn test_tensor_split_indices() {
let x = leaf(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[6], false);
let parts = tensor_split(&x, &[2, 4], 0).unwrap();
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].data_vec().unwrap(), &[0.0, 1.0]);
assert_eq!(parts[1].data_vec().unwrap(), &[2.0, 3.0]);
assert_eq!(parts[2].data_vec().unwrap(), &[4.0, 5.0]);
}
#[test]
fn test_tensor_split_empty_section() {
let x = leaf(&[0.0, 1.0, 2.0, 3.0], &[4], false);
let parts = tensor_split(&x, &[2, 2], 0).unwrap();
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape(), &[2]);
assert_eq!(parts[1].shape(), &[0]);
assert_eq!(parts[2].shape(), &[2]);
}
#[test]
fn test_tensor_split_backward() {
let x = leaf(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[6], true);
let parts = tensor_split(&x, &[2, 4], 0).unwrap();
let loss = sum_to_scalar(&crate::methods::contiguous_t(&parts[1]).unwrap());
backward(&loss).unwrap();
let g = x.grad().unwrap().expect("x should have gradient");
assert_eq!(g.data().unwrap(), &[0.0, 0.0, 1.0, 1.0, 0.0, 0.0]);
}
#[test]
fn test_vstack_1d_inputs() {
let a = leaf(&[1.0, 2.0, 3.0], &[3], false);
let b = leaf(&[4.0, 5.0, 6.0], &[3], false);
let y = vstack(&[a, b]).unwrap();
assert_eq!(y.shape(), &[2, 3]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_hstack_1d_inputs() {
let a = leaf(&[1.0, 2.0], &[2], false);
let b = leaf(&[3.0, 4.0, 5.0], &[3], false);
let y = hstack(&[a, b]).unwrap();
assert_eq!(y.shape(), &[5]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_hstack_2d_inputs() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let b = leaf(&[5.0, 6.0], &[2, 1], false);
let y = hstack(&[a, b]).unwrap();
assert_eq!(y.shape(), &[2, 3]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 2.0, 5.0, 3.0, 4.0, 6.0]);
}
#[test]
fn test_dstack_1d_inputs() {
let a = leaf(&[1.0, 2.0, 3.0], &[3], false);
let b = leaf(&[4.0, 5.0, 6.0], &[3], false);
let y = dstack(&[a, b]).unwrap();
assert_eq!(y.shape(), &[1, 3, 2]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_column_stack_1d_inputs() {
let a = leaf(&[1.0, 2.0, 3.0], &[3], false);
let b = leaf(&[4.0, 5.0, 6.0], &[3], false);
let y = column_stack(&[a, b]).unwrap();
assert_eq!(y.shape(), &[3, 2]);
assert_eq!(y.data_vec().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_vstack_backward() {
let a = leaf(&[1.0, 2.0, 3.0], &[3], true);
let b = leaf(&[4.0, 5.0, 6.0], &[3], true);
let y = vstack(&[a.clone(), b.clone()]).unwrap();
let loss = sum_to_scalar(&y);
backward(&loss).unwrap();
for t in [&a, &b] {
let g = t.grad().unwrap().expect("should have gradient");
assert_eq!(g.shape(), &[3]);
for &v in g.data().unwrap() {
assert!((v - 1.0).abs() < 1e-6);
}
}
}
}