use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct IndexSelectBackward<T: Float> {
pub input: Tensor<T>,
pub indices: Vec<usize>,
}
impl<T: Float> GradFn<T> for IndexSelectBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
let input_len = self.input.numel();
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let go_data = cpu_go.data()?;
let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_len];
for (i, &idx) in self.indices.iter().enumerate() {
grad_input[idx] += go_data[i];
}
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_input),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IndexSelectBackward"
}
}
pub fn index_select_1d<T: Float>(
input: &Tensor<T>,
indices: &[usize],
) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index_select_1d requires a 1-D input, got shape {:?}",
input.shape()
),
});
}
let cpu_input = if input.is_cuda() { input.cpu()? } else { input.clone() };
let input_data = cpu_input.data()?;
let input_len = input_data.len();
for &idx in indices {
if idx >= input_len {
return Err(FerrotorchError::IndexOutOfBounds {
index: idx,
axis: 0,
size: input_len,
});
}
}
let output_data: Vec<T> = indices.iter().map(|&idx| input_data[idx]).collect();
let output_shape = vec![indices.len()];
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(IndexSelectBackward {
input: input.clone(),
indices: indices.to_vec(),
});
Tensor::from_operation(TensorStorage::cpu(output_data), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output_data), output_shape, false)
}
}
#[derive(Debug)]
pub struct MaskedFillBackward<T: Float> {
pub input: Tensor<T>,
pub masked_indices: Vec<usize>,
}
impl<T: Float> GradFn<T> for MaskedFillBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !is_grad_enabled() {
return Ok(vec![None]);
}
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let go_data = cpu_go.data()?;
let mut grad_input: Vec<T> = go_data.to_vec();
for &idx in &self.masked_indices {
grad_input[idx] = <T as num_traits::Zero>::zero();
}
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_input),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"MaskedFillBackward"
}
}
pub fn masked_fill<T: Float>(
input: &Tensor<T>,
mask: &[bool],
value: T,
) -> FerrotorchResult<Tensor<T>> {
let cpu_input = if input.is_cuda() { input.cpu()? } else { input.clone() };
let input_data = cpu_input.data()?;
if mask.len() != input_data.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"masked_fill: mask length {} does not match input length {}",
mask.len(),
input_data.len()
),
});
}
let masked_indices: Vec<usize> = mask
.iter()
.enumerate()
.filter_map(|(i, &m)| if m { Some(i) } else { None })
.collect();
let output_data: Vec<T> = input_data
.iter()
.zip(mask.iter())
.map(|(&x, &m)| if m { value } else { x })
.collect();
let output_shape = input.shape().to_vec();
if input.requires_grad() && is_grad_enabled() {
let grad_fn = Arc::new(MaskedFillBackward {
input: input.clone(),
masked_indices,
});
Tensor::from_operation(TensorStorage::cpu(output_data), output_shape, grad_fn)
} else {
Tensor::from_storage(TensorStorage::cpu(output_data), output_shape, false)
}
}
#[derive(Debug)]
pub struct GatherBackward<T: Float> {
pub input: Tensor<T>,
}
impl<T: Float> GradFn<T> for GatherBackward<T> {
fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
Err(FerrotorchError::InvalidArgument {
message: "GatherBackward is not yet implemented".into(),
})
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"GatherBackward"
}
}
#[derive(Debug)]
pub struct ScatterAddBackward<T: Float> {
pub input: Tensor<T>,
}
impl<T: Float> GradFn<T> for ScatterAddBackward<T> {
fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
Err(FerrotorchError::InvalidArgument {
message: "ScatterAddBackward is not yet implemented".into(),
})
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ScatterAddBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::graph::backward;
use crate::autograd::no_grad;
use crate::storage::TensorStorage;
fn leaf_1d(data: &[f32], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
vec![data.len()],
requires_grad,
)
.unwrap()
}
#[test]
fn test_index_select_1d_forward() {
let input = leaf_1d(&[10.0, 20.0, 30.0, 40.0, 50.0], false);
let result = index_select_1d(&input, &[0, 2, 4]).unwrap();
assert_eq!(result.shape(), &[3]);
assert_eq!(result.data().unwrap(), &[10.0, 30.0, 50.0]);
}
#[test]
fn test_index_select_1d_duplicate_indices() {
let input = leaf_1d(&[10.0, 20.0, 30.0], false);
let result = index_select_1d(&input, &[1, 1, 2, 0, 1]).unwrap();
assert_eq!(result.shape(), &[5]);
assert_eq!(result.data().unwrap(), &[20.0, 20.0, 30.0, 10.0, 20.0]);
}
#[test]
fn test_index_select_1d_out_of_bounds() {
let input = leaf_1d(&[10.0, 20.0, 30.0], false);
let result = index_select_1d(&input, &[0, 5]);
assert!(result.is_err());
}
#[test]
fn test_index_select_1d_non_1d_input() {
let input = Tensor::<f32>::from_storage(
TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
vec![2, 2],
false,
)
.unwrap();
let result = index_select_1d(&input, &[0]);
assert!(result.is_err());
}
#[test]
fn test_index_select_1d_backward_simple() {
let input = leaf_1d(&[10.0, 20.0, 30.0, 40.0], true);
let selected = index_select_1d(&input, &[1, 3]).unwrap();
assert!(selected.requires_grad());
assert!(!selected.is_leaf());
assert_eq!(selected.grad_fn().unwrap().name(), "IndexSelectBackward");
let data = selected.data().unwrap();
let total: f32 = data.iter().sum();
let sum_storage = TensorStorage::cpu(vec![total]);
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_val = grad_output.data()?[0];
let grad = vec![go_val; self.input.numel()];
let t = Tensor::from_storage(
TensorStorage::cpu(grad),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(t)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
sum_storage,
vec![],
Arc::new(SumBackward {
input: selected.clone(),
}),
)
.unwrap();
backward(&loss).unwrap();
let grad = input.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
assert_eq!(grad_data.len(), 4);
assert!((grad_data[0] - 0.0).abs() < 1e-6, "grad[0] should be 0");
assert!((grad_data[1] - 1.0).abs() < 1e-6, "grad[1] should be 1");
assert!((grad_data[2] - 0.0).abs() < 1e-6, "grad[2] should be 0");
assert!((grad_data[3] - 1.0).abs() < 1e-6, "grad[3] should be 1");
}
#[test]
fn test_index_select_1d_backward_duplicate_indices() {
let input = leaf_1d(&[10.0, 20.0, 30.0], true);
let selected = index_select_1d(&input, &[0, 1, 1, 2, 1]).unwrap();
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![1.0; 5]),
vec![5],
false,
)
.unwrap();
let grad_fn = selected.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_input = grads[0].as_ref().unwrap();
let gd = grad_input.data().unwrap();
assert_eq!(gd.len(), 3);
assert!((gd[0] - 1.0).abs() < 1e-6, "grad[0] = {}, expected 1", gd[0]);
assert!((gd[1] - 3.0).abs() < 1e-6, "grad[1] = {}, expected 3", gd[1]);
assert!((gd[2] - 1.0).abs() < 1e-6, "grad[2] = {}, expected 1", gd[2]);
}
#[test]
fn test_index_select_1d_backward_weighted_grad() {
let input = leaf_1d(&[100.0, 200.0, 300.0], true);
let selected = index_select_1d(&input, &[2, 0]).unwrap();
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![0.5, 2.0]),
vec![2],
false,
)
.unwrap();
let grad_fn = selected.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_input = grads[0].as_ref().unwrap();
let gd = grad_input.data().unwrap();
assert!((gd[0] - 2.0).abs() < 1e-6, "grad[0] = {}, expected 2.0", gd[0]);
assert!((gd[1] - 0.0).abs() < 1e-6, "grad[1] = {}, expected 0.0", gd[1]);
assert!((gd[2] - 0.5).abs() < 1e-6, "grad[2] = {}, expected 0.5", gd[2]);
}
#[test]
fn test_index_select_1d_no_grad_context() {
let input = leaf_1d(&[10.0, 20.0, 30.0], true);
let result = no_grad(|| index_select_1d(&input, &[0, 2])).unwrap();
assert!(!result.requires_grad());
assert!(result.grad_fn().is_none());
}
#[test]
fn test_masked_fill_forward() {
let input = leaf_1d(&[1.0, 2.0, 3.0, 4.0], false);
let mask = [false, true, false, true];
let result = masked_fill(&input, &mask, -999.0).unwrap();
assert_eq!(result.data().unwrap(), &[1.0, -999.0, 3.0, -999.0]);
}
#[test]
fn test_masked_fill_backward() {
let input = leaf_1d(&[1.0, 2.0, 3.0, 4.0], true);
let mask = [false, true, false, true];
let filled = masked_fill(&input, &mask, 0.0).unwrap();
let grad_output = Tensor::from_storage(
TensorStorage::cpu(vec![1.0; 4]),
vec![4],
false,
)
.unwrap();
let grad_fn = filled.grad_fn().unwrap();
let grads = grad_fn.backward(&grad_output).unwrap();
let grad_input = grads[0].as_ref().unwrap();
let gd = grad_input.data().unwrap();
assert!((gd[0] - 1.0).abs() < 1e-6);
assert!((gd[1] - 0.0).abs() < 1e-6);
assert!((gd[2] - 1.0).abs() < 1e-6);
assert!((gd[3] - 0.0).abs() < 1e-6);
}
#[test]
fn test_masked_fill_shape_mismatch() {
let input = leaf_1d(&[1.0, 2.0, 3.0], false);
let mask = [true, false]; let result = masked_fill(&input, &mask, 0.0);
assert!(result.is_err());
}
#[test]
fn test_gather_backward_not_implemented() {
let input = leaf_1d(&[1.0, 2.0], false);
let gf = GatherBackward { input };
let dummy = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32]), vec![1], false).unwrap();
assert!(gf.backward(&dummy).is_err());
}
#[test]
fn test_scatter_add_backward_not_implemented() {
let input = leaf_1d(&[1.0, 2.0], false);
let gf = ScatterAddBackward { input };
let dummy = Tensor::from_storage(TensorStorage::cpu(vec![1.0f32]), vec![1], false).unwrap();
assert!(gf.backward(&dummy).is_err());
}
}