use crate::autograd::{
GradFn, Var, var_abs, var_cos, var_div, var_mul, var_mul_scalar, var_neg, var_sin, var_square,
var_sub,
};
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{BinaryOps, CompareOps, ScalarOps, TensorOps, UnaryOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
use std::sync::Arc;
pub struct NegBackward<R: Runtime> {
input_id: TensorId,
_marker: std::marker::PhantomData<R>,
}
impl<R: Runtime> NegBackward<R> {
pub fn new(input_id: TensorId) -> Self {
Self {
input_id,
_marker: std::marker::PhantomData,
}
}
}
impl<R: Runtime> GradFn<R> for NegBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let grad = client.neg(grad_output)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let grad = var_neg(grad_output, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn name(&self) -> &'static str {
"NegBackward"
}
}
pub struct ExpBackward<R: Runtime> {
input_id: TensorId,
saved_output: Tensor<R>, }
impl<R: Runtime> ExpBackward<R> {
pub fn new(input_id: TensorId, output: Tensor<R>) -> Self {
Self {
input_id,
saved_output: output,
}
}
}
impl<R: Runtime> GradFn<R> for ExpBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let grad = client.mul(grad_output, &self.saved_output)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let output_var = Var::new(self.saved_output.clone(), false);
let grad = var_mul(grad_output, &output_var, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_output)
}
fn name(&self) -> &'static str {
"ExpBackward"
}
}
pub struct LogBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
}
impl<R: Runtime> LogBackward<R> {
pub fn new(input_id: TensorId, input: Tensor<R>) -> Self {
Self {
input_id,
saved_input: input,
}
}
}
impl<R: Runtime> GradFn<R> for LogBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let grad = client.div(grad_output, &self.saved_input)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let input_var = Var::with_id(self.saved_input.clone(), self.input_id, true);
let grad = var_div(grad_output, &input_var, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"LogBackward"
}
}
pub struct SqrtBackward<R: Runtime> {
input_id: TensorId,
saved_output: Tensor<R>, }
impl<R: Runtime> SqrtBackward<R> {
pub fn new(input_id: TensorId, output: Tensor<R>) -> Self {
Self {
input_id,
saved_output: output,
}
}
}
impl<R: Runtime> GradFn<R> for SqrtBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let two_sqrt = client.mul_scalar(&self.saved_output, 2.0)?;
let grad = client.div(grad_output, &two_sqrt)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let output_var = Var::new(self.saved_output.clone(), false);
let two_sqrt = var_mul_scalar(&output_var, 2.0, &client)?;
let grad = var_div(grad_output, &two_sqrt, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_output)
}
fn name(&self) -> &'static str {
"SqrtBackward"
}
}
pub struct SinBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
}
impl<R: Runtime> SinBackward<R> {
pub fn new(input_id: TensorId, input: Tensor<R>) -> Self {
Self {
input_id,
saved_input: input,
}
}
}
impl<R: Runtime> GradFn<R> for SinBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let cos_a = client.cos(&self.saved_input)?;
let grad = client.mul(grad_output, &cos_a)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let input_var = Var::with_id(self.saved_input.clone(), self.input_id, true);
let cos_a = var_cos(&input_var, &client)?;
let grad = var_mul(grad_output, &cos_a, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"SinBackward"
}
}
pub struct CosBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
}
impl<R: Runtime> CosBackward<R> {
pub fn new(input_id: TensorId, input: Tensor<R>) -> Self {
Self {
input_id,
saved_input: input,
}
}
}
impl<R: Runtime> GradFn<R> for CosBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let sin_a = client.sin(&self.saved_input)?;
let neg_sin = client.neg(&sin_a)?;
let grad = client.mul(grad_output, &neg_sin)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let input_var = Var::with_id(self.saved_input.clone(), self.input_id, true);
let sin_a = var_sin(&input_var, &client)?;
let neg_sin = var_neg(&sin_a, &client)?;
let grad = var_mul(grad_output, &neg_sin, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"CosBackward"
}
}
pub struct TanhBackward<R: Runtime> {
input_id: TensorId,
saved_output: Tensor<R>, }
impl<R: Runtime> TanhBackward<R> {
pub fn new(input_id: TensorId, output: Tensor<R>) -> Self {
Self {
input_id,
saved_output: output,
}
}
}
impl<R: Runtime<DType = DType>> GradFn<R> for TanhBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let tanh_squared = client.square(&self.saved_output)?;
let one = Tensor::<R>::ones(
self.saved_output.shape(),
self.saved_output.dtype(),
self.saved_output.device(),
);
let one_minus_tanh2 = client.sub(&one, &tanh_squared)?;
let grad = client.mul(grad_output, &one_minus_tanh2)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let output_var = Var::new(self.saved_output.clone(), false);
let tanh_squared = var_square(&output_var, &client)?;
let one = Tensor::<R>::ones(
self.saved_output.shape(),
self.saved_output.dtype(),
self.saved_output.device(),
);
let one_var = Var::new(one, false);
let one_minus_tanh2 = var_sub(&one_var, &tanh_squared, &client)?;
let grad = var_mul(grad_output, &one_minus_tanh2, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_output)
}
fn name(&self) -> &'static str {
"TanhBackward"
}
}
pub struct SquareBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
}
impl<R: Runtime> SquareBackward<R> {
pub fn new(input_id: TensorId, input: Tensor<R>) -> Self {
Self {
input_id,
saved_input: input,
}
}
}
impl<R: Runtime> GradFn<R> for SquareBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let two_a = client.mul_scalar(&self.saved_input, 2.0)?;
let grad = client.mul(grad_output, &two_a)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let input_var = Var::with_id(self.saved_input.clone(), self.input_id, true);
let two_a = var_mul_scalar(&input_var, 2.0, &client)?;
let grad = var_mul(grad_output, &two_a, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"SquareBackward"
}
}
pub struct RecipBackward<R: Runtime> {
input_id: TensorId,
saved_output: Tensor<R>, }
impl<R: Runtime> RecipBackward<R> {
pub fn new(input_id: TensorId, output: Tensor<R>) -> Self {
Self {
input_id,
saved_output: output,
}
}
}
impl<R: Runtime> GradFn<R> for RecipBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let z_squared = client.square(&self.saved_output)?;
let neg_grad = client.neg(grad_output)?;
let grad = client.mul(&neg_grad, &z_squared)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let output_var = Var::new(self.saved_output.clone(), false);
let z_squared = var_square(&output_var, &client)?;
let neg_grad = var_neg(grad_output, &client)?;
let grad = var_mul(&neg_grad, &z_squared, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_output)
}
fn name(&self) -> &'static str {
"RecipBackward"
}
}
pub struct TanBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
}
impl<R: Runtime> TanBackward<R> {
pub fn new(input_id: TensorId, input: Tensor<R>) -> Self {
Self {
input_id,
saved_input: input,
}
}
}
impl<R: Runtime> GradFn<R> for TanBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let cos_a = client.cos(&self.saved_input)?;
let cos_squared = client.square(&cos_a)?;
let grad = client.div(grad_output, &cos_squared)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let input_var = Var::with_id(self.saved_input.clone(), self.input_id, true);
let cos_a = var_cos(&input_var, &client)?;
let cos_squared = var_square(&cos_a, &client)?;
let grad = var_div(grad_output, &cos_squared, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"TanBackward"
}
}
pub struct AbsBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
}
impl<R: Runtime> AbsBackward<R> {
pub fn new(input_id: TensorId, input: Tensor<R>) -> Self {
Self {
input_id,
saved_input: input,
}
}
}
impl<R: Runtime> GradFn<R> for AbsBackward<R>
where
R::Client: TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let abs_a = client.abs(&self.saved_input)?;
let grad_sign = client.div(&self.saved_input, &abs_a)?;
let grad = client.mul(grad_output, &grad_sign)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let input_var = Var::with_id(self.saved_input.clone(), self.input_id, true);
let abs_a = var_abs(&input_var, &client)?;
let grad_sign = var_div(&input_var, &abs_a, &client)?;
let grad = var_mul(grad_output, &grad_sign, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"AbsBackward"
}
}
pub struct ClampBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
min_val: f64,
max_val: f64,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> ClampBackward<R> {
pub fn new(
input_id: TensorId,
input: Tensor<R>,
min_val: f64,
max_val: f64,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
saved_input: input,
min_val,
max_val,
input_grad_fn,
}
}
}
impl<R: Runtime<DType = DType>> GradFn<R> for ClampBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R> + CompareOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let min_tensor = Tensor::<R>::full_scalar(
self.saved_input.shape(),
self.saved_input.dtype(),
self.min_val,
self.saved_input.device(),
);
let max_tensor = Tensor::<R>::full_scalar(
self.saved_input.shape(),
self.saved_input.dtype(),
self.max_val,
self.saved_input.device(),
);
let gt_min = client.gt(&self.saved_input, &min_tensor)?;
let lt_max = client.lt(&self.saved_input, &max_tensor)?;
let mask = client.mul(>_min, <_max)?;
let grad = client.mul(grad_output, &mask)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + CompareOps<R>,
{
let client = R::default_client(grad_output.tensor().device());
let min_tensor = Tensor::<R>::full_scalar(
self.saved_input.shape(),
self.saved_input.dtype(),
self.min_val,
self.saved_input.device(),
);
let max_tensor = Tensor::<R>::full_scalar(
self.saved_input.shape(),
self.saved_input.dtype(),
self.max_val,
self.saved_input.device(),
);
let gt_min = client.gt(&self.saved_input, &min_tensor)?;
let lt_max = client.lt(&self.saved_input, &max_tensor)?;
let mask = client.mul(>_min, <_max)?;
let mask_var = Var::new(mask, false);
let grad = var_mul(grad_output, &mask_var, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"ClampBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
#[test]
fn test_neg_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3], DType::F32, &device);
let backward = NegBackward::<CpuRuntime>::new(a.id());
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert_eq!(grad_a, vec![-1.0, -1.0, -1.0]);
}
#[test]
fn test_exp_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device); let output = client.exp(&a).unwrap();
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward = ExpBackward::<CpuRuntime>::new(a.id(), output);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert!((grad_a[0] - 1.0).abs() < 1e-6); }
#[test]
fn test_log_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward = LogBackward::<CpuRuntime>::new(a.id(), a.clone());
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert!((grad_a[0] - 0.5).abs() < 1e-6); }
#[test]
fn test_sqrt_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[4.0f32], &[1], &device);
let output = client.sqrt(&a).unwrap();
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward = SqrtBackward::<CpuRuntime>::new(a.id(), output);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert!((grad_a[0] - 0.25).abs() < 1e-6); }
#[test]
fn test_tanh_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device);
let output = client.tanh(&a).unwrap();
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward = TanhBackward::<CpuRuntime>::new(a.id(), output);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert!((grad_a[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_square_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward = SquareBackward::<CpuRuntime>::new(a.id(), a.clone());
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert!((grad_a[0] - 6.0).abs() < 1e-6); }
#[test]
fn test_tan_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward = TanBackward::<CpuRuntime>::new(a.id(), a.clone());
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert!((grad_a[0] - 1.0).abs() < 1e-6); }
}