use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::{DType, Float};
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::gpu_dispatch::GpuBufferHandle;
use crate::shape::normalize_axis;
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[inline]
fn factor(shape: &[usize], dim: usize) -> (usize, usize, usize) {
let outer: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let inner: usize = shape[dim + 1..].iter().product();
(outer, dim_size, inner)
}
fn upload_index_i64(index: &[usize], ordinal: usize) -> FerrotorchResult<GpuBufferHandle> {
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let widened: Vec<i64> = index.iter().map(|&v| v as i64).collect();
let bytes: &[u8] =
unsafe { std::slice::from_raw_parts(widened.as_ptr().cast::<u8>(), widened.len() * 8) };
backend.cpu_to_gpu(bytes, DType::I64, ordinal)
}
#[inline]
fn needs_grad<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> bool {
is_grad_enabled() && (a.requires_grad() || b.requires_grad())
}
#[inline]
fn flat_index(coords: &[usize], shape: &[usize]) -> usize {
let mut idx = 0;
let mut stride = 1;
for d in (0..shape.len()).rev() {
idx += coords[d] * stride;
stride *= shape[d];
}
idx
}
#[inline]
fn increment_coords(coords: &mut [usize], shape: &[usize]) -> bool {
for d in (0..shape.len()).rev() {
coords[d] += 1;
if coords[d] < shape[d] {
return true;
}
coords[d] = 0;
}
false
}
fn validate_gather_shapes(
input_shape: &[usize],
dim: usize,
index_shape: &[usize],
index_data: &[usize],
axis_size: usize,
) -> FerrotorchResult<()> {
if input_shape.len() != index_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"gather/scatter: input ndim ({}) must equal index ndim ({})",
input_shape.len(),
index_shape.len()
),
});
}
for &v in index_data {
if v >= axis_size {
return Err(FerrotorchError::IndexOutOfBounds {
index: v,
axis: dim,
size: axis_size,
});
}
}
Ok(())
}
pub fn gather<T: Float>(
input: &Tensor<T>,
dim: isize,
index: &[usize],
index_shape: &[usize],
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
match T::dtype() {
DType::F32 | DType::F64 => {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "gather: 0-D CUDA input not supported".into(),
});
}
let dim = normalize_axis(dim, ndim)?;
let input = input.contiguous()?;
let input_shape = input.shape().to_vec();
let (outer, in_dim, inner) = factor(&input_shape, dim);
let out_dim = if index_shape.is_empty() {
1
} else {
index_shape[dim]
};
let input_handle = input.gpu_handle()?;
let ordinal = input_handle.device_ordinal();
let idx_handle = upload_index_i64(index, ordinal)?;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = if T::dtype() == DType::F32 {
backend.gather_dim_f32(
input_handle,
&idx_handle,
outer,
in_dim,
out_dim,
inner,
)?
} else {
backend.gather_dim_f64(
input_handle,
&idx_handle,
outer,
in_dim,
out_dim,
inner,
)?
};
let output_shape = index_shape.to_vec();
let storage = TensorStorage::gpu(h);
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(crate::grad_fns::indexing::GatherBackward {
input: input.clone(),
dim,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
});
return Tensor::from_operation(storage, output_shape, grad_fn);
}
return Tensor::from_storage(storage, output_shape, false);
}
_ => return Err(FerrotorchError::NotImplementedOnCuda { op: "gather" }),
}
}
let ndim = input.ndim();
let effective_input_shape: Vec<usize> = if ndim == 0 {
vec![1]
} else {
input.shape().to_vec()
};
let effective_ndim = effective_input_shape.len();
let effective_index_shape: Vec<usize> = if ndim == 0 && index_shape.is_empty() {
vec![1]
} else {
index_shape.to_vec()
};
let dim = normalize_axis(dim, effective_ndim)?;
validate_gather_shapes(
&effective_input_shape,
dim,
&effective_index_shape,
index,
effective_input_shape[dim],
)?;
let input_shape: &[usize] = &effective_input_shape;
let input_data = input.data_vec()?;
let out_numel: usize = index_shape.iter().product();
let mut output = vec![<T as num_traits::Zero>::zero(); out_numel];
let mut coords = vec![0usize; effective_ndim];
for out_flat in 0..out_numel {
let idx_val = index[out_flat];
let mut src_coords = coords.clone();
src_coords[dim] = idx_val;
let src_flat = flat_index(&src_coords, input_shape);
output[out_flat] = input_data[src_flat];
if out_flat + 1 < out_numel {
increment_coords(&mut coords, &effective_index_shape);
}
}
let output_shape = index_shape.to_vec();
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(crate::grad_fns::indexing::GatherBackward {
input: input.clone(),
dim,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
});
Tensor::from_operation(TensorStorage::cpu(output), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output), output_shape, false)
}
}
pub fn scatter<T: Float>(
input: &Tensor<T>,
dim: isize,
index: &[usize],
index_shape: &[usize],
src: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "scatter: input must have at least 1 dimension".into(),
});
}
let dim = normalize_axis(dim, ndim)?;
let input_shape = input.shape();
if input.is_cuda() || src.is_cuda() {
match T::dtype() {
DType::F32 | DType::F64 if input.is_cuda() && src.is_cuda() => {
if input.device() != src.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: input.device(),
got: src.device(),
});
}
let input = input.contiguous()?;
let src = src.contiguous()?;
let input_shape: &[usize] = input.shape();
let (outer, out_dim, inner) = factor(input_shape, dim);
let idx_dim = index_shape[dim];
let input_handle = input.gpu_handle()?;
let ordinal = input_handle.device_ordinal();
let idx_handle = upload_index_i64(index, ordinal)?;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let src_handle = src.gpu_handle()?;
let h = if T::dtype() == DType::F32 {
backend.scatter_dim_f32(
input_handle,
&idx_handle,
src_handle,
outer,
out_dim,
idx_dim,
inner,
)?
} else {
backend.scatter_dim_f64(
input_handle,
&idx_handle,
src_handle,
outer,
out_dim,
idx_dim,
inner,
)?
};
let output_shape = input_shape.to_vec();
let storage = TensorStorage::gpu(h);
if needs_grad(&input, &src) {
let grad_fn = Arc::new(crate::grad_fns::indexing::ScatterBackward {
input: input.clone(),
src: src.clone(),
dim,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
});
return Tensor::from_operation(storage, output_shape, grad_fn);
}
return Tensor::from_storage(storage, output_shape, false);
}
_ => return Err(FerrotorchError::NotImplementedOnCuda { op: "scatter" }),
}
}
validate_gather_shapes(input_shape, dim, index_shape, index, input_shape[dim])?;
let index_numel: usize = index_shape.iter().product();
if src.numel() < index_numel {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"scatter: src has {} elements but index has {}",
src.numel(),
index_numel
),
});
}
let mut output = input.data_vec()?;
let src_data = src.data_vec()?;
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = index[i];
let mut dst_coords = coords.clone();
dst_coords[dim] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
output[dst_flat] = src_data[i];
if i + 1 < index_numel {
increment_coords(&mut coords, index_shape);
}
}
let output_shape = input_shape.to_vec();
if needs_grad(input, src) {
let grad_fn = Arc::new(crate::grad_fns::indexing::ScatterBackward {
input: input.clone(),
src: src.clone(),
dim,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
});
Tensor::from_operation(TensorStorage::cpu(output), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output), output_shape, false)
}
}
pub fn scatter_value<T: Float>(
input: &Tensor<T>,
dim: isize,
index: &[usize],
index_shape: &[usize],
value: T,
) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "scatter_value: input must have at least 1 dimension".into(),
});
}
let dim = normalize_axis(dim, ndim)?;
let input_shape = input.shape();
if input.is_cuda() {
match T::dtype() {
DType::F32 | DType::F64 => {
let input = input.contiguous()?;
let input_shape: &[usize] = input.shape();
let (outer, out_dim, inner) = factor(input_shape, dim);
let idx_dim = index_shape[dim];
let input_handle = input.gpu_handle()?;
let ordinal = input_handle.device_ordinal();
let idx_handle = upload_index_i64(index, ordinal)?;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = if T::dtype() == DType::F32 {
backend.scatter_value_dim_f32(
input_handle,
&idx_handle,
value.to_f32().ok_or(FerrotorchError::InvalidArgument {
message: "scatter_value: value not representable as f32".into(),
})?,
outer,
out_dim,
idx_dim,
inner,
)?
} else {
backend.scatter_value_dim_f64(
input_handle,
&idx_handle,
value.to_f64().ok_or(FerrotorchError::InvalidArgument {
message: "scatter_value: value not representable as f64".into(),
})?,
outer,
out_dim,
idx_dim,
inner,
)?
};
let output_shape = input_shape.to_vec();
let storage = TensorStorage::gpu(h);
if is_grad_enabled() && input.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let zeros_src = Tensor::from_storage(
TensorStorage::cpu(vec![zero; index_shape.iter().product()]),
index_shape.to_vec(),
false,
)?;
let grad_fn = Arc::new(crate::grad_fns::indexing::ScatterBackward {
input: input.clone(),
src: zeros_src,
dim,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
});
return Tensor::from_operation(storage, output_shape, grad_fn);
}
return Tensor::from_storage(storage, output_shape, false);
}
_ => {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "scatter_value",
});
}
}
}
validate_gather_shapes(input_shape, dim, index_shape, index, input_shape[dim])?;
let index_numel: usize = index_shape.iter().product();
let mut output = input.data_vec()?;
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = index[i];
let mut dst_coords = coords.clone();
dst_coords[dim] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
output[dst_flat] = value;
if i + 1 < index_numel {
increment_coords(&mut coords, index_shape);
}
}
let output_shape = input_shape.to_vec();
if is_grad_enabled() && input.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let zeros_src = Tensor::from_storage(
TensorStorage::cpu(vec![zero; index_numel]),
index_shape.to_vec(),
false,
)?;
let grad_fn = Arc::new(crate::grad_fns::indexing::ScatterBackward {
input: input.clone(),
src: zeros_src,
dim,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
});
Tensor::from_operation(TensorStorage::cpu(output), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output), output_shape, false)
}
}
pub fn scatter_add<T: Float>(
input: &Tensor<T>,
dim: isize,
index: &[usize],
index_shape: &[usize],
src: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let ndim = input.ndim();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "scatter_add: input must have at least 1 dimension".into(),
});
}
let dim = normalize_axis(dim, ndim)?;
let input_shape = input.shape();
if input.is_cuda() || src.is_cuda() {
match T::dtype() {
DType::F32 | DType::F64 if input.is_cuda() && src.is_cuda() => {
if input.device() != src.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: input.device(),
got: src.device(),
});
}
let input = input.contiguous()?;
let src = src.contiguous()?;
let input_shape: &[usize] = input.shape();
let (outer, out_dim, inner) = factor(input_shape, dim);
let idx_dim = index_shape[dim];
let input_handle = input.gpu_handle()?;
let ordinal = input_handle.device_ordinal();
let idx_handle = upload_index_i64(index, ordinal)?;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let src_handle = src.gpu_handle()?;
let h = if T::dtype() == DType::F32 {
backend.scatter_add_dim_f32(
input_handle,
&idx_handle,
src_handle,
outer,
out_dim,
idx_dim,
inner,
)?
} else {
backend.scatter_add_dim_f64(
input_handle,
&idx_handle,
src_handle,
outer,
out_dim,
idx_dim,
inner,
)?
};
let output_shape = input_shape.to_vec();
let storage = TensorStorage::gpu(h);
if needs_grad(&input, &src) {
let grad_fn = Arc::new(crate::grad_fns::indexing::ScatterAddBackward {
input: input.clone(),
src: src.clone(),
dim,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
});
return Tensor::from_operation(storage, output_shape, grad_fn);
}
return Tensor::from_storage(storage, output_shape, false);
}
_ => return Err(FerrotorchError::NotImplementedOnCuda { op: "scatter_add" }),
}
}
validate_gather_shapes(input_shape, dim, index_shape, index, input_shape[dim])?;
let index_numel: usize = index_shape.iter().product();
if src.numel() < index_numel {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"scatter_add: src has {} elements but index has {}",
src.numel(),
index_numel
),
});
}
let mut output = input.data_vec()?;
let src_data = src.data_vec()?;
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = index[i];
let mut dst_coords = coords.clone();
dst_coords[dim] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
output[dst_flat] += src_data[i];
if i + 1 < index_numel {
increment_coords(&mut coords, index_shape);
}
}
let output_shape = input_shape.to_vec();
if needs_grad(input, src) {
let grad_fn = Arc::new(crate::grad_fns::indexing::ScatterAddBackward {
input: input.clone(),
src: src.clone(),
dim,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
});
Tensor::from_operation(TensorStorage::cpu(output), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output), output_shape, false)
}
}
pub fn where_cond<T: Float>(
condition: &[bool],
x: &Tensor<T>,
y: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if x.shape() != y.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"where_cond: x shape {:?} != y shape {:?}",
x.shape(),
y.shape()
),
});
}
if x.is_cuda() || y.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "where_cond" });
}
let numel = x.numel();
if condition.len() != numel {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"where_cond: condition length {} != tensor numel {}",
condition.len(),
numel
),
});
}
let x_data = x.data_vec()?;
let y_data = y.data_vec()?;
let output: Vec<T> = condition
.iter()
.zip(x_data.iter().zip(y_data.iter()))
.map(|(&c, (&xv, &yv))| if c { xv } else { yv })
.collect();
let output_shape = x.shape().to_vec();
if needs_grad(x, y) {
let grad_fn = Arc::new(crate::grad_fns::indexing::WhereCondBackward {
x: x.clone(),
y: y.clone(),
condition: crate::bool_tensor::BoolTensor::from_slice(condition, &output_shape)?,
});
Tensor::from_operation(TensorStorage::cpu(output), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output), output_shape, false)
}
}
pub fn where_cond_bt<T: Float>(
cond: &crate::bool_tensor::BoolTensor,
x: &Tensor<T>,
y: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if x.shape() != y.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"where_cond_bt: x shape {:?} != y shape {:?}",
x.shape(),
y.shape()
),
});
}
if cond.shape() != x.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"where_cond_bt: cond shape {:?} != x shape {:?}",
cond.shape(),
x.shape()
),
});
}
if x.is_cuda() && y.is_cuda() && cond.is_cuda() {
if x.device() != y.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: x.device(),
got: y.device(),
});
}
if x.device() != cond.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: x.device(),
got: cond.device(),
});
}
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let x = x.contiguous()?;
let y = y.contiguous()?;
let h = backend.where_cond(cond.gpu_handle()?, x.gpu_handle()?, y.gpu_handle()?)?;
let storage = TensorStorage::gpu(h);
let output_shape = x.shape().to_vec();
if needs_grad(&x, &y) {
let grad_fn = Arc::new(crate::grad_fns::indexing::WhereCondBackward {
x: x.clone(),
y: y.clone(),
condition: cond.clone(),
});
return Tensor::from_operation(storage, output_shape, grad_fn);
}
return Tensor::from_storage(storage, output_shape, false);
}
where_cond(cond.data()?, x, y)
}
pub fn masked_select<T: Float>(
input: &Tensor<T>,
mask: &crate::bool_tensor::BoolTensor,
) -> FerrotorchResult<Tensor<T>> {
if mask.numel() != input.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"masked_select: mask numel {} != input numel {}",
mask.numel(),
input.numel()
),
});
}
if input.is_cuda() && mask.is_cuda() {
if input.device() != mask.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: input.device(),
got: mask.device(),
});
}
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let (handle, len) = backend.masked_select(input.gpu_handle()?, mask.gpu_handle()?)?;
let storage = TensorStorage::gpu(handle);
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(crate::grad_fns::indexing::MaskedSelectBackward {
input: input.clone(),
mask: mask.clone(),
});
return Tensor::from_operation(storage, vec![len], grad_fn);
}
return Tensor::from_storage(storage, vec![len], false);
}
let data = input.data_vec()?;
let mask_h = mask.data()?;
let out: Vec<T> = data
.iter()
.zip(mask_h.iter())
.filter_map(|(&v, &m)| if m { Some(v) } else { None })
.collect();
let len = out.len();
let storage = TensorStorage::cpu(out);
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(crate::grad_fns::indexing::MaskedSelectBackward {
input: input.clone(),
mask: mask.clone(),
});
Tensor::from_operation(storage, vec![len], grad_fn)
} else {
Tensor::from_storage(storage, vec![len], false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::graph::backward;
use crate::autograd::no_grad;
use crate::storage::TensorStorage;
use crate::tensor::GradFn;
fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
#[test]
fn test_gather_1d() {
let input = leaf(&[10.0, 20.0, 30.0, 40.0], &[4], false);
let index = &[3, 0, 2];
let result = gather(&input, 0, index, &[3]).unwrap();
assert_eq!(result.shape(), &[3]);
assert_eq!(result.data().unwrap(), &[40.0, 10.0, 30.0]);
}
#[test]
fn test_gather_2d_dim0() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], false);
let index = &[2, 0, 1, 1];
let result = gather(&input, 0, index, &[2, 2]).unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.data().unwrap(), &[5.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_gather_2d_dim1() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let index = &[0, 2, 1, 0];
let result = gather(&input, 1, index, &[2, 2]).unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.data().unwrap(), &[1.0, 3.0, 5.0, 4.0]);
}
#[test]
fn test_gather_out_of_bounds() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let result = gather(&input, 0, &[5], &[1]);
assert!(result.is_err());
}
#[test]
fn test_gather_ndim_mismatch() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let result = gather(&input, 0, &[0, 1], &[2]);
assert!(result.is_err());
}
#[test]
fn test_scatter_1d() {
let input = leaf(&[0.0; 5], &[5], false);
let src = leaf(&[10.0, 20.0, 30.0], &[3], false);
let result = scatter(&input, 0, &[1, 3, 0], &[3], &src).unwrap();
assert_eq!(result.data().unwrap(), &[30.0, 10.0, 0.0, 20.0, 0.0]);
}
#[test]
fn test_scatter_2d_dim0() {
let input = leaf(&[0.0; 6], &[3, 2], false);
let src = leaf(&[1.0, 2.0], &[1, 2], false);
let result = scatter(&input, 0, &[2, 0], &[1, 2], &src).unwrap();
assert_eq!(result.shape(), &[3, 2]);
assert_eq!(result.data().unwrap(), &[0.0, 2.0, 0.0, 0.0, 1.0, 0.0]);
}
#[test]
fn test_scatter_2d_dim1() {
let input = leaf(&[0.0; 6], &[2, 3], false);
let src = leaf(&[5.0, 6.0], &[2, 1], false);
let result = scatter(&input, 1, &[2, 0], &[2, 1], &src).unwrap();
assert_eq!(result.data().unwrap(), &[0.0, 0.0, 5.0, 6.0, 0.0, 0.0]);
}
#[test]
fn test_scatter_add_1d() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let src = leaf(&[10.0, 20.0, 30.0], &[3], false);
let result = scatter_add(&input, 0, &[0, 2, 0], &[3], &src).unwrap();
assert_eq!(result.data().unwrap(), &[41.0, 2.0, 23.0]);
}
#[test]
fn test_scatter_add_2d_dim0() {
let input = leaf(&[0.0; 4], &[2, 2], false);
let src = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], false);
let result = scatter_add(&input, 0, &[0, 1, 1, 0, 0, 0], &[3, 2], &src).unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.data().unwrap(), &[6.0, 10.0, 3.0, 2.0]);
}
#[test]
fn test_where_cond_basic() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let y = leaf(&[10.0, 20.0, 30.0, 40.0], &[4], false);
let cond = [true, false, true, false];
let result = where_cond(&cond, &x, &y).unwrap();
assert_eq!(result.data().unwrap(), &[1.0, 20.0, 3.0, 40.0]);
}
#[test]
fn test_where_cond_all_true() {
let x = leaf(&[1.0, 2.0], &[2], false);
let y = leaf(&[10.0, 20.0], &[2], false);
let result = where_cond(&[true, true], &x, &y).unwrap();
assert_eq!(result.data().unwrap(), &[1.0, 2.0]);
}
#[test]
fn test_where_cond_all_false() {
let x = leaf(&[1.0, 2.0], &[2], false);
let y = leaf(&[10.0, 20.0], &[2], false);
let result = where_cond(&[false, false], &x, &y).unwrap();
assert_eq!(result.data().unwrap(), &[10.0, 20.0]);
}
#[test]
fn test_where_cond_shape_mismatch() {
let x = leaf(&[1.0, 2.0], &[2], false);
let y = leaf(&[1.0, 2.0, 3.0], &[3], false);
let result = where_cond(&[true, false], &x, &y);
assert!(result.is_err());
}
#[test]
fn test_where_cond_cond_length_mismatch() {
let x = leaf(&[1.0, 2.0], &[2], false);
let y = leaf(&[10.0, 20.0], &[2], false);
let result = where_cond(&[true], &x, &y);
assert!(result.is_err());
}
#[test]
fn test_gather_backward_1d() {
let input = leaf(&[10.0, 20.0, 30.0], &[3], true);
let result = gather(&input, 0, &[2, 0, 0], &[3]).unwrap();
assert!(result.requires_grad());
let grad_output = leaf(&[1.0, 1.0, 1.0], &[3], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let gi = grads[0].as_ref().unwrap();
let gd = gi.data().unwrap();
assert!((gd[0] - 2.0).abs() < 1e-6, "grad[0]={}, expected 2", gd[0]);
assert!((gd[1] - 0.0).abs() < 1e-6, "grad[1]={}, expected 0", gd[1]);
assert!((gd[2] - 1.0).abs() < 1e-6, "grad[2]={}, expected 1", gd[2]);
}
#[test]
fn test_gather_backward_2d() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
let result = gather(&input, 1, &[0, 2, 1, 0], &[2, 2]).unwrap();
let grad_output = leaf(&[1.0, 1.0, 1.0, 1.0], &[2, 2], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let gi = grads[0].as_ref().unwrap();
let gd = gi.data().unwrap();
assert_eq!(gi.shape(), &[2, 3]);
assert!((gd[0] - 1.0).abs() < 1e-6);
assert!((gd[1] - 0.0).abs() < 1e-6);
assert!((gd[2] - 1.0).abs() < 1e-6);
assert!((gd[3] - 1.0).abs() < 1e-6);
assert!((gd[4] - 1.0).abs() < 1e-6);
assert!((gd[5] - 0.0).abs() < 1e-6);
}
#[test]
fn test_scatter_backward_input() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5], true);
let src = leaf(&[10.0, 20.0], &[2], false);
let result = scatter(&input, 0, &[1, 3], &[2], &src).unwrap();
let grad_output = leaf(&[1.0; 5], &[5], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let gi = grads[0].as_ref().unwrap();
let gd = gi.data().unwrap();
assert_eq!(gd, &[1.0, 0.0, 1.0, 0.0, 1.0]);
}
#[test]
fn test_scatter_backward_src() {
let input = leaf(&[0.0; 3], &[3], false);
let src = leaf(&[1.0, 2.0], &[2], true);
let result = scatter(&input, 0, &[2, 0], &[2], &src).unwrap();
let grad_output = leaf(&[10.0, 20.0, 30.0], &[3], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
assert!(grads[0].is_none());
let gs = grads[1].as_ref().unwrap();
let gd = gs.data().unwrap();
assert_eq!(gd, &[30.0, 10.0]);
}
#[test]
fn test_scatter_add_backward_input() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], true);
let src = leaf(&[10.0, 20.0], &[2], false);
let result = scatter_add(&input, 0, &[0, 2], &[2], &src).unwrap();
let grad_output = leaf(&[5.0, 6.0, 7.0], &[3], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let gi = grads[0].as_ref().unwrap();
assert_eq!(gi.data().unwrap(), &[5.0, 6.0, 7.0]);
}
#[test]
fn test_scatter_add_backward_src() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let src = leaf(&[10.0, 20.0], &[2], true);
let result = scatter_add(&input, 0, &[2, 0], &[2], &src).unwrap();
let grad_output = leaf(&[5.0, 6.0, 7.0], &[3], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
assert!(grads[0].is_none());
let gs = grads[1].as_ref().unwrap();
assert_eq!(gs.data().unwrap(), &[7.0, 5.0]);
}
#[test]
fn test_where_cond_backward_x() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], true);
let y = leaf(&[10.0, 20.0, 30.0, 40.0], &[4], false);
let cond = [true, false, true, false];
let result = where_cond(&cond, &x, &y).unwrap();
let grad_output = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let gx = grads[0].as_ref().unwrap();
assert_eq!(gx.data().unwrap(), &[1.0, 0.0, 3.0, 0.0]);
assert!(grads[1].is_none());
}
#[test]
fn test_where_cond_backward_y() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let y = leaf(&[10.0, 20.0, 30.0, 40.0], &[4], true);
let cond = [true, false, true, false];
let result = where_cond(&cond, &x, &y).unwrap();
let grad_output = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
assert!(grads[0].is_none());
let gy = grads[1].as_ref().unwrap();
assert_eq!(gy.data().unwrap(), &[0.0, 2.0, 0.0, 4.0]);
}
#[test]
fn test_where_cond_backward_both() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], true);
let y = leaf(&[10.0, 20.0, 30.0], &[3], true);
let cond = [false, true, false];
let result = where_cond(&cond, &x, &y).unwrap();
let grad_output = leaf(&[5.0, 6.0, 7.0], &[3], false);
let grad_fn = result.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let gx = grads[0].as_ref().unwrap();
assert_eq!(gx.data().unwrap(), &[0.0, 6.0, 0.0]);
let gy = grads[1].as_ref().unwrap();
assert_eq!(gy.data().unwrap(), &[5.0, 0.0, 7.0]);
}
#[test]
fn test_gather_no_grad() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], true);
let result = no_grad(|| gather(&input, 0, &[2, 0], &[2])).unwrap();
assert!(!result.requires_grad());
assert!(result.grad_fn().is_none());
}
#[test]
fn test_where_cond_no_grad() {
let x = leaf(&[1.0, 2.0], &[2], true);
let y = leaf(&[3.0, 4.0], &[2], true);
let result = no_grad(|| where_cond(&[true, false], &x, &y)).unwrap();
assert!(!result.requires_grad());
}
#[test]
fn test_gather_end_to_end_backward() {
let input = leaf(&[10.0, 20.0, 30.0, 40.0], &[4], true);
let gathered = gather(&input, 0, &[1, 3], &[2]).unwrap();
let data = gathered.data().unwrap();
let total: f32 = data.iter().sum();
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_val = grad_output.data()?[0];
let grad = vec![go_val; self.input.numel()];
let t = Tensor::from_storage(
TensorStorage::cpu(grad),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(t)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(SumBackward {
input: gathered.clone(),
}),
)
.unwrap();
backward(&loss).unwrap();
let grad = input.grad().unwrap().unwrap();
let gd = grad.data().unwrap();
assert!((gd[0] - 0.0).abs() < 1e-6);
assert!((gd[1] - 1.0).abs() < 1e-6);
assert!((gd[2] - 0.0).abs() < 1e-6);
assert!((gd[3] - 1.0).abs() < 1e-6);
}
}