use crate::dtype::Element;
use crate::ops::GemmActivation;
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_bias_activation_kernel<T: Element>(
a: *const T,
b: *const T,
bias: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
activation: GemmActivation,
) {
if activation == GemmActivation::None {
crate::runtime::cpu::kernels::matmul_bias_kernel(a, b, bias, out, m, n, k, lda, ldb, ldc);
return;
}
#[cfg(target_arch = "x86_64")]
{
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
matmul_bias_activation_simd_f32(
a as *const f32,
b as *const f32,
bias as *const f32,
out as *mut f32,
m,
n,
k,
lda,
ldb,
ldc,
activation,
);
return;
}
DType::F64 => {
matmul_bias_activation_simd_f64(
a as *const f64,
b as *const f64,
bias as *const f64,
out as *mut f64,
m,
n,
k,
lda,
ldb,
ldc,
activation,
);
return;
}
_ => {} }
}
matmul_bias_activation_scalar(a, b, bias, out, m, n, k, lda, ldb, ldc, activation);
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_bias_residual_kernel<T: Element>(
a: *const T,
b: *const T,
bias: *const T,
residual: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
for i in 0..m {
for j in 0..n {
*out.add(i * ldc + j) = *bias.add(j) + *residual.add(i * ldc + j);
}
}
for i in 0..m {
for kk in 0..k {
let a_val = *a.add(i * lda + kk);
for j in 0..n {
let out_ptr = out.add(i * ldc + j);
*out_ptr = *out_ptr + a_val * *b.add(kk * ldb + j);
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[allow(clippy::too_many_arguments, dead_code)]
unsafe fn matmul_bias_activation_simd_f32(
a: *const f32,
b: *const f32,
bias: *const f32,
out: *mut f32,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
activation: GemmActivation,
) {
use super::super::simd::matmul;
matmul::matmul_bias_f32(a, b, bias, out, m, n, k, lda, ldb, ldc);
let total = m * n;
apply_activation_inplace_f32(out, total, activation);
}
#[cfg(target_arch = "x86_64")]
#[allow(clippy::too_many_arguments, dead_code)]
unsafe fn matmul_bias_activation_simd_f64(
a: *const f64,
b: *const f64,
bias: *const f64,
out: *mut f64,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
activation: GemmActivation,
) {
use super::super::simd::matmul;
matmul::matmul_bias_f64(a, b, bias, out, m, n, k, lda, ldb, ldc);
let total = m * n;
apply_activation_inplace_f64(out, total, activation);
}
#[cfg(target_arch = "x86_64")]
#[allow(dead_code)]
unsafe fn apply_activation_inplace_f32(buf: *mut f32, len: usize, activation: GemmActivation) {
use super::super::simd::activations;
match activation {
GemmActivation::None => {}
GemmActivation::ReLU => {
for i in 0..len {
let val = *buf.add(i);
if val < 0.0 {
*buf.add(i) = 0.0;
}
}
}
GemmActivation::GELU => {
activations::gelu_f32(buf as *const f32, buf, len);
}
GemmActivation::SiLU => {
activations::silu_f32(buf as *const f32, buf, len);
}
GemmActivation::Sigmoid => {
activations::sigmoid_f32(buf as *const f32, buf, len);
}
GemmActivation::Tanh => {
for i in 0..len {
*buf.add(i) = (*buf.add(i)).tanh();
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[allow(dead_code)]
unsafe fn apply_activation_inplace_f64(buf: *mut f64, len: usize, activation: GemmActivation) {
use super::super::simd::activations;
match activation {
GemmActivation::None => {}
GemmActivation::ReLU => {
for i in 0..len {
let val = *buf.add(i);
if val < 0.0 {
*buf.add(i) = 0.0;
}
}
}
GemmActivation::GELU => {
activations::gelu_f64(buf as *const f64, buf, len);
}
GemmActivation::SiLU => {
activations::silu_f64(buf as *const f64, buf, len);
}
GemmActivation::Sigmoid => {
activations::sigmoid_f64(buf as *const f64, buf, len);
}
GemmActivation::Tanh => {
for i in 0..len {
*buf.add(i) = (*buf.add(i)).tanh();
}
}
}
}
#[allow(clippy::too_many_arguments, dead_code)]
unsafe fn matmul_bias_activation_scalar<T: Element>(
a: *const T,
b: *const T,
bias: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
activation: GemmActivation,
) {
for i in 0..m {
for j in 0..n {
*out.add(i * ldc + j) = *bias.add(j);
}
}
for i in 0..m {
for kk in 0..k {
let a_val = *a.add(i * lda + kk);
for j in 0..n {
let out_ptr = out.add(i * ldc + j);
*out_ptr = *out_ptr + a_val * *b.add(kk * ldb + j);
}
}
}
apply_activation_scalar(out, m * n, activation);
}
#[allow(dead_code)]
unsafe fn apply_activation_scalar<T: Element>(buf: *mut T, len: usize, activation: GemmActivation) {
match activation {
GemmActivation::None => {}
GemmActivation::ReLU => {
for i in 0..len {
let val = *buf.add(i);
if val < T::zero() {
*buf.add(i) = T::zero();
}
}
}
GemmActivation::GELU => {
for i in 0..len {
let x = (*buf.add(i)).to_f64();
let inner = 0.7978845608028654 * (x + 0.044715 * x * x * x);
let result = 0.5 * x * (1.0 + inner.tanh());
*buf.add(i) = T::from_f64(result);
}
}
GemmActivation::SiLU => {
for i in 0..len {
let x = (*buf.add(i)).to_f64();
let result = x / (1.0 + (-x).exp());
*buf.add(i) = T::from_f64(result);
}
}
GemmActivation::Sigmoid => {
for i in 0..len {
let x = (*buf.add(i)).to_f64();
let result = 1.0 / (1.0 + (-x).exp());
*buf.add(i) = T::from_f64(result);
}
}
GemmActivation::Tanh => {
for i in 0..len {
let x = (*buf.add(i)).to_f64();
*buf.add(i) = T::from_f64(x.tanh());
}
}
}
}