use crate::dtype::{DType, Element};
use crate::ops::GemmActivation;
trait AccFloat:
Copy
+ std::ops::Add<Output = Self>
+ std::ops::AddAssign
+ std::ops::Sub<Output = Self>
+ std::ops::Mul<Output = Self>
+ std::ops::Neg<Output = Self>
+ PartialOrd
{
fn zero() -> Self;
fn one() -> Self;
fn half() -> Self;
fn from_elem<T: Element>(v: T) -> Self;
fn to_elem<T: Element>(self) -> T;
fn tanh(self) -> Self;
fn exp(self) -> Self;
fn recip(self) -> Self;
fn from_f64_const(v: f64) -> Self;
fn is_finite(self) -> bool;
}
impl AccFloat for f32 {
#[inline]
fn zero() -> Self {
0.0
}
#[inline]
fn one() -> Self {
1.0
}
#[inline]
fn half() -> Self {
0.5
}
#[inline]
fn from_elem<T: Element>(v: T) -> Self {
v.to_f32()
}
#[inline]
fn to_elem<T: Element>(self) -> T {
T::from_f32(self)
}
#[inline]
fn tanh(self) -> Self {
f32::tanh(self)
}
#[inline]
fn exp(self) -> Self {
f32::exp(self)
}
#[inline]
fn recip(self) -> Self {
1.0 / self
}
#[inline]
fn from_f64_const(v: f64) -> Self {
v as f32
}
#[inline]
fn is_finite(self) -> bool {
f32::is_finite(self)
}
}
impl AccFloat for f64 {
#[inline]
fn zero() -> Self {
0.0
}
#[inline]
fn one() -> Self {
1.0
}
#[inline]
fn half() -> Self {
0.5
}
#[inline]
fn from_elem<T: Element>(v: T) -> Self {
v.to_f64()
}
#[inline]
fn to_elem<T: Element>(self) -> T {
T::from_f64(self)
}
#[inline]
fn tanh(self) -> Self {
f64::tanh(self)
}
#[inline]
fn exp(self) -> Self {
f64::exp(self)
}
#[inline]
fn recip(self) -> Self {
1.0 / self
}
#[inline]
fn from_f64_const(v: f64) -> Self {
v
}
#[inline]
fn is_finite(self) -> bool {
f64::is_finite(self)
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_bias_activation_bwd_kernel<T: Element>(
grad: *const T,
a: *const T,
b: *const T,
bias: *const T,
d_a: *mut T,
d_b: *mut T,
d_bias: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ld_grad: usize,
activation: GemmActivation,
) {
if T::DTYPE == DType::F64 {
bwd_in::<T, f64>(
grad, a, b, bias, d_a, d_b, d_bias, m, n, k, lda, ldb, ld_grad, activation,
);
} else {
bwd_in::<T, f32>(
grad, a, b, bias, d_a, d_b, d_bias, m, n, k, lda, ldb, ld_grad, activation,
);
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn bwd_in<T: Element, A: AccFloat>(
grad: *const T,
a: *const T,
b: *const T,
bias: *const T,
d_a: *mut T,
d_b: *mut T,
d_bias: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ld_grad: usize,
activation: GemmActivation,
) {
let total = m * n;
let mut grad_pre = vec![A::zero(); total];
for i in 0..m {
for j in 0..n {
grad_pre[i * n + j] = A::from_elem(*bias.add(j));
}
}
for i in 0..m {
for kk in 0..k {
let a_val: A = A::from_elem(*a.add(i * lda + kk));
for j in 0..n {
grad_pre[i * n + j] += a_val * A::from_elem(*b.add(kk * ldb + j));
}
}
}
for i in 0..total {
let g: A = A::from_elem(*grad.add((i / n) * ld_grad + (i % n)));
let deriv = activation_derivative(grad_pre[i], activation);
let deriv = if deriv.is_finite() { deriv } else { A::zero() };
grad_pre[i] = g * deriv;
}
let mut d_a_buf = vec![A::zero(); m * k];
for i in 0..m {
for j in 0..n {
let gp = grad_pre[i * n + j];
for kk in 0..k {
d_a_buf[i * k + kk] += gp * A::from_elem(*b.add(kk * ldb + j));
}
}
}
for i in 0..m * k {
*d_a.add(i) = d_a_buf[i].to_elem::<T>();
}
let mut d_b_buf = vec![A::zero(); k * n];
for i in 0..m {
for kk in 0..k {
let a_val: A = A::from_elem(*a.add(i * lda + kk));
for j in 0..n {
d_b_buf[kk * n + j] += a_val * grad_pre[i * n + j];
}
}
}
for i in 0..k * n {
*d_b.add(i) = d_b_buf[i].to_elem::<T>();
}
let mut d_bias_buf = vec![A::zero(); n];
for i in 0..m {
for j in 0..n {
d_bias_buf[j] += grad_pre[i * n + j];
}
}
for j in 0..n {
*d_bias.add(j) = d_bias_buf[j].to_elem::<T>();
}
}
fn activation_derivative<A: AccFloat>(pre_act: A, activation: GemmActivation) -> A {
match activation {
GemmActivation::None => A::one(),
GemmActivation::ReLU => {
if pre_act > A::zero() {
A::one()
} else {
A::zero()
}
}
GemmActivation::GELU => {
let sqrt_2_over_pi = A::from_f64_const(0.7978845608028654);
let coef = A::from_f64_const(0.044715);
let three = A::from_f64_const(3.0);
let x = pre_act;
let inner = sqrt_2_over_pi * (x + coef * x * x * x);
let tanh_val = inner.tanh();
let sech2 = A::one() - tanh_val * tanh_val;
let d_inner = sqrt_2_over_pi * (A::one() + three * coef * x * x);
A::half() * (A::one() + tanh_val) + A::half() * x * sech2 * d_inner
}
GemmActivation::SiLU => {
let sig = (A::one() + (-pre_act).exp()).recip();
sig + pre_act * sig * (A::one() - sig)
}
GemmActivation::Sigmoid => {
let sig = (A::one() + (-pre_act).exp()).recip();
sig * (A::one() - sig)
}
GemmActivation::Tanh => {
let t = pre_act.tanh();
A::one() - t * t
}
}
}