Skip to main content

sklears_simd/
activation.rs

1//! SIMD-optimized activation functions for machine learning
2//!
3//! This module provides high-performance implementations of activation functions
4//! commonly used in neural networks, optimized using SIMD instructions for
5//! maximum throughput. All functions include both forward and derivative variants
6//! for efficient backpropagation.
7
8use crate::vector::sum;
9use scirs2_autograd::ndarray::{Array1, Array2};
10
11#[cfg(feature = "no-std")]
12use alloc::vec;
13
14/// SIMD-optimized sigmoid activation function
15pub fn sigmoid(input: &[f32], output: &mut [f32]) {
16    assert_eq!(
17        input.len(),
18        output.len(),
19        "Vectors must have the same length"
20    );
21
22    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
23    {
24        if crate::simd_feature_detected!("avx2") {
25            unsafe { sigmoid_avx2(input, output) };
26            return;
27        } else if crate::simd_feature_detected!("sse2") {
28            unsafe { sigmoid_sse2(input, output) };
29            return;
30        }
31    }
32
33    sigmoid_scalar(input, output);
34}
35
36fn sigmoid_scalar(input: &[f32], output: &mut [f32]) {
37    for i in 0..input.len() {
38        output[i] = 1.0 / (1.0 + (-input[i]).exp());
39    }
40}
41
42#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
43#[target_feature(enable = "sse2")]
44unsafe fn sigmoid_sse2(input: &[f32], output: &mut [f32]) {
45    use core::arch::x86_64::*;
46
47    let mut i = 0;
48    let one = _mm_set1_ps(1.0);
49
50    while i + 4 <= input.len() {
51        let x = _mm_loadu_ps(input.as_ptr().add(i));
52
53        // Approximate exp(-x) using polynomial approximation for better SIMD performance
54        let neg_x = _mm_sub_ps(_mm_setzero_ps(), x);
55        let exp_neg_x = exp_approx_sse2(neg_x);
56
57        let one_plus_exp = _mm_add_ps(one, exp_neg_x);
58        let result = _mm_div_ps(one, one_plus_exp);
59
60        _mm_storeu_ps(output.as_mut_ptr().add(i), result);
61        i += 4;
62    }
63
64    while i < input.len() {
65        output[i] = 1.0 / (1.0 + (-input[i]).exp());
66        i += 1;
67    }
68}
69
70#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
71#[target_feature(enable = "avx2")]
72unsafe fn sigmoid_avx2(input: &[f32], output: &mut [f32]) {
73    use core::arch::x86_64::*;
74
75    let mut i = 0;
76    let one = _mm256_set1_ps(1.0);
77
78    while i + 8 <= input.len() {
79        let x = _mm256_loadu_ps(input.as_ptr().add(i));
80
81        let neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
82        let exp_neg_x = exp_approx_avx2(neg_x);
83
84        let one_plus_exp = _mm256_add_ps(one, exp_neg_x);
85        let result = _mm256_div_ps(one, one_plus_exp);
86
87        _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
88        i += 8;
89    }
90
91    while i < input.len() {
92        output[i] = 1.0 / (1.0 + (-input[i]).exp());
93        i += 1;
94    }
95}
96
97/// Fast exponential approximation for SSE2
98#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
99#[target_feature(enable = "sse2")]
100unsafe fn exp_approx_sse2(x: core::arch::x86_64::__m128) -> core::arch::x86_64::__m128 {
101    use core::arch::x86_64::*;
102
103    // Clamp input to reasonable range to avoid overflow
104    let min_val = _mm_set1_ps(-10.0);
105    let max_val = _mm_set1_ps(10.0);
106    let clamped = _mm_max_ps(min_val, _mm_min_ps(max_val, x));
107
108    // Simple polynomial approximation: exp(x) ≈ 1 + x + x²/2 + x³/6
109    let one = _mm_set1_ps(1.0);
110    let half = _mm_set1_ps(0.5);
111    let sixth = _mm_set1_ps(1.0 / 6.0);
112
113    let x2 = _mm_mul_ps(clamped, clamped);
114    let x3 = _mm_mul_ps(x2, clamped);
115
116    let term1 = one;
117    let term2 = clamped;
118    let term3 = _mm_mul_ps(x2, half);
119    let term4 = _mm_mul_ps(x3, sixth);
120
121    _mm_add_ps(_mm_add_ps(term1, term2), _mm_add_ps(term3, term4))
122}
123
124/// Fast exponential approximation for AVX2
125#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
126#[target_feature(enable = "avx2")]
127unsafe fn exp_approx_avx2(x: core::arch::x86_64::__m256) -> core::arch::x86_64::__m256 {
128    use core::arch::x86_64::*;
129
130    let min_val = _mm256_set1_ps(-10.0);
131    let max_val = _mm256_set1_ps(10.0);
132    let clamped = _mm256_max_ps(min_val, _mm256_min_ps(max_val, x));
133
134    let one = _mm256_set1_ps(1.0);
135    let half = _mm256_set1_ps(0.5);
136    let sixth = _mm256_set1_ps(1.0 / 6.0);
137
138    let x2 = _mm256_mul_ps(clamped, clamped);
139    let x3 = _mm256_mul_ps(x2, clamped);
140
141    let term1 = one;
142    let term2 = clamped;
143    let term3 = _mm256_mul_ps(x2, half);
144    let term4 = _mm256_mul_ps(x3, sixth);
145
146    _mm256_add_ps(_mm256_add_ps(term1, term2), _mm256_add_ps(term3, term4))
147}
148
149/// SIMD-optimized ReLU activation function
150pub fn relu(input: &[f32], output: &mut [f32]) {
151    assert_eq!(
152        input.len(),
153        output.len(),
154        "Vectors must have the same length"
155    );
156
157    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
158    {
159        if crate::simd_feature_detected!("avx2") {
160            unsafe { relu_avx2(input, output) };
161            return;
162        } else if crate::simd_feature_detected!("sse2") {
163            unsafe { relu_sse2(input, output) };
164            return;
165        }
166    }
167
168    relu_scalar(input, output);
169}
170
171fn relu_scalar(input: &[f32], output: &mut [f32]) {
172    for i in 0..input.len() {
173        output[i] = input[i].max(0.0);
174    }
175}
176
177#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
178#[target_feature(enable = "sse2")]
179unsafe fn relu_sse2(input: &[f32], output: &mut [f32]) {
180    use core::arch::x86_64::*;
181
182    let mut i = 0;
183    let zero = _mm_setzero_ps();
184
185    while i + 4 <= input.len() {
186        let x = _mm_loadu_ps(input.as_ptr().add(i));
187        let result = _mm_max_ps(x, zero);
188        _mm_storeu_ps(output.as_mut_ptr().add(i), result);
189        i += 4;
190    }
191
192    while i < input.len() {
193        output[i] = input[i].max(0.0);
194        i += 1;
195    }
196}
197
198#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
199#[target_feature(enable = "avx2")]
200unsafe fn relu_avx2(input: &[f32], output: &mut [f32]) {
201    use core::arch::x86_64::*;
202
203    let mut i = 0;
204    let zero = _mm256_setzero_ps();
205
206    while i + 8 <= input.len() {
207        let x = _mm256_loadu_ps(input.as_ptr().add(i));
208        let result = _mm256_max_ps(x, zero);
209        _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
210        i += 8;
211    }
212
213    while i < input.len() {
214        output[i] = input[i].max(0.0);
215        i += 1;
216    }
217}
218
219/// SIMD-optimized Leaky ReLU activation function
220pub fn leaky_relu(input: &[f32], output: &mut [f32], alpha: f32) {
221    assert_eq!(
222        input.len(),
223        output.len(),
224        "Vectors must have the same length"
225    );
226
227    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
228    {
229        if crate::simd_feature_detected!("avx2") {
230            unsafe { leaky_relu_avx2(input, output, alpha) };
231            return;
232        } else if crate::simd_feature_detected!("sse2") {
233            unsafe { leaky_relu_sse2(input, output, alpha) };
234            return;
235        }
236    }
237
238    leaky_relu_scalar(input, output, alpha);
239}
240
241fn leaky_relu_scalar(input: &[f32], output: &mut [f32], alpha: f32) {
242    for i in 0..input.len() {
243        output[i] = if input[i] > 0.0 {
244            input[i]
245        } else {
246            alpha * input[i]
247        };
248    }
249}
250
251#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
252#[target_feature(enable = "sse2")]
253unsafe fn leaky_relu_sse2(input: &[f32], output: &mut [f32], alpha: f32) {
254    use core::arch::x86_64::*;
255
256    let mut i = 0;
257    let zero = _mm_setzero_ps();
258    let alpha_vec = _mm_set1_ps(alpha);
259
260    while i + 4 <= input.len() {
261        let x = _mm_loadu_ps(input.as_ptr().add(i));
262        let mask = _mm_cmpgt_ps(x, zero);
263        let positive = x;
264        let negative = _mm_mul_ps(x, alpha_vec);
265        let result = _mm_blendv_ps(negative, positive, mask);
266        _mm_storeu_ps(output.as_mut_ptr().add(i), result);
267        i += 4;
268    }
269
270    while i < input.len() {
271        output[i] = if input[i] > 0.0 {
272            input[i]
273        } else {
274            alpha * input[i]
275        };
276        i += 1;
277    }
278}
279
280#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
281#[target_feature(enable = "avx2")]
282unsafe fn leaky_relu_avx2(input: &[f32], output: &mut [f32], alpha: f32) {
283    use core::arch::x86_64::*;
284
285    let mut i = 0;
286    let zero = _mm256_setzero_ps();
287    let alpha_vec = _mm256_set1_ps(alpha);
288
289    while i + 8 <= input.len() {
290        let x = _mm256_loadu_ps(input.as_ptr().add(i));
291        let mask = _mm256_cmp_ps(x, zero, _CMP_GT_OQ);
292        let positive = x;
293        let negative = _mm256_mul_ps(x, alpha_vec);
294        let result = _mm256_blendv_ps(negative, positive, mask);
295        _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
296        i += 8;
297    }
298
299    while i < input.len() {
300        output[i] = if input[i] > 0.0 {
301            input[i]
302        } else {
303            alpha * input[i]
304        };
305        i += 1;
306    }
307}
308
309/// SIMD-optimized tanh activation function
310pub fn tanh_activation(input: &[f32], output: &mut [f32]) {
311    assert_eq!(
312        input.len(),
313        output.len(),
314        "Vectors must have the same length"
315    );
316
317    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
318    {
319        if crate::simd_feature_detected!("avx2") {
320            unsafe { tanh_avx2(input, output) };
321            return;
322        } else if crate::simd_feature_detected!("sse2") {
323            unsafe { tanh_sse2(input, output) };
324            return;
325        }
326    }
327
328    tanh_scalar(input, output);
329}
330
331fn tanh_scalar(input: &[f32], output: &mut [f32]) {
332    for i in 0..input.len() {
333        output[i] = input[i].tanh();
334    }
335}
336
337#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
338#[target_feature(enable = "sse2")]
339unsafe fn tanh_sse2(input: &[f32], output: &mut [f32]) {
340    use core::arch::x86_64::*;
341
342    let mut i = 0;
343
344    while i + 4 <= input.len() {
345        let x = _mm_loadu_ps(input.as_ptr().add(i));
346        let result = tanh_approx_sse2(x);
347        _mm_storeu_ps(output.as_mut_ptr().add(i), result);
348        i += 4;
349    }
350
351    while i < input.len() {
352        output[i] = input[i].tanh();
353        i += 1;
354    }
355}
356
357#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
358#[target_feature(enable = "avx2")]
359unsafe fn tanh_avx2(input: &[f32], output: &mut [f32]) {
360    use core::arch::x86_64::*;
361
362    let mut i = 0;
363
364    while i + 8 <= input.len() {
365        let x = _mm256_loadu_ps(input.as_ptr().add(i));
366        let result = tanh_approx_avx2(x);
367        _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
368        i += 8;
369    }
370
371    while i < input.len() {
372        output[i] = input[i].tanh();
373        i += 1;
374    }
375}
376
377/// Fast tanh approximation for SSE2
378#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
379#[target_feature(enable = "sse2")]
380unsafe fn tanh_approx_sse2(x: core::arch::x86_64::__m128) -> core::arch::x86_64::__m128 {
381    use core::arch::x86_64::*;
382
383    // Clamp input
384    let min_val = _mm_set1_ps(-5.0);
385    let max_val = _mm_set1_ps(5.0);
386    let clamped = _mm_max_ps(min_val, _mm_min_ps(max_val, x));
387
388    // Use rational approximation: tanh(x) ≈ x * (1 - x²/3)
389    let x2 = _mm_mul_ps(clamped, clamped);
390    let third = _mm_set1_ps(1.0 / 3.0);
391    let one = _mm_set1_ps(1.0);
392
393    let term = _mm_sub_ps(one, _mm_mul_ps(x2, third));
394    _mm_mul_ps(clamped, term)
395}
396
397/// Fast tanh approximation for AVX2
398#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
399#[target_feature(enable = "avx2")]
400unsafe fn tanh_approx_avx2(x: core::arch::x86_64::__m256) -> core::arch::x86_64::__m256 {
401    use core::arch::x86_64::*;
402
403    let min_val = _mm256_set1_ps(-5.0);
404    let max_val = _mm256_set1_ps(5.0);
405    let clamped = _mm256_max_ps(min_val, _mm256_min_ps(max_val, x));
406
407    let x2 = _mm256_mul_ps(clamped, clamped);
408    let third = _mm256_set1_ps(1.0 / 3.0);
409    let one = _mm256_set1_ps(1.0);
410
411    let term = _mm256_sub_ps(one, _mm256_mul_ps(x2, third));
412    _mm256_mul_ps(clamped, term)
413}
414
415/// SIMD-optimized softmax activation function
416pub fn softmax(input: &[f32], output: &mut [f32]) {
417    assert_eq!(
418        input.len(),
419        output.len(),
420        "Vectors must have the same length"
421    );
422
423    if input.is_empty() {
424        return;
425    }
426
427    // Find max for numerical stability
428    let max_val = input.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
429
430    // Compute exp(x - max)
431    let mut exp_values = vec![0.0; input.len()];
432    for i in 0..input.len() {
433        exp_values[i] = (input[i] - max_val).exp();
434    }
435
436    // Compute sum of exponentials
437    let exp_sum = sum(&exp_values);
438
439    // Normalize
440    for i in 0..input.len() {
441        output[i] = exp_values[i] / exp_sum;
442    }
443}
444
445/// SIMD-optimized ELU (Exponential Linear Unit) activation function
446pub fn elu(input: &[f32], output: &mut [f32], alpha: f32) {
447    assert_eq!(
448        input.len(),
449        output.len(),
450        "Vectors must have the same length"
451    );
452
453    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
454    {
455        if crate::simd_feature_detected!("avx2") {
456            unsafe { elu_avx2(input, output, alpha) };
457            return;
458        } else if crate::simd_feature_detected!("sse2") {
459            unsafe { elu_sse2(input, output, alpha) };
460            return;
461        }
462    }
463
464    elu_scalar(input, output, alpha);
465}
466
467fn elu_scalar(input: &[f32], output: &mut [f32], alpha: f32) {
468    for i in 0..input.len() {
469        output[i] = if input[i] >= 0.0 {
470            input[i]
471        } else {
472            alpha * (input[i].exp() - 1.0)
473        };
474    }
475}
476
477#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
478#[target_feature(enable = "sse2")]
479unsafe fn elu_sse2(input: &[f32], output: &mut [f32], alpha: f32) {
480    use core::arch::x86_64::*;
481
482    let zero = _mm_setzero_ps();
483    let one = _mm_set1_ps(1.0);
484    let alpha_vec = _mm_set1_ps(alpha);
485    let mut i = 0;
486
487    while i + 4 <= input.len() {
488        let x = _mm_loadu_ps(input.as_ptr().add(i));
489        let mask = _mm_cmpge_ps(x, zero);
490
491        let positive = x;
492        let exp_x = exp_approx_sse2(x);
493        let negative = _mm_mul_ps(alpha_vec, _mm_sub_ps(exp_x, one));
494
495        let result = _mm_blendv_ps(negative, positive, mask);
496        _mm_storeu_ps(output.as_mut_ptr().add(i), result);
497        i += 4;
498    }
499
500    while i < input.len() {
501        output[i] = if input[i] >= 0.0 {
502            input[i]
503        } else {
504            alpha * (input[i].exp() - 1.0)
505        };
506        i += 1;
507    }
508}
509
510#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
511#[target_feature(enable = "avx2")]
512unsafe fn elu_avx2(input: &[f32], output: &mut [f32], alpha: f32) {
513    use core::arch::x86_64::*;
514
515    let zero = _mm256_setzero_ps();
516    let one = _mm256_set1_ps(1.0);
517    let alpha_vec = _mm256_set1_ps(alpha);
518    let mut i = 0;
519
520    while i + 8 <= input.len() {
521        let x = _mm256_loadu_ps(input.as_ptr().add(i));
522        let mask = _mm256_cmp_ps(x, zero, _CMP_GE_OQ);
523
524        let positive = x;
525        let exp_x = exp_approx_avx2(x);
526        let negative = _mm256_mul_ps(alpha_vec, _mm256_sub_ps(exp_x, one));
527
528        let result = _mm256_blendv_ps(negative, positive, mask);
529        _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
530        i += 8;
531    }
532
533    while i < input.len() {
534        output[i] = if input[i] >= 0.0 {
535            input[i]
536        } else {
537            alpha * (input[i].exp() - 1.0)
538        };
539        i += 1;
540    }
541}
542
543/// SIMD-optimized Swish (SiLU) activation function: x * sigmoid(x)
544pub fn swish(input: &[f32], output: &mut [f32]) {
545    assert_eq!(
546        input.len(),
547        output.len(),
548        "Vectors must have the same length"
549    );
550
551    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
552    {
553        if crate::simd_feature_detected!("avx2") {
554            unsafe { swish_avx2(input, output) };
555            return;
556        } else if crate::simd_feature_detected!("sse2") {
557            unsafe { swish_sse2(input, output) };
558            return;
559        }
560    }
561
562    swish_scalar(input, output);
563}
564
565fn swish_scalar(input: &[f32], output: &mut [f32]) {
566    for i in 0..input.len() {
567        let sigmoid_x = 1.0 / (1.0 + (-input[i]).exp());
568        output[i] = input[i] * sigmoid_x;
569    }
570}
571
572#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
573#[target_feature(enable = "sse2")]
574unsafe fn swish_sse2(input: &[f32], output: &mut [f32]) {
575    use core::arch::x86_64::*;
576
577    let one = _mm_set1_ps(1.0);
578    let mut i = 0;
579
580    while i + 4 <= input.len() {
581        let x = _mm_loadu_ps(input.as_ptr().add(i));
582
583        let neg_x = _mm_sub_ps(_mm_setzero_ps(), x);
584        let exp_neg_x = exp_approx_sse2(neg_x);
585        let one_plus_exp = _mm_add_ps(one, exp_neg_x);
586        let sigmoid_x = _mm_div_ps(one, one_plus_exp);
587
588        let result = _mm_mul_ps(x, sigmoid_x);
589        _mm_storeu_ps(output.as_mut_ptr().add(i), result);
590        i += 4;
591    }
592
593    while i < input.len() {
594        let sigmoid_x = 1.0 / (1.0 + (-input[i]).exp());
595        output[i] = input[i] * sigmoid_x;
596        i += 1;
597    }
598}
599
600#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
601#[target_feature(enable = "avx2")]
602unsafe fn swish_avx2(input: &[f32], output: &mut [f32]) {
603    use core::arch::x86_64::*;
604
605    let one = _mm256_set1_ps(1.0);
606    let mut i = 0;
607
608    while i + 8 <= input.len() {
609        let x = _mm256_loadu_ps(input.as_ptr().add(i));
610
611        let neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
612        let exp_neg_x = exp_approx_avx2(neg_x);
613        let one_plus_exp = _mm256_add_ps(one, exp_neg_x);
614        let sigmoid_x = _mm256_div_ps(one, one_plus_exp);
615
616        let result = _mm256_mul_ps(x, sigmoid_x);
617        _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
618        i += 8;
619    }
620
621    while i < input.len() {
622        let sigmoid_x = 1.0 / (1.0 + (-input[i]).exp());
623        output[i] = input[i] * sigmoid_x;
624        i += 1;
625    }
626}
627
628/// SIMD-optimized GELU (Gaussian Error Linear Unit) activation function
629pub fn gelu(input: &[f32], output: &mut [f32]) {
630    assert_eq!(
631        input.len(),
632        output.len(),
633        "Vectors must have the same length"
634    );
635
636    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
637    {
638        if crate::simd_feature_detected!("avx2") {
639            unsafe { gelu_avx2(input, output) };
640            return;
641        } else if crate::simd_feature_detected!("sse2") {
642            unsafe { gelu_sse2(input, output) };
643            return;
644        }
645    }
646
647    gelu_scalar(input, output);
648}
649
650fn gelu_scalar(input: &[f32], output: &mut [f32]) {
651    const SQRT_2_PI: f32 = 0.797_884_6; // sqrt(2/π)
652    for i in 0..input.len() {
653        let x = input[i];
654        // GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
655        let x_cubed = x * x * x;
656        let inner = SQRT_2_PI * (x + 0.044715 * x_cubed);
657        output[i] = 0.5 * x * (1.0 + inner.tanh());
658    }
659}
660
661#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
662#[target_feature(enable = "sse2")]
663unsafe fn gelu_sse2(input: &[f32], output: &mut [f32]) {
664    use core::arch::x86_64::*;
665
666    let sqrt_2_pi = _mm_set1_ps(0.797_884_6_f32);
667    let coeff = _mm_set1_ps(0.044715);
668    let half = _mm_set1_ps(0.5);
669    let one = _mm_set1_ps(1.0);
670    let mut i = 0;
671
672    while i + 4 <= input.len() {
673        let x = _mm_loadu_ps(input.as_ptr().add(i));
674
675        let x2 = _mm_mul_ps(x, x);
676        let x3 = _mm_mul_ps(x2, x);
677        let coeff_x3 = _mm_mul_ps(coeff, x3);
678        let inner_term = _mm_add_ps(x, coeff_x3);
679        let scaled_inner = _mm_mul_ps(sqrt_2_pi, inner_term);
680
681        let tanh_result = tanh_approx_sse2(scaled_inner);
682        let one_plus_tanh = _mm_add_ps(one, tanh_result);
683        let result = _mm_mul_ps(_mm_mul_ps(half, x), one_plus_tanh);
684
685        _mm_storeu_ps(output.as_mut_ptr().add(i), result);
686        i += 4;
687    }
688
689    while i < input.len() {
690        let x = input[i];
691        let x_cubed = x * x * x;
692        let inner = 0.797_884_6_f32 * (x + 0.044715 * x_cubed);
693        output[i] = 0.5 * x * (1.0 + inner.tanh());
694        i += 1;
695    }
696}
697
698#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
699#[target_feature(enable = "avx2")]
700unsafe fn gelu_avx2(input: &[f32], output: &mut [f32]) {
701    use core::arch::x86_64::*;
702
703    let sqrt_2_pi = _mm256_set1_ps(0.797_884_6_f32);
704    let coeff = _mm256_set1_ps(0.044715);
705    let half = _mm256_set1_ps(0.5);
706    let one = _mm256_set1_ps(1.0);
707    let mut i = 0;
708
709    while i + 8 <= input.len() {
710        let x = _mm256_loadu_ps(input.as_ptr().add(i));
711
712        let x2 = _mm256_mul_ps(x, x);
713        let x3 = _mm256_mul_ps(x2, x);
714        let coeff_x3 = _mm256_mul_ps(coeff, x3);
715        let inner_term = _mm256_add_ps(x, coeff_x3);
716        let scaled_inner = _mm256_mul_ps(sqrt_2_pi, inner_term);
717
718        let tanh_result = tanh_approx_avx2(scaled_inner);
719        let one_plus_tanh = _mm256_add_ps(one, tanh_result);
720        let result = _mm256_mul_ps(_mm256_mul_ps(half, x), one_plus_tanh);
721
722        _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
723        i += 8;
724    }
725
726    while i < input.len() {
727        let x = input[i];
728        let x_cubed = x * x * x;
729        let inner = 0.797_884_6_f32 * (x + 0.044715 * x_cubed);
730        output[i] = 0.5 * x * (1.0 + inner.tanh());
731        i += 1;
732    }
733}
734
735// ===== DERIVATIVE FUNCTIONS FOR BACKPROPAGATION =====
736
737/// SIMD-optimized sigmoid derivative
738pub fn sigmoid_derivative(input: &[f32], output: &mut [f32]) {
739    assert_eq!(
740        input.len(),
741        output.len(),
742        "Vectors must have the same length"
743    );
744
745    // Compute sigmoid first, then derivative: sigmoid(x) * (1 - sigmoid(x))
746    let mut sigmoid_vals = vec![0.0; input.len()];
747    sigmoid(input, &mut sigmoid_vals);
748
749    for i in 0..input.len() {
750        output[i] = sigmoid_vals[i] * (1.0 - sigmoid_vals[i]);
751    }
752}
753
754/// SIMD-optimized ReLU derivative
755pub fn relu_derivative(input: &[f32], output: &mut [f32]) {
756    assert_eq!(
757        input.len(),
758        output.len(),
759        "Vectors must have the same length"
760    );
761
762    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
763    {
764        if crate::simd_feature_detected!("avx2") {
765            unsafe { relu_derivative_avx2(input, output) };
766            return;
767        } else if crate::simd_feature_detected!("sse2") {
768            unsafe { relu_derivative_sse2(input, output) };
769            return;
770        }
771    }
772
773    relu_derivative_scalar(input, output);
774}
775
776fn relu_derivative_scalar(input: &[f32], output: &mut [f32]) {
777    for i in 0..input.len() {
778        output[i] = if input[i] > 0.0 { 1.0 } else { 0.0 };
779    }
780}
781
782#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
783#[target_feature(enable = "sse2")]
784unsafe fn relu_derivative_sse2(input: &[f32], output: &mut [f32]) {
785    use core::arch::x86_64::*;
786
787    let zero = _mm_setzero_ps();
788    let one = _mm_set1_ps(1.0);
789    let mut i = 0;
790
791    while i + 4 <= input.len() {
792        let x = _mm_loadu_ps(input.as_ptr().add(i));
793        let mask = _mm_cmpgt_ps(x, zero);
794        let result = _mm_and_ps(mask, one);
795        _mm_storeu_ps(output.as_mut_ptr().add(i), result);
796        i += 4;
797    }
798
799    while i < input.len() {
800        output[i] = if input[i] > 0.0 { 1.0 } else { 0.0 };
801        i += 1;
802    }
803}
804
805#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
806#[target_feature(enable = "avx2")]
807unsafe fn relu_derivative_avx2(input: &[f32], output: &mut [f32]) {
808    use core::arch::x86_64::*;
809
810    let zero = _mm256_setzero_ps();
811    let one = _mm256_set1_ps(1.0);
812    let mut i = 0;
813
814    while i + 8 <= input.len() {
815        let x = _mm256_loadu_ps(input.as_ptr().add(i));
816        let mask = _mm256_cmp_ps(x, zero, _CMP_GT_OQ);
817        let result = _mm256_and_ps(mask, one);
818        _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
819        i += 8;
820    }
821
822    while i < input.len() {
823        output[i] = if input[i] > 0.0 { 1.0 } else { 0.0 };
824        i += 1;
825    }
826}
827
828/// SIMD-optimized tanh derivative: 1 - tanh²(x)
829pub fn tanh_derivative(input: &[f32], output: &mut [f32]) {
830    assert_eq!(
831        input.len(),
832        output.len(),
833        "Vectors must have the same length"
834    );
835
836    // Compute tanh first, then derivative
837    let mut tanh_vals = vec![0.0; input.len()];
838    tanh_activation(input, &mut tanh_vals);
839
840    for i in 0..input.len() {
841        output[i] = 1.0 - tanh_vals[i] * tanh_vals[i];
842    }
843}
844
845// ===== CONVENIENT NDARRAY INTERFACES =====
846
847/// Apply activation function to ndarray Array1
848pub fn apply_activation_1d(
849    input: &Array1<f32>,
850    activation: ActivationFunction,
851    alpha: Option<f32>,
852) -> Array1<f32> {
853    let mut output = Array1::zeros(input.len());
854    apply_activation_slice(
855        input.as_slice().expect("slice operation should succeed"),
856        output
857            .as_slice_mut()
858            .expect("slice operation should succeed"),
859        activation,
860        alpha,
861    );
862    output
863}
864
865/// Apply activation function to ndarray Array2 (applies to each element)
866pub fn apply_activation_2d(
867    input: &Array2<f32>,
868    activation: ActivationFunction,
869    alpha: Option<f32>,
870) -> Array2<f32> {
871    let mut output = Array2::zeros(input.dim());
872    if let (Some(input_slice), Some(output_slice)) = (input.as_slice(), output.as_slice_mut()) {
873        apply_activation_slice(input_slice, output_slice, activation, alpha);
874    }
875    output
876}
877
878/// Apply activation function to raw slice
879pub fn apply_activation_slice(
880    input: &[f32],
881    output: &mut [f32],
882    activation: ActivationFunction,
883    alpha: Option<f32>,
884) {
885    match activation {
886        ActivationFunction::ReLU => relu(input, output),
887        ActivationFunction::Sigmoid => sigmoid(input, output),
888        ActivationFunction::Tanh => tanh_activation(input, output),
889        ActivationFunction::LeakyReLU => {
890            let alpha_val = alpha.unwrap_or(0.01);
891            leaky_relu(input, output, alpha_val);
892        }
893        ActivationFunction::ELU => {
894            let alpha_val = alpha.unwrap_or(1.0);
895            elu(input, output, alpha_val);
896        }
897        ActivationFunction::Swish => swish(input, output),
898        ActivationFunction::GELU => gelu(input, output),
899        ActivationFunction::Softmax => softmax(input, output),
900    }
901}
902
903/// Enumeration of available activation functions
904#[derive(Debug, Clone, Copy, PartialEq, Eq)]
905pub enum ActivationFunction {
906    ReLU,
907    Sigmoid,
908    Tanh,
909    LeakyReLU,
910    ELU,
911    Swish,
912    GELU,
913    Softmax,
914}
915
916#[allow(non_snake_case)]
917#[cfg(all(test, not(feature = "no-std")))]
918mod tests {
919    use super::*;
920    use approx::assert_relative_eq;
921
922    #[cfg(feature = "no-std")]
923    use alloc::{vec, vec::Vec};
924
925    #[test]
926    fn test_sigmoid() {
927        let input = vec![0.0, 1.0, -1.0, 2.0, -2.0];
928        let mut output = vec![0.0; 5];
929
930        sigmoid(&input, &mut output);
931
932        assert_relative_eq!(output[0], 0.5, epsilon = 1e-3);
933        assert!(output[1] > 0.7 && output[1] < 0.8);
934        assert!(output[2] > 0.2 && output[2] < 0.3);
935        assert!(output[3] > 0.8 && output[3] < 0.9);
936        assert!(output[4] > 0.1 && output[4] < 0.2);
937    }
938
939    #[test]
940    fn test_relu() {
941        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
942        let mut output = vec![0.0; 5];
943
944        relu(&input, &mut output);
945
946        assert_relative_eq!(output[0], 0.0, epsilon = 1e-6);
947        assert_relative_eq!(output[1], 0.0, epsilon = 1e-6);
948        assert_relative_eq!(output[2], 0.0, epsilon = 1e-6);
949        assert_relative_eq!(output[3], 1.0, epsilon = 1e-6);
950        assert_relative_eq!(output[4], 2.0, epsilon = 1e-6);
951    }
952
953    #[test]
954    fn test_leaky_relu() {
955        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
956        let mut output = vec![0.0; 5];
957        let alpha = 0.1;
958
959        leaky_relu(&input, &mut output, alpha);
960
961        assert_relative_eq!(output[0], -0.2, epsilon = 1e-6);
962        assert_relative_eq!(output[1], -0.1, epsilon = 1e-6);
963        assert_relative_eq!(output[2], 0.0, epsilon = 1e-6);
964        assert_relative_eq!(output[3], 1.0, epsilon = 1e-6);
965        assert_relative_eq!(output[4], 2.0, epsilon = 1e-6);
966    }
967
968    #[test]
969    fn test_tanh_activation() {
970        let input = vec![0.0, 1.0, -1.0, 2.0, -2.0];
971        let mut output = vec![0.0; 5];
972
973        tanh_activation(&input, &mut output);
974
975        assert_relative_eq!(output[0], 0.0, epsilon = 1e-3);
976        assert!(output[1] > 0.7 && output[1] < 0.8);
977        assert!(output[2] > -0.8 && output[2] < -0.7);
978        assert!(output[3] > 0.9);
979        assert!(output[4] < -0.9);
980    }
981
982    #[test]
983    fn test_softmax() {
984        let input = vec![1.0, 2.0, 3.0];
985        let mut output = vec![0.0; 3];
986
987        softmax(&input, &mut output);
988
989        // Check that probabilities sum to 1
990        let sum: f32 = output.iter().sum();
991        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
992
993        // Check that largest input corresponds to largest output
994        assert!(output[2] > output[1]);
995        assert!(output[1] > output[0]);
996    }
997
998    #[test]
999    fn test_elu() {
1000        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
1001        let mut output = vec![0.0; 5];
1002        let alpha = 1.0;
1003
1004        elu(&input, &mut output, alpha);
1005
1006        // For positive values, ELU should equal input
1007        assert_relative_eq!(output[2], 0.0, epsilon = 1e-6);
1008        assert_relative_eq!(output[3], 1.0, epsilon = 1e-6);
1009        assert_relative_eq!(output[4], 2.0, epsilon = 1e-6);
1010
1011        // For negative values, ELU should be alpha * (exp(x) - 1)
1012        assert!(output[0] < 0.0 && output[0] > -alpha); // Should approach -alpha
1013        assert!(output[1] < 0.0 && output[1] > output[0]); // Less negative than output[0]
1014    }
1015
1016    #[test]
1017    fn test_swish() {
1018        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
1019        let mut output = vec![0.0; 5];
1020
1021        swish(&input, &mut output);
1022
1023        // Swish(0) should be 0
1024        assert_relative_eq!(output[2], 0.0, epsilon = 1e-3);
1025
1026        // For positive values, should be positive and increase with input
1027        assert!(output[3] > 0.0);
1028        assert!(output[4] > output[3]);
1029
1030        // For negative values, should be negative but approaching 0
1031        assert!(output[0] < 0.0);
1032        assert!(output[1] < 0.0);
1033        // Note: Swish is not monotonic for very negative values
1034        // The minimum occurs around x ≈ -1.28, so swish(-1) < swish(-2)
1035    }
1036
1037    #[test]
1038    fn test_gelu() {
1039        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
1040        let mut output = vec![0.0; 5];
1041
1042        gelu(&input, &mut output);
1043
1044        // GELU(0) should be 0
1045        assert_relative_eq!(output[2], 0.0, epsilon = 1e-3);
1046
1047        // For positive values, should be positive and roughly follow input
1048        assert!(output[3] > 0.0);
1049        assert!(output[4] > output[3]);
1050
1051        // GELU should be smooth and differentiable everywhere
1052        for &val in &output {
1053            assert!(!val.is_nan());
1054            assert!(val.is_finite());
1055        }
1056    }
1057
1058    #[test]
1059    fn test_relu_derivative() {
1060        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
1061        let mut output = vec![0.0; 5];
1062
1063        relu_derivative(&input, &mut output);
1064
1065        assert_relative_eq!(output[0], 0.0, epsilon = 1e-6); // negative -> 0
1066        assert_relative_eq!(output[1], 0.0, epsilon = 1e-6); // negative -> 0
1067        assert_relative_eq!(output[2], 0.0, epsilon = 1e-6); // zero -> 0
1068        assert_relative_eq!(output[3], 1.0, epsilon = 1e-6); // positive -> 1
1069        assert_relative_eq!(output[4], 1.0, epsilon = 1e-6); // positive -> 1
1070    }
1071
1072    #[test]
1073    fn test_sigmoid_derivative() {
1074        let input = vec![0.0, 1.0, -1.0];
1075        let mut output = vec![0.0; 3];
1076
1077        sigmoid_derivative(&input, &mut output);
1078
1079        // Sigmoid derivative at 0 should be 0.25
1080        assert_relative_eq!(output[0], 0.25, epsilon = 1e-3);
1081
1082        // All derivatives should be positive
1083        for &val in &output {
1084            assert!(val > 0.0);
1085        }
1086    }
1087
1088    #[test]
1089    fn test_tanh_derivative() {
1090        let input = vec![0.0, 1.0, -1.0];
1091        let mut output = vec![0.0; 3];
1092
1093        tanh_derivative(&input, &mut output);
1094
1095        // Tanh derivative at 0 should be 1.0
1096        assert_relative_eq!(output[0], 1.0, epsilon = 1e-3);
1097
1098        // All derivatives should be positive and <= 1
1099        for &val in &output {
1100            assert!(val > 0.0 && val <= 1.0);
1101        }
1102    }
1103
1104    #[test]
1105    fn test_activation_function_enum() {
1106        let input = vec![1.0, 2.0, 3.0];
1107        let mut output = vec![0.0; 3];
1108
1109        // Test all activation functions through the enum interface
1110        apply_activation_slice(&input, &mut output, ActivationFunction::ReLU, None);
1111        assert_eq!(output, input); // ReLU for positive inputs
1112
1113        apply_activation_slice(&input, &mut output, ActivationFunction::Sigmoid, None);
1114        assert!(output.iter().all(|&x| x > 0.0 && x < 1.0)); // Sigmoid range
1115
1116        apply_activation_slice(&input, &mut output, ActivationFunction::Softmax, None);
1117        let sum: f32 = output.iter().sum();
1118        assert_relative_eq!(sum, 1.0, epsilon = 1e-6); // Softmax sums to 1
1119    }
1120
1121    #[test]
1122    fn test_ndarray_interface() {
1123        let input_1d = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1124        let output_1d = apply_activation_1d(&input_1d, ActivationFunction::ReLU, None);
1125        assert_eq!(
1126            output_1d
1127                .as_slice()
1128                .expect("slice operation should succeed"),
1129            &[1.0, 2.0, 3.0]
1130        );
1131
1132        let input_2d = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
1133            .expect("shape and data length should match");
1134        let output_2d = apply_activation_2d(&input_2d, ActivationFunction::ReLU, None);
1135        assert_eq!(
1136            output_2d
1137                .as_slice()
1138                .expect("slice operation should succeed"),
1139            &[1.0, 2.0, 3.0, 4.0]
1140        );
1141    }
1142
1143    #[test]
1144    fn test_activation_with_alpha() {
1145        let input = vec![-1.0, 0.0, 1.0];
1146        let mut output = vec![0.0; 3];
1147
1148        // Test LeakyReLU with custom alpha
1149        apply_activation_slice(
1150            &input,
1151            &mut output,
1152            ActivationFunction::LeakyReLU,
1153            Some(0.2),
1154        );
1155        assert_relative_eq!(output[0], -0.2, epsilon = 1e-6);
1156        assert_relative_eq!(output[1], 0.0, epsilon = 1e-6);
1157        assert_relative_eq!(output[2], 1.0, epsilon = 1e-6);
1158
1159        // Test ELU with custom alpha
1160        apply_activation_slice(&input, &mut output, ActivationFunction::ELU, Some(2.0));
1161        assert!(output[0] < 0.0 && output[0] > -2.0); // Should approach -2.0 for negative inputs
1162    }
1163}