use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::device::Device;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::gpu_dispatch::gpu_backend;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
use crate::bool_tensor::BoolTensor;
use crate::int_tensor::{IntElement, IntTensor};
fn upload_f32_to_gpu(
data: &[f32],
ordinal: usize,
) -> FerrotorchResult<crate::gpu_dispatch::GpuBufferHandle> {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let bytes: &[u8] =
unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<u8>(), data.len() * 4) };
backend.cpu_to_gpu(bytes, crate::dtype::DType::F32, ordinal)
}
fn scatter_write_mask(
index: &[usize],
index_shape: &[usize],
input_shape: &[usize],
dim: usize,
) -> Vec<f32> {
let input_numel: usize = input_shape.iter().product();
let index_numel: usize = index_shape.iter().product();
let mut mask = vec![0.0f32; input_numel];
let ndim = input_shape.len();
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = index[i];
let mut dst_coords = coords.clone();
dst_coords[dim] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
mask[dst_flat] = 1.0;
if i + 1 < index_numel {
increment_coords(&mut coords, index_shape);
}
}
mask
}
fn gather_dst_flat_indices(
index: &[usize],
index_shape: &[usize],
input_shape: &[usize],
dim: usize,
) -> Vec<f32> {
let ndim = input_shape.len();
let index_numel: usize = index_shape.iter().product();
let mut result = Vec::with_capacity(index_numel);
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = index[i];
let mut dst_coords = coords.clone();
dst_coords[dim] = idx_val;
result.push(flat_index(&dst_coords, input_shape) as f32);
if i + 1 < index_numel {
increment_coords(&mut coords, index_shape);
}
}
result
}
fn scatter_src_flat_indices(
index: &[usize],
index_shape: &[usize],
input_shape: &[usize],
dim: usize,
) -> Vec<f32> {
gather_dst_flat_indices(index, index_shape, input_shape, dim)
}
#[inline]
fn flat_index(coords: &[usize], shape: &[usize]) -> usize {
let mut idx = 0;
let mut stride = 1;
for d in (0..shape.len()).rev() {
idx += coords[d] * stride;
stride *= shape[d];
}
idx
}
#[inline]
fn increment_coords(coords: &mut [usize], shape: &[usize]) -> bool {
for d in (0..shape.len()).rev() {
coords[d] += 1;
if coords[d] < shape[d] {
return true;
}
coords[d] = 0;
}
false
}
#[derive(Debug)]
pub struct IndexSelectBackward<T: Float> {
pub input: Tensor<T>,
pub indices: Vec<usize>,
}
impl<T: Float> GradFn<T> for IndexSelectBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
let input_len = self.input.numel();
if grad_output.is_cuda() {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let ordinal = match grad_output.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let indices_f32: Vec<f32> = self.indices.iter().map(|&i| i as f32).collect();
let idx_handle = upload_f32_to_gpu(&indices_f32, ordinal)?;
let result_handle =
backend.scatter_add_1d_f32(grad_output.gpu_handle()?, &idx_handle, input_len)?;
let grad_tensor = Tensor::from_storage(
TensorStorage::gpu(result_handle),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_tensor)])
} else {
let go_data = grad_output.data()?;
let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_len];
for (i, &idx) in self.indices.iter().enumerate() {
grad_input[idx] += go_data[i];
}
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_input),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_tensor)])
}
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IndexSelectBackward"
}
}
pub fn index_select_1d<T: Float>(
input: &Tensor<T>,
indices: &[usize],
) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index_select_1d requires a 1-D input, got shape {:?}",
input.shape()
),
});
}
let input_len = input.shape()[0];
for &idx in indices {
if idx >= input_len {
return Err(FerrotorchError::IndexOutOfBounds {
index: idx,
axis: 0,
size: input_len,
});
}
}
let output_shape = vec![indices.len()];
if input.is_cuda() {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let ordinal = match input.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let indices_f32: Vec<f32> = indices.iter().map(|&i| i as f32).collect();
let idx_handle = upload_f32_to_gpu(&indices_f32, ordinal)?;
let result_handle = backend.index_select_1d_f32(input.gpu_handle()?, &idx_handle)?;
let storage = TensorStorage::gpu(result_handle);
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(IndexSelectBackward {
input: input.clone(),
indices: indices.to_vec(),
});
Tensor::from_operation(storage, output_shape, grad_fn)
} else {
Tensor::from_storage(storage, output_shape, false)
}
} else {
let input_data = input.data()?;
let output_data: Vec<T> = indices.iter().map(|&idx| input_data[idx]).collect();
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(IndexSelectBackward {
input: input.clone(),
indices: indices.to_vec(),
});
Tensor::from_operation(TensorStorage::cpu(output_data), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output_data), output_shape, false)
}
}
}
#[derive(Debug)]
pub struct MaskedFillBackward<T: Float> {
pub input: Tensor<T>,
pub mask: BoolTensor,
}
impl<T: Float> GradFn<T> for MaskedFillBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
if grad_output.is_cuda() && self.mask.is_cuda() {
if grad_output.device() != self.mask.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: grad_output.device(),
got: self.mask.device(),
});
}
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let result_handle =
backend.masked_fill_dt(grad_output.gpu_handle()?, self.mask.gpu_handle()?, 0.0)?;
let grad_tensor = Tensor::from_storage(
TensorStorage::gpu(result_handle),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_tensor)])
} else {
let go_data = grad_output.data()?;
let mask_h = self.mask.data()?;
let mut grad_input: Vec<T> = go_data.to_vec();
for (i, &m) in mask_h.iter().enumerate() {
if m {
grad_input[i] = <T as num_traits::Zero>::zero();
}
}
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_input),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_tensor)])
}
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"MaskedFillBackward"
}
}
pub fn masked_fill<T: Float>(
input: &Tensor<T>,
mask: &[bool],
value: T,
) -> FerrotorchResult<Tensor<T>> {
let input_len = input.numel();
if mask.len() != input_len {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"masked_fill: mask length {} does not match input length {}",
mask.len(),
input_len
),
});
}
let output_shape = input.shape().to_vec();
if input.is_cuda() {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let ordinal = match input.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let mask_f32: Vec<f32> = mask.iter().map(|&m| if m { 1.0 } else { 0.0 }).collect();
let mask_handle = upload_f32_to_gpu(&mask_f32, ordinal)?;
let value_f32: f32 = num_traits::ToPrimitive::to_f32(&value).unwrap_or(0.0);
let input = input.contiguous()?;
let result_handle =
backend.masked_fill_f32(input.gpu_handle()?, &mask_handle, value_f32)?;
let storage = TensorStorage::gpu(result_handle);
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(MaskedFillBackward {
input: input.clone(),
mask: BoolTensor::from_slice(mask, &output_shape)?,
});
Tensor::from_operation(storage, output_shape, grad_fn)
} else {
Tensor::from_storage(storage, output_shape, false)
}
} else {
let input_data = input.data()?;
let output_data: Vec<T> = input_data
.iter()
.zip(mask.iter())
.map(|(&x, &m)| if m { value } else { x })
.collect();
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(MaskedFillBackward {
input: input.clone(),
mask: BoolTensor::from_slice(mask, &output_shape)?,
});
Tensor::from_operation(TensorStorage::cpu(output_data), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output_data), output_shape, false)
}
}
}
#[derive(Debug)]
pub struct GatherBackward<T: Float> {
pub input: Tensor<T>,
pub dim: usize,
pub index: Vec<usize>,
pub index_shape: Vec<usize>,
}
impl<T: Float> GradFn<T> for GatherBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
let input_shape = self.input.shape();
let input_numel: usize = input_shape.iter().product();
if grad_output.is_cuda() {
let ordinal = match grad_output.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let dst_indices =
gather_dst_flat_indices(&self.index, &self.index_shape, input_shape, self.dim);
let idx_handle = upload_f32_to_gpu(&dst_indices, ordinal)?;
let result_handle =
backend.scatter_add_1d_f32(grad_output.gpu_handle()?, &idx_handle, input_numel)?;
let grad_tensor = Tensor::from_storage(
TensorStorage::gpu(result_handle),
input_shape.to_vec(),
false,
)?;
return Ok(vec![Some(grad_tensor)]);
}
let go_data = grad_output.data_vec()?;
let ndim = input_shape.len();
let index_numel: usize = self.index_shape.iter().product();
let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
let mut coords = vec![0usize; ndim];
for (i, &go_val) in go_data.iter().enumerate().take(index_numel) {
let idx_val = self.index[i];
let mut dst_coords = coords.clone();
dst_coords[self.dim] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
grad_input[dst_flat] += go_val;
if i + 1 < index_numel {
increment_coords(&mut coords, &self.index_shape);
}
}
let grad_tensor =
Tensor::from_storage(TensorStorage::cpu(grad_input), input_shape.to_vec(), false)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"GatherBackward"
}
}
#[derive(Debug)]
pub struct ScatterBackward<T: Float> {
pub input: Tensor<T>,
pub src: Tensor<T>,
pub dim: usize,
pub index: Vec<usize>,
pub index_shape: Vec<usize>,
}
impl<T: Float> GradFn<T> for ScatterBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None, None]);
}
let input_shape = self.input.shape();
let index_numel: usize = self.index_shape.iter().product();
if grad_output.is_cuda() {
let ordinal = match grad_output.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let grad_input = if self.input.requires_grad() {
let mask_f32 =
scatter_write_mask(&self.index, &self.index_shape, input_shape, self.dim);
let mask_handle = upload_f32_to_gpu(&mask_f32, ordinal)?;
let result_h = backend.masked_zero_f32(grad_output.gpu_handle()?, &mask_handle)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h),
input_shape.to_vec(),
false,
)?)
} else {
None
};
let grad_src = if self.src.requires_grad() {
let src_indices =
scatter_src_flat_indices(&self.index, &self.index_shape, input_shape, self.dim);
let idx_handle = upload_f32_to_gpu(&src_indices, ordinal)?;
let result_h =
backend.index_select_1d_f32(grad_output.gpu_handle()?, &idx_handle)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h),
self.index_shape.clone(),
false,
)?)
} else {
None
};
return Ok(vec![grad_input, grad_src]);
}
let ndim = input_shape.len();
let go_data = grad_output.data_vec()?;
let grad_input = if self.input.requires_grad() {
let mut gi = go_data.clone();
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = self.index[i];
let mut dst_coords = coords.clone();
dst_coords[self.dim] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
gi[dst_flat] = <T as num_traits::Zero>::zero();
if i + 1 < index_numel {
increment_coords(&mut coords, &self.index_shape);
}
}
let t = Tensor::from_storage(TensorStorage::cpu(gi), input_shape.to_vec(), false)?;
Some(t)
} else {
None
};
let grad_src = if self.src.requires_grad() {
let mut gs = vec![<T as num_traits::Zero>::zero(); index_numel];
let mut coords = vec![0usize; ndim];
for (i, gs_elem) in gs.iter_mut().enumerate() {
let idx_val = self.index[i];
let mut src_coords = coords.clone();
src_coords[self.dim] = idx_val;
let src_flat = flat_index(&src_coords, input_shape);
*gs_elem = go_data[src_flat];
if i + 1 < index_numel {
increment_coords(&mut coords, &self.index_shape);
}
}
let t = Tensor::from_storage(TensorStorage::cpu(gs), self.index_shape.clone(), false)?;
Some(t)
} else {
None
};
Ok(vec![grad_input, grad_src])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input, &self.src]
}
fn name(&self) -> &'static str {
"ScatterBackward"
}
}
#[derive(Debug)]
pub struct ScatterAddBackward<T: Float> {
pub input: Tensor<T>,
pub src: Tensor<T>,
pub dim: usize,
pub index: Vec<usize>,
pub index_shape: Vec<usize>,
}
impl<T: Float> GradFn<T> for ScatterAddBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None, None]);
}
let input_shape = self.input.shape();
let index_numel: usize = self.index_shape.iter().product();
if grad_output.is_cuda() {
let ordinal = match grad_output.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let grad_input = if self.input.requires_grad() {
let cloned_h = backend.clone_buffer(grad_output.gpu_handle()?)?;
Some(Tensor::from_storage(
TensorStorage::gpu(cloned_h),
input_shape.to_vec(),
false,
)?)
} else {
None
};
let grad_src = if self.src.requires_grad() {
let src_indices =
scatter_src_flat_indices(&self.index, &self.index_shape, input_shape, self.dim);
let idx_handle = upload_f32_to_gpu(&src_indices, ordinal)?;
let result_h =
backend.index_select_1d_f32(grad_output.gpu_handle()?, &idx_handle)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h),
self.index_shape.clone(),
false,
)?)
} else {
None
};
return Ok(vec![grad_input, grad_src]);
}
if grad_output.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "scatter_add backward",
});
}
let ndim = input_shape.len();
let go_data = grad_output.data_vec()?;
let grad_input = if self.input.requires_grad() {
let t = Tensor::from_storage(
TensorStorage::cpu(go_data.clone()),
input_shape.to_vec(),
false,
)?;
Some(t)
} else {
None
};
let grad_src = if self.src.requires_grad() {
let mut gs = vec![<T as num_traits::Zero>::zero(); index_numel];
let mut coords = vec![0usize; ndim];
for (i, gs_elem) in gs.iter_mut().enumerate() {
let idx_val = self.index[i];
let mut src_coords = coords.clone();
src_coords[self.dim] = idx_val;
let src_flat = flat_index(&src_coords, input_shape);
*gs_elem = go_data[src_flat];
if i + 1 < index_numel {
increment_coords(&mut coords, &self.index_shape);
}
}
let t = Tensor::from_storage(TensorStorage::cpu(gs), self.index_shape.clone(), false)?;
Some(t)
} else {
None
};
Ok(vec![grad_input, grad_src])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input, &self.src]
}
fn name(&self) -> &'static str {
"ScatterAddBackward"
}
}
#[derive(Debug)]
pub struct WhereCondBackward<T: Float> {
pub x: Tensor<T>,
pub y: Tensor<T>,
pub condition: BoolTensor,
}
impl<T: Float> GradFn<T> for WhereCondBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None, None]);
}
if grad_output.is_cuda() && self.condition.is_cuda() {
if grad_output.device() != self.condition.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: grad_output.device(),
got: self.condition.device(),
});
}
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let cond_h = self.condition.gpu_handle()?;
let grad_h = grad_output.gpu_handle()?;
let grad_x = if self.x.requires_grad() {
let not_cond = backend.bool_not(cond_h)?;
let result_h = backend.masked_fill_dt(grad_h, ¬_cond, 0.0)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h),
self.x.shape().to_vec(),
false,
)?)
} else {
None
};
let grad_y = if self.y.requires_grad() {
let result_h = backend.masked_fill_dt(grad_h, cond_h, 0.0)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h),
self.y.shape().to_vec(),
false,
)?)
} else {
None
};
return Ok(vec![grad_x, grad_y]);
}
let go_data = grad_output.data_vec()?;
let cond = self.condition.data()?;
let zero = <T as num_traits::Zero>::zero();
let grad_x = if self.x.requires_grad() {
let gx: Vec<T> = cond
.iter()
.zip(go_data.iter())
.map(|(&c, &g)| if c { g } else { zero })
.collect();
let t = Tensor::from_storage(TensorStorage::cpu(gx), self.x.shape().to_vec(), false)?;
Some(t)
} else {
None
};
let grad_y = if self.y.requires_grad() {
let gy: Vec<T> = cond
.iter()
.zip(go_data.iter())
.map(|(&c, &g)| if c { zero } else { g })
.collect();
let t = Tensor::from_storage(TensorStorage::cpu(gy), self.y.shape().to_vec(), false)?;
Some(t)
} else {
None
};
Ok(vec![grad_x, grad_y])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.x, &self.y]
}
fn name(&self) -> &'static str {
"WhereCondBackward"
}
}
#[derive(Debug)]
pub struct MaskedSelectBackward<T: Float> {
pub input: Tensor<T>,
pub mask: BoolTensor,
}
impl<T: Float> GradFn<T> for MaskedSelectBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
let input_shape = self.input.shape().to_vec();
let input_numel: usize = input_shape.iter().product();
if grad_output.is_cuda() && self.mask.is_cuda() {
if grad_output.device() != self.mask.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: grad_output.device(),
got: self.mask.device(),
});
}
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let result_handle = backend.masked_scatter(
grad_output.gpu_handle()?,
self.mask.gpu_handle()?,
input_numel,
)?;
let grad_tensor =
Tensor::from_storage(TensorStorage::gpu(result_handle), input_shape, false)?;
return Ok(vec![Some(grad_tensor)]);
}
let go_data = grad_output.data()?;
let mask_h = self.mask.data()?;
let zero = <T as num_traits::Zero>::zero();
let mut grad_input: Vec<T> = vec![zero; input_numel];
let mut j = 0usize;
for (i, &m) in mask_h.iter().enumerate() {
if m {
grad_input[i] = go_data[j];
j += 1;
}
}
let grad_tensor = Tensor::from_storage(TensorStorage::cpu(grad_input), input_shape, false)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"MaskedSelectBackward"
}
}
pub fn masked_fill_bt<T: Float>(
input: &Tensor<T>,
mask: &BoolTensor,
value: T,
) -> FerrotorchResult<Tensor<T>> {
if mask.numel() != input.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"masked_fill_bt: mask numel={} != input numel={}",
mask.numel(),
input.numel()
),
});
}
if input.is_cuda() && mask.is_cuda() {
if input.device() != mask.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: input.device(),
got: mask.device(),
});
}
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let value_f64 = num_traits::ToPrimitive::to_f64(&value).unwrap_or(0.0);
let input = input.contiguous()?;
let result_handle =
backend.masked_fill_dt(input.gpu_handle()?, mask.gpu_handle()?, value_f64)?;
let storage = TensorStorage::gpu(result_handle);
let output_shape = input.shape().to_vec();
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(MaskedFillBackward {
input: input.clone(),
mask: mask.clone(),
});
return Tensor::from_operation(storage, output_shape, grad_fn);
}
return Tensor::from_storage(storage, output_shape, false);
}
masked_fill(input, mask.data()?, value)
}
pub fn index_select_1d_it<T: Float, I: IntElement>(
input: &Tensor<T>,
indices: &IntTensor<I>,
) -> FerrotorchResult<Tensor<T>> {
if indices.ndim() != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"index_select_1d_it: indices must be 1-D, got shape {:?}",
indices.shape()
),
});
}
let mut idx_usize: Vec<usize> = Vec::with_capacity(indices.numel());
for v in indices.data()? {
let i = v.to_i64();
if i < 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("index_select_1d_it: negative index {i} not allowed"),
});
}
idx_usize.push(i as usize);
}
index_select_1d(input, &idx_usize)
}
#[derive(Debug)]
pub struct IndexSelectDimBackward<T: Float> {
pub input: Tensor<T>,
pub dim: usize,
pub indices: Vec<usize>,
}
impl<T: Float> GradFn<T> for IndexSelectDimBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let input_shape = self.input.shape();
let input_numel: usize = input_shape.iter().product();
let dim = self.dim;
let outer: usize = input_shape[..dim].iter().product();
let inner: usize = input_shape[dim + 1..].iter().product();
let in_dim_size = input_shape[dim];
let out_dim_size = self.indices.len();
if grad_output.is_cuda() {
use std::any::TypeId;
let is_t_f32 = TypeId::of::<T>() == TypeId::of::<f32>();
let is_t_f64 = TypeId::of::<T>() == TypeId::of::<f64>();
if is_t_f32 || is_t_f64 {
let ordinal = match grad_output.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let go_numel = outer * out_dim_size * inner;
let mut dst_indices: Vec<f32> = Vec::with_capacity(go_numel);
for o in 0..outer {
for i in 0..out_dim_size {
let dst_i = self.indices[i];
let base = o * in_dim_size * inner + dst_i * inner;
for k in 0..inner {
dst_indices.push((base + k) as f32);
}
}
}
let idx_handle = upload_f32_to_gpu(&dst_indices, ordinal)?;
let result_handle = if is_t_f32 {
backend.scatter_add_1d_f32(
grad_output.gpu_handle()?,
&idx_handle,
input_numel,
)?
} else {
backend.scatter_add_1d_f64(
grad_output.gpu_handle()?,
&idx_handle,
input_numel,
)?
};
let grad_tensor = Tensor::from_storage(
TensorStorage::gpu(result_handle),
input_shape.to_vec(),
false,
)?;
return Ok(vec![Some(grad_tensor)]);
}
return Err(FerrotorchError::NotImplementedOnCuda {
op: "IndexSelectDimBackward",
});
}
let go_data = grad_output.data_vec()?;
let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
for o in 0..outer {
for i in 0..out_dim_size {
let dst_i = self.indices[i];
let go_base = o * out_dim_size * inner + i * inner;
let gi_base = o * in_dim_size * inner + dst_i * inner;
for k in 0..inner {
grad_input[gi_base + k] += go_data[go_base + k];
}
}
}
let grad_tensor =
Tensor::from_storage(TensorStorage::cpu(grad_input), input_shape.to_vec(), false)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IndexSelectDimBackward"
}
}
pub fn index_select_dim<T: Float, I: IntElement>(
input: &Tensor<T>,
dim: usize,
indices: &IntTensor<I>,
) -> FerrotorchResult<Tensor<T>> {
let input_shape = input.shape();
let ndim = input_shape.len();
if ndim == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "index_select_dim: input must have at least 1 dimension".into(),
});
}
if dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!("index_select_dim: dim {dim} out of range for shape {input_shape:?}"),
});
}
if indices.ndim() != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"index_select_dim: indices must be 1-D, got shape {:?}",
indices.shape()
),
});
}
let in_dim_size = input_shape[dim];
let mut idx_usize: Vec<usize> = Vec::with_capacity(indices.numel());
for v in indices.data()? {
let i = v.to_i64();
if i < 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("index_select_dim: negative index {i} not allowed"),
});
}
let iu = i as usize;
if iu >= in_dim_size {
return Err(FerrotorchError::IndexOutOfBounds {
index: iu,
axis: dim,
size: in_dim_size,
});
}
idx_usize.push(iu);
}
let mut output_shape = input_shape.to_vec();
output_shape[dim] = idx_usize.len();
let outer: usize = input_shape[..dim].iter().product();
let inner: usize = input_shape[dim + 1..].iter().product();
let out_dim_size = idx_usize.len();
if input.is_cuda() {
use std::any::TypeId;
let is_t_f32 = TypeId::of::<T>() == TypeId::of::<f32>();
let is_t_f64 = TypeId::of::<T>() == TypeId::of::<f64>();
if is_t_f32 || is_t_f64 {
let ordinal = match input.device() {
Device::Cuda(o) => o,
_ => unreachable!("input.is_cuda() but device() not Cuda"),
};
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let indices_f32: Vec<f32> = idx_usize.iter().map(|&u| u as f32).collect();
let idx_handle = upload_f32_to_gpu(&indices_f32, ordinal)?;
let result_handle = if is_t_f32 {
backend.index_select_dim_f32(
input.gpu_handle()?,
&idx_handle,
outer,
in_dim_size,
out_dim_size,
inner,
)?
} else {
backend.index_select_dim_f64(
input.gpu_handle()?,
&idx_handle,
outer,
in_dim_size,
out_dim_size,
inner,
)?
};
let storage = TensorStorage::gpu(result_handle);
return if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(IndexSelectDimBackward {
input: input.clone(),
dim,
indices: idx_usize,
});
Tensor::from_operation(storage, output_shape, grad_fn)
} else {
Tensor::from_storage(storage, output_shape, false)
};
}
return Err(FerrotorchError::NotImplementedOnCuda {
op: "index_select_dim",
});
}
let out_numel: usize = output_shape.iter().product();
let in_data = input.data_vec()?;
let mut out = vec![<T as num_traits::Zero>::zero(); out_numel];
for o in 0..outer {
for i in 0..out_dim_size {
let src_i = idx_usize[i];
let in_base = o * in_dim_size * inner + src_i * inner;
let out_base = o * out_dim_size * inner + i * inner;
out[out_base..out_base + inner].copy_from_slice(&in_data[in_base..in_base + inner]);
}
}
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(IndexSelectDimBackward {
input: input.clone(),
dim,
indices: idx_usize,
});
Tensor::from_operation(TensorStorage::cpu(out), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out), output_shape, false)
}
}
#[derive(Debug)]
pub struct IndexFillBackward<T: Float> {
pub input: Tensor<T>,
pub dim: usize,
pub index: Vec<usize>,
}
impl<T: Float> GradFn<T> for IndexFillBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let input_shape = self.input.shape();
let dim = self.dim;
if input_shape.is_empty() {
let go_data = grad_output.data_vec()?;
let mut grad_input = go_data.clone();
if !self.index.is_empty() {
let zero = <T as num_traits::Zero>::zero();
grad_input[0] = zero;
}
let grad_tensor = Tensor::from_storage(TensorStorage::cpu(grad_input), vec![], false)?;
return Ok(vec![Some(grad_tensor)]);
}
let outer: usize = input_shape[..dim].iter().product();
let inner: usize = input_shape[dim + 1..].iter().product();
let dim_size = input_shape[dim];
let go_data = grad_output.data_vec()?;
let mut grad_input = go_data.clone();
let zero = <T as num_traits::Zero>::zero();
for o in 0..outer {
for &idx in &self.index {
let base = o * dim_size * inner + idx * inner;
for k in 0..inner {
grad_input[base + k] = zero;
}
}
}
let grad_tensor =
Tensor::from_storage(TensorStorage::cpu(grad_input), input_shape.to_vec(), false)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IndexFillBackward"
}
}
pub fn index_fill<T: Float>(
input: &Tensor<T>,
dim: i64,
index: &IntTensor<i64>,
value: f64,
) -> FerrotorchResult<Tensor<T>> {
let input_shape = input.shape();
let ndim = input_shape.len();
if ndim == 0 {
let dim_for_0d = match dim {
0 | -1 => 0i64,
_ => {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index_fill: dim {dim} out of range for 0-d input \
(valid range: [-1, 0])"
),
});
}
};
let scalar_val = input.data_vec()?[0];
let mut result_val = scalar_val;
let mut any_filled = false;
for v in index.data()? {
let i_raw = v.to_i64();
let i = if i_raw < 0 { i_raw + 1 } else { i_raw };
if !(0..1).contains(&i) {
return Err(FerrotorchError::IndexOutOfBounds {
index: if i_raw < 0 {
i_raw.unsigned_abs() as usize
} else {
i_raw as usize
},
axis: dim_for_0d as usize,
size: 1,
});
}
result_val = <T as num_traits::NumCast>::from(value).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: format!("index_fill: value {value} not representable in target dtype"),
}
})?;
any_filled = true;
}
let out_storage = TensorStorage::cpu(vec![result_val]);
if input.requires_grad() && is_grad_enabled() {
let saved_index: Vec<usize> = if any_filled { vec![0] } else { vec![] };
let grad_fn = Arc::new(IndexFillBackward {
input: input.clone(),
dim: 0,
index: saved_index,
});
return Tensor::from_operation(out_storage, vec![], grad_fn);
}
return Tensor::from_storage(out_storage, vec![], false);
}
if index.ndim() > 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"index_fill: index must be 1-D or scalar, got shape {:?}",
index.shape()
),
});
}
let ndim_i64 = ndim as i64;
let dim_norm = if dim < 0 { dim + ndim_i64 } else { dim };
if !(0..ndim_i64).contains(&dim_norm) {
return Err(FerrotorchError::InvalidArgument {
message: format!("index_fill: dim {dim} out of range for input ndim {ndim}"),
});
}
let dim_usize = dim_norm as usize;
let dim_size = input_shape[dim_usize];
let dim_size_i64 = dim_size as i64;
let mut idx_usize: Vec<usize> = Vec::with_capacity(index.numel());
for v in index.data()? {
let i_raw = v.to_i64();
if i_raw < -dim_size_i64 || i_raw >= dim_size_i64 {
return Err(FerrotorchError::IndexOutOfBounds {
index: if i_raw < 0 {
i_raw.unsigned_abs() as usize
} else {
i_raw as usize
},
axis: dim_usize,
size: dim_size,
});
}
let i = if i_raw < 0 {
i_raw + dim_size_i64
} else {
i_raw
};
idx_usize.push(i as usize);
}
let outer: usize = input_shape[..dim_usize].iter().product();
let inner: usize = input_shape[dim_usize + 1..].iter().product();
let in_data = input.data_vec()?;
let mut out = in_data.clone();
let value_t = <T as num_traits::NumCast>::from(value).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: format!("index_fill: value {value} not representable in target dtype"),
}
})?;
for o in 0..outer {
for &idx in &idx_usize {
let base = o * dim_size * inner + idx * inner;
for k in 0..inner {
out[base + k] = value_t;
}
}
}
let output_shape = input_shape.to_vec();
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(IndexFillBackward {
input: input.clone(),
dim: dim_usize,
index: idx_usize,
});
Tensor::from_operation(TensorStorage::cpu(out), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out), output_shape, false)
}
}
#[inline]
fn broadcast_in_flat(flat: usize, out_shape: &[usize], in_shape: &[usize]) -> usize {
let out_ndim = out_shape.len();
let in_ndim = in_shape.len();
let mut rem = flat;
let mut in_idx = 0usize;
let mut in_strides = vec![0usize; in_ndim];
if in_ndim > 0 {
in_strides[in_ndim - 1] = 1;
for d in (0..in_ndim - 1).rev() {
in_strides[d] = in_strides[d + 1] * in_shape[d + 1];
}
}
for d_out in (0..out_ndim).rev() {
let out_dim = out_shape[d_out];
let coord = rem % out_dim;
rem /= out_dim;
let d_in_off = out_ndim - 1 - d_out;
if d_in_off < in_ndim {
let d_in = in_ndim - 1 - d_in_off;
if in_shape[d_in] == 1 {
} else {
in_idx += coord * in_strides[d_in];
}
}
}
in_idx
}
fn broadcast_bool_tensor(mask: &BoolTensor, out_shape: &[usize]) -> FerrotorchResult<BoolTensor> {
if mask.shape() == out_shape {
return Ok(mask.clone());
}
if mask.is_cuda() {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = backend.broadcast_bool(mask.gpu_handle()?, mask.shape(), out_shape)?;
return Ok(BoolTensor::from_gpu_handle(handle, out_shape.to_vec()));
}
let in_data = mask.data()?;
let in_shape: Vec<usize> = mask.shape().to_vec();
let out_numel: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let out_ndim = out_shape.len();
let in_ndim = in_shape.len();
if in_ndim > out_ndim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"broadcast_bool_tensor: input ndim {in_ndim} > target ndim {out_ndim} \
(shapes {in_shape:?} -> {out_shape:?})"
),
});
}
for d_in_off in 0..in_ndim {
let in_dim = in_shape[in_ndim - 1 - d_in_off];
let out_dim = out_shape[out_ndim - 1 - d_in_off];
if in_dim != 1 && in_dim != out_dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"broadcast_bool_tensor: cannot broadcast {in_shape:?} -> {out_shape:?} \
(axis {} mismatch: {in_dim} vs {out_dim})",
out_ndim - 1 - d_in_off
),
});
}
}
let mut out = Vec::with_capacity(out_numel);
for flat in 0..out_numel {
let src = broadcast_in_flat(flat, out_shape, &in_shape);
out.push(in_data[src]);
}
BoolTensor::from_vec(out, out_shape.to_vec())
}
pub fn masked_fill_bcast<T: Float>(
input: &Tensor<T>,
mask: &BoolTensor,
value: T,
) -> FerrotorchResult<Tensor<T>> {
if input.shape() == mask.shape() {
return masked_fill_bt(input, mask, value);
}
let common = crate::shape::broadcast_shapes(input.shape(), mask.shape())?;
let input_b = crate::grad_fns::shape::expand(input, &common)?;
let mask_b = broadcast_bool_tensor(mask, &common)?;
masked_fill_bt(&input_b, &mask_b, value)
}
pub fn masked_select_bcast<T: Float>(
input: &Tensor<T>,
mask: &BoolTensor,
) -> FerrotorchResult<Tensor<T>> {
if input.shape() == mask.shape() {
return crate::ops::indexing::masked_select(input, mask);
}
let common = crate::shape::broadcast_shapes(input.shape(), mask.shape())?;
let input_b = crate::grad_fns::shape::expand(input, &common)?;
let mask_b = broadcast_bool_tensor(mask, &common)?;
crate::ops::indexing::masked_select(&input_b, &mask_b)
}
pub fn where_cond_bcast<T: Float>(
cond: &BoolTensor,
x: &Tensor<T>,
y: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if cond.shape() == x.shape() && x.shape() == y.shape() {
return crate::ops::indexing::where_cond_bt(cond, x, y);
}
let xy_common = crate::shape::broadcast_shapes(x.shape(), y.shape())?;
let common = crate::shape::broadcast_shapes(cond.shape(), &xy_common)?;
let cond_b = broadcast_bool_tensor(cond, &common)?;
let x_b = crate::grad_fns::shape::expand(x, &common)?;
let y_b = crate::grad_fns::shape::expand(y, &common)?;
crate::ops::indexing::where_cond_bt(&cond_b, &x_b, &y_b)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScatterReduce {
Sum,
Prod,
Amax,
Amin,
}
impl ScatterReduce {
pub fn parse_str(s: &str) -> Option<Self> {
match s {
"sum" => Some(Self::Sum),
"prod" => Some(Self::Prod),
"amax" => Some(Self::Amax),
"amin" => Some(Self::Amin),
_ => None,
}
}
}
#[derive(Debug)]
pub struct ScatterReduceBackward<T: Float> {
pub input: Tensor<T>,
pub src: Tensor<T>,
pub dim: usize,
pub index: Vec<usize>,
pub index_shape: Vec<usize>,
pub reduce: ScatterReduce,
pub include_self: bool,
pub result: Vec<T>,
}
impl<T: Float> GradFn<T> for ScatterReduceBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None, None]);
}
let input_shape = self.input.shape();
let ndim = input_shape.len();
if ndim == 0 {
return self.backward_0d(grad_output);
}
match self.reduce {
ScatterReduce::Sum => self.backward_sum(grad_output),
ScatterReduce::Prod => self.backward_prod(grad_output),
ScatterReduce::Amax | ScatterReduce::Amin => self.backward_amax_amin(grad_output),
}
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input, &self.src]
}
fn name(&self) -> &'static str {
"ScatterReduceBackward"
}
}
impl<T: Float> ScatterReduceBackward<T> {
fn for_each_index<F: FnMut(usize, usize, &[usize], usize)>(&self, mut f: F) {
let input_shape = self.input.shape();
let ndim = input_shape.len();
let index_numel: usize = self.index_shape.iter().product();
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = self.index[i];
let mut dst_coords = coords.clone();
dst_coords[self.dim] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
f(i, idx_val, &coords, dst_flat);
if i + 1 < index_numel {
increment_coords(&mut coords, &self.index_shape);
}
}
}
fn backward_0d(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let mut grad_input_data = go_data.clone();
if !self.include_self && !self.index.is_empty() {
grad_input_data[0] = zero;
}
let grad_input = if self.input.requires_grad() {
Some(Tensor::from_storage(
TensorStorage::cpu(grad_input_data),
vec![],
false,
)?)
} else {
None
};
let grad_src = if self.src.requires_grad() {
Some(Tensor::from_storage(
TensorStorage::cpu(go_data),
self.src.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_src])
}
fn backward_sum(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let input_shape = self.input.shape();
let go_data = grad_output.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let index_numel: usize = self.index_shape.iter().product();
let grad_input = if self.input.requires_grad() {
let mut gi = go_data.clone();
if !self.include_self {
self.for_each_index(|_, _, _, dst_flat| {
gi[dst_flat] = zero;
});
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
input_shape.to_vec(),
false,
)?)
} else {
None
};
let grad_src = if self.src.requires_grad() {
let mut gs = vec![zero; index_numel];
self.for_each_index(|i, _, _, dst_flat| {
gs[i] = go_data[dst_flat];
});
Some(Tensor::from_storage(
TensorStorage::cpu(gs),
self.index_shape.clone(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_src])
}
fn backward_amax_amin(
&self,
grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let input_shape = self.input.shape();
let go_data = grad_output.data_vec()?;
let in_data = self.input.data_vec()?;
let src_data = self.src.data_vec()?;
let src_shape = self.src.shape();
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let input_numel: usize = input_shape.iter().product();
let index_numel: usize = self.index_shape.iter().product();
let mut self_is_result = vec![zero; input_numel];
for p in 0..input_numel {
if in_data[p] == self.result[p] {
self_is_result[p] = one;
}
}
let read_src_at = |coords: &[usize]| -> T { src_data[flat_index(coords, src_shape)] };
let mut src_is_result = vec![zero; index_numel];
let mut value = vec![zero; index_numel];
self.for_each_index(|i, _, coords, dst_flat| {
let v = self.result[dst_flat];
value[i] = v;
if read_src_at(coords) == v {
src_is_result[i] = one;
}
});
let mut n_to_distribute = self_is_result.clone();
self.for_each_index(|i, _, _, dst_flat| {
n_to_distribute[dst_flat] += src_is_result[i];
});
let mut grad_distributed = vec![zero; input_numel];
for p in 0..input_numel {
if n_to_distribute[p] != zero {
grad_distributed[p] = go_data[p] / n_to_distribute[p];
}
}
let grad_input = if self.input.requires_grad() {
let mut gi = vec![zero; input_numel];
for p in 0..input_numel {
if self_is_result[p] != zero {
gi[p] = grad_distributed[p];
}
}
if !self.include_self {
self.for_each_index(|_, _, _, dst_flat| {
gi[dst_flat] = zero;
});
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
input_shape.to_vec(),
false,
)?)
} else {
None
};
let grad_src = if self.src.requires_grad() {
let mut gs = vec![zero; index_numel];
self.for_each_index(|i, _, _, dst_flat| {
if src_is_result[i] != zero {
gs[i] = grad_distributed[dst_flat];
}
});
Some(Tensor::from_storage(
TensorStorage::cpu(gs),
self.index_shape.clone(),
false,
)?)
} else {
None
};
let _ = value; Ok(vec![grad_input, grad_src])
}
fn backward_prod(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let input_shape = self.input.shape();
let go_data = grad_output.data_vec()?;
let in_data = self.input.data_vec()?;
let src_data = self.src.data_vec()?;
let src_shape = self.src.shape();
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let input_numel: usize = input_shape.iter().product();
let index_numel: usize = self.index_shape.iter().product();
let mut masked_self = in_data.clone();
for v in &mut masked_self {
if *v == zero {
*v = one;
}
}
let read_src_at = |coords: &[usize]| -> T { src_data[flat_index(coords, src_shape)] };
let mut masked_self_result = if self.include_self {
masked_self.clone()
} else {
let mut buf = masked_self.clone();
self.for_each_index(|_, _, _, dst_flat| {
buf[dst_flat] = one;
});
buf
};
self.for_each_index(|_, _, coords, dst_flat| {
masked_self_result[dst_flat] = masked_self_result[dst_flat] * read_src_at(coords);
});
let mut src_zero = vec![zero; index_numel];
self.for_each_index(|i, _, coords, _| {
if read_src_at(coords) == zero {
src_zero[i] = one;
}
});
let mut zero_count_per_dst = vec![zero; input_numel];
self.for_each_index(|i, _, _, dst_flat| {
zero_count_per_dst[dst_flat] += src_zero[i];
});
let mut src_num_zeros = vec![zero; index_numel];
self.for_each_index(|i, _, _, dst_flat| {
src_num_zeros[i] = zero_count_per_dst[dst_flat];
});
let mut src_single_zero = vec![zero; index_numel];
for i in 0..index_numel {
if src_zero[i] != zero && src_num_zeros[i] == one {
src_single_zero[i] = one;
}
}
let mut masked_src_result = if self.include_self {
in_data.clone()
} else {
let mut buf = in_data.clone();
self.for_each_index(|_, _, _, dst_flat| {
buf[dst_flat] = one;
});
buf
};
let mut masked_src_values = vec![zero; index_numel];
self.for_each_index(|i, _, coords, _| {
let s = read_src_at(coords);
let m = if src_single_zero[i] == zero { s } else { one };
masked_src_values[i] = m;
});
self.for_each_index(|i, _, _, dst_flat| {
masked_src_result[dst_flat] = masked_src_result[dst_flat] * masked_src_values[i];
});
let grad_input = if self.input.requires_grad() {
let mut gi = vec![zero; input_numel];
for p in 0..input_numel {
if masked_self[p] != zero {
gi[p] = go_data[p] * masked_self_result[p] / masked_self[p];
}
}
if !self.include_self {
self.for_each_index(|_, _, _, dst_flat| {
gi[dst_flat] = zero;
});
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
input_shape.to_vec(),
false,
)?)
} else {
None
};
let grad_src = if self.src.requires_grad() {
let mut gs = vec![zero; index_numel];
self.for_each_index(|i, _, coords, dst_flat| {
let s_raw = read_src_at(coords);
let denom = if s_raw == zero { one } else { s_raw };
let primary = (go_data[dst_flat] * self.result[dst_flat]) / denom;
let single_zero_branch = go_data[dst_flat] * masked_src_result[dst_flat];
gs[i] = if src_single_zero[i] == zero {
primary
} else {
single_zero_branch
};
});
Some(Tensor::from_storage(
TensorStorage::cpu(gs),
self.index_shape.clone(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_src])
}
}
pub fn scatter_reduce<T: Float>(
input: &Tensor<T>,
dim: i64,
index: &[usize],
index_shape: &[usize],
src: &Tensor<T>,
reduce: ScatterReduce,
include_self: bool,
) -> FerrotorchResult<Tensor<T>> {
let input_shape = input.shape();
let ndim = input_shape.len();
if ndim == 0 {
let in_data = input.data_vec()?;
let src_data = src.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let mut out = in_data[0];
if !include_self && !index.is_empty() {
out = match reduce {
ScatterReduce::Sum => zero,
ScatterReduce::Prod => one,
ScatterReduce::Amax | ScatterReduce::Amin => src_data[0],
};
}
for (i, &_idx) in index.iter().enumerate() {
let s = src_data[i.min(src_data.len() - 1)];
out = apply_reduce(reduce, out, s);
}
let out_storage = TensorStorage::cpu(vec![out]);
if (input.requires_grad() || src.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(ScatterReduceBackward {
input: input.clone(),
src: src.clone(),
dim: 0,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
reduce,
include_self,
result: vec![out],
});
return Tensor::from_operation(out_storage, vec![], grad_fn);
}
return Tensor::from_storage(out_storage, vec![], false);
}
let ndim_i64 = ndim as i64;
let dim_norm = if dim < 0 { dim + ndim_i64 } else { dim };
if !(0..ndim_i64).contains(&dim_norm) {
return Err(FerrotorchError::InvalidArgument {
message: format!("scatter_reduce: dim {dim} out of range for input ndim {ndim}"),
});
}
let dim_usize = dim_norm as usize;
if index_shape.len() != ndim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"scatter_reduce: index ndim {} != input ndim {}",
index_shape.len(),
ndim
),
});
}
let index_numel: usize = index_shape.iter().product();
if src.numel() < index_numel {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"scatter_reduce: src numel {} < index numel {}",
src.numel(),
index_numel
),
});
}
let in_data = input.data_vec()?;
let src_data = src.data_vec()?;
let src_shape = src.shape();
let mut out = in_data.clone();
let read_src_at = |coords: &[usize]| -> T { src_data[flat_index(coords, src_shape)] };
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
if !include_self {
let identity = match reduce {
ScatterReduce::Sum => Some(zero),
ScatterReduce::Prod => Some(one),
ScatterReduce::Amax | ScatterReduce::Amin => None,
};
if let Some(id) = identity {
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = index[i];
let mut dst_coords = coords.clone();
dst_coords[dim_usize] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
out[dst_flat] = id;
if i + 1 < index_numel {
increment_coords(&mut coords, index_shape);
}
}
} else {
let input_numel: usize = input_shape.iter().product();
let mut touched = vec![false; input_numel];
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = index[i];
let mut dst_coords = coords.clone();
dst_coords[dim_usize] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
let s = read_src_at(&coords);
out[dst_flat] = if touched[dst_flat] {
apply_reduce(reduce, out[dst_flat], s)
} else {
touched[dst_flat] = true;
s
};
if i + 1 < index_numel {
increment_coords(&mut coords, index_shape);
}
}
let output_shape = input_shape.to_vec();
if (input.requires_grad() || src.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(ScatterReduceBackward {
input: input.clone(),
src: src.clone(),
dim: dim_usize,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
reduce,
include_self,
result: out.clone(),
});
return Tensor::from_operation(TensorStorage::cpu(out), output_shape, grad_fn);
}
return Tensor::from_storage(TensorStorage::cpu(out), output_shape, false);
}
}
let mut coords = vec![0usize; ndim];
for i in 0..index_numel {
let idx_val = index[i];
let mut dst_coords = coords.clone();
dst_coords[dim_usize] = idx_val;
let dst_flat = flat_index(&dst_coords, input_shape);
out[dst_flat] = apply_reduce(reduce, out[dst_flat], read_src_at(&coords));
if i + 1 < index_numel {
increment_coords(&mut coords, index_shape);
}
}
let output_shape = input_shape.to_vec();
if (input.requires_grad() || src.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(ScatterReduceBackward {
input: input.clone(),
src: src.clone(),
dim: dim_usize,
index: index.to_vec(),
index_shape: index_shape.to_vec(),
reduce,
include_self,
result: out.clone(),
});
Tensor::from_operation(TensorStorage::cpu(out), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out), output_shape, false)
}
}
#[inline]
fn apply_reduce<T: Float>(mode: ScatterReduce, a: T, b: T) -> T {
match mode {
ScatterReduce::Sum => a + b,
ScatterReduce::Prod => a * b,
ScatterReduce::Amax => {
if a.partial_cmp(&b) == Some(std::cmp::Ordering::Less) {
b
} else {
a
}
}
ScatterReduce::Amin => {
if b.partial_cmp(&a) == Some(std::cmp::Ordering::Less) {
b
} else {
a
}
}
}
}
fn strict_index_add_copy_validate<T: Float>(
op_name: &'static str,
input: &Tensor<T>,
dim: i64,
index: &IntTensor<i64>,
source: &Tensor<T>,
accept_0d_source: bool,
) -> FerrotorchResult<(usize, Vec<usize>)> {
let input_shape = input.shape();
let ndim = input_shape.len();
let ndim_i64 = ndim as i64;
if index.ndim() > 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op_name}: index must be 1-D or scalar, got shape {:?}",
index.shape()
),
});
}
let dim_norm = if dim < 0 { dim + ndim_i64 } else { dim };
if !(0..ndim_i64).contains(&dim_norm) {
return Err(FerrotorchError::InvalidArgument {
message: format!("{op_name}: dim {dim} out of range for input ndim {ndim}"),
});
}
let dim_usize = dim_norm as usize;
let in_dim_size = input_shape[dim_usize];
let mut idx_usize: Vec<usize> = Vec::with_capacity(index.numel());
for v in index.data()? {
let i_raw = v.to_i64();
if i_raw < 0 || i_raw >= in_dim_size as i64 {
return Err(FerrotorchError::IndexOutOfBounds {
index: if i_raw < 0 {
i_raw.unsigned_abs() as usize
} else {
i_raw as usize
},
axis: dim_usize,
size: in_dim_size,
});
}
idx_usize.push(i_raw as usize);
}
let source_shape = source.shape();
let source_ndim = source_shape.len();
let n_indices = index.numel();
let expected_src_at_dim = if source_ndim == 0 {
1
} else if dim_usize < source_ndim {
source_shape[dim_usize]
} else {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op_name}: source.dim() ({source_ndim}) does not contain dim {dim_usize}"
),
});
};
if n_indices != expected_src_at_dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op_name}: Number of indices ({n_indices}) should be equal to \
source.size(dim): ({expected_src_at_dim}), for dim: {dim_usize}"
),
});
}
if source_ndim == 0 && ndim > 0 && n_indices > 0 {
if !accept_0d_source {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op_name}: source tensor shape must match self tensor shape, \
excluding the specified dimension. Got self.shape = {input_shape:?} \
source.shape = {source_shape:?}"
),
});
}
if n_indices != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op_name}: When source is scalar, index should have one element \
(got {n_indices})"
),
});
}
return Ok((dim_usize, idx_usize));
}
if source_ndim != 0 && ndim != 0 {
for d in 0..ndim {
if d == dim_usize {
continue;
}
let self_d = input_shape[d];
let src_d = if d < source_ndim {
source_shape[d]
} else {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op_name}: source tensor shape must match self tensor shape, \
excluding the specified dimension. Got self.shape = \
{input_shape:?} source.shape = {source_shape:?}"
),
});
};
if self_d != src_d {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op_name}: source tensor shape must match self tensor shape, \
excluding the specified dimension. Got self.shape = \
{input_shape:?} source.shape = {source_shape:?}"
),
});
}
}
if source_ndim != ndim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op_name}: source.dim() ({source_ndim}) must match self.dim() \
({ndim}) (excluding 0-d source on 0-d self)"
),
});
}
}
Ok((dim_usize, idx_usize))
}
#[derive(Debug)]
pub struct IndexAddBackward<T: Float> {
pub input: Tensor<T>,
pub source: Tensor<T>,
pub dim: usize,
pub index: Vec<usize>,
pub alpha: f64,
}
impl<T: Float> GradFn<T> for IndexAddBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None, None]);
}
let input_shape = self.input.shape();
let ndim = input_shape.len();
let grad_input = if self.input.requires_grad() {
let go = grad_output.data_vec()?;
Some(Tensor::from_storage(
TensorStorage::cpu(go),
input_shape.to_vec(),
false,
)?)
} else {
None
};
let grad_source = if self.source.requires_grad() {
let go = grad_output.data_vec()?;
let source_shape = self.source.shape();
let alpha_t = <T as num_traits::NumCast>::from(self.alpha).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: format!(
"IndexAddBackward: alpha {} not representable in target dtype",
self.alpha
),
}
})?;
let gs = if ndim == 0 || source_shape.is_empty() {
let v = if go.is_empty() {
<T as num_traits::Zero>::zero()
} else {
go[0] * alpha_t
};
vec![v]
} else {
let outer: usize = input_shape[..self.dim].iter().product();
let inner: usize = input_shape[self.dim + 1..].iter().product();
let in_dim_size = input_shape[self.dim];
let src_dim_size = if source_shape.len() == ndim {
source_shape[self.dim]
} else {
self.index.len()
};
let src_numel = if source_shape.is_empty() {
1
} else {
source_shape.iter().product::<usize>()
};
let mut out = vec![<T as num_traits::Zero>::zero(); src_numel];
for o in 0..outer {
for i in 0..src_dim_size.min(self.index.len()) {
let dst_i = self.index[i];
let go_base = o * in_dim_size * inner + dst_i * inner;
let src_base = o * src_dim_size * inner + i * inner;
for k in 0..inner {
out[src_base + k] = go[go_base + k] * alpha_t;
}
}
}
out
};
Some(Tensor::from_storage(
TensorStorage::cpu(gs),
source_shape.to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_source])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input, &self.source]
}
fn name(&self) -> &'static str {
"IndexAddBackward"
}
}
pub fn index_add<T: Float>(
input: &Tensor<T>,
dim: i64,
index: &IntTensor<i64>,
source: &Tensor<T>,
alpha: f64,
) -> FerrotorchResult<Tensor<T>> {
let input_shape = input.shape();
let ndim = input_shape.len();
if ndim == 0 {
let dim_for_0d = match dim {
0 | -1 => 0i64,
_ => {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index_add: dim {dim} out of range for 0-d input (valid: -1, 0)"
),
});
}
};
let source_shape = source.shape();
if !source_shape.is_empty() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"index_add: source tensor shape must match self tensor shape, \
excluding the specified dimension. Got self.shape = [] \
source.shape = {source_shape:?}"
),
});
}
let scalar_val = input.data_vec()?[0];
let alpha_t = <T as num_traits::NumCast>::from(alpha).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: format!("index_add: alpha {alpha} not representable"),
}
})?;
let src_data = source.data_vec()?;
let n_indices = index.numel();
if n_indices != 1 && n_indices != 0 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"index_add: Number of indices ({n_indices}) should be equal to \
source.size(dim): (1), for dim: 0"
),
});
}
let mut acc = scalar_val;
let mut saved_index: Vec<usize> = Vec::new();
for v in index.data()? {
let i_raw = v.to_i64();
if i_raw != 0 {
return Err(FerrotorchError::IndexOutOfBounds {
index: if i_raw < 0 {
i_raw.unsigned_abs() as usize
} else {
i_raw as usize
},
axis: dim_for_0d as usize,
size: 1,
});
}
let src_v = if src_data.is_empty() {
<T as num_traits::Zero>::zero()
} else {
src_data[0]
};
acc += alpha_t * src_v;
saved_index.push(0);
}
let storage = TensorStorage::cpu(vec![acc]);
if (input.requires_grad() || source.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(IndexAddBackward {
input: input.clone(),
source: source.clone(),
dim: 0,
index: saved_index,
alpha,
});
return Tensor::from_operation(storage, vec![], grad_fn);
}
return Tensor::from_storage(storage, vec![], false);
}
let (dim_usize, idx_usize) =
strict_index_add_copy_validate("index_add", input, dim, index, source, false)?;
let in_dim_size = input_shape[dim_usize];
let alpha_t = <T as num_traits::NumCast>::from(alpha).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: format!("index_add: alpha {alpha} not representable"),
}
})?;
let outer: usize = input_shape[..dim_usize].iter().product();
let inner: usize = input_shape[dim_usize + 1..].iter().product();
let mut out = input.data_vec()?;
let src_data = source.data_vec()?;
let source_shape = source.shape();
let src_dim_size = if source_shape.is_empty() {
return Err(FerrotorchError::Internal {
message: "index_add: unexpected 0-d source after strict validation".into(),
});
} else {
source_shape[dim_usize]
};
for o in 0..outer {
for (i, &dst_i) in idx_usize.iter().enumerate() {
let dst_base = o * in_dim_size * inner + dst_i * inner;
let src_base = o * src_dim_size * inner + i * inner;
for k in 0..inner {
let s = src_data[src_base + k];
out[dst_base + k] += alpha_t * s;
}
}
}
let output_shape = input_shape.to_vec();
if (input.requires_grad() || source.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(IndexAddBackward {
input: input.clone(),
source: source.clone(),
dim: dim_usize,
index: idx_usize,
alpha,
});
Tensor::from_operation(TensorStorage::cpu(out), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out), output_shape, false)
}
}
#[derive(Debug)]
pub struct IndexCopyBackward<T: Float> {
pub input: Tensor<T>,
pub source: Tensor<T>,
pub dim: usize,
pub index: Vec<usize>,
}
impl<T: Float> GradFn<T> for IndexCopyBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None, None]);
}
let input_shape = self.input.shape();
let ndim = input_shape.len();
let zero = <T as num_traits::Zero>::zero();
let grad_input = if self.input.requires_grad() {
let mut gi = grad_output.data_vec()?;
if ndim == 0 {
if !self.index.is_empty() {
gi[0] = zero;
}
} else {
let outer: usize = input_shape[..self.dim].iter().product();
let inner: usize = input_shape[self.dim + 1..].iter().product();
let dim_size = input_shape[self.dim];
for o in 0..outer {
for &idx in &self.index {
let base = o * dim_size * inner + idx * inner;
for k in 0..inner {
gi[base + k] = zero;
}
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
input_shape.to_vec(),
false,
)?)
} else {
None
};
let grad_source = if self.source.requires_grad() {
let go = grad_output.data_vec()?;
let source_shape = self.source.shape();
let gs = if ndim == 0 || source_shape.is_empty() {
let v = if go.is_empty() { zero } else { go[0] };
vec![v]
} else {
let outer: usize = input_shape[..self.dim].iter().product();
let inner: usize = input_shape[self.dim + 1..].iter().product();
let in_dim_size = input_shape[self.dim];
let src_dim_size = if source_shape.len() == ndim {
source_shape[self.dim]
} else {
self.index.len()
};
let src_numel = source_shape.iter().product::<usize>();
let mut out = vec![zero; src_numel];
for o in 0..outer {
for i in 0..src_dim_size.min(self.index.len()) {
let dst_i = self.index[i];
let go_base = o * in_dim_size * inner + dst_i * inner;
let src_base = o * src_dim_size * inner + i * inner;
out[src_base..src_base + inner]
.copy_from_slice(&go[go_base..go_base + inner]);
}
}
out
};
Some(Tensor::from_storage(
TensorStorage::cpu(gs),
source_shape.to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_source])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input, &self.source]
}
fn name(&self) -> &'static str {
"IndexCopyBackward"
}
}
pub fn index_copy<T: Float>(
input: &Tensor<T>,
dim: i64,
index: &IntTensor<i64>,
source: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let input_shape = input.shape();
let ndim = input_shape.len();
if ndim == 0 {
let dim_for_0d = match dim {
0 | -1 => 0i64,
_ => {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index_copy: dim {dim} out of range for 0-d input (valid: -1, 0)"
),
});
}
};
let source_shape = source.shape();
let source_is_0d_compatible =
source_shape.is_empty() || (source_shape.len() == 1 && source_shape[0] <= 1);
if !source_is_0d_compatible {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"index_copy: When source and destination are not scalars, \
their dimensionality must match. Source dimensionality \
({}), destination dimensionality (0)",
source_shape.len()
),
});
}
let n_indices = index.numel();
if source_shape.is_empty() && n_indices != 1 && n_indices != 0 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"index_copy: When source is scalar, index should have one element \
(got {n_indices})"
),
});
}
let scalar_val = input.data_vec()?[0];
let src_data = source.data_vec()?;
let mut result_val = scalar_val;
let mut saved_index: Vec<usize> = Vec::new();
for v in index.data()? {
let i_raw = v.to_i64();
if i_raw != 0 {
return Err(FerrotorchError::IndexOutOfBounds {
index: if i_raw < 0 {
i_raw.unsigned_abs() as usize
} else {
i_raw as usize
},
axis: dim_for_0d as usize,
size: 1,
});
}
result_val = if src_data.is_empty() {
<T as num_traits::Zero>::zero()
} else {
src_data[0]
};
saved_index.push(0);
}
let storage = TensorStorage::cpu(vec![result_val]);
if (input.requires_grad() || source.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(IndexCopyBackward {
input: input.clone(),
source: source.clone(),
dim: 0,
index: saved_index,
});
return Tensor::from_operation(storage, vec![], grad_fn);
}
return Tensor::from_storage(storage, vec![], false);
}
let (dim_usize, idx_usize) =
strict_index_add_copy_validate("index_copy", input, dim, index, source, true)?;
let in_dim_size = input_shape[dim_usize];
let outer: usize = input_shape[..dim_usize].iter().product();
let inner: usize = input_shape[dim_usize + 1..].iter().product();
let mut out = input.data_vec()?;
let src_data = source.data_vec()?;
let source_shape = source.shape();
if source_shape.is_empty() {
let scalar = if src_data.is_empty() {
<T as num_traits::Zero>::zero()
} else {
src_data[0]
};
let dst_i = idx_usize[0];
for o in 0..outer {
let dst_base = o * in_dim_size * inner + dst_i * inner;
for k in 0..inner {
out[dst_base + k] = scalar;
}
}
} else {
let src_dim_size = source_shape[dim_usize];
for o in 0..outer {
for (i, &dst_i) in idx_usize.iter().enumerate() {
let dst_base = o * in_dim_size * inner + dst_i * inner;
let src_base = o * src_dim_size * inner + i * inner;
out[dst_base..dst_base + inner]
.copy_from_slice(&src_data[src_base..src_base + inner]);
}
}
}
let output_shape = input_shape.to_vec();
if (input.requires_grad() || source.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(IndexCopyBackward {
input: input.clone(),
source: source.clone(),
dim: dim_usize,
index: idx_usize,
});
Tensor::from_operation(TensorStorage::cpu(out), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out), output_shape, false)
}
}
#[derive(Debug)]
pub struct MaskedScatterBackward<T: Float> {
pub input: Tensor<T>,
pub source: Tensor<T>,
pub mask: BoolTensor,
}
impl<T: Float> GradFn<T> for MaskedScatterBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None, None]);
}
let mask_cpu = if self.mask.is_cuda() {
self.mask.to(Device::Cpu)?
} else {
self.mask.clone()
};
let mask_h = mask_cpu.data()?;
let go = grad_output.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let grad_input = if self.input.requires_grad() {
let mut gi = go.clone();
for (i, &m) in mask_h.iter().enumerate() {
if m {
gi[i] = zero;
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let grad_source = if self.source.requires_grad() {
let source_numel = self.source.numel();
let mut gs = vec![zero; source_numel];
let mut j = 0usize;
for (i, &m) in mask_h.iter().enumerate() {
if m && j < source_numel {
gs[j] = go[i];
j += 1;
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gs),
self.source.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_source])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input, &self.source]
}
fn name(&self) -> &'static str {
"MaskedScatterBackward"
}
}
pub fn masked_scatter<T: Float>(
input: &Tensor<T>,
mask: &BoolTensor,
source: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let common = if input.shape() == mask.shape() {
input.shape().to_vec()
} else {
crate::shape::broadcast_shapes(input.shape(), mask.shape())?
};
let input_b = if input.shape() == common.as_slice() {
input.clone()
} else {
crate::grad_fns::shape::expand(input, &common)?
};
let mask_b = if mask.shape() == common.as_slice() {
mask.clone()
} else {
broadcast_bool_tensor(mask, &common)?
};
if input_b.is_cuda() && mask_b.is_cuda() && source.is_cuda() {
use std::any::TypeId;
let is_t_f32 = TypeId::of::<T>() == TypeId::of::<f32>();
let is_t_f64 = TypeId::of::<T>() == TypeId::of::<f64>();
if (is_t_f32 || is_t_f64)
&& input_b.device() == mask_b.device()
&& input_b.device() == source.device()
{
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input_c = input_b.contiguous()?;
let source_c = source.contiguous()?;
let n = input_c.numel();
let result_handle = backend.masked_scatter_forward(
input_c.gpu_handle()?,
source_c.gpu_handle()?,
mask_b.gpu_handle()?,
n,
)?;
let storage = TensorStorage::gpu(result_handle);
let output_shape = common.clone();
if (input_c.requires_grad() || source.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(MaskedScatterBackward {
input: input_c.clone(),
source: source.clone(),
mask: mask_b.clone(),
});
return Tensor::from_operation(storage, output_shape, grad_fn);
}
return Tensor::from_storage(storage, output_shape, false);
}
}
let mask_h = mask_b.data()?;
let true_count = mask_h.iter().filter(|&&b| b).count();
if source.numel() < true_count {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"masked_scatter: source has {} elements, but mask has {} true positions",
source.numel(),
true_count
),
});
}
let in_data = input_b.data_vec()?;
let src_data = source.data_vec()?;
let mut out = in_data.clone();
let mut j = 0usize;
for (i, &m) in mask_h.iter().enumerate() {
if m {
out[i] = src_data[j];
j += 1;
}
}
let output_shape = common.clone();
if (input_b.requires_grad() || source.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(MaskedScatterBackward {
input: input_b.clone(),
source: source.clone(),
mask: mask_b.clone(),
});
Tensor::from_operation(TensorStorage::cpu(out), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out), output_shape, false)
}
}
#[derive(Debug)]
pub struct TakeBackward<T: Float> {
pub input: Tensor<T>,
pub index: Vec<usize>,
}
impl<T: Float> GradFn<T> for TakeBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let input_shape = self.input.shape().to_vec();
let input_numel: usize = if input_shape.is_empty() {
1
} else {
input_shape.iter().product()
};
let go = grad_output.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let mut grad_input = vec![zero; input_numel];
for (i, &idx) in self.index.iter().enumerate() {
if idx < input_numel && i < go.len() {
grad_input[idx] += go[i];
}
}
let grad_tensor = Tensor::from_storage(TensorStorage::cpu(grad_input), input_shape, false)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"TakeBackward"
}
}
pub fn take<T: Float>(input: &Tensor<T>, index: &IntTensor<i64>) -> FerrotorchResult<Tensor<T>> {
let input_data = input.data_vec()?;
let input_numel: usize = if input.shape().is_empty() {
1
} else {
input.shape().iter().product()
};
let input_numel_i64 = input_numel as i64;
let mut idx_usize: Vec<usize> = Vec::with_capacity(index.numel());
for v in index.data()? {
let i_raw = v.to_i64();
if i_raw < -input_numel_i64 || i_raw >= input_numel_i64 {
return Err(FerrotorchError::IndexOutOfBounds {
index: if i_raw < 0 {
i_raw.unsigned_abs() as usize
} else {
i_raw as usize
},
axis: 0,
size: input_numel,
});
}
let i = if i_raw < 0 {
i_raw + input_numel_i64
} else {
i_raw
};
idx_usize.push(i as usize);
}
let output_shape = index.shape().to_vec();
let output_numel = if output_shape.is_empty() {
1
} else {
output_shape.iter().product()
};
let mut out = Vec::with_capacity(output_numel);
for &idx in &idx_usize {
out.push(input_data[idx]);
}
if out.is_empty() && output_numel == 1 {
out.push(<T as num_traits::Zero>::zero());
}
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(TakeBackward {
input: input.clone(),
index: idx_usize,
});
Tensor::from_operation(TensorStorage::cpu(out), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out), output_shape, false)
}
}
#[derive(Debug)]
pub struct PutBackward<T: Float> {
pub input: Tensor<T>,
pub source: Tensor<T>,
pub index: Vec<usize>,
pub accumulate: bool,
}
impl<T: Float> GradFn<T> for PutBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None, None]);
}
let input_shape = self.input.shape().to_vec();
let input_numel: usize = if input_shape.is_empty() {
1
} else {
input_shape.iter().product()
};
let go = grad_output.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let grad_input = if self.input.requires_grad() {
let mut gi = go.clone();
if !self.accumulate {
for &idx in &self.index {
if idx < input_numel {
gi[idx] = zero;
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
input_shape,
false,
)?)
} else {
None
};
let grad_source = if self.source.requires_grad() {
let source_numel = self.source.numel();
let mut gs = vec![zero; source_numel];
for (i, &idx) in self.index.iter().enumerate() {
if idx < go.len() && i < source_numel {
gs[i] = go[idx];
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gs),
self.source.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_source])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input, &self.source]
}
fn name(&self) -> &'static str {
"PutBackward"
}
}
pub fn put<T: Float>(
input: &Tensor<T>,
index: &IntTensor<i64>,
source: &Tensor<T>,
accumulate: bool,
) -> FerrotorchResult<Tensor<T>> {
let input_shape = input.shape().to_vec();
let input_numel: usize = if input_shape.is_empty() {
1
} else {
input_shape.iter().product()
};
let input_numel_i64 = input_numel as i64;
let mut idx_usize: Vec<usize> = Vec::with_capacity(index.numel());
for v in index.data()? {
let i_raw = v.to_i64();
if i_raw < -input_numel_i64 || i_raw >= input_numel_i64 {
return Err(FerrotorchError::IndexOutOfBounds {
index: if i_raw < 0 {
i_raw.unsigned_abs() as usize
} else {
i_raw as usize
},
axis: 0,
size: input_numel,
});
}
let i = if i_raw < 0 {
i_raw + input_numel_i64
} else {
i_raw
};
idx_usize.push(i as usize);
}
if source.numel() < idx_usize.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"put: source numel {} < index numel {}",
source.numel(),
idx_usize.len()
),
});
}
let mut out = input.data_vec()?;
if out.is_empty() && input_numel == 1 {
out.push(<T as num_traits::Zero>::zero());
}
let src_data = source.data_vec()?;
for (i, &idx) in idx_usize.iter().enumerate() {
let s = src_data[i];
if accumulate {
out[idx] += s;
} else {
out[idx] = s;
}
}
if (input.requires_grad() || source.requires_grad()) && is_grad_enabled() {
let grad_fn = Arc::new(PutBackward {
input: input.clone(),
source: source.clone(),
index: idx_usize,
accumulate,
});
Tensor::from_operation(TensorStorage::cpu(out), input_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(out), input_shape, false)
}
}
#[cfg(test)]
mod first_class_wrappers_tests {
use super::*;
#[test]
fn masked_fill_bt_replaces_true_positions() {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![1.0_f32, 2.0, 3.0, 4.0]),
vec![4],
false,
)
.unwrap();
let mask = BoolTensor::from_vec(vec![true, false, true, false], vec![4]).unwrap();
let out = masked_fill_bt(&t, &mask, -1.0).unwrap();
assert_eq!(out.data().unwrap(), &[-1.0, 2.0, -1.0, 4.0]);
}
#[test]
fn masked_fill_bt_rejects_shape_mismatch() {
let t =
Tensor::from_storage(TensorStorage::cpu(vec![1.0_f32, 2.0]), vec![2], false).unwrap();
let mask = BoolTensor::from_vec(vec![true, false, true], vec![3]).unwrap();
let err = masked_fill_bt(&t, &mask, 0.0).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn index_select_1d_it_picks_at_indices() {
let t = Tensor::from_storage(
TensorStorage::cpu(vec![10.0_f32, 20.0, 30.0, 40.0]),
vec![4],
false,
)
.unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![3, 0, 2], vec![3]).unwrap();
let out = index_select_1d_it(&t, &idx).unwrap();
assert_eq!(out.data().unwrap(), &[40.0, 10.0, 30.0]);
}
#[test]
fn index_select_1d_it_rejects_2d_indices() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f32; 4]), vec![4], false).unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![0, 1, 2, 3], vec![2, 2]).unwrap();
let err = index_select_1d_it(&t, &idx).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn index_select_1d_it_rejects_negative() {
let t = Tensor::from_storage(TensorStorage::cpu(vec![1.0_f32; 4]), vec![4], false).unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![0, -1, 2], vec![3]).unwrap();
let err = index_select_1d_it(&t, &idx).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
fn bcast_cpu_f32(data: Vec<f32>, shape: Vec<usize>) -> FerrotorchResult<Tensor<f32>> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false)
}
fn bcast_cpu_f32_grad(data: Vec<f32>, shape: Vec<usize>) -> FerrotorchResult<Tensor<f32>> {
Tensor::from_storage(TensorStorage::cpu(data), shape, true)
}
#[test]
fn masked_fill_bcast_passthrough_same_shape() -> FerrotorchResult<()> {
let t = bcast_cpu_f32(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])?;
let mask = BoolTensor::from_vec(vec![true, false, false, true], vec![2, 2])?;
let out = masked_fill_bcast(&t, &mask, -1.0)?;
assert_eq!(out.shape(), &[2, 2]);
assert_eq!(out.data()?, &[-1.0, 2.0, 3.0, -1.0]);
Ok(())
}
#[test]
fn masked_fill_bcast_broadcasts_row_mask_to_matrix() -> FerrotorchResult<()> {
let t = bcast_cpu_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])?;
let mask = BoolTensor::from_vec(vec![true, false, true], vec![3])?;
let out = masked_fill_bcast(&t, &mask, 0.0)?;
assert_eq!(out.shape(), &[2, 3]);
assert_eq!(out.data()?, &[0.0, 2.0, 0.0, 0.0, 5.0, 0.0]);
Ok(())
}
#[test]
fn masked_fill_bcast_broadcasts_scalar_input_against_2d_mask() -> FerrotorchResult<()> {
let t = bcast_cpu_f32(vec![7.0], vec![])?;
let mask = BoolTensor::from_vec(vec![true, false, true, true], vec![2, 2])?;
let out = masked_fill_bcast(&t, &mask, -1.0)?;
assert_eq!(out.shape(), &[2, 2]);
assert_eq!(out.data()?, &[-1.0, 7.0, -1.0, -1.0]);
Ok(())
}
#[test]
fn masked_fill_bcast_rejects_incompatible_shapes() -> FerrotorchResult<()> {
let t = bcast_cpu_f32(vec![1.0_f32; 6], vec![2, 3])?;
let mask = BoolTensor::from_vec(vec![true; 4], vec![2, 2])?;
let err = masked_fill_bcast(&t, &mask, 0.0).err();
assert!(matches!(err, Some(FerrotorchError::ShapeMismatch { .. })));
Ok(())
}
#[test]
fn masked_select_bcast_passthrough_same_shape() -> FerrotorchResult<()> {
let t = bcast_cpu_f32(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])?;
let mask = BoolTensor::from_vec(vec![true, false, false, true], vec![2, 2])?;
let out = masked_select_bcast(&t, &mask)?;
assert_eq!(out.shape(), &[2]);
assert_eq!(out.data()?, &[1.0, 4.0]);
Ok(())
}
#[test]
fn masked_select_bcast_broadcasts_1d_mask_against_2d_input() -> FerrotorchResult<()> {
let t = bcast_cpu_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])?;
let mask = BoolTensor::from_vec(vec![true, false, true], vec![3])?;
let out = masked_select_bcast(&t, &mask)?;
assert_eq!(out.shape(), &[4]);
assert_eq!(out.data()?, &[1.0, 3.0, 4.0, 6.0]);
Ok(())
}
#[test]
fn masked_select_bcast_broadcasts_1d_input_against_2d_mask() -> FerrotorchResult<()> {
let t = bcast_cpu_f32(vec![10.0, 20.0, 30.0], vec![3])?;
let mask = BoolTensor::from_vec(vec![true, true, false, false, true, true], vec![2, 3])?;
let out = masked_select_bcast(&t, &mask)?;
assert_eq!(out.shape(), &[4]);
assert_eq!(out.data()?, &[10.0, 20.0, 20.0, 30.0]);
Ok(())
}
#[test]
fn where_cond_bcast_passthrough_same_shape() -> FerrotorchResult<()> {
let cond = BoolTensor::from_vec(vec![true, false, true, false], vec![2, 2])?;
let x = bcast_cpu_f32(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])?;
let y = bcast_cpu_f32(vec![10.0, 20.0, 30.0, 40.0], vec![2, 2])?;
let out = where_cond_bcast(&cond, &x, &y)?;
assert_eq!(out.shape(), &[2, 2]);
assert_eq!(out.data()?, &[1.0, 20.0, 3.0, 40.0]);
Ok(())
}
#[test]
fn where_cond_bcast_three_way_broadcast_with_scalars() -> FerrotorchResult<()> {
let cond = BoolTensor::from_vec(vec![true, false, false, true], vec![2, 2])?;
let x = bcast_cpu_f32(vec![7.0], vec![])?;
let y = bcast_cpu_f32(vec![100.0, 200.0], vec![1, 2])?;
let out = where_cond_bcast(&cond, &x, &y)?;
assert_eq!(out.shape(), &[2, 2]);
assert_eq!(out.data()?, &[7.0, 200.0, 100.0, 7.0]);
Ok(())
}
#[test]
fn where_cond_bcast_rejects_incompatible_shapes() -> FerrotorchResult<()> {
let cond = BoolTensor::from_vec(vec![true; 6], vec![2, 3])?;
let x = bcast_cpu_f32(vec![1.0_f32; 6], vec![2, 3])?;
let y = bcast_cpu_f32(vec![0.0_f32; 8], vec![2, 4])?;
let err = where_cond_bcast(&cond, &x, &y).err();
assert!(matches!(err, Some(FerrotorchError::ShapeMismatch { .. })));
Ok(())
}
#[test]
fn masked_select_bcast_backward_reduces_to_input_shape() -> FerrotorchResult<()> {
use crate::autograd::graph::backward;
let t = bcast_cpu_f32_grad(vec![10.0, 20.0, 30.0], vec![3])?;
let mask = BoolTensor::from_vec(vec![true, false, true, false, true, true], vec![2, 3])?;
let out = masked_select_bcast(&t, &mask)?;
#[derive(Debug)]
struct BcastSumBackward<T: Float> {
input: Tensor<T>,
numel: usize,
}
impl<T: Float> GradFn<T> for BcastSumBackward<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones = vec![<T as num_traits::One>::one(); self.numel];
let t = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(t)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"BcastTestSumBackward"
}
}
let out_numel = out.numel();
let total: f32 = out.data()?.iter().sum();
let scalar = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(BcastSumBackward {
input: out.clone(),
numel: out_numel,
}),
)?;
backward(&scalar)?;
let g_opt = t.grad()?;
let g = match g_opt {
Some(g) => g,
None => {
return Err(FerrotorchError::Internal {
message: "no grad on leaf".into(),
});
}
};
assert_eq!(g.shape(), &[3]);
assert_eq!(g.data()?, &[1.0, 1.0, 2.0]);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::graph::backward;
use crate::autograd::no_grad;
use crate::storage::TensorStorage;
fn leaf_1d(data: &[f32], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
vec![data.len()],
requires_grad,
)
.unwrap()
}
#[test]
fn test_index_select_1d_forward() {
let input = leaf_1d(&[10.0, 20.0, 30.0, 40.0, 50.0], false);
let result = index_select_1d(&input, &[0, 2, 4]).unwrap();
assert_eq!(result.shape(), &[3]);
assert_eq!(result.data().unwrap(), &[10.0, 30.0, 50.0]);
}
#[test]
fn test_index_select_1d_duplicate_indices() {
let input = leaf_1d(&[10.0, 20.0, 30.0], false);
let result = index_select_1d(&input, &[1, 1, 2, 0, 1]).unwrap();
assert_eq!(result.shape(), &[5]);
assert_eq!(result.data().unwrap(), &[20.0, 20.0, 30.0, 10.0, 20.0]);
}
#[test]
fn test_index_select_1d_out_of_bounds() {
let input = leaf_1d(&[10.0, 20.0, 30.0], false);
let result = index_select_1d(&input, &[0, 5]);
assert!(result.is_err());
}
#[test]
fn test_index_select_1d_non_1d_input() {
let input = Tensor::<f32>::from_storage(
TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
vec![2, 2],
false,
)
.unwrap();
let result = index_select_1d(&input, &[0]);
assert!(result.is_err());
}
#[test]
fn test_index_select_1d_backward_simple() {
let input = leaf_1d(&[10.0, 20.0, 30.0, 40.0], true);
let selected = index_select_1d(&input, &[1, 3]).unwrap();
assert!(selected.requires_grad());
assert!(!selected.is_leaf());
assert_eq!(selected.grad_fn().unwrap().name(), "IndexSelectBackward");
let data = selected.data().unwrap();
let total: f32 = data.iter().sum();
let sum_storage = TensorStorage::cpu(vec![total]);
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_val = grad_output.data()?[0];
let grad = vec![go_val; self.input.numel()];
let t = Tensor::from_storage(
TensorStorage::cpu(grad),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(t)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
sum_storage,
vec![],
Arc::new(SumBackward {
input: selected.clone(),
}),
)
.unwrap();
backward(&loss).unwrap();
let grad = input.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
assert_eq!(grad_data.len(), 4);
assert!((grad_data[0] - 0.0).abs() < 1e-6, "grad[0] should be 0");
assert!((grad_data[1] - 1.0).abs() < 1e-6, "grad[1] should be 1");
assert!((grad_data[2] - 0.0).abs() < 1e-6, "grad[2] should be 0");
assert!((grad_data[3] - 1.0).abs() < 1e-6, "grad[3] should be 1");
}
#[test]
fn test_index_select_1d_backward_duplicate_indices() {
let input = leaf_1d(&[10.0, 20.0, 30.0], true);
let selected = index_select_1d(&input, &[0, 1, 1, 2, 1]).unwrap();
let grad_output =
Tensor::from_storage(TensorStorage::cpu(vec![1.0; 5]), vec![5], false).unwrap();
let grad_fn = selected.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_input = grads[0].as_ref().unwrap();
let gd = grad_input.data().unwrap();
assert_eq!(gd.len(), 3);
assert!(
(gd[0] - 1.0).abs() < 1e-6,
"grad[0] = {}, expected 1",
gd[0]
);
assert!(
(gd[1] - 3.0).abs() < 1e-6,
"grad[1] = {}, expected 3",
gd[1]
);
assert!(
(gd[2] - 1.0).abs() < 1e-6,
"grad[2] = {}, expected 1",
gd[2]
);
}
#[test]
fn test_index_select_1d_backward_weighted_grad() {
let input = leaf_1d(&[100.0, 200.0, 300.0], true);
let selected = index_select_1d(&input, &[2, 0]).unwrap();
let grad_output =
Tensor::from_storage(TensorStorage::cpu(vec![0.5, 2.0]), vec![2], false).unwrap();
let grad_fn = selected.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_input = grads[0].as_ref().unwrap();
let gd = grad_input.data().unwrap();
assert!(
(gd[0] - 2.0).abs() < 1e-6,
"grad[0] = {}, expected 2.0",
gd[0]
);
assert!(
(gd[1] - 0.0).abs() < 1e-6,
"grad[1] = {}, expected 0.0",
gd[1]
);
assert!(
(gd[2] - 0.5).abs() < 1e-6,
"grad[2] = {}, expected 0.5",
gd[2]
);
}
#[test]
fn test_index_select_1d_no_grad_context() {
let input = leaf_1d(&[10.0, 20.0, 30.0], true);
let result = no_grad(|| index_select_1d(&input, &[0, 2])).unwrap();
assert!(!result.requires_grad());
assert!(result.grad_fn().is_none());
}
#[test]
fn test_masked_fill_forward() {
let input = leaf_1d(&[1.0, 2.0, 3.0, 4.0], false);
let mask = [false, true, false, true];
let result = masked_fill(&input, &mask, -999.0).unwrap();
assert_eq!(result.data().unwrap(), &[1.0, -999.0, 3.0, -999.0]);
}
#[test]
fn test_masked_fill_backward() {
let input = leaf_1d(&[1.0, 2.0, 3.0, 4.0], true);
let mask = [false, true, false, true];
let filled = masked_fill(&input, &mask, 0.0).unwrap();
let grad_output =
Tensor::from_storage(TensorStorage::cpu(vec![1.0; 4]), vec![4], false).unwrap();
let grad_fn = filled.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_input = grads[0].as_ref().unwrap();
let gd = grad_input.data().unwrap();
assert!((gd[0] - 1.0).abs() < 1e-6);
assert!((gd[1] - 0.0).abs() < 1e-6);
assert!((gd[2] - 1.0).abs() < 1e-6);
assert!((gd[3] - 0.0).abs() < 1e-6);
}
#[test]
fn test_masked_fill_shape_mismatch() {
let input = leaf_1d(&[1.0, 2.0, 3.0], false);
let mask = [true, false]; let result = masked_fill(&input, &mask, 0.0);
assert!(result.is_err());
}
#[test]
fn test_gather_backward_stub() {
let input = leaf_1d(&[1.0, 2.0], true);
let gf = GatherBackward {
input,
dim: 0,
index: vec![0, 1],
index_shape: vec![2],
};
let grad_output =
Tensor::from_storage(TensorStorage::cpu(vec![1.0, 1.0]), vec![2], false).unwrap();
let result = gf.backward(&grad_output);
assert!(result.is_ok());
}
#[test]
fn test_scatter_add_backward_stub() {
let input = leaf_1d(&[1.0, 2.0], true);
let src = leaf_1d(&[3.0], false);
let gf = ScatterAddBackward {
input,
src,
dim: 0,
index: vec![0],
index_shape: vec![1],
};
let grad_output =
Tensor::from_storage(TensorStorage::cpu(vec![1.0, 1.0]), vec![2], false).unwrap();
let result = gf.backward(&grad_output);
assert!(result.is_ok());
}
#[test]
fn test_index_select_dim_2d_dim0_forward() {
let input = Tensor::from_storage(
TensorStorage::cpu(vec![
10.0_f32, 11.0, 12.0, 20.0, 21.0, 22.0, 30.0, 31.0, 32.0, 40.0, 41.0, 42.0,
]),
vec![4, 3],
false,
)
.unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![3, 0, 2], vec![3]).unwrap();
let out = index_select_dim(&input, 0, &idx).unwrap();
assert_eq!(out.shape(), &[3, 3]);
assert_eq!(
out.data().unwrap(),
&[40.0, 41.0, 42.0, 10.0, 11.0, 12.0, 30.0, 31.0, 32.0]
);
}
#[test]
fn test_index_select_dim_2d_dim1_forward() {
let input = Tensor::from_storage(
TensorStorage::cpu(vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
vec![2, 4],
false,
)
.unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![1, 3, 0], vec![3]).unwrap();
let out = index_select_dim(&input, 1, &idx).unwrap();
assert_eq!(out.shape(), &[2, 3]);
assert_eq!(out.data().unwrap(), &[2.0, 4.0, 1.0, 6.0, 8.0, 5.0]);
}
#[test]
fn test_index_select_dim_registers_grad_fn() {
let input = Tensor::from_storage(
TensorStorage::cpu(vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]),
vec![3, 2],
true,
)
.unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![0, 2], vec![2]).unwrap();
let out = index_select_dim(&input, 0, &idx).unwrap();
assert!(out.requires_grad());
assert!(!out.is_leaf());
assert_eq!(out.grad_fn().unwrap().name(), "IndexSelectDimBackward");
}
#[test]
fn test_index_select_dim_backward_simple_2d() {
let input = Tensor::from_storage(
TensorStorage::cpu(vec![
1.0_f32, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ]),
vec![4, 2],
true,
)
.unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![2, 0, 2], vec![3]).unwrap();
let out = index_select_dim(&input, 0, &idx).unwrap();
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![1.0_f32, 10.0, 100.0, 1000.0, 10000.0, 100000.0]),
vec![3, 2],
false,
)
.unwrap();
let grads = out.grad_fn().unwrap().backward(&grad_output).unwrap();
let g = grads[0].as_ref().unwrap();
assert_eq!(g.shape(), &[4, 2]);
let gd = g.data().unwrap();
let expected = [100.0_f32, 1000.0, 0.0, 0.0, 10001.0, 100010.0, 0.0, 0.0];
for (i, (&got, &exp)) in gd.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-3,
"grad[{i}] = {got}, expected {exp}"
);
}
}
#[test]
fn test_index_select_dim_backward_dim1() {
let input = Tensor::from_storage(
TensorStorage::cpu(vec![1.0_f32, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
vec![2, 4],
true,
)
.unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![3, 1], vec![2]).unwrap();
let out = index_select_dim(&input, 1, &idx).unwrap();
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![1.0_f32, 10.0, 100.0, 1000.0]),
vec![2, 2],
false,
)
.unwrap();
let grads = out.grad_fn().unwrap().backward(&grad_output).unwrap();
let g = grads[0].as_ref().unwrap();
assert_eq!(g.shape(), &[2, 4]);
let gd = g.data().unwrap();
let expected = [0.0_f32, 10.0, 0.0, 1.0, 0.0, 1000.0, 0.0, 100.0];
for (i, (&got, &exp)) in gd.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-6,
"grad[{i}] = {got}, expected {exp}"
);
}
}
#[test]
fn test_index_select_dim_e2e_via_autograd() {
let x = Tensor::from_storage(
TensorStorage::cpu(vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]),
vec![3, 2],
true,
)
.unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![0, 2, 0], vec![3]).unwrap();
let out = index_select_dim(&x, 0, &idx).unwrap();
let total: f32 = out.data().unwrap().iter().sum();
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new({
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
_go: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let n = self.input.numel();
let ones = vec![<T as num_traits::One>::one(); n];
let g = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
SumBackward { input: out.clone() }
}),
)
.unwrap();
crate::autograd::graph::backward(&loss).unwrap();
let grad = x.grad().unwrap().expect("x.grad() should be Some");
assert_eq!(grad.shape(), &[3, 2]);
let gd = grad.data().unwrap();
let expected = [2.0_f32, 2.0, 0.0, 0.0, 1.0, 1.0];
for (i, (&got, &exp)) in gd.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-6,
"grad[{i}] = {got}, expected {exp}"
);
}
}
#[test]
fn test_index_select_dim_rejects_2d_indices() {
let x =
Tensor::from_storage(TensorStorage::cpu(vec![1.0_f32; 6]), vec![3, 2], false).unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![0, 1, 0, 1], vec![2, 2]).unwrap();
let err = index_select_dim(&x, 0, &idx).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn test_index_select_dim_rejects_oob() {
let x =
Tensor::from_storage(TensorStorage::cpu(vec![1.0_f32; 6]), vec![3, 2], false).unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![0, 7], vec![2]).unwrap();
let err = index_select_dim(&x, 0, &idx).unwrap_err();
assert!(matches!(err, FerrotorchError::IndexOutOfBounds { .. }));
}
#[test]
fn test_index_select_dim_rejects_negative() {
let x =
Tensor::from_storage(TensorStorage::cpu(vec![1.0_f32; 6]), vec![3, 2], false).unwrap();
let idx: IntTensor<i64> = IntTensor::from_vec(vec![0, -1], vec![2]).unwrap();
let err = index_select_dim(&x, 0, &idx).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
}
#[cfg(test)]
mod index_fill_tests {
use super::*;
use crate::autograd::graph::backward;
fn cpu_f32(data: Vec<f32>, shape: Vec<usize>) -> FerrotorchResult<Tensor<f32>> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false)
}
fn cpu_f32_grad(data: Vec<f32>, shape: Vec<usize>) -> FerrotorchResult<Tensor<f32>> {
Tensor::from_storage(TensorStorage::cpu(data), shape, true)
}
fn idx_i64(values: Vec<i64>, shape: Vec<usize>) -> FerrotorchResult<IntTensor<i64>> {
IntTensor::from_vec(values, shape)
}
#[test]
fn index_fill_forward_2d_dim1_matches_torch_docstring() -> FerrotorchResult<()> {
let input = cpu_f32(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![3, 3],
)?;
let idx = idx_i64(vec![0, 2], vec![2])?;
let out = index_fill(&input, 1, &idx, -1.0)?;
assert_eq!(out.shape(), &[3, 3]);
let got = out.data()?;
let expected = [-1.0_f32, 2.0, -1.0, -1.0, 5.0, -1.0, -1.0, 8.0, -1.0];
assert_eq!(got, &expected);
Ok(())
}
#[test]
fn index_fill_forward_2d_dim0_replaces_row() -> FerrotorchResult<()> {
let input = cpu_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])?;
let idx = idx_i64(vec![1], vec![1])?;
let out = index_fill(&input, 0, &idx, -9.0)?;
assert_eq!(out.shape(), &[2, 3]);
let got = out.data()?;
let expected = [1.0_f32, 2.0, 3.0, -9.0, -9.0, -9.0];
assert_eq!(got, &expected);
Ok(())
}
#[test]
fn index_fill_backward_zeros_at_fill_positions() -> FerrotorchResult<()> {
let input = cpu_f32_grad(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])?;
let idx = idx_i64(vec![0, 2], vec![2])?;
let out = index_fill(&input, 1, &idx, -1.0)?;
let gf = match out.grad_fn() {
Some(g) => g,
None => {
return Err(FerrotorchError::Internal {
message: "expected grad_fn on requires_grad output".into(),
});
}
};
assert_eq!(gf.name(), "IndexFillBackward");
let grad_output = cpu_f32(vec![1.0_f32; 6], vec![2, 3])?;
let grads = gf.backward(&grad_output)?;
let g = match grads[0].as_ref() {
Some(g) => g,
None => {
return Err(FerrotorchError::Internal {
message: "expected Some(grad_input)".into(),
});
}
};
assert_eq!(g.shape(), &[2, 3]);
let gd = g.data()?;
let expected = [0.0_f32, 1.0, 0.0, 0.0, 1.0, 0.0];
assert_eq!(gd, &expected);
Ok(())
}
#[test]
fn index_fill_negative_dim_wraps() -> FerrotorchResult<()> {
let input = cpu_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])?;
let idx = idx_i64(vec![0, 2], vec![2])?;
let neg = index_fill(&input, -1, &idx, -7.0)?;
let pos = index_fill(&input, 1, &idx, -7.0)?;
assert_eq!(neg.data()?, pos.data()?);
let expected = [-7.0_f32, 2.0, -7.0, -7.0, 5.0, -7.0];
assert_eq!(neg.data()?, &expected);
Ok(())
}
#[test]
fn index_fill_rejects_out_of_bounds() -> FerrotorchResult<()> {
let input = cpu_f32(vec![1.0_f32; 6], vec![2, 3])?;
let idx = idx_i64(vec![0, 7], vec![2])?;
let err = index_fill(&input, 1, &idx, 0.0).err();
assert!(matches!(
err,
Some(FerrotorchError::IndexOutOfBounds { .. })
));
Ok(())
}
#[test]
fn index_fill_wraps_negative_index_per_upstream() -> FerrotorchResult<()> {
let input = cpu_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])?;
let idx = idx_i64(vec![-1], vec![1])?;
let out = index_fill(&input, 1, &idx, -1.0)?;
let expected = [1.0_f32, 2.0, -1.0, 4.0, 5.0, -1.0];
assert_eq!(out.data()?, &expected);
let idx_oob = idx_i64(vec![-4], vec![1])?;
let err = index_fill(&input, 1, &idx_oob, 0.0).err();
assert!(matches!(
err,
Some(FerrotorchError::IndexOutOfBounds { .. })
));
Ok(())
}
#[test]
fn index_fill_rejects_out_of_range_dim() -> FerrotorchResult<()> {
let input = cpu_f32(vec![1.0_f32; 6], vec![2, 3])?;
let idx = idx_i64(vec![0], vec![1])?;
let err = index_fill(&input, 5, &idx, 0.0).err();
assert!(matches!(err, Some(FerrotorchError::InvalidArgument { .. })));
Ok(())
}
#[test]
fn index_fill_zero_dim_input_succeeds_per_upstream() -> FerrotorchResult<()> {
let input = cpu_f32(vec![1.0_f32], vec![])?;
let idx = idx_i64(vec![0], vec![1])?;
let out = index_fill(&input, 0, &idx, 0.0)?;
assert_eq!(out.shape(), &[] as &[usize], "0-d output must remain 0-d");
assert_eq!(out.data()?, &[0.0_f32], "filled value must be 0.0");
Ok(())
}
#[test]
fn index_fill_rejects_multi_d_index() -> FerrotorchResult<()> {
let input = cpu_f32(vec![1.0_f32; 6], vec![2, 3])?;
let idx = idx_i64(vec![0, 1, 0, 1], vec![2, 2])?;
let err = index_fill(&input, 1, &idx, 0.0).err();
assert!(matches!(err, Some(FerrotorchError::ShapeMismatch { .. })));
Ok(())
}
#[test]
fn index_fill_e2e_via_autograd() -> FerrotorchResult<()> {
let x = cpu_f32_grad(vec![10.0, 20.0, 30.0, 40.0], vec![4])?;
let idx = idx_i64(vec![1, 3], vec![2])?;
let out = index_fill(&x, 0, &idx, -1.0)?;
let total: f32 = out.data()?.iter().sum();
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(&self, _go: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let n = self.input.numel();
let ones = vec![<T as num_traits::One>::one(); n];
let g = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(SumBackward { input: out.clone() }),
)?;
backward(&loss)?;
let grad = match x.grad()? {
Some(g) => g,
None => {
return Err(FerrotorchError::Internal {
message: "expected leaf grad".into(),
});
}
};
assert_eq!(grad.shape(), &[4]);
let expected = [1.0_f32, 0.0, 1.0, 0.0];
assert_eq!(grad.data()?, &expected);
Ok(())
}
}