Skip to main content

numr/ops/cpu/
activation.rs

1//! CPU implementation of activation operations.
2
3use crate::error::{Error, Result};
4use crate::ops::impl_generic::activation::{dropout_impl, log_softmax_impl, softplus_impl};
5use crate::ops::{
6    ActivationOps, BinaryOps, CompareOps, ConditionalOps, ScalarOps, UnaryOps,
7    activation::normalize_softmax_dim,
8};
9use crate::runtime::cpu::{
10    CpuClient, CpuRuntime,
11    helpers::{
12        ActivationOp, FusedActivationMulOp, activation_op_impl, dispatch_dtype, elu_impl,
13        ensure_contiguous, fused_activation_mul_impl, leaky_relu_impl,
14    },
15    kernels,
16};
17use crate::tensor::Tensor;
18
19/// ActivationOps implementation for CPU runtime.
20impl ActivationOps<CpuRuntime> for CpuClient {
21    fn relu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
22        activation_op_impl(self, a, ActivationOp::Relu, "relu")
23    }
24
25    fn sigmoid(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
26        activation_op_impl(self, a, ActivationOp::Sigmoid, "sigmoid")
27    }
28
29    fn silu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
30        activation_op_impl(self, a, ActivationOp::Silu, "silu")
31    }
32
33    fn gelu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
34        activation_op_impl(self, a, ActivationOp::Gelu, "gelu")
35    }
36
37    fn silu_mul(
38        &self,
39        a: &Tensor<CpuRuntime>,
40        b: &Tensor<CpuRuntime>,
41    ) -> Result<Tensor<CpuRuntime>> {
42        fused_activation_mul_impl(self, a, b, FusedActivationMulOp::SiluMul, "silu_mul")
43    }
44
45    fn gelu_mul(
46        &self,
47        a: &Tensor<CpuRuntime>,
48        b: &Tensor<CpuRuntime>,
49    ) -> Result<Tensor<CpuRuntime>> {
50        fused_activation_mul_impl(self, a, b, FusedActivationMulOp::GeluMul, "gelu_mul")
51    }
52
53    fn relu_mul(
54        &self,
55        a: &Tensor<CpuRuntime>,
56        b: &Tensor<CpuRuntime>,
57    ) -> Result<Tensor<CpuRuntime>> {
58        fused_activation_mul_impl(self, a, b, FusedActivationMulOp::ReluMul, "relu_mul")
59    }
60
61    fn sigmoid_mul(
62        &self,
63        a: &Tensor<CpuRuntime>,
64        b: &Tensor<CpuRuntime>,
65    ) -> Result<Tensor<CpuRuntime>> {
66        fused_activation_mul_impl(self, a, b, FusedActivationMulOp::SigmoidMul, "sigmoid_mul")
67    }
68
69    fn silu_mul_bwd(
70        &self,
71        grad: &Tensor<CpuRuntime>,
72        a: &Tensor<CpuRuntime>,
73        b: &Tensor<CpuRuntime>,
74    ) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
75        // silu(a) = a * sigmoid(a)
76        let silu_a = self.silu(a)?;
77        let d_b = self.mul(grad, &silu_a)?;
78        // silu'(x) = sigmoid(x) * (1 + x - silu(x))
79        let sigmoid_a = self.sigmoid(a)?;
80        let one_plus_a = self.add_scalar(a, 1.0)?;
81        let one_plus_a_minus_silu = self.sub(&one_plus_a, &silu_a)?;
82        let silu_deriv = self.mul(&sigmoid_a, &one_plus_a_minus_silu)?;
83        let grad_times_b = self.mul(grad, b)?;
84        let d_a = self.mul(&grad_times_b, &silu_deriv)?;
85        Ok((d_a, d_b))
86    }
87
88    fn gelu_mul_bwd(
89        &self,
90        grad: &Tensor<CpuRuntime>,
91        a: &Tensor<CpuRuntime>,
92        b: &Tensor<CpuRuntime>,
93    ) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
94        let gelu_a = self.gelu(a)?;
95        let d_b = self.mul(grad, &gelu_a)?;
96        // gelu'(x) = 0.5*(1+tanh(inner)) + 0.5*x*sech²(inner)*inner'
97        // inner = sqrt(2/π) * (x + 0.044715*x³), inner' = sqrt(2/π)*(1 + 3*0.044715*x²)
98        let x_sq = self.mul(a, a)?;
99        let x_cu = self.mul(&x_sq, a)?;
100        let coef_x_cu = self.mul_scalar(&x_cu, 0.044715)?;
101        let inner_arg = self.add(a, &coef_x_cu)?;
102        let sqrt_2_pi: f64 = 0.7978845608028654;
103        let inner = self.mul_scalar(&inner_arg, sqrt_2_pi)?;
104        // Use tanh op directly — avoids exp overflow for low-precision dtypes (F16/FP8)
105        let tanh_inner = self.tanh(&inner)?;
106        // term1 = 0.5*(1+tanh(inner))
107        let one_plus_tanh = self.add_scalar(&tanh_inner, 1.0)?;
108        let term1 = self.mul_scalar(&one_plus_tanh, 0.5)?;
109        // sech²(inner) = 1 - tanh²(inner)
110        let tanh_sq = self.mul(&tanh_inner, &tanh_inner)?;
111        let sech_sq = self.add_scalar(&tanh_sq, -1.0)?;
112        let sech_sq = self.neg(&sech_sq)?;
113        // inner' = sqrt(2/π) * (1 + 3*0.044715*x²)
114        let three_coef_x_sq = self.mul_scalar(&x_sq, 3.0 * 0.044715)?;
115        let inner_deriv_unscaled = self.add_scalar(&three_coef_x_sq, 1.0)?;
116        let inner_deriv = self.mul_scalar(&inner_deriv_unscaled, sqrt_2_pi)?;
117        // term2 = 0.5 * x * sech²(inner) * inner'
118        let x_sech_sq = self.mul(a, &sech_sq)?;
119        let x_sech_sq_inner_d = self.mul(&x_sech_sq, &inner_deriv)?;
120        let term2 = self.mul_scalar(&x_sech_sq_inner_d, 0.5)?;
121        let gelu_deriv = self.add(&term1, &term2)?;
122        let grad_times_b = self.mul(grad, b)?;
123        let d_a = self.mul(&grad_times_b, &gelu_deriv)?;
124        Ok((d_a, d_b))
125    }
126
127    fn relu_mul_bwd(
128        &self,
129        grad: &Tensor<CpuRuntime>,
130        a: &Tensor<CpuRuntime>,
131        b: &Tensor<CpuRuntime>,
132    ) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
133        let relu_a = self.relu(a)?;
134        let d_b = self.mul(grad, &relu_a)?;
135        // relu'(x) = 1 if x > 0, else 0
136        let zeros = Tensor::<CpuRuntime>::zeros(a.shape(), a.dtype(), a.device());
137        let ones = Tensor::<CpuRuntime>::ones(a.shape(), a.dtype(), a.device());
138        let mask = self.gt(a, &zeros)?;
139        let relu_deriv = self.where_cond(&mask, &ones, &zeros)?;
140        let grad_times_b = self.mul(grad, b)?;
141        let d_a = self.mul(&grad_times_b, &relu_deriv)?;
142        Ok((d_a, d_b))
143    }
144
145    fn sigmoid_mul_bwd(
146        &self,
147        grad: &Tensor<CpuRuntime>,
148        a: &Tensor<CpuRuntime>,
149        b: &Tensor<CpuRuntime>,
150    ) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
151        let sigmoid_a = self.sigmoid(a)?;
152        let d_b = self.mul(grad, &sigmoid_a)?;
153        // sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
154        let one_minus_sig = self.add_scalar(&sigmoid_a, -1.0)?;
155        let one_minus_sig = self.neg(&one_minus_sig)?;
156        let sigmoid_deriv = self.mul(&sigmoid_a, &one_minus_sig)?;
157        let grad_times_b = self.mul(grad, b)?;
158        let d_a = self.mul(&grad_times_b, &sigmoid_deriv)?;
159        Ok((d_a, d_b))
160    }
161
162    fn leaky_relu(
163        &self,
164        a: &Tensor<CpuRuntime>,
165        negative_slope: f64,
166    ) -> Result<Tensor<CpuRuntime>> {
167        leaky_relu_impl(self, a, negative_slope)
168    }
169
170    fn elu(&self, a: &Tensor<CpuRuntime>, alpha: f64) -> Result<Tensor<CpuRuntime>> {
171        elu_impl(self, a, alpha)
172    }
173
174    fn softmax(&self, a: &Tensor<CpuRuntime>, dim: isize) -> Result<Tensor<CpuRuntime>> {
175        let dtype = a.dtype();
176        let ndim = a.ndim();
177
178        // Normalize dimension
179        let dim_idx =
180            normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
181
182        let a_contig = ensure_contiguous(a);
183        let out = Tensor::<CpuRuntime>::empty(a.shape(), dtype, &self.device);
184
185        let shape = a.shape();
186
187        // Calculate outer_size (product of dims before softmax dim)
188        // and dim_size (size of softmax dim)
189        // and inner_size (product of dims after softmax dim)
190        let outer_size: usize = shape[..dim_idx].iter().product();
191        let dim_size = shape[dim_idx];
192        let inner_size: usize = shape[dim_idx + 1..].iter().product();
193
194        // For softmax, we need the data laid out so that the softmax dimension is contiguous
195        // If dim is the last dimension, we can use the simple kernel
196        // Otherwise, we need to iterate
197
198        if dim_idx == ndim - 1 {
199            // Simple case: softmax over last dimension
200            let a_ptr = a_contig.ptr();
201            let out_ptr = out.ptr();
202
203            dispatch_dtype!(dtype, T => {
204                unsafe {
205                    kernels::softmax_kernel::<T>(
206                        a_ptr as *const T,
207                        out_ptr as *mut T,
208                        outer_size,
209                        dim_size,
210                    );
211                }
212            }, "softmax");
213        } else {
214            // General case: softmax over non-last dimension
215            // Pre-allocate buffer outside loops to avoid repeated allocations
216            let a_ptr = a_contig.ptr();
217            let out_ptr = out.ptr();
218
219            dispatch_dtype!(dtype, T => {
220                unsafe {
221                    softmax_non_last_dim::<T>(
222                        a_ptr as *const T,
223                        out_ptr as *mut T,
224                        outer_size,
225                        dim_size,
226                        inner_size,
227                    );
228                }
229            }, "softmax");
230        }
231
232        Ok(out)
233    }
234
235    fn softmax_bwd(
236        &self,
237        grad: &Tensor<CpuRuntime>,
238        output: &Tensor<CpuRuntime>,
239        dim: isize,
240    ) -> Result<Tensor<CpuRuntime>> {
241        let dtype = grad.dtype();
242        let ndim = grad.ndim();
243        let dim_idx =
244            normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
245
246        let grad_contig = ensure_contiguous(grad);
247        let output_contig = ensure_contiguous(output);
248        let d_input = Tensor::<CpuRuntime>::empty(grad.shape(), dtype, &self.device);
249
250        let shape = grad.shape();
251        let outer_size: usize = shape[..dim_idx].iter().product();
252        let dim_size = shape[dim_idx];
253        let inner_size: usize = shape[dim_idx + 1..].iter().product();
254
255        if dim_idx == ndim - 1 {
256            // Last dim: use fused SIMD kernel
257            let g_ptr = grad_contig.ptr();
258            let o_ptr = output_contig.ptr();
259            let d_ptr = d_input.ptr();
260
261            dispatch_dtype!(dtype, T => {
262                unsafe {
263                    kernels::softmax_bwd_kernel::<T>(
264                        g_ptr as *const T,
265                        o_ptr as *const T,
266                        d_ptr as *mut T,
267                        outer_size,
268                        dim_size,
269                    );
270                }
271            }, "softmax_bwd");
272        } else {
273            // Non-last dim: strided access pattern
274            let g_ptr = grad_contig.ptr();
275            let o_ptr = output_contig.ptr();
276            let d_ptr = d_input.ptr();
277
278            dispatch_dtype!(dtype, T => {
279                unsafe {
280                    softmax_bwd_non_last_dim::<T>(
281                        g_ptr as *const T,
282                        o_ptr as *const T,
283                        d_ptr as *mut T,
284                        outer_size,
285                        dim_size,
286                        inner_size,
287                    );
288                }
289            }, "softmax_bwd");
290        }
291
292        Ok(d_input)
293    }
294
295    fn softplus(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
296        softplus_impl(self, a)
297    }
298
299    fn log_softmax(&self, a: &Tensor<CpuRuntime>, dim: isize) -> Result<Tensor<CpuRuntime>> {
300        log_softmax_impl(self, a, dim)
301    }
302
303    fn dropout(
304        &self,
305        a: &Tensor<CpuRuntime>,
306        p: f64,
307        training: bool,
308    ) -> Result<Tensor<CpuRuntime>> {
309        dropout_impl(self, a, p, training)
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::ops::ActivationOps;
317    use crate::runtime::cpu::CpuDevice;
318
319    #[test]
320    fn test_log_softmax_basic() {
321        let device = CpuDevice::new();
322        let client = CpuClient::new(device.clone());
323
324        let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
325        let result = client.log_softmax(&input, -1).unwrap();
326        let data: Vec<f32> = result.to_vec();
327
328        // log_softmax should sum to something reasonable
329        // exp(log_softmax) should sum to 1
330        let exp_sum: f32 = data.iter().map(|x| x.exp()).sum();
331        assert!((exp_sum - 1.0).abs() < 1e-5);
332
333        // Values should be negative (log of probability)
334        for &v in &data {
335            assert!(v < 0.0);
336        }
337    }
338
339    #[test]
340    fn test_log_softmax_2d() {
341        let device = CpuDevice::new();
342        let client = CpuClient::new(device.clone());
343
344        let input =
345            Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
346        let result = client.log_softmax(&input, -1).unwrap();
347        let data: Vec<f32> = result.to_vec();
348
349        // Each row should independently sum (in exp space) to 1
350        let row1_sum: f32 = data[0..3].iter().map(|x| x.exp()).sum();
351        let row2_sum: f32 = data[3..6].iter().map(|x| x.exp()).sum();
352        assert!((row1_sum - 1.0).abs() < 1e-5);
353        assert!((row2_sum - 1.0).abs() < 1e-5);
354    }
355
356    #[test]
357    fn test_dropout_training() {
358        let device = CpuDevice::new();
359        let client = CpuClient::new(device.clone());
360
361        let input = Tensor::<CpuRuntime>::ones(&[1000], crate::dtype::DType::F32, &device);
362        let result = client.dropout(&input, 0.5, true).unwrap();
363        let data: Vec<f32> = result.to_vec();
364
365        // Some values should be 0 (dropped), others should be 2.0 (scaled by 1/(1-0.5))
366        let zeros = data.iter().filter(|&&v| v == 0.0).count();
367        let scaled = data.iter().filter(|&&v| (v - 2.0).abs() < 1e-5).count();
368
369        // With p=0.5, roughly half should be dropped (allow wide margin for randomness)
370        assert!(zeros > 200, "too few zeros: {zeros}");
371        assert!(zeros < 800, "too many zeros: {zeros}");
372        assert_eq!(zeros + scaled, 1000);
373    }
374
375    #[test]
376    fn test_dropout_inference() {
377        let device = CpuDevice::new();
378        let client = CpuClient::new(device.clone());
379
380        let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
381        let result = client.dropout(&input, 0.5, false).unwrap();
382        let data: Vec<f32> = result.to_vec();
383
384        // During inference, dropout is identity
385        assert!((data[0] - 1.0).abs() < 1e-6);
386        assert!((data[1] - 2.0).abs() < 1e-6);
387        assert!((data[2] - 3.0).abs() < 1e-6);
388    }
389
390    #[test]
391    fn test_dropout_p_zero() {
392        let device = CpuDevice::new();
393        let client = CpuClient::new(device.clone());
394
395        let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
396        let result = client.dropout(&input, 0.0, true).unwrap();
397        let data: Vec<f32> = result.to_vec();
398
399        // p=0 means no dropout
400        assert!((data[0] - 1.0).abs() < 1e-6);
401        assert!((data[1] - 2.0).abs() < 1e-6);
402        assert!((data[2] - 3.0).abs() < 1e-6);
403    }
404
405    #[test]
406    fn test_dropout_p_one() {
407        let device = CpuDevice::new();
408        let client = CpuClient::new(device.clone());
409
410        let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
411        let result = client.dropout(&input, 1.0, true).unwrap();
412        let data: Vec<f32> = result.to_vec();
413
414        // p=1 means all dropped
415        for &v in &data {
416            assert!((v).abs() < 1e-6);
417        }
418    }
419}
420
421/// Softmax backward for non-last dimension (strided access pattern).
422///
423/// d_input = output * (grad - dot), where dot = sum(grad * output) along dim.
424unsafe fn softmax_bwd_non_last_dim<T: crate::dtype::Element>(
425    grad: *const T,
426    output: *const T,
427    d_input: *mut T,
428    outer_size: usize,
429    dim_size: usize,
430    inner_size: usize,
431) {
432    unsafe {
433        for outer in 0..outer_size {
434            for inner in 0..inner_size {
435                let base_idx = outer * dim_size * inner_size + inner;
436                let stride = inner_size;
437
438                // Pass 1: dot = sum(grad * output) along dim
439                let mut dot = 0.0f64;
440                for d in 0..dim_size {
441                    let idx = base_idx + d * stride;
442                    dot += (*grad.add(idx)).to_f64() * (*output.add(idx)).to_f64();
443                }
444
445                // Pass 2: d_input = output * (grad - dot)
446                for d in 0..dim_size {
447                    let idx = base_idx + d * stride;
448                    let g = (*grad.add(idx)).to_f64();
449                    let o = (*output.add(idx)).to_f64();
450                    *d_input.add(idx) = T::from_f64(o * (g - dot));
451                }
452            }
453        }
454    }
455}
456
457unsafe fn softmax_non_last_dim<T: crate::dtype::Element>(
458    a_ptr: *const T,
459    out_ptr: *mut T,
460    outer_size: usize,
461    dim_size: usize,
462    inner_size: usize,
463) {
464    unsafe {
465        for outer in 0..outer_size {
466            for inner in 0..inner_size {
467                let base_idx = outer * dim_size * inner_size + inner;
468                let stride = inner_size;
469
470                // Pass 1: Online max + sum (reads strided input once)
471                let mut max_val = (*a_ptr.add(base_idx)).to_f64();
472                let mut sum = 1.0f64;
473                for d in 1..dim_size {
474                    let idx = base_idx + d * stride;
475                    let val = (*a_ptr.add(idx)).to_f64();
476                    if val > max_val {
477                        sum = sum * (max_val - val).exp() + 1.0;
478                        max_val = val;
479                    } else {
480                        sum += (val - max_val).exp();
481                    }
482                }
483
484                // Pass 2: exp(x - max) / sum (reads input, writes output)
485                let inv_sum = if sum > 0.0 { 1.0 / sum } else { 0.0 };
486                for d in 0..dim_size {
487                    let idx = base_idx + d * stride;
488                    let val = (*a_ptr.add(idx)).to_f64();
489                    *out_ptr.add(idx) = T::from_f64((val - max_val).exp() * inv_sum);
490                }
491            }
492        }
493    }
494}