use crate::dtype::Element;
#[inline]
pub unsafe fn silu_mul_kernel<T: Element>(a: *const T, b: *const T, out: *mut T, len: usize) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::fused_activation_mul;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
fused_activation_mul::silu_mul_f32(
a as *const f32,
b as *const f32,
out as *mut f32,
len,
);
return;
}
DType::F64 => {
fused_activation_mul::silu_mul_f64(
a as *const f64,
b as *const f64,
out as *mut f64,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
fused_activation_mul::silu_mul_f16(
a as *const half::f16,
b as *const half::f16,
out as *mut half::f16,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
fused_activation_mul::silu_mul_bf16(
a as *const half::bf16,
b as *const half::bf16,
out as *mut half::bf16,
len,
);
return;
}
_ => {}
}
}
fused_scalar(a, b, out, len, |x| x / (1.0 + (-x).exp()));
}
#[inline]
pub unsafe fn gelu_mul_kernel<T: Element>(a: *const T, b: *const T, out: *mut T, len: usize) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::fused_activation_mul;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
fused_activation_mul::gelu_mul_f32(
a as *const f32,
b as *const f32,
out as *mut f32,
len,
);
return;
}
DType::F64 => {
fused_activation_mul::gelu_mul_f64(
a as *const f64,
b as *const f64,
out as *mut f64,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
fused_activation_mul::gelu_mul_f16(
a as *const half::f16,
b as *const half::f16,
out as *mut half::f16,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
fused_activation_mul::gelu_mul_bf16(
a as *const half::bf16,
b as *const half::bf16,
out as *mut half::bf16,
len,
);
return;
}
_ => {}
}
}
const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
const TANH_COEF: f64 = 0.044715;
fused_scalar(a, b, out, len, |x| {
let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x);
0.5 * x * (1.0 + inner.tanh())
});
}
#[inline]
pub unsafe fn relu_mul_kernel<T: Element>(a: *const T, b: *const T, out: *mut T, len: usize) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::fused_activation_mul;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
fused_activation_mul::relu_mul_f32(
a as *const f32,
b as *const f32,
out as *mut f32,
len,
);
return;
}
DType::F64 => {
fused_activation_mul::relu_mul_f64(
a as *const f64,
b as *const f64,
out as *mut f64,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
fused_activation_mul::relu_mul_f16(
a as *const half::f16,
b as *const half::f16,
out as *mut half::f16,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
fused_activation_mul::relu_mul_bf16(
a as *const half::bf16,
b as *const half::bf16,
out as *mut half::bf16,
len,
);
return;
}
_ => {}
}
}
fused_scalar(a, b, out, len, |x| if x > 0.0 { x } else { 0.0 });
}
#[inline]
pub unsafe fn sigmoid_mul_kernel<T: Element>(a: *const T, b: *const T, out: *mut T, len: usize) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::fused_activation_mul;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
fused_activation_mul::sigmoid_mul_f32(
a as *const f32,
b as *const f32,
out as *mut f32,
len,
);
return;
}
DType::F64 => {
fused_activation_mul::sigmoid_mul_f64(
a as *const f64,
b as *const f64,
out as *mut f64,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
fused_activation_mul::sigmoid_mul_f16(
a as *const half::f16,
b as *const half::f16,
out as *mut half::f16,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
fused_activation_mul::sigmoid_mul_bf16(
a as *const half::bf16,
b as *const half::bf16,
out as *mut half::bf16,
len,
);
return;
}
_ => {}
}
}
fused_scalar(a, b, out, len, |x| 1.0 / (1.0 + (-x).exp()));
}
#[inline]
unsafe fn fused_scalar<T: Element, F: Fn(f64) -> f64>(
a: *const T,
b: *const T,
out: *mut T,
len: usize,
activation: F,
) {
let a_slice = std::slice::from_raw_parts(a, len);
let b_slice = std::slice::from_raw_parts(b, len);
let out_slice = std::slice::from_raw_parts_mut(out, len);
for i in 0..len {
let x = a_slice[i].to_f64();
let y = b_slice[i].to_f64();
out_slice[i] = T::from_f64(activation(x) * y);
}
}