use super::GradFn;
use crate::tensor::Tensor;
use num_traits::Float;
pub struct AddBackward<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
pub input0_data: crate::tensor::Tensor<T>,
pub input1_data: crate::tensor::Tensor<T>,
pub input0_var: crate::autograd::Variable<T>,
pub input1_var: crate::autograd::Variable<T>,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
GradFn<T> for AddBackward<T>
{
fn apply(&self, grad_outputs: &[Tensor<T>]) -> Vec<Option<Tensor<T>>> {
let grad_output = &grad_outputs[0];
if self.input0_var.requires_grad() {
self.input0_var
.backward_with_grad(Some(grad_output.clone()));
}
if self.input1_var.requires_grad() {
self.input1_var
.backward_with_grad(Some(grad_output.clone()));
}
vec![Some(grad_output.clone()), Some(grad_output.clone())]
}
}
pub struct SubBackward<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
pub input0_data: crate::tensor::Tensor<T>,
pub input1_data: crate::tensor::Tensor<T>,
pub input0_var: crate::autograd::Variable<T>,
pub input1_var: crate::autograd::Variable<T>,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
GradFn<T> for SubBackward<T>
{
fn apply(&self, grad_outputs: &[Tensor<T>]) -> Vec<Option<Tensor<T>>> {
let grad_output = &grad_outputs[0];
let grad_input0 = grad_output.clone();
let grad_input1 = grad_output * &Tensor::from_vec(vec![T::from(-1).unwrap()], vec![1]);
if self.input0_var.requires_grad() {
self.input0_var
.backward_with_grad(Some(grad_input0.clone()));
}
if self.input1_var.requires_grad() {
self.input1_var
.backward_with_grad(Some(grad_input1.clone()));
}
vec![Some(grad_input0), Some(grad_input1)]
}
}
pub struct MulBackward<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
pub input0_data: Tensor<T>,
pub input1_data: Tensor<T>,
pub input0_var: crate::autograd::Variable<T>,
pub input1_var: crate::autograd::Variable<T>,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
GradFn<T> for MulBackward<T>
{
fn apply(&self, grad_outputs: &[Tensor<T>]) -> Vec<Option<Tensor<T>>> {
let grad_output = &grad_outputs[0];
let grad_input0 = grad_output * &self.input1_data;
let grad_input1 = grad_output * &self.input0_data;
if self.input0_var.requires_grad() {
self.input0_var
.backward_with_grad(Some(grad_input0.clone()));
}
if self.input1_var.requires_grad() {
self.input1_var
.backward_with_grad(Some(grad_input1.clone()));
}
vec![Some(grad_input0), Some(grad_input1)]
}
}
pub struct MatMulBackward<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
pub input0_data: Tensor<T>,
pub input1_data: Tensor<T>,
pub input0_var: Option<crate::autograd::Variable<T>>,
pub input1_var: Option<crate::autograd::Variable<T>>,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
GradFn<T> for MatMulBackward<T>
{
fn apply(&self, grad_outputs: &[Tensor<T>]) -> Vec<Option<Tensor<T>>> {
if grad_outputs.is_empty() {
return vec![];
}
let grad_output = &grad_outputs[0];
let grad_input0 = grad_output
.matmul(&self.input1_data.transpose().expect("Transpose failed"))
.expect("MatMul failed");
let grad_input1 = self
.input0_data
.transpose()
.expect("Transpose failed")
.matmul(grad_output)
.expect("MatMul failed");
if let Some(ref input0_var) = self.input0_var {
if input0_var.requires_grad() {
input0_var.backward_with_grad(Some(grad_input0.clone()));
}
}
if let Some(ref input1_var) = self.input1_var {
if input1_var.requires_grad() {
input1_var.backward_with_grad(Some(grad_input1.clone()));
}
}
vec![Some(grad_input0), Some(grad_input1)]
}
}
pub struct SumBackward<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
pub input_shape: Vec<usize>,
pub input_var: crate::autograd::Variable<T>,
pub _phantom: std::marker::PhantomData<T>,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
GradFn<T> for SumBackward<T>
{
fn apply(&self, grad_outputs: &[Tensor<T>]) -> Vec<Option<Tensor<T>>> {
let grad_output = &grad_outputs[0];
let grad_expanded = Tensor::ones(&self.input_shape);
let grad_value = grad_output
.as_array()
.iter()
.next()
.copied()
.unwrap_or(T::zero());
let mut result = grad_expanded;
result.as_array_mut().mapv_inplace(|_| grad_value);
if self.input_var.requires_grad() {
let grad = self.input_var.grad();
let mut grad_lock = grad.write().unwrap();
match grad_lock.as_mut() {
Some(existing_grad) => {
*existing_grad = &*existing_grad + &result;
}
None => {
*grad_lock = Some(result.clone());
}
}
}
vec![Some(result)]
}
}
pub struct MeanBackward<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
pub input_var: Option<crate::autograd::Variable<T>>,
pub numel: T,
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
GradFn<T> for MeanBackward<T>
{
fn apply(&self, grad_outputs: &[Tensor<T>]) -> Vec<Option<Tensor<T>>> {
if grad_outputs.is_empty() {
return vec![];
}
let grad_output = &grad_outputs[0];
let grad_input_scalar =
*grad_output.as_array().iter().next().unwrap_or(&T::zero()) / self.numel;
if let Some(ref input_var) = self.input_var {
let input_data = input_var.data();
let input_shape = {
let input_lock = input_data.read().unwrap();
input_lock.shape().to_vec()
};
let grad_input_vec = vec![grad_input_scalar; input_shape.iter().product::<usize>()];
let grad_input = Tensor::from_vec(grad_input_vec, input_shape);
if input_var.requires_grad() {
input_var.backward_with_grad(Some(grad_input.clone()));
}
vec![Some(grad_input)]
} else {
vec![None]
}
}
}