1use crate::error::{Error, Result};
4use crate::ops::impl_generic::activation::{dropout_impl, log_softmax_impl, softplus_impl};
5use crate::ops::{
6 ActivationOps, BinaryOps, CompareOps, ConditionalOps, ScalarOps, UnaryOps,
7 activation::normalize_softmax_dim,
8};
9use crate::runtime::cpu::{
10 CpuClient, CpuRuntime,
11 helpers::{
12 ActivationOp, FusedActivationMulOp, activation_op_impl, dispatch_dtype, elu_impl,
13 ensure_contiguous, fused_activation_mul_impl, leaky_relu_impl,
14 },
15 kernels,
16};
17use crate::tensor::Tensor;
18
19impl ActivationOps<CpuRuntime> for CpuClient {
21 fn relu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
22 activation_op_impl(self, a, ActivationOp::Relu, "relu")
23 }
24
25 fn sigmoid(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
26 activation_op_impl(self, a, ActivationOp::Sigmoid, "sigmoid")
27 }
28
29 fn silu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
30 activation_op_impl(self, a, ActivationOp::Silu, "silu")
31 }
32
33 fn gelu(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
34 activation_op_impl(self, a, ActivationOp::Gelu, "gelu")
35 }
36
37 fn silu_mul(
38 &self,
39 a: &Tensor<CpuRuntime>,
40 b: &Tensor<CpuRuntime>,
41 ) -> Result<Tensor<CpuRuntime>> {
42 fused_activation_mul_impl(self, a, b, FusedActivationMulOp::SiluMul, "silu_mul")
43 }
44
45 fn gelu_mul(
46 &self,
47 a: &Tensor<CpuRuntime>,
48 b: &Tensor<CpuRuntime>,
49 ) -> Result<Tensor<CpuRuntime>> {
50 fused_activation_mul_impl(self, a, b, FusedActivationMulOp::GeluMul, "gelu_mul")
51 }
52
53 fn relu_mul(
54 &self,
55 a: &Tensor<CpuRuntime>,
56 b: &Tensor<CpuRuntime>,
57 ) -> Result<Tensor<CpuRuntime>> {
58 fused_activation_mul_impl(self, a, b, FusedActivationMulOp::ReluMul, "relu_mul")
59 }
60
61 fn sigmoid_mul(
62 &self,
63 a: &Tensor<CpuRuntime>,
64 b: &Tensor<CpuRuntime>,
65 ) -> Result<Tensor<CpuRuntime>> {
66 fused_activation_mul_impl(self, a, b, FusedActivationMulOp::SigmoidMul, "sigmoid_mul")
67 }
68
69 fn silu_mul_bwd(
70 &self,
71 grad: &Tensor<CpuRuntime>,
72 a: &Tensor<CpuRuntime>,
73 b: &Tensor<CpuRuntime>,
74 ) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
75 let silu_a = self.silu(a)?;
77 let d_b = self.mul(grad, &silu_a)?;
78 let sigmoid_a = self.sigmoid(a)?;
80 let one_plus_a = self.add_scalar(a, 1.0)?;
81 let one_plus_a_minus_silu = self.sub(&one_plus_a, &silu_a)?;
82 let silu_deriv = self.mul(&sigmoid_a, &one_plus_a_minus_silu)?;
83 let grad_times_b = self.mul(grad, b)?;
84 let d_a = self.mul(&grad_times_b, &silu_deriv)?;
85 Ok((d_a, d_b))
86 }
87
88 fn gelu_mul_bwd(
89 &self,
90 grad: &Tensor<CpuRuntime>,
91 a: &Tensor<CpuRuntime>,
92 b: &Tensor<CpuRuntime>,
93 ) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
94 let gelu_a = self.gelu(a)?;
95 let d_b = self.mul(grad, &gelu_a)?;
96 let x_sq = self.mul(a, a)?;
99 let x_cu = self.mul(&x_sq, a)?;
100 let coef_x_cu = self.mul_scalar(&x_cu, 0.044715)?;
101 let inner_arg = self.add(a, &coef_x_cu)?;
102 let sqrt_2_pi: f64 = 0.7978845608028654;
103 let inner = self.mul_scalar(&inner_arg, sqrt_2_pi)?;
104 let tanh_inner = self.tanh(&inner)?;
106 let one_plus_tanh = self.add_scalar(&tanh_inner, 1.0)?;
108 let term1 = self.mul_scalar(&one_plus_tanh, 0.5)?;
109 let tanh_sq = self.mul(&tanh_inner, &tanh_inner)?;
111 let sech_sq = self.add_scalar(&tanh_sq, -1.0)?;
112 let sech_sq = self.neg(&sech_sq)?;
113 let three_coef_x_sq = self.mul_scalar(&x_sq, 3.0 * 0.044715)?;
115 let inner_deriv_unscaled = self.add_scalar(&three_coef_x_sq, 1.0)?;
116 let inner_deriv = self.mul_scalar(&inner_deriv_unscaled, sqrt_2_pi)?;
117 let x_sech_sq = self.mul(a, &sech_sq)?;
119 let x_sech_sq_inner_d = self.mul(&x_sech_sq, &inner_deriv)?;
120 let term2 = self.mul_scalar(&x_sech_sq_inner_d, 0.5)?;
121 let gelu_deriv = self.add(&term1, &term2)?;
122 let grad_times_b = self.mul(grad, b)?;
123 let d_a = self.mul(&grad_times_b, &gelu_deriv)?;
124 Ok((d_a, d_b))
125 }
126
127 fn relu_mul_bwd(
128 &self,
129 grad: &Tensor<CpuRuntime>,
130 a: &Tensor<CpuRuntime>,
131 b: &Tensor<CpuRuntime>,
132 ) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
133 let relu_a = self.relu(a)?;
134 let d_b = self.mul(grad, &relu_a)?;
135 let zeros = Tensor::<CpuRuntime>::zeros(a.shape(), a.dtype(), a.device());
137 let ones = Tensor::<CpuRuntime>::ones(a.shape(), a.dtype(), a.device());
138 let mask = self.gt(a, &zeros)?;
139 let relu_deriv = self.where_cond(&mask, &ones, &zeros)?;
140 let grad_times_b = self.mul(grad, b)?;
141 let d_a = self.mul(&grad_times_b, &relu_deriv)?;
142 Ok((d_a, d_b))
143 }
144
145 fn sigmoid_mul_bwd(
146 &self,
147 grad: &Tensor<CpuRuntime>,
148 a: &Tensor<CpuRuntime>,
149 b: &Tensor<CpuRuntime>,
150 ) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
151 let sigmoid_a = self.sigmoid(a)?;
152 let d_b = self.mul(grad, &sigmoid_a)?;
153 let one_minus_sig = self.add_scalar(&sigmoid_a, -1.0)?;
155 let one_minus_sig = self.neg(&one_minus_sig)?;
156 let sigmoid_deriv = self.mul(&sigmoid_a, &one_minus_sig)?;
157 let grad_times_b = self.mul(grad, b)?;
158 let d_a = self.mul(&grad_times_b, &sigmoid_deriv)?;
159 Ok((d_a, d_b))
160 }
161
162 fn leaky_relu(
163 &self,
164 a: &Tensor<CpuRuntime>,
165 negative_slope: f64,
166 ) -> Result<Tensor<CpuRuntime>> {
167 leaky_relu_impl(self, a, negative_slope)
168 }
169
170 fn elu(&self, a: &Tensor<CpuRuntime>, alpha: f64) -> Result<Tensor<CpuRuntime>> {
171 elu_impl(self, a, alpha)
172 }
173
174 fn softmax(&self, a: &Tensor<CpuRuntime>, dim: isize) -> Result<Tensor<CpuRuntime>> {
175 let dtype = a.dtype();
176 let ndim = a.ndim();
177
178 let dim_idx =
180 normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
181
182 let a_contig = ensure_contiguous(a);
183 let out = Tensor::<CpuRuntime>::empty(a.shape(), dtype, &self.device);
184
185 let shape = a.shape();
186
187 let outer_size: usize = shape[..dim_idx].iter().product();
191 let dim_size = shape[dim_idx];
192 let inner_size: usize = shape[dim_idx + 1..].iter().product();
193
194 if dim_idx == ndim - 1 {
199 let a_ptr = a_contig.ptr();
201 let out_ptr = out.ptr();
202
203 dispatch_dtype!(dtype, T => {
204 unsafe {
205 kernels::softmax_kernel::<T>(
206 a_ptr as *const T,
207 out_ptr as *mut T,
208 outer_size,
209 dim_size,
210 );
211 }
212 }, "softmax");
213 } else {
214 let a_ptr = a_contig.ptr();
217 let out_ptr = out.ptr();
218
219 dispatch_dtype!(dtype, T => {
220 unsafe {
221 softmax_non_last_dim::<T>(
222 a_ptr as *const T,
223 out_ptr as *mut T,
224 outer_size,
225 dim_size,
226 inner_size,
227 );
228 }
229 }, "softmax");
230 }
231
232 Ok(out)
233 }
234
235 fn softmax_bwd(
236 &self,
237 grad: &Tensor<CpuRuntime>,
238 output: &Tensor<CpuRuntime>,
239 dim: isize,
240 ) -> Result<Tensor<CpuRuntime>> {
241 let dtype = grad.dtype();
242 let ndim = grad.ndim();
243 let dim_idx =
244 normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
245
246 let grad_contig = ensure_contiguous(grad);
247 let output_contig = ensure_contiguous(output);
248 let d_input = Tensor::<CpuRuntime>::empty(grad.shape(), dtype, &self.device);
249
250 let shape = grad.shape();
251 let outer_size: usize = shape[..dim_idx].iter().product();
252 let dim_size = shape[dim_idx];
253 let inner_size: usize = shape[dim_idx + 1..].iter().product();
254
255 if dim_idx == ndim - 1 {
256 let g_ptr = grad_contig.ptr();
258 let o_ptr = output_contig.ptr();
259 let d_ptr = d_input.ptr();
260
261 dispatch_dtype!(dtype, T => {
262 unsafe {
263 kernels::softmax_bwd_kernel::<T>(
264 g_ptr as *const T,
265 o_ptr as *const T,
266 d_ptr as *mut T,
267 outer_size,
268 dim_size,
269 );
270 }
271 }, "softmax_bwd");
272 } else {
273 let g_ptr = grad_contig.ptr();
275 let o_ptr = output_contig.ptr();
276 let d_ptr = d_input.ptr();
277
278 dispatch_dtype!(dtype, T => {
279 unsafe {
280 softmax_bwd_non_last_dim::<T>(
281 g_ptr as *const T,
282 o_ptr as *const T,
283 d_ptr as *mut T,
284 outer_size,
285 dim_size,
286 inner_size,
287 );
288 }
289 }, "softmax_bwd");
290 }
291
292 Ok(d_input)
293 }
294
295 fn softplus(&self, a: &Tensor<CpuRuntime>) -> Result<Tensor<CpuRuntime>> {
296 softplus_impl(self, a)
297 }
298
299 fn log_softmax(&self, a: &Tensor<CpuRuntime>, dim: isize) -> Result<Tensor<CpuRuntime>> {
300 log_softmax_impl(self, a, dim)
301 }
302
303 fn dropout(
304 &self,
305 a: &Tensor<CpuRuntime>,
306 p: f64,
307 training: bool,
308 ) -> Result<Tensor<CpuRuntime>> {
309 dropout_impl(self, a, p, training)
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::ops::ActivationOps;
317 use crate::runtime::cpu::CpuDevice;
318
319 #[test]
320 fn test_log_softmax_basic() {
321 let device = CpuDevice::new();
322 let client = CpuClient::new(device.clone());
323
324 let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
325 let result = client.log_softmax(&input, -1).unwrap();
326 let data: Vec<f32> = result.to_vec();
327
328 let exp_sum: f32 = data.iter().map(|x| x.exp()).sum();
331 assert!((exp_sum - 1.0).abs() < 1e-5);
332
333 for &v in &data {
335 assert!(v < 0.0);
336 }
337 }
338
339 #[test]
340 fn test_log_softmax_2d() {
341 let device = CpuDevice::new();
342 let client = CpuClient::new(device.clone());
343
344 let input =
345 Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
346 let result = client.log_softmax(&input, -1).unwrap();
347 let data: Vec<f32> = result.to_vec();
348
349 let row1_sum: f32 = data[0..3].iter().map(|x| x.exp()).sum();
351 let row2_sum: f32 = data[3..6].iter().map(|x| x.exp()).sum();
352 assert!((row1_sum - 1.0).abs() < 1e-5);
353 assert!((row2_sum - 1.0).abs() < 1e-5);
354 }
355
356 #[test]
357 fn test_dropout_training() {
358 let device = CpuDevice::new();
359 let client = CpuClient::new(device.clone());
360
361 let input = Tensor::<CpuRuntime>::ones(&[1000], crate::dtype::DType::F32, &device);
362 let result = client.dropout(&input, 0.5, true).unwrap();
363 let data: Vec<f32> = result.to_vec();
364
365 let zeros = data.iter().filter(|&&v| v == 0.0).count();
367 let scaled = data.iter().filter(|&&v| (v - 2.0).abs() < 1e-5).count();
368
369 assert!(zeros > 200, "too few zeros: {zeros}");
371 assert!(zeros < 800, "too many zeros: {zeros}");
372 assert_eq!(zeros + scaled, 1000);
373 }
374
375 #[test]
376 fn test_dropout_inference() {
377 let device = CpuDevice::new();
378 let client = CpuClient::new(device.clone());
379
380 let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
381 let result = client.dropout(&input, 0.5, false).unwrap();
382 let data: Vec<f32> = result.to_vec();
383
384 assert!((data[0] - 1.0).abs() < 1e-6);
386 assert!((data[1] - 2.0).abs() < 1e-6);
387 assert!((data[2] - 3.0).abs() < 1e-6);
388 }
389
390 #[test]
391 fn test_dropout_p_zero() {
392 let device = CpuDevice::new();
393 let client = CpuClient::new(device.clone());
394
395 let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
396 let result = client.dropout(&input, 0.0, true).unwrap();
397 let data: Vec<f32> = result.to_vec();
398
399 assert!((data[0] - 1.0).abs() < 1e-6);
401 assert!((data[1] - 2.0).abs() < 1e-6);
402 assert!((data[2] - 3.0).abs() < 1e-6);
403 }
404
405 #[test]
406 fn test_dropout_p_one() {
407 let device = CpuDevice::new();
408 let client = CpuClient::new(device.clone());
409
410 let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
411 let result = client.dropout(&input, 1.0, true).unwrap();
412 let data: Vec<f32> = result.to_vec();
413
414 for &v in &data {
416 assert!((v).abs() < 1e-6);
417 }
418 }
419}
420
421unsafe fn softmax_bwd_non_last_dim<T: crate::dtype::Element>(
425 grad: *const T,
426 output: *const T,
427 d_input: *mut T,
428 outer_size: usize,
429 dim_size: usize,
430 inner_size: usize,
431) {
432 unsafe {
433 for outer in 0..outer_size {
434 for inner in 0..inner_size {
435 let base_idx = outer * dim_size * inner_size + inner;
436 let stride = inner_size;
437
438 let mut dot = 0.0f64;
440 for d in 0..dim_size {
441 let idx = base_idx + d * stride;
442 dot += (*grad.add(idx)).to_f64() * (*output.add(idx)).to_f64();
443 }
444
445 for d in 0..dim_size {
447 let idx = base_idx + d * stride;
448 let g = (*grad.add(idx)).to_f64();
449 let o = (*output.add(idx)).to_f64();
450 *d_input.add(idx) = T::from_f64(o * (g - dot));
451 }
452 }
453 }
454 }
455}
456
457unsafe fn softmax_non_last_dim<T: crate::dtype::Element>(
458 a_ptr: *const T,
459 out_ptr: *mut T,
460 outer_size: usize,
461 dim_size: usize,
462 inner_size: usize,
463) {
464 unsafe {
465 for outer in 0..outer_size {
466 for inner in 0..inner_size {
467 let base_idx = outer * dim_size * inner_size + inner;
468 let stride = inner_size;
469
470 let mut max_val = (*a_ptr.add(base_idx)).to_f64();
472 let mut sum = 1.0f64;
473 for d in 1..dim_size {
474 let idx = base_idx + d * stride;
475 let val = (*a_ptr.add(idx)).to_f64();
476 if val > max_val {
477 sum = sum * (max_val - val).exp() + 1.0;
478 max_val = val;
479 } else {
480 sum += (val - max_val).exp();
481 }
482 }
483
484 let inv_sum = if sum > 0.0 { 1.0 / sum } else { 0.0 };
486 for d in 0..dim_size {
487 let idx = base_idx + d * stride;
488 let val = (*a_ptr.add(idx)).to_f64();
489 *out_ptr.add(idx) = T::from_f64((val - max_val).exp() * inv_sum);
490 }
491 }
492 }
493 }
494}