use crate::error::{Error, Result};
use crate::ops::impl_generic::activation::{dropout_impl, log_softmax_impl, softplus_impl};
use crate::ops::{
ActivationOps, BinaryOps, CompareOps, ConditionalOps, ScalarOps, UnaryOps,
activation::normalize_softmax_dim,
};
use crate::runtime::cpu::{
CpuClient, CpuRuntime,
helpers::{
ActivationOp, FusedActivationMulOp, activation_op_impl, dispatch_dtype, elu_impl,
ensure_contiguous, fused_activation_mul_impl, leaky_relu_impl,
},
kernels,
};
use crate::tensor::Tensor;
impl ActivationOps<CpuRuntime> for CpuClient {
fn relu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
activation_op_impl(self, a, ActivationOp::Relu, "relu")
}
fn sigmoid(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
activation_op_impl(self, a, ActivationOp::Sigmoid, "sigmoid")
}
fn silu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
activation_op_impl(self, a, ActivationOp::Silu, "silu")
}
fn gelu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
activation_op_impl(self, a, ActivationOp::Gelu, "gelu")
}
fn silu_mul(
&self,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
fused_activation_mul_impl(self, a, b, FusedActivationMulOp::SiluMul, "silu_mul")
}
fn gelu_mul(
&self,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
fused_activation_mul_impl(self, a, b, FusedActivationMulOp::GeluMul, "gelu_mul")
}
fn relu_mul(
&self,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
fused_activation_mul_impl(self, a, b, FusedActivationMulOp::ReluMul, "relu_mul")
}
fn sigmoid_mul(
&self,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
fused_activation_mul_impl(self, a, b, FusedActivationMulOp::SigmoidMul, "sigmoid_mul")
}
fn silu_mul_bwd(
&self,
grad: &Tensor<CpuRuntime>,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
let silu_a = self.silu(a)?;
let d_b = self.mul(grad, &silu_a)?;
let sigmoid_a = self.sigmoid(a)?;
let one_plus_a = self.add_scalar(a, 1.0)?;
let one_plus_a_minus_silu = self.sub(&one_plus_a, &silu_a)?;
let silu_deriv = self.mul(&sigmoid_a, &one_plus_a_minus_silu)?;
let grad_times_b = self.mul(grad, b)?;
let d_a = self.mul(&grad_times_b, &silu_deriv)?;
Ok((d_a, d_b))
}
fn gelu_mul_bwd(
&self,
grad: &Tensor<CpuRuntime>,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
let gelu_a = self.gelu(a)?;
let d_b = self.mul(grad, &gelu_a)?;
let x_sq = self.mul(a, a)?;
let x_cu = self.mul(&x_sq, a)?;
let coef_x_cu = self.mul_scalar(&x_cu, 0.044715)?;
let inner_arg = self.add(a, &coef_x_cu)?;
let sqrt_2_pi: f64 = 0.7978845608028654;
let inner = self.mul_scalar(&inner_arg, sqrt_2_pi)?;
let tanh_inner = self.tanh(&inner)?;
let one_plus_tanh = self.add_scalar(&tanh_inner, 1.0)?;
let term1 = self.mul_scalar(&one_plus_tanh, 0.5)?;
let tanh_sq = self.mul(&tanh_inner, &tanh_inner)?;
let sech_sq = self.add_scalar(&tanh_sq, -1.0)?;
let sech_sq = self.neg(&sech_sq)?;
let three_coef_x_sq = self.mul_scalar(&x_sq, 3.0 * 0.044715)?;
let inner_deriv_unscaled = self.add_scalar(&three_coef_x_sq, 1.0)?;
let inner_deriv = self.mul_scalar(&inner_deriv_unscaled, sqrt_2_pi)?;
let x_sech_sq = self.mul(a, &sech_sq)?;
let x_sech_sq_inner_d = self.mul(&x_sech_sq, &inner_deriv)?;
let term2 = self.mul_scalar(&x_sech_sq_inner_d, 0.5)?;
let gelu_deriv = self.add(&term1, &term2)?;
let grad_times_b = self.mul(grad, b)?;
let d_a = self.mul(&grad_times_b, &gelu_deriv)?;
Ok((d_a, d_b))
}
fn relu_mul_bwd(
&self,
grad: &Tensor<CpuRuntime>,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
let relu_a = self.relu(a)?;
let d_b = self.mul(grad, &relu_a)?;
let zeros = Tensor::<CpuRuntime>::zeros(a.shape(), a.dtype(), a.device());
let ones = Tensor::<CpuRuntime>::ones(a.shape(), a.dtype(), a.device());
let mask = self.gt(a, &zeros)?;
let relu_deriv = self.where_cond(&mask, &ones, &zeros)?;
let grad_times_b = self.mul(grad, b)?;
let d_a = self.mul(&grad_times_b, &relu_deriv)?;
Ok((d_a, d_b))
}
fn sigmoid_mul_bwd(
&self,
grad: &Tensor<CpuRuntime>,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
let sigmoid_a = self.sigmoid(a)?;
let d_b = self.mul(grad, &sigmoid_a)?;
let one_minus_sig = self.add_scalar(&sigmoid_a, -1.0)?;
let one_minus_sig = self.neg(&one_minus_sig)?;
let sigmoid_deriv = self.mul(&sigmoid_a, &one_minus_sig)?;
let grad_times_b = self.mul(grad, b)?;
let d_a = self.mul(&grad_times_b, &sigmoid_deriv)?;
Ok((d_a, d_b))
}
fn leaky_relu(
&self,
a: &Tensor<CpuRuntime>,
negative_slope: f64,
) -> Result<Tensor<CpuRuntime>> {
leaky_relu_impl(self, a, negative_slope)
}
fn elu(&self, a: &Tensor<CpuRuntime>, alpha: f64) -> Result<Tensor<CpuRuntime>> {
elu_impl(self, a, alpha)
}
fn softmax(&self, a: &Tensor<CpuRuntime>, dim: isize) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let ndim = a.ndim();
let dim_idx =
normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
let a_contig = ensure_contiguous(a);
let out = Tensor::<CpuRuntime>::empty(a.shape(), dtype, &self.device);
let shape = a.shape();
let outer_size: usize = shape[..dim_idx].iter().product();
let dim_size = shape[dim_idx];
let inner_size: usize = shape[dim_idx + 1..].iter().product();
if dim_idx == ndim - 1 {
let a_ptr = a_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::softmax_kernel::<T>(
a_ptr as *const T,
out_ptr as *mut T,
outer_size,
dim_size,
);
}
}, "softmax");
} else {
let a_ptr = a_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
softmax_non_last_dim::<T>(
a_ptr as *const T,
out_ptr as *mut T,
outer_size,
dim_size,
inner_size,
);
}
}, "softmax");
}
Ok(out)
}
fn softmax_bwd(
&self,
grad: &Tensor<CpuRuntime>,
output: &Tensor<CpuRuntime>,
dim: isize,
) -> Result<Tensor<CpuRuntime>> {
let dtype = grad.dtype();
let ndim = grad.ndim();
let dim_idx =
normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
let grad_contig = ensure_contiguous(grad);
let output_contig = ensure_contiguous(output);
let d_input = Tensor::<CpuRuntime>::empty(grad.shape(), dtype, &self.device);
let shape = grad.shape();
let outer_size: usize = shape[..dim_idx].iter().product();
let dim_size = shape[dim_idx];
let inner_size: usize = shape[dim_idx + 1..].iter().product();
if dim_idx == ndim - 1 {
let g_ptr = grad_contig.ptr();
let o_ptr = output_contig.ptr();
let d_ptr = d_input.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::softmax_bwd_kernel::<T>(
g_ptr as *const T,
o_ptr as *const T,
d_ptr as *mut T,
outer_size,
dim_size,
);
}
}, "softmax_bwd");
} else {
let g_ptr = grad_contig.ptr();
let o_ptr = output_contig.ptr();
let d_ptr = d_input.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
softmax_bwd_non_last_dim::<T>(
g_ptr as *const T,
o_ptr as *const T,
d_ptr as *mut T,
outer_size,
dim_size,
inner_size,
);
}
}, "softmax_bwd");
}
Ok(d_input)
}
fn softplus(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
softplus_impl(self, a)
}
fn log_softmax(&self, a: &Tensor<CpuRuntime>, dim: isize) -> Result<Tensor<CpuRuntime>> {
log_softmax_impl(self, a, dim)
}
fn dropout(
&self,
a: &Tensor<CpuRuntime>,
p: f64,
training: bool,
) -> Result<Tensor<CpuRuntime>> {
dropout_impl(self, a, p, training)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::ActivationOps;
use crate::runtime::cpu::CpuDevice;
#[test]
fn test_log_softmax_basic() {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let result = client.log_softmax(&input, -1).unwrap();
let data: Vec<f32> = result.to_vec();
let exp_sum: f32 = data.iter().map(|x| x.exp()).sum();
assert!((exp_sum - 1.0).abs() < 1e-5);
for &v in &data {
assert!(v < 0.0);
}
}
#[test]
fn test_log_softmax_2d() {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
let input =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
let result = client.log_softmax(&input, -1).unwrap();
let data: Vec<f32> = result.to_vec();
let row1_sum: f32 = data[0..3].iter().map(|x| x.exp()).sum();
let row2_sum: f32 = data[3..6].iter().map(|x| x.exp()).sum();
assert!((row1_sum - 1.0).abs() < 1e-5);
assert!((row2_sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_dropout_training() {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
let input = Tensor::<CpuRuntime>::ones(&[1000], crate::dtype::DType::F32, &device);
let result = client.dropout(&input, 0.5, true).unwrap();
let data: Vec<f32> = result.to_vec();
let zeros = data.iter().filter(|&&v| v == 0.0).count();
let scaled = data.iter().filter(|&&v| (v - 2.0).abs() < 1e-5).count();
assert!(zeros > 200, "too few zeros: {zeros}");
assert!(zeros < 800, "too many zeros: {zeros}");
assert_eq!(zeros + scaled, 1000);
}
#[test]
fn test_dropout_inference() {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let result = client.dropout(&input, 0.5, false).unwrap();
let data: Vec<f32> = result.to_vec();
assert!((data[0] - 1.0).abs() < 1e-6);
assert!((data[1] - 2.0).abs() < 1e-6);
assert!((data[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_dropout_p_zero() {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let result = client.dropout(&input, 0.0, true).unwrap();
let data: Vec<f32> = result.to_vec();
assert!((data[0] - 1.0).abs() < 1e-6);
assert!((data[1] - 2.0).abs() < 1e-6);
assert!((data[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_dropout_p_one() {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let result = client.dropout(&input, 1.0, true).unwrap();
let data: Vec<f32> = result.to_vec();
for &v in &data {
assert!((v).abs() < 1e-6);
}
}
}
unsafe fn softmax_bwd_non_last_dim<T: crate::dtype::Element>(
grad: *const T,
output: *const T,
d_input: *mut T,
outer_size: usize,
dim_size: usize,
inner_size: usize,
) {
unsafe {
for outer in 0..outer_size {
for inner in 0..inner_size {
let base_idx = outer * dim_size * inner_size + inner;
let stride = inner_size;
let mut dot = 0.0f64;
for d in 0..dim_size {
let idx = base_idx + d * stride;
dot += (*grad.add(idx)).to_f64() * (*output.add(idx)).to_f64();
}
for d in 0..dim_size {
let idx = base_idx + d * stride;
let g = (*grad.add(idx)).to_f64();
let o = (*output.add(idx)).to_f64();
*d_input.add(idx) = T::from_f64(o * (g - dot));
}
}
}
}
}
unsafe fn softmax_non_last_dim<T: crate::dtype::Element>(
a_ptr: *const T,
out_ptr: *mut T,
outer_size: usize,
dim_size: usize,
inner_size: usize,
) {
unsafe {
for outer in 0..outer_size {
for inner in 0..inner_size {
let base_idx = outer * dim_size * inner_size + inner;
let stride = inner_size;
let mut max_val = (*a_ptr.add(base_idx)).to_f64();
let mut sum = 1.0f64;
for d in 1..dim_size {
let idx = base_idx + d * stride;
let val = (*a_ptr.add(idx)).to_f64();
if val > max_val {
sum = sum * (max_val - val).exp() + 1.0;
max_val = val;
} else {
sum += (val - max_val).exp();
}
}
let inv_sum = if sum > 0.0 { 1.0 / sum } else { 0.0 };
for d in 0..dim_size {
let idx = base_idx + d * stride;
let val = (*a_ptr.add(idx)).to_f64();
*out_ptr.add(idx) = T::from_f64((val - max_val).exp() * inv_sum);
}
}
}
}
}