use super::super::kernels;
use super::super::{CpuClient, CpuRuntime};
use crate::dispatch_dtype;
use crate::error::Result;
use crate::runtime::ensure_contiguous;
use crate::tensor::Tensor;
#[derive(Copy, Clone)]
pub enum ActivationOp {
Relu,
Sigmoid,
Silu,
Gelu,
}
#[derive(Copy, Clone)]
pub enum ParametricActivationOp {
LeakyRelu,
Elu,
}
pub fn activation_op_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
op: ActivationOp,
op_name: &'static str,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let a_contig = ensure_contiguous(a);
let out = Tensor::<CpuRuntime>::empty(a.shape(), dtype, &client.device);
let len = a.numel();
let a_ptr = a_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
match op {
ActivationOp::Relu => kernels::relu_kernel::<T>(
a_ptr as *const T,
out_ptr as *mut T,
len,
),
ActivationOp::Sigmoid => kernels::sigmoid_kernel::<T>(
a_ptr as *const T,
out_ptr as *mut T,
len,
),
ActivationOp::Silu => kernels::silu_kernel::<T>(
a_ptr as *const T,
out_ptr as *mut T,
len,
),
ActivationOp::Gelu => kernels::gelu_kernel::<T>(
a_ptr as *const T,
out_ptr as *mut T,
len,
),
}
}
}, op_name);
Ok(out)
}
#[derive(Copy, Clone)]
#[allow(clippy::enum_variant_names)]
pub enum FusedActivationMulOp {
SiluMul,
GeluMul,
ReluMul,
SigmoidMul,
}
pub fn fused_activation_mul_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
op: FusedActivationMulOp,
op_name: &'static str,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
if b.dtype() != dtype {
return Err(crate::error::Error::DTypeMismatch {
lhs: dtype,
rhs: b.dtype(),
});
}
if a.shape() != b.shape() {
return Err(crate::error::Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
});
}
let a_contig = ensure_contiguous(a);
let b_contig = ensure_contiguous(b);
let out = Tensor::<CpuRuntime>::empty(a.shape(), dtype, &client.device);
let len = a.numel();
let a_ptr = a_contig.ptr();
let b_ptr = b_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
match op {
FusedActivationMulOp::SiluMul => kernels::silu_mul_kernel::<T>(
a_ptr as *const T, b_ptr as *const T, out_ptr as *mut T, len,
),
FusedActivationMulOp::GeluMul => kernels::gelu_mul_kernel::<T>(
a_ptr as *const T, b_ptr as *const T, out_ptr as *mut T, len,
),
FusedActivationMulOp::ReluMul => kernels::relu_mul_kernel::<T>(
a_ptr as *const T, b_ptr as *const T, out_ptr as *mut T, len,
),
FusedActivationMulOp::SigmoidMul => kernels::sigmoid_mul_kernel::<T>(
a_ptr as *const T, b_ptr as *const T, out_ptr as *mut T, len,
),
}
}
}, op_name);
Ok(out)
}
pub fn parametric_activation_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
op: ParametricActivationOp,
param: f64,
op_name: &'static str,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let a_contig = ensure_contiguous(a);
let out = Tensor::<CpuRuntime>::empty(a.shape(), dtype, &client.device);
let len = a.numel();
let a_ptr = a_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
match op {
ParametricActivationOp::LeakyRelu => kernels::leaky_relu_kernel::<T>(
a_ptr as *const T,
out_ptr as *mut T,
len,
param,
),
ParametricActivationOp::Elu => kernels::elu_kernel::<T>(
a_ptr as *const T,
out_ptr as *mut T,
len,
param,
),
}
}
}, op_name);
Ok(out)
}
#[inline]
pub fn leaky_relu_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
negative_slope: f64,
) -> Result<Tensor<CpuRuntime>> {
parametric_activation_impl(
client,
a,
ParametricActivationOp::LeakyRelu,
negative_slope,
"leaky_relu",
)
}
#[inline]
pub fn elu_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
alpha: f64,
) -> Result<Tensor<CpuRuntime>> {
parametric_activation_impl(client, a, ParametricActivationOp::Elu, alpha, "elu")
}