use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct WhereBackward<T: Float> {
condition: Vec<bool>,
x: Tensor<T>,
y: Tensor<T>,
}
impl<T: Float> GradFn<T> for WhereBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let device = grad_output.device();
let go = grad_output.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let grad_x: Vec<T> = go
.iter()
.zip(self.condition.iter())
.map(|(&g, &c)| if c { g } else { zero })
.collect();
let grad_y: Vec<T> = go
.iter()
.zip(self.condition.iter())
.map(|(&g, &c)| if c { zero } else { g })
.collect();
let grad_x_tensor =
Tensor::from_storage(TensorStorage::cpu(grad_x), self.x.shape().to_vec(), false)?;
let grad_y_tensor =
Tensor::from_storage(TensorStorage::cpu(grad_y), self.y.shape().to_vec(), false)?;
if device.is_cuda() {
Ok(vec![
Some(grad_x_tensor.to(device)?),
Some(grad_y_tensor.to(device)?),
])
} else {
Ok(vec![Some(grad_x_tensor), Some(grad_y_tensor)])
}
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.x, &self.y]
}
fn name(&self) -> &'static str {
"WhereBackward"
}
}
pub fn where_<T: Float>(
condition: &[bool],
x: &Tensor<T>,
y: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let device = x.device();
let x_data = x.data_vec()?;
let y_data = y.data_vec()?;
debug_assert_eq!(condition.len(), x_data.len());
debug_assert_eq!(condition.len(), y_data.len());
let result: Vec<T> = condition
.iter()
.zip(x_data.iter().zip(y_data.iter()))
.map(|(&c, (&xv, &yv))| if c { xv } else { yv })
.collect();
let needs_grad = is_grad_enabled() && (x.requires_grad() || y.requires_grad());
let storage = TensorStorage::on_device(result, device)?;
if needs_grad {
let grad_fn = Arc::new(WhereBackward {
condition: condition.to_vec(),
x: x.clone(),
y: y.clone(),
});
Tensor::from_operation(storage, x.shape().to_vec(), grad_fn)
} else {
Tensor::from_storage(storage, x.shape().to_vec(), false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::graph::backward;
use crate::storage::TensorStorage;
fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
#[test]
fn test_where_forward() {
let cond = vec![true, false, true, false];
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let y = leaf(&[10.0, 20.0, 30.0, 40.0], &[4], false);
let out = where_(&cond, &x, &y).unwrap();
assert_eq!(out.data().unwrap(), &[1.0, 20.0, 3.0, 40.0]);
}
#[test]
fn test_where_backward() {
let cond = vec![true, false, true, false];
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], true);
let y = leaf(&[10.0, 20.0, 30.0, 40.0], &[4], true);
let out = where_(&cond, &x, &y).unwrap();
let out_data = out.data().unwrap();
let total: f32 = out_data.iter().sum();
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
numel: usize,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones = vec![<T as num_traits::One>::one(); self.numel];
let t = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(t)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let scalar = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(SumBackward {
input: out.clone(),
numel: 4,
}),
)
.unwrap();
backward(&scalar).unwrap();
let x_grad = x.grad().unwrap().unwrap();
let y_grad = y.grad().unwrap().unwrap();
assert_eq!(x_grad.data().unwrap(), &[1.0, 0.0, 1.0, 0.0]);
assert_eq!(y_grad.data().unwrap(), &[0.0, 1.0, 0.0, 1.0]);
}
#[test]
fn test_where_no_grad() {
crate::autograd::no_grad::no_grad(|| {
let cond = vec![true, false];
let x = leaf(&[1.0, 2.0], &[2], true);
let y = leaf(&[10.0, 20.0], &[2], true);
let out = where_(&cond, &x, &y).unwrap();
assert!(!out.requires_grad());
assert!(out.grad_fn().is_none());
assert_eq!(out.data().unwrap(), &[1.0, 20.0]);
});
}
}