use crate::autograd::GradFn;
use crate::autograd::var::Var;
use crate::autograd::var_ops::{var_matmul, var_sum};
use crate::error::Result;
use crate::ops::{BinaryOps, GemmActivation, MatmulOps, ReduceOps, ScalarOps, TensorOps, UnaryOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
use std::sync::Arc;
pub struct MatmulBiasActivationBackward<R: Runtime> {
input_ids: [TensorId; 3],
saved_tensors: Vec<Tensor<R>>, activation: GemmActivation,
input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 3],
}
impl<R: Runtime> MatmulBiasActivationBackward<R> {
pub fn new(
a_id: TensorId,
b_id: TensorId,
bias_id: TensorId,
a: Tensor<R>,
b: Tensor<R>,
bias: Tensor<R>,
activation: GemmActivation,
a_grad_fn: Option<Arc<dyn GradFn<R>>>,
b_grad_fn: Option<Arc<dyn GradFn<R>>>,
bias_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_ids: [a_id, b_id, bias_id],
saved_tensors: vec![a, b, bias],
activation,
input_grad_fns: [a_grad_fn, b_grad_fn, bias_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for MatmulBiasActivationBackward<R>
where
R::Client:
TensorOps<R> + ScalarOps<R> + BinaryOps<R> + ReduceOps<R> + UnaryOps<R> + MatmulOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let a = &self.saved_tensors[0];
let b = &self.saved_tensors[1];
let bias = &self.saved_tensors[2];
let matmul_out = client.matmul(a, b)?;
let pre_act = client.add(&matmul_out, bias)?;
let grad_pre = apply_activation_grad(&client, grad_output, &pre_act, self.activation)?;
let b_t = b.transpose(-2, -1)?;
let d_a = client.matmul(&grad_pre, &b_t)?;
let a_t = a.transpose(-2, -1)?;
let d_b = client.matmul(&a_t, &grad_pre)?;
let ndim = grad_output.ndim();
let batch_dims: Vec<usize> = (0..ndim - 1).collect();
let d_bias = if batch_dims.is_empty() {
grad_pre
} else {
client.sum(&grad_pre, &batch_dims, false)?
};
Ok(vec![Some(d_a), Some(d_b), Some(d_bias)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ BinaryOps<R>
+ ReduceOps<R>
+ UnaryOps<R>
+ MatmulOps<R>,
{
let client = R::default_client(grad_output.tensor().device());
let a = &self.saved_tensors[0];
let b = &self.saved_tensors[1];
let bias = &self.saved_tensors[2];
let matmul_out = client.matmul(a, b)?;
let pre_act = client.add(&matmul_out, bias)?;
let ones = client.add_scalar(&client.mul_scalar(&pre_act, 0.0)?, 1.0)?;
let act_grad = apply_activation_grad(&client, &ones, &pre_act, self.activation)?;
let act_grad_var = Var::new(act_grad, false);
let grad_pre = crate::autograd::var_ops::var_mul(grad_output, &act_grad_var, &client)?;
let b_t = b.transpose(-2, -1)?;
let b_t_var = Var::new(b_t, false);
let d_a = var_matmul(&grad_pre, &b_t_var, &client)?;
let a_t = a.transpose(-2, -1)?;
let a_t_var = Var::new(a_t, false);
let d_b = var_matmul(&a_t_var, &grad_pre, &client)?;
let ndim = grad_output.tensor().ndim();
let batch_dims: Vec<usize> = (0..ndim - 1).collect();
let d_bias = if batch_dims.is_empty() {
grad_pre
} else {
var_sum(&grad_pre, &batch_dims, false, &client)?
};
Ok(vec![Some(d_a), Some(d_b), Some(d_bias)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"MatmulBiasActivationBackward"
}
}
fn apply_activation_grad<R: Runtime>(
client: &R::Client,
grad: &Tensor<R>,
pre_act: &Tensor<R>,
activation: GemmActivation,
) -> Result<Tensor<R>>
where
R::Client: TensorOps<R> + ScalarOps<R> + BinaryOps<R> + UnaryOps<R>,
{
match activation {
GemmActivation::None => {
Ok(grad.clone())
}
GemmActivation::ReLU => {
let abs_x = client.abs(pre_act)?;
let abs_plus_eps = client.add_scalar(&abs_x, 1e-30)?;
let sign = client.div(pre_act, &abs_plus_eps)?;
let mask = client.mul_scalar(&client.add_scalar(&sign, 1.0)?, 0.5)?;
client.mul(grad, &mask)
}
GemmActivation::Sigmoid => {
let neg_x = client.neg(pre_act)?;
let exp_neg = client.exp(&neg_x)?;
let one_plus = client.add_scalar(&exp_neg, 1.0)?;
let sig = client.recip(&one_plus)?;
let one_minus_sig = client.rsub_scalar(&sig, 1.0)?;
let deriv = client.mul(&sig, &one_minus_sig)?;
client.mul(grad, &deriv)
}
GemmActivation::Tanh => {
let t = client.tanh(pre_act)?;
let t_sq = client.mul(&t, &t)?;
let deriv = client.rsub_scalar(&t_sq, 1.0)?;
client.mul(grad, &deriv)
}
GemmActivation::SiLU => {
let neg_x = client.neg(pre_act)?;
let exp_neg = client.exp(&neg_x)?;
let one_plus = client.add_scalar(&exp_neg, 1.0)?;
let sig = client.recip(&one_plus)?;
let one_minus_sig = client.rsub_scalar(&sig, 1.0)?;
let x_one_minus_sig = client.mul(pre_act, &one_minus_sig)?;
let inner = client.add_scalar(&x_one_minus_sig, 1.0)?;
let deriv = client.mul(&sig, &inner)?;
client.mul(grad, &deriv)
}
GemmActivation::GELU => {
let sqrt_2_pi: f64 = (2.0f64 / std::f64::consts::PI).sqrt();
let x_sq = client.mul(pre_act, pre_act)?;
let x_cubed = client.mul(pre_act, &x_sq)?;
let inner = client.add(pre_act, &client.mul_scalar(&x_cubed, 0.044715)?)?;
let k = client.mul_scalar(&inner, sqrt_2_pi)?;
let tanh_k = client.tanh(&k)?;
let term1 = client.mul_scalar(&client.add_scalar(&tanh_k, 1.0)?, 0.5)?;
let tanh_sq = client.mul(&tanh_k, &tanh_k)?;
let sech_sq = client.rsub_scalar(&tanh_sq, 1.0)?;
let dk_dx = client.mul_scalar(
&client.add_scalar(&client.mul_scalar(&x_sq, 3.0 * 0.044715)?, 1.0)?,
sqrt_2_pi,
)?;
let term2 =
client.mul_scalar(&client.mul(pre_act, &client.mul(&sech_sq, &dk_dx)?)?, 0.5)?;
let deriv = client.add(&term1, &term2)?;
client.mul(grad, &deriv)
}
}
}