use crate::autograd::{GradFn, Var, var_div, var_log, var_mul, var_neg, var_sum};
use crate::error::Result;
use crate::ops::{BinaryOps, ReduceOps, TensorOps, UnaryOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
use std::sync::Arc;
pub struct AddBackward<R: Runtime> {
input_ids: [TensorId; 2],
a_shape: Vec<usize>,
b_shape: Vec<usize>,
input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 2],
}
impl<R: Runtime> AddBackward<R> {
pub fn new(
a_id: TensorId,
b_id: TensorId,
a_shape: &[usize],
b_shape: &[usize],
a_grad_fn: Option<Arc<dyn GradFn<R>>>,
b_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_ids: [a_id, b_id],
a_shape: a_shape.to_vec(),
b_shape: b_shape.to_vec(),
input_grad_fns: [a_grad_fn, b_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for AddBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let grad_a = reduce_grad_for_broadcast::<R>(grad_output, &self.a_shape)?;
let grad_b = reduce_grad_for_broadcast::<R>(grad_output, &self.b_shape)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let grad_a = reduce_var_for_broadcast(grad_output, &self.a_shape, &client)?;
let grad_b = reduce_var_for_broadcast(grad_output, &self.b_shape, &client)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn name(&self) -> &'static str {
"AddBackward"
}
}
pub struct SubBackward<R: Runtime> {
input_ids: [TensorId; 2],
a_shape: Vec<usize>,
b_shape: Vec<usize>,
input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 2],
}
impl<R: Runtime> SubBackward<R> {
pub fn new(
a_id: TensorId,
b_id: TensorId,
a_shape: &[usize],
b_shape: &[usize],
a_grad_fn: Option<Arc<dyn GradFn<R>>>,
b_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_ids: [a_id, b_id],
a_shape: a_shape.to_vec(),
b_shape: b_shape.to_vec(),
input_grad_fns: [a_grad_fn, b_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for SubBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let grad_a = reduce_grad_for_broadcast::<R>(grad_output, &self.a_shape)?;
let neg_grad = client.neg(grad_output)?;
let grad_b = reduce_grad_for_broadcast::<R>(&neg_grad, &self.b_shape)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let grad_a = reduce_var_for_broadcast(grad_output, &self.a_shape, &client)?;
let neg_grad = var_neg(grad_output, &client)?;
let grad_b = reduce_var_for_broadcast(&neg_grad, &self.b_shape, &client)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn name(&self) -> &'static str {
"SubBackward"
}
}
pub struct MulBackward<R: Runtime> {
input_ids: [TensorId; 2],
saved_tensors: Vec<Tensor<R>>, input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 2],
}
impl<R: Runtime> MulBackward<R> {
pub fn new(
a_id: TensorId,
b_id: TensorId,
a: Tensor<R>,
b: Tensor<R>,
a_grad_fn: Option<Arc<dyn GradFn<R>>>,
b_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_ids: [a_id, b_id],
saved_tensors: vec![a, b],
input_grad_fns: [a_grad_fn, b_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for MulBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let saved_a = &self.saved_tensors[0];
let saved_b = &self.saved_tensors[1];
let grad_a_full = client.mul(grad_output, saved_b)?;
let grad_a = reduce_grad_for_broadcast::<R>(&grad_a_full, saved_a.shape())?;
let grad_b_full = client.mul(grad_output, saved_a)?;
let grad_b = reduce_grad_for_broadcast::<R>(&grad_b_full, saved_b.shape())?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let saved_a = &self.saved_tensors[0];
let saved_b = &self.saved_tensors[1];
let a_var = Var::with_id_and_grad_fn(
saved_a.clone(),
self.input_ids[0],
self.input_grad_fns[0].clone(),
);
let b_var = Var::with_id_and_grad_fn(
saved_b.clone(),
self.input_ids[1],
self.input_grad_fns[1].clone(),
);
let grad_a_full = var_mul(grad_output, &b_var, &client)?;
let grad_a = reduce_var_for_broadcast(&grad_a_full, saved_a.shape(), &client)?;
let grad_b_full = var_mul(grad_output, &a_var, &client)?;
let grad_b = reduce_var_for_broadcast(&grad_b_full, saved_b.shape(), &client)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"MulBackward"
}
}
pub struct DivBackward<R: Runtime> {
input_ids: [TensorId; 2],
saved_tensors: Vec<Tensor<R>>, input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 2],
}
impl<R: Runtime> DivBackward<R> {
pub fn new(
a_id: TensorId,
b_id: TensorId,
a: Tensor<R>,
b: Tensor<R>,
a_grad_fn: Option<Arc<dyn GradFn<R>>>,
b_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_ids: [a_id, b_id],
saved_tensors: vec![a, b],
input_grad_fns: [a_grad_fn, b_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for DivBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let saved_a = &self.saved_tensors[0];
let saved_b = &self.saved_tensors[1];
let grad_a_full = client.div(grad_output, saved_b)?;
let grad_a = reduce_grad_for_broadcast::<R>(&grad_a_full, saved_a.shape())?;
let neg_grad = client.neg(grad_output)?;
let neg_grad_a = client.mul(&neg_grad, saved_a)?;
let b_squared = client.mul(saved_b, saved_b)?;
let grad_b_full = client.div(&neg_grad_a, &b_squared)?;
let grad_b = reduce_grad_for_broadcast::<R>(&grad_b_full, saved_b.shape())?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let saved_a = &self.saved_tensors[0];
let saved_b = &self.saved_tensors[1];
let a_var = Var::with_id_and_grad_fn(
saved_a.clone(),
self.input_ids[0],
self.input_grad_fns[0].clone(),
);
let b_var = Var::with_id_and_grad_fn(
saved_b.clone(),
self.input_ids[1],
self.input_grad_fns[1].clone(),
);
let grad_a_full = var_div(grad_output, &b_var, &client)?;
let grad_a = reduce_var_for_broadcast(&grad_a_full, saved_a.shape(), &client)?;
let neg_grad = var_neg(grad_output, &client)?;
let neg_grad_a = var_mul(&neg_grad, &a_var, &client)?;
let b_squared = var_mul(&b_var, &b_var, &client)?;
let grad_b_full = var_div(&neg_grad_a, &b_squared, &client)?;
let grad_b = reduce_var_for_broadcast(&grad_b_full, saved_b.shape(), &client)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"DivBackward"
}
}
pub struct PowBackward<R: Runtime> {
input_ids: [TensorId; 2],
saved_tensors: Vec<Tensor<R>>, input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 2],
}
impl<R: Runtime> PowBackward<R> {
pub fn new(
a_id: TensorId,
b_id: TensorId,
a: Tensor<R>,
b: Tensor<R>,
output: Tensor<R>,
a_grad_fn: Option<Arc<dyn GradFn<R>>>,
b_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_ids: [a_id, b_id],
saved_tensors: vec![a, b, output],
input_grad_fns: [a_grad_fn, b_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for PowBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let saved_a = &self.saved_tensors[0];
let saved_b = &self.saved_tensors[1];
let saved_output = &self.saved_tensors[2];
let grad_a_temp = client.mul(grad_output, saved_b)?;
let grad_a_temp2 = client.mul(&grad_a_temp, saved_output)?;
let grad_a_full = client.div(&grad_a_temp2, saved_a)?;
let grad_a = reduce_grad_for_broadcast::<R>(&grad_a_full, saved_a.shape())?;
let ln_a = client.log(saved_a)?;
let grad_b_temp = client.mul(grad_output, saved_output)?;
let grad_b_full = client.mul(&grad_b_temp, &ln_a)?;
let grad_b = reduce_grad_for_broadcast::<R>(&grad_b_full, saved_b.shape())?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let saved_a = &self.saved_tensors[0];
let saved_b = &self.saved_tensors[1];
let saved_output = &self.saved_tensors[2];
let a_var = Var::with_id_and_grad_fn(
saved_a.clone(),
self.input_ids[0],
self.input_grad_fns[0].clone(),
);
let b_var = Var::with_id_and_grad_fn(
saved_b.clone(),
self.input_ids[1],
self.input_grad_fns[1].clone(),
);
let output_var = Var::new(saved_output.clone(), false);
let grad_a_temp = var_mul(grad_output, &b_var, &client)?;
let grad_a_temp2 = var_mul(&grad_a_temp, &output_var, &client)?;
let grad_a_full = var_div(&grad_a_temp2, &a_var, &client)?;
let grad_a = reduce_var_for_broadcast(&grad_a_full, saved_a.shape(), &client)?;
let ln_a = var_log(&a_var, &client)?;
let grad_b_temp = var_mul(grad_output, &output_var, &client)?;
let grad_b_full = var_mul(&grad_b_temp, &ln_a, &client)?;
let grad_b = reduce_var_for_broadcast(&grad_b_full, saved_b.shape(), &client)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"PowBackward"
}
}
fn reduce_grad_for_broadcast<R: Runtime>(
grad: &Tensor<R>,
target_shape: &[usize],
) -> Result<Tensor<R>>
where
R::Client: TensorOps<R> + ReduceOps<R>,
{
let grad_shape = grad.shape();
if grad_shape == target_shape {
return Ok(grad.clone());
}
let client = R::default_client(grad.device());
let grad_ndim = grad_shape.len();
let target_ndim = target_shape.len();
let mut padded_target = vec![1usize; grad_ndim];
let offset = grad_ndim.saturating_sub(target_ndim);
for (i, &dim) in target_shape.iter().enumerate() {
padded_target[offset + i] = dim;
}
let mut reduce_dims = Vec::new();
for (i, (&grad_dim, &target_dim)) in grad_shape.iter().zip(padded_target.iter()).enumerate() {
if target_dim == 1 && grad_dim > 1 {
reduce_dims.push(i);
}
}
let mut result = grad.clone();
if !reduce_dims.is_empty() {
result = client.sum(&result, &reduce_dims, true)?;
}
if target_ndim < grad_ndim {
let extra_dims: Vec<usize> = (0..(grad_ndim - target_ndim)).collect();
if !extra_dims.is_empty() {
result = client.sum(&result, &extra_dims, false)?;
}
}
if result.shape() != target_shape {
result = result.reshape(target_shape)?;
}
Ok(result)
}
fn reduce_var_for_broadcast<R, C>(
var: &Var<R>,
target_shape: &[usize],
client: &C,
) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + TensorOps<R>,
R::Client: TensorOps<R>,
{
let var_shape = var.shape();
if var_shape == target_shape {
return Ok(var.clone());
}
let var_ndim = var_shape.len();
let target_ndim = target_shape.len();
let mut padded_target = vec![1usize; var_ndim];
let offset = var_ndim.saturating_sub(target_ndim);
for (i, &dim) in target_shape.iter().enumerate() {
padded_target[offset + i] = dim;
}
let mut reduce_dims = Vec::new();
for (i, (&var_dim, &target_dim)) in var_shape.iter().zip(padded_target.iter()).enumerate() {
if target_dim == 1 && var_dim > 1 {
reduce_dims.push(i);
}
}
let mut result = var.clone();
if !reduce_dims.is_empty() {
result = var_sum(&result, &reduce_dims, true, client)?;
}
if target_ndim < var_ndim {
let extra_dims: Vec<usize> = (0..(var_ndim - target_ndim)).collect();
if !extra_dims.is_empty() {
result = var_sum(&result, &extra_dims, false, client)?;
}
}
if result.shape() != target_shape {
result = super::shape::var_reshape(&result, target_shape)?;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
#[test]
fn test_add_backward_same_shape() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device);
let grad_out = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
let backward =
AddBackward::<CpuRuntime>::new(a.id(), b.id(), a.shape(), b.shape(), None, None);
let grads = backward.backward(&grad_out).unwrap();
assert_eq!(grads.len(), 2);
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
let grad_b: Vec<f32> = grads[1].as_ref().unwrap().to_vec();
assert_eq!(grad_a, vec![1.0, 1.0]);
assert_eq!(grad_b, vec![1.0, 1.0]);
}
#[test]
fn test_mul_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward =
MulBackward::<CpuRuntime>::new(a.id(), b.id(), a.clone(), b.clone(), None, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
let grad_b: Vec<f32> = grads[1].as_ref().unwrap().to_vec();
assert_eq!(grad_a, vec![3.0]); assert_eq!(grad_b, vec![2.0]); }
#[test]
fn test_div_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[6.0f32], &[1], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward =
DivBackward::<CpuRuntime>::new(a.id(), b.id(), a.clone(), b.clone(), None, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
let grad_b: Vec<f32> = grads[1].as_ref().unwrap().to_vec();
assert!((grad_a[0] - 0.5).abs() < 1e-6);
assert!((grad_b[0] - (-1.5)).abs() < 1e-6);
}
}