use crate::dtype::Element;
#[inline]
pub unsafe fn sigmoid_kernel<T: Element>(a: *const T, out: *mut T, len: usize) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::activations;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
activations::sigmoid_f32(a as *const f32, out as *mut f32, len);
return;
}
DType::F64 => {
activations::sigmoid_f64(a as *const f64, out as *mut f64, len);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
activations::sigmoid_f16(a as *const half::f16, out as *mut half::f16, len);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
activations::sigmoid_bf16(a as *const half::bf16, out as *mut half::bf16, len);
return;
}
_ => {}
}
}
sigmoid_scalar(a, out, len);
}
#[inline]
unsafe fn sigmoid_scalar<T: Element>(a: *const T, out: *mut T, len: usize) {
let a_slice = std::slice::from_raw_parts(a, len);
let out_slice = std::slice::from_raw_parts_mut(out, len);
for i in 0..len {
let v = a_slice[i].to_f64();
let sig = 1.0 / (1.0 + (-v).exp());
out_slice[i] = T::from_f64(sig);
}
}
#[inline]
pub unsafe fn silu_kernel<T: Element>(a: *const T, out: *mut T, len: usize) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::activations;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
activations::silu_f32(a as *const f32, out as *mut f32, len);
return;
}
DType::F64 => {
activations::silu_f64(a as *const f64, out as *mut f64, len);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
activations::silu_f16(a as *const half::f16, out as *mut half::f16, len);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
activations::silu_bf16(a as *const half::bf16, out as *mut half::bf16, len);
return;
}
_ => {}
}
}
silu_scalar(a, out, len);
}
#[inline]
unsafe fn silu_scalar<T: Element>(a: *const T, out: *mut T, len: usize) {
let a_slice = std::slice::from_raw_parts(a, len);
let out_slice = std::slice::from_raw_parts_mut(out, len);
for i in 0..len {
let x = a_slice[i].to_f64();
let result = x / (1.0 + (-x).exp());
out_slice[i] = T::from_f64(result);
}
}
#[inline]
pub unsafe fn gelu_kernel<T: Element>(a: *const T, out: *mut T, len: usize) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::activations;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
activations::gelu_f32(a as *const f32, out as *mut f32, len);
return;
}
DType::F64 => {
activations::gelu_f64(a as *const f64, out as *mut f64, len);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
activations::gelu_f16(a as *const half::f16, out as *mut half::f16, len);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
activations::gelu_bf16(a as *const half::bf16, out as *mut half::bf16, len);
return;
}
_ => {}
}
}
gelu_scalar(a, out, len);
}
#[inline]
unsafe fn gelu_scalar<T: Element>(a: *const T, out: *mut T, len: usize) {
let a_slice = std::slice::from_raw_parts(a, len);
let out_slice = std::slice::from_raw_parts_mut(out, len);
const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
const TANH_COEF: f64 = 0.044715;
for i in 0..len {
let x = a_slice[i].to_f64();
let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x);
let result = 0.5 * x * (1.0 + inner.tanh());
out_slice[i] = T::from_f64(result);
}
}
pub unsafe fn leaky_relu_kernel<T: Element>(
a: *const T,
out: *mut T,
len: usize,
negative_slope: f64,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::activations;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
activations::leaky_relu_f32(
a as *const f32,
out as *mut f32,
len,
negative_slope as f32,
);
return;
}
DType::F64 => {
activations::leaky_relu_f64(a as *const f64, out as *mut f64, len, negative_slope);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
activations::leaky_relu_f16(
a as *const half::f16,
out as *mut half::f16,
len,
negative_slope as f32,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
activations::leaky_relu_bf16(
a as *const half::bf16,
out as *mut half::bf16,
len,
negative_slope as f32,
);
return;
}
_ => {}
}
}
leaky_relu_scalar(a, out, len, negative_slope);
}
#[inline]
unsafe fn leaky_relu_scalar<T: Element>(a: *const T, out: *mut T, len: usize, negative_slope: f64) {
let a_slice = std::slice::from_raw_parts(a, len);
let out_slice = std::slice::from_raw_parts_mut(out, len);
let zero = T::zero();
for i in 0..len {
let x = a_slice[i];
out_slice[i] = if x > zero {
x
} else {
T::from_f64(x.to_f64() * negative_slope)
};
}
}
pub unsafe fn elu_kernel<T: Element>(a: *const T, out: *mut T, len: usize, alpha: f64) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::super::simd::activations;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
activations::elu_f32(a as *const f32, out as *mut f32, len, alpha as f32);
return;
}
DType::F64 => {
activations::elu_f64(a as *const f64, out as *mut f64, len, alpha);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
activations::elu_f16(
a as *const half::f16,
out as *mut half::f16,
len,
alpha as f32,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
activations::elu_bf16(
a as *const half::bf16,
out as *mut half::bf16,
len,
alpha as f32,
);
return;
}
_ => {}
}
}
elu_scalar(a, out, len, alpha);
}
#[inline]
unsafe fn elu_scalar<T: Element>(a: *const T, out: *mut T, len: usize, alpha: f64) {
let a_slice = std::slice::from_raw_parts(a, len);
let out_slice = std::slice::from_raw_parts_mut(out, len);
let zero = T::zero();
for i in 0..len {
let x = a_slice[i];
out_slice[i] = if x > zero {
x
} else {
T::from_f64(alpha * (x.to_f64().exp() - 1.0))
};
}
}