1use crate::vector::sum;
9use scirs2_autograd::ndarray::{Array1, Array2};
10
11#[cfg(feature = "no-std")]
12use alloc::vec;
13
14pub fn sigmoid(input: &[f32], output: &mut [f32]) {
16 assert_eq!(
17 input.len(),
18 output.len(),
19 "Vectors must have the same length"
20 );
21
22 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
23 {
24 if crate::simd_feature_detected!("avx2") {
25 unsafe { sigmoid_avx2(input, output) };
26 return;
27 } else if crate::simd_feature_detected!("sse2") {
28 unsafe { sigmoid_sse2(input, output) };
29 return;
30 }
31 }
32
33 sigmoid_scalar(input, output);
34}
35
36fn sigmoid_scalar(input: &[f32], output: &mut [f32]) {
37 for i in 0..input.len() {
38 output[i] = 1.0 / (1.0 + (-input[i]).exp());
39 }
40}
41
42#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
43#[target_feature(enable = "sse2")]
44unsafe fn sigmoid_sse2(input: &[f32], output: &mut [f32]) {
45 use core::arch::x86_64::*;
46
47 let mut i = 0;
48 let one = _mm_set1_ps(1.0);
49
50 while i + 4 <= input.len() {
51 let x = _mm_loadu_ps(input.as_ptr().add(i));
52
53 let neg_x = _mm_sub_ps(_mm_setzero_ps(), x);
55 let exp_neg_x = exp_approx_sse2(neg_x);
56
57 let one_plus_exp = _mm_add_ps(one, exp_neg_x);
58 let result = _mm_div_ps(one, one_plus_exp);
59
60 _mm_storeu_ps(output.as_mut_ptr().add(i), result);
61 i += 4;
62 }
63
64 while i < input.len() {
65 output[i] = 1.0 / (1.0 + (-input[i]).exp());
66 i += 1;
67 }
68}
69
70#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
71#[target_feature(enable = "avx2")]
72unsafe fn sigmoid_avx2(input: &[f32], output: &mut [f32]) {
73 use core::arch::x86_64::*;
74
75 let mut i = 0;
76 let one = _mm256_set1_ps(1.0);
77
78 while i + 8 <= input.len() {
79 let x = _mm256_loadu_ps(input.as_ptr().add(i));
80
81 let neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
82 let exp_neg_x = exp_approx_avx2(neg_x);
83
84 let one_plus_exp = _mm256_add_ps(one, exp_neg_x);
85 let result = _mm256_div_ps(one, one_plus_exp);
86
87 _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
88 i += 8;
89 }
90
91 while i < input.len() {
92 output[i] = 1.0 / (1.0 + (-input[i]).exp());
93 i += 1;
94 }
95}
96
97#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
99#[target_feature(enable = "sse2")]
100unsafe fn exp_approx_sse2(x: core::arch::x86_64::__m128) -> core::arch::x86_64::__m128 {
101 use core::arch::x86_64::*;
102
103 let min_val = _mm_set1_ps(-10.0);
105 let max_val = _mm_set1_ps(10.0);
106 let clamped = _mm_max_ps(min_val, _mm_min_ps(max_val, x));
107
108 let one = _mm_set1_ps(1.0);
110 let half = _mm_set1_ps(0.5);
111 let sixth = _mm_set1_ps(1.0 / 6.0);
112
113 let x2 = _mm_mul_ps(clamped, clamped);
114 let x3 = _mm_mul_ps(x2, clamped);
115
116 let term1 = one;
117 let term2 = clamped;
118 let term3 = _mm_mul_ps(x2, half);
119 let term4 = _mm_mul_ps(x3, sixth);
120
121 _mm_add_ps(_mm_add_ps(term1, term2), _mm_add_ps(term3, term4))
122}
123
124#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
126#[target_feature(enable = "avx2")]
127unsafe fn exp_approx_avx2(x: core::arch::x86_64::__m256) -> core::arch::x86_64::__m256 {
128 use core::arch::x86_64::*;
129
130 let min_val = _mm256_set1_ps(-10.0);
131 let max_val = _mm256_set1_ps(10.0);
132 let clamped = _mm256_max_ps(min_val, _mm256_min_ps(max_val, x));
133
134 let one = _mm256_set1_ps(1.0);
135 let half = _mm256_set1_ps(0.5);
136 let sixth = _mm256_set1_ps(1.0 / 6.0);
137
138 let x2 = _mm256_mul_ps(clamped, clamped);
139 let x3 = _mm256_mul_ps(x2, clamped);
140
141 let term1 = one;
142 let term2 = clamped;
143 let term3 = _mm256_mul_ps(x2, half);
144 let term4 = _mm256_mul_ps(x3, sixth);
145
146 _mm256_add_ps(_mm256_add_ps(term1, term2), _mm256_add_ps(term3, term4))
147}
148
149pub fn relu(input: &[f32], output: &mut [f32]) {
151 assert_eq!(
152 input.len(),
153 output.len(),
154 "Vectors must have the same length"
155 );
156
157 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
158 {
159 if crate::simd_feature_detected!("avx2") {
160 unsafe { relu_avx2(input, output) };
161 return;
162 } else if crate::simd_feature_detected!("sse2") {
163 unsafe { relu_sse2(input, output) };
164 return;
165 }
166 }
167
168 relu_scalar(input, output);
169}
170
171fn relu_scalar(input: &[f32], output: &mut [f32]) {
172 for i in 0..input.len() {
173 output[i] = input[i].max(0.0);
174 }
175}
176
177#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
178#[target_feature(enable = "sse2")]
179unsafe fn relu_sse2(input: &[f32], output: &mut [f32]) {
180 use core::arch::x86_64::*;
181
182 let mut i = 0;
183 let zero = _mm_setzero_ps();
184
185 while i + 4 <= input.len() {
186 let x = _mm_loadu_ps(input.as_ptr().add(i));
187 let result = _mm_max_ps(x, zero);
188 _mm_storeu_ps(output.as_mut_ptr().add(i), result);
189 i += 4;
190 }
191
192 while i < input.len() {
193 output[i] = input[i].max(0.0);
194 i += 1;
195 }
196}
197
198#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
199#[target_feature(enable = "avx2")]
200unsafe fn relu_avx2(input: &[f32], output: &mut [f32]) {
201 use core::arch::x86_64::*;
202
203 let mut i = 0;
204 let zero = _mm256_setzero_ps();
205
206 while i + 8 <= input.len() {
207 let x = _mm256_loadu_ps(input.as_ptr().add(i));
208 let result = _mm256_max_ps(x, zero);
209 _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
210 i += 8;
211 }
212
213 while i < input.len() {
214 output[i] = input[i].max(0.0);
215 i += 1;
216 }
217}
218
219pub fn leaky_relu(input: &[f32], output: &mut [f32], alpha: f32) {
221 assert_eq!(
222 input.len(),
223 output.len(),
224 "Vectors must have the same length"
225 );
226
227 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
228 {
229 if crate::simd_feature_detected!("avx2") {
230 unsafe { leaky_relu_avx2(input, output, alpha) };
231 return;
232 } else if crate::simd_feature_detected!("sse2") {
233 unsafe { leaky_relu_sse2(input, output, alpha) };
234 return;
235 }
236 }
237
238 leaky_relu_scalar(input, output, alpha);
239}
240
241fn leaky_relu_scalar(input: &[f32], output: &mut [f32], alpha: f32) {
242 for i in 0..input.len() {
243 output[i] = if input[i] > 0.0 {
244 input[i]
245 } else {
246 alpha * input[i]
247 };
248 }
249}
250
251#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
252#[target_feature(enable = "sse2")]
253unsafe fn leaky_relu_sse2(input: &[f32], output: &mut [f32], alpha: f32) {
254 use core::arch::x86_64::*;
255
256 let mut i = 0;
257 let zero = _mm_setzero_ps();
258 let alpha_vec = _mm_set1_ps(alpha);
259
260 while i + 4 <= input.len() {
261 let x = _mm_loadu_ps(input.as_ptr().add(i));
262 let mask = _mm_cmpgt_ps(x, zero);
263 let positive = x;
264 let negative = _mm_mul_ps(x, alpha_vec);
265 let result = _mm_blendv_ps(negative, positive, mask);
266 _mm_storeu_ps(output.as_mut_ptr().add(i), result);
267 i += 4;
268 }
269
270 while i < input.len() {
271 output[i] = if input[i] > 0.0 {
272 input[i]
273 } else {
274 alpha * input[i]
275 };
276 i += 1;
277 }
278}
279
280#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
281#[target_feature(enable = "avx2")]
282unsafe fn leaky_relu_avx2(input: &[f32], output: &mut [f32], alpha: f32) {
283 use core::arch::x86_64::*;
284
285 let mut i = 0;
286 let zero = _mm256_setzero_ps();
287 let alpha_vec = _mm256_set1_ps(alpha);
288
289 while i + 8 <= input.len() {
290 let x = _mm256_loadu_ps(input.as_ptr().add(i));
291 let mask = _mm256_cmp_ps(x, zero, _CMP_GT_OQ);
292 let positive = x;
293 let negative = _mm256_mul_ps(x, alpha_vec);
294 let result = _mm256_blendv_ps(negative, positive, mask);
295 _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
296 i += 8;
297 }
298
299 while i < input.len() {
300 output[i] = if input[i] > 0.0 {
301 input[i]
302 } else {
303 alpha * input[i]
304 };
305 i += 1;
306 }
307}
308
309pub fn tanh_activation(input: &[f32], output: &mut [f32]) {
311 assert_eq!(
312 input.len(),
313 output.len(),
314 "Vectors must have the same length"
315 );
316
317 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
318 {
319 if crate::simd_feature_detected!("avx2") {
320 unsafe { tanh_avx2(input, output) };
321 return;
322 } else if crate::simd_feature_detected!("sse2") {
323 unsafe { tanh_sse2(input, output) };
324 return;
325 }
326 }
327
328 tanh_scalar(input, output);
329}
330
331fn tanh_scalar(input: &[f32], output: &mut [f32]) {
332 for i in 0..input.len() {
333 output[i] = input[i].tanh();
334 }
335}
336
337#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
338#[target_feature(enable = "sse2")]
339unsafe fn tanh_sse2(input: &[f32], output: &mut [f32]) {
340 use core::arch::x86_64::*;
341
342 let mut i = 0;
343
344 while i + 4 <= input.len() {
345 let x = _mm_loadu_ps(input.as_ptr().add(i));
346 let result = tanh_approx_sse2(x);
347 _mm_storeu_ps(output.as_mut_ptr().add(i), result);
348 i += 4;
349 }
350
351 while i < input.len() {
352 output[i] = input[i].tanh();
353 i += 1;
354 }
355}
356
357#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
358#[target_feature(enable = "avx2")]
359unsafe fn tanh_avx2(input: &[f32], output: &mut [f32]) {
360 use core::arch::x86_64::*;
361
362 let mut i = 0;
363
364 while i + 8 <= input.len() {
365 let x = _mm256_loadu_ps(input.as_ptr().add(i));
366 let result = tanh_approx_avx2(x);
367 _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
368 i += 8;
369 }
370
371 while i < input.len() {
372 output[i] = input[i].tanh();
373 i += 1;
374 }
375}
376
377#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
379#[target_feature(enable = "sse2")]
380unsafe fn tanh_approx_sse2(x: core::arch::x86_64::__m128) -> core::arch::x86_64::__m128 {
381 use core::arch::x86_64::*;
382
383 let min_val = _mm_set1_ps(-5.0);
385 let max_val = _mm_set1_ps(5.0);
386 let clamped = _mm_max_ps(min_val, _mm_min_ps(max_val, x));
387
388 let x2 = _mm_mul_ps(clamped, clamped);
390 let third = _mm_set1_ps(1.0 / 3.0);
391 let one = _mm_set1_ps(1.0);
392
393 let term = _mm_sub_ps(one, _mm_mul_ps(x2, third));
394 _mm_mul_ps(clamped, term)
395}
396
397#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
399#[target_feature(enable = "avx2")]
400unsafe fn tanh_approx_avx2(x: core::arch::x86_64::__m256) -> core::arch::x86_64::__m256 {
401 use core::arch::x86_64::*;
402
403 let min_val = _mm256_set1_ps(-5.0);
404 let max_val = _mm256_set1_ps(5.0);
405 let clamped = _mm256_max_ps(min_val, _mm256_min_ps(max_val, x));
406
407 let x2 = _mm256_mul_ps(clamped, clamped);
408 let third = _mm256_set1_ps(1.0 / 3.0);
409 let one = _mm256_set1_ps(1.0);
410
411 let term = _mm256_sub_ps(one, _mm256_mul_ps(x2, third));
412 _mm256_mul_ps(clamped, term)
413}
414
415pub fn softmax(input: &[f32], output: &mut [f32]) {
417 assert_eq!(
418 input.len(),
419 output.len(),
420 "Vectors must have the same length"
421 );
422
423 if input.is_empty() {
424 return;
425 }
426
427 let max_val = input.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
429
430 let mut exp_values = vec![0.0; input.len()];
432 for i in 0..input.len() {
433 exp_values[i] = (input[i] - max_val).exp();
434 }
435
436 let exp_sum = sum(&exp_values);
438
439 for i in 0..input.len() {
441 output[i] = exp_values[i] / exp_sum;
442 }
443}
444
445pub fn elu(input: &[f32], output: &mut [f32], alpha: f32) {
447 assert_eq!(
448 input.len(),
449 output.len(),
450 "Vectors must have the same length"
451 );
452
453 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
454 {
455 if crate::simd_feature_detected!("avx2") {
456 unsafe { elu_avx2(input, output, alpha) };
457 return;
458 } else if crate::simd_feature_detected!("sse2") {
459 unsafe { elu_sse2(input, output, alpha) };
460 return;
461 }
462 }
463
464 elu_scalar(input, output, alpha);
465}
466
467fn elu_scalar(input: &[f32], output: &mut [f32], alpha: f32) {
468 for i in 0..input.len() {
469 output[i] = if input[i] >= 0.0 {
470 input[i]
471 } else {
472 alpha * (input[i].exp() - 1.0)
473 };
474 }
475}
476
477#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
478#[target_feature(enable = "sse2")]
479unsafe fn elu_sse2(input: &[f32], output: &mut [f32], alpha: f32) {
480 use core::arch::x86_64::*;
481
482 let zero = _mm_setzero_ps();
483 let one = _mm_set1_ps(1.0);
484 let alpha_vec = _mm_set1_ps(alpha);
485 let mut i = 0;
486
487 while i + 4 <= input.len() {
488 let x = _mm_loadu_ps(input.as_ptr().add(i));
489 let mask = _mm_cmpge_ps(x, zero);
490
491 let positive = x;
492 let exp_x = exp_approx_sse2(x);
493 let negative = _mm_mul_ps(alpha_vec, _mm_sub_ps(exp_x, one));
494
495 let result = _mm_blendv_ps(negative, positive, mask);
496 _mm_storeu_ps(output.as_mut_ptr().add(i), result);
497 i += 4;
498 }
499
500 while i < input.len() {
501 output[i] = if input[i] >= 0.0 {
502 input[i]
503 } else {
504 alpha * (input[i].exp() - 1.0)
505 };
506 i += 1;
507 }
508}
509
510#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
511#[target_feature(enable = "avx2")]
512unsafe fn elu_avx2(input: &[f32], output: &mut [f32], alpha: f32) {
513 use core::arch::x86_64::*;
514
515 let zero = _mm256_setzero_ps();
516 let one = _mm256_set1_ps(1.0);
517 let alpha_vec = _mm256_set1_ps(alpha);
518 let mut i = 0;
519
520 while i + 8 <= input.len() {
521 let x = _mm256_loadu_ps(input.as_ptr().add(i));
522 let mask = _mm256_cmp_ps(x, zero, _CMP_GE_OQ);
523
524 let positive = x;
525 let exp_x = exp_approx_avx2(x);
526 let negative = _mm256_mul_ps(alpha_vec, _mm256_sub_ps(exp_x, one));
527
528 let result = _mm256_blendv_ps(negative, positive, mask);
529 _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
530 i += 8;
531 }
532
533 while i < input.len() {
534 output[i] = if input[i] >= 0.0 {
535 input[i]
536 } else {
537 alpha * (input[i].exp() - 1.0)
538 };
539 i += 1;
540 }
541}
542
543pub fn swish(input: &[f32], output: &mut [f32]) {
545 assert_eq!(
546 input.len(),
547 output.len(),
548 "Vectors must have the same length"
549 );
550
551 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
552 {
553 if crate::simd_feature_detected!("avx2") {
554 unsafe { swish_avx2(input, output) };
555 return;
556 } else if crate::simd_feature_detected!("sse2") {
557 unsafe { swish_sse2(input, output) };
558 return;
559 }
560 }
561
562 swish_scalar(input, output);
563}
564
565fn swish_scalar(input: &[f32], output: &mut [f32]) {
566 for i in 0..input.len() {
567 let sigmoid_x = 1.0 / (1.0 + (-input[i]).exp());
568 output[i] = input[i] * sigmoid_x;
569 }
570}
571
572#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
573#[target_feature(enable = "sse2")]
574unsafe fn swish_sse2(input: &[f32], output: &mut [f32]) {
575 use core::arch::x86_64::*;
576
577 let one = _mm_set1_ps(1.0);
578 let mut i = 0;
579
580 while i + 4 <= input.len() {
581 let x = _mm_loadu_ps(input.as_ptr().add(i));
582
583 let neg_x = _mm_sub_ps(_mm_setzero_ps(), x);
584 let exp_neg_x = exp_approx_sse2(neg_x);
585 let one_plus_exp = _mm_add_ps(one, exp_neg_x);
586 let sigmoid_x = _mm_div_ps(one, one_plus_exp);
587
588 let result = _mm_mul_ps(x, sigmoid_x);
589 _mm_storeu_ps(output.as_mut_ptr().add(i), result);
590 i += 4;
591 }
592
593 while i < input.len() {
594 let sigmoid_x = 1.0 / (1.0 + (-input[i]).exp());
595 output[i] = input[i] * sigmoid_x;
596 i += 1;
597 }
598}
599
600#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
601#[target_feature(enable = "avx2")]
602unsafe fn swish_avx2(input: &[f32], output: &mut [f32]) {
603 use core::arch::x86_64::*;
604
605 let one = _mm256_set1_ps(1.0);
606 let mut i = 0;
607
608 while i + 8 <= input.len() {
609 let x = _mm256_loadu_ps(input.as_ptr().add(i));
610
611 let neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
612 let exp_neg_x = exp_approx_avx2(neg_x);
613 let one_plus_exp = _mm256_add_ps(one, exp_neg_x);
614 let sigmoid_x = _mm256_div_ps(one, one_plus_exp);
615
616 let result = _mm256_mul_ps(x, sigmoid_x);
617 _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
618 i += 8;
619 }
620
621 while i < input.len() {
622 let sigmoid_x = 1.0 / (1.0 + (-input[i]).exp());
623 output[i] = input[i] * sigmoid_x;
624 i += 1;
625 }
626}
627
628pub fn gelu(input: &[f32], output: &mut [f32]) {
630 assert_eq!(
631 input.len(),
632 output.len(),
633 "Vectors must have the same length"
634 );
635
636 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
637 {
638 if crate::simd_feature_detected!("avx2") {
639 unsafe { gelu_avx2(input, output) };
640 return;
641 } else if crate::simd_feature_detected!("sse2") {
642 unsafe { gelu_sse2(input, output) };
643 return;
644 }
645 }
646
647 gelu_scalar(input, output);
648}
649
650fn gelu_scalar(input: &[f32], output: &mut [f32]) {
651 const SQRT_2_PI: f32 = 0.797_884_6; for i in 0..input.len() {
653 let x = input[i];
654 let x_cubed = x * x * x;
656 let inner = SQRT_2_PI * (x + 0.044715 * x_cubed);
657 output[i] = 0.5 * x * (1.0 + inner.tanh());
658 }
659}
660
661#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
662#[target_feature(enable = "sse2")]
663unsafe fn gelu_sse2(input: &[f32], output: &mut [f32]) {
664 use core::arch::x86_64::*;
665
666 let sqrt_2_pi = _mm_set1_ps(0.797_884_6_f32);
667 let coeff = _mm_set1_ps(0.044715);
668 let half = _mm_set1_ps(0.5);
669 let one = _mm_set1_ps(1.0);
670 let mut i = 0;
671
672 while i + 4 <= input.len() {
673 let x = _mm_loadu_ps(input.as_ptr().add(i));
674
675 let x2 = _mm_mul_ps(x, x);
676 let x3 = _mm_mul_ps(x2, x);
677 let coeff_x3 = _mm_mul_ps(coeff, x3);
678 let inner_term = _mm_add_ps(x, coeff_x3);
679 let scaled_inner = _mm_mul_ps(sqrt_2_pi, inner_term);
680
681 let tanh_result = tanh_approx_sse2(scaled_inner);
682 let one_plus_tanh = _mm_add_ps(one, tanh_result);
683 let result = _mm_mul_ps(_mm_mul_ps(half, x), one_plus_tanh);
684
685 _mm_storeu_ps(output.as_mut_ptr().add(i), result);
686 i += 4;
687 }
688
689 while i < input.len() {
690 let x = input[i];
691 let x_cubed = x * x * x;
692 let inner = 0.797_884_6_f32 * (x + 0.044715 * x_cubed);
693 output[i] = 0.5 * x * (1.0 + inner.tanh());
694 i += 1;
695 }
696}
697
698#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
699#[target_feature(enable = "avx2")]
700unsafe fn gelu_avx2(input: &[f32], output: &mut [f32]) {
701 use core::arch::x86_64::*;
702
703 let sqrt_2_pi = _mm256_set1_ps(0.797_884_6_f32);
704 let coeff = _mm256_set1_ps(0.044715);
705 let half = _mm256_set1_ps(0.5);
706 let one = _mm256_set1_ps(1.0);
707 let mut i = 0;
708
709 while i + 8 <= input.len() {
710 let x = _mm256_loadu_ps(input.as_ptr().add(i));
711
712 let x2 = _mm256_mul_ps(x, x);
713 let x3 = _mm256_mul_ps(x2, x);
714 let coeff_x3 = _mm256_mul_ps(coeff, x3);
715 let inner_term = _mm256_add_ps(x, coeff_x3);
716 let scaled_inner = _mm256_mul_ps(sqrt_2_pi, inner_term);
717
718 let tanh_result = tanh_approx_avx2(scaled_inner);
719 let one_plus_tanh = _mm256_add_ps(one, tanh_result);
720 let result = _mm256_mul_ps(_mm256_mul_ps(half, x), one_plus_tanh);
721
722 _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
723 i += 8;
724 }
725
726 while i < input.len() {
727 let x = input[i];
728 let x_cubed = x * x * x;
729 let inner = 0.797_884_6_f32 * (x + 0.044715 * x_cubed);
730 output[i] = 0.5 * x * (1.0 + inner.tanh());
731 i += 1;
732 }
733}
734
735pub fn sigmoid_derivative(input: &[f32], output: &mut [f32]) {
739 assert_eq!(
740 input.len(),
741 output.len(),
742 "Vectors must have the same length"
743 );
744
745 let mut sigmoid_vals = vec![0.0; input.len()];
747 sigmoid(input, &mut sigmoid_vals);
748
749 for i in 0..input.len() {
750 output[i] = sigmoid_vals[i] * (1.0 - sigmoid_vals[i]);
751 }
752}
753
754pub fn relu_derivative(input: &[f32], output: &mut [f32]) {
756 assert_eq!(
757 input.len(),
758 output.len(),
759 "Vectors must have the same length"
760 );
761
762 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
763 {
764 if crate::simd_feature_detected!("avx2") {
765 unsafe { relu_derivative_avx2(input, output) };
766 return;
767 } else if crate::simd_feature_detected!("sse2") {
768 unsafe { relu_derivative_sse2(input, output) };
769 return;
770 }
771 }
772
773 relu_derivative_scalar(input, output);
774}
775
776fn relu_derivative_scalar(input: &[f32], output: &mut [f32]) {
777 for i in 0..input.len() {
778 output[i] = if input[i] > 0.0 { 1.0 } else { 0.0 };
779 }
780}
781
782#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
783#[target_feature(enable = "sse2")]
784unsafe fn relu_derivative_sse2(input: &[f32], output: &mut [f32]) {
785 use core::arch::x86_64::*;
786
787 let zero = _mm_setzero_ps();
788 let one = _mm_set1_ps(1.0);
789 let mut i = 0;
790
791 while i + 4 <= input.len() {
792 let x = _mm_loadu_ps(input.as_ptr().add(i));
793 let mask = _mm_cmpgt_ps(x, zero);
794 let result = _mm_and_ps(mask, one);
795 _mm_storeu_ps(output.as_mut_ptr().add(i), result);
796 i += 4;
797 }
798
799 while i < input.len() {
800 output[i] = if input[i] > 0.0 { 1.0 } else { 0.0 };
801 i += 1;
802 }
803}
804
805#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
806#[target_feature(enable = "avx2")]
807unsafe fn relu_derivative_avx2(input: &[f32], output: &mut [f32]) {
808 use core::arch::x86_64::*;
809
810 let zero = _mm256_setzero_ps();
811 let one = _mm256_set1_ps(1.0);
812 let mut i = 0;
813
814 while i + 8 <= input.len() {
815 let x = _mm256_loadu_ps(input.as_ptr().add(i));
816 let mask = _mm256_cmp_ps(x, zero, _CMP_GT_OQ);
817 let result = _mm256_and_ps(mask, one);
818 _mm256_storeu_ps(output.as_mut_ptr().add(i), result);
819 i += 8;
820 }
821
822 while i < input.len() {
823 output[i] = if input[i] > 0.0 { 1.0 } else { 0.0 };
824 i += 1;
825 }
826}
827
828pub fn tanh_derivative(input: &[f32], output: &mut [f32]) {
830 assert_eq!(
831 input.len(),
832 output.len(),
833 "Vectors must have the same length"
834 );
835
836 let mut tanh_vals = vec![0.0; input.len()];
838 tanh_activation(input, &mut tanh_vals);
839
840 for i in 0..input.len() {
841 output[i] = 1.0 - tanh_vals[i] * tanh_vals[i];
842 }
843}
844
845pub fn apply_activation_1d(
849 input: &Array1<f32>,
850 activation: ActivationFunction,
851 alpha: Option<f32>,
852) -> Array1<f32> {
853 let mut output = Array1::zeros(input.len());
854 apply_activation_slice(
855 input.as_slice().expect("slice operation should succeed"),
856 output
857 .as_slice_mut()
858 .expect("slice operation should succeed"),
859 activation,
860 alpha,
861 );
862 output
863}
864
865pub fn apply_activation_2d(
867 input: &Array2<f32>,
868 activation: ActivationFunction,
869 alpha: Option<f32>,
870) -> Array2<f32> {
871 let mut output = Array2::zeros(input.dim());
872 if let (Some(input_slice), Some(output_slice)) = (input.as_slice(), output.as_slice_mut()) {
873 apply_activation_slice(input_slice, output_slice, activation, alpha);
874 }
875 output
876}
877
878pub fn apply_activation_slice(
880 input: &[f32],
881 output: &mut [f32],
882 activation: ActivationFunction,
883 alpha: Option<f32>,
884) {
885 match activation {
886 ActivationFunction::ReLU => relu(input, output),
887 ActivationFunction::Sigmoid => sigmoid(input, output),
888 ActivationFunction::Tanh => tanh_activation(input, output),
889 ActivationFunction::LeakyReLU => {
890 let alpha_val = alpha.unwrap_or(0.01);
891 leaky_relu(input, output, alpha_val);
892 }
893 ActivationFunction::ELU => {
894 let alpha_val = alpha.unwrap_or(1.0);
895 elu(input, output, alpha_val);
896 }
897 ActivationFunction::Swish => swish(input, output),
898 ActivationFunction::GELU => gelu(input, output),
899 ActivationFunction::Softmax => softmax(input, output),
900 }
901}
902
903#[derive(Debug, Clone, Copy, PartialEq, Eq)]
905pub enum ActivationFunction {
906 ReLU,
907 Sigmoid,
908 Tanh,
909 LeakyReLU,
910 ELU,
911 Swish,
912 GELU,
913 Softmax,
914}
915
916#[allow(non_snake_case)]
917#[cfg(all(test, not(feature = "no-std")))]
918mod tests {
919 use super::*;
920 use approx::assert_relative_eq;
921
922 #[cfg(feature = "no-std")]
923 use alloc::{vec, vec::Vec};
924
925 #[test]
926 fn test_sigmoid() {
927 let input = vec![0.0, 1.0, -1.0, 2.0, -2.0];
928 let mut output = vec![0.0; 5];
929
930 sigmoid(&input, &mut output);
931
932 assert_relative_eq!(output[0], 0.5, epsilon = 1e-3);
933 assert!(output[1] > 0.7 && output[1] < 0.8);
934 assert!(output[2] > 0.2 && output[2] < 0.3);
935 assert!(output[3] > 0.8 && output[3] < 0.9);
936 assert!(output[4] > 0.1 && output[4] < 0.2);
937 }
938
939 #[test]
940 fn test_relu() {
941 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
942 let mut output = vec![0.0; 5];
943
944 relu(&input, &mut output);
945
946 assert_relative_eq!(output[0], 0.0, epsilon = 1e-6);
947 assert_relative_eq!(output[1], 0.0, epsilon = 1e-6);
948 assert_relative_eq!(output[2], 0.0, epsilon = 1e-6);
949 assert_relative_eq!(output[3], 1.0, epsilon = 1e-6);
950 assert_relative_eq!(output[4], 2.0, epsilon = 1e-6);
951 }
952
953 #[test]
954 fn test_leaky_relu() {
955 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
956 let mut output = vec![0.0; 5];
957 let alpha = 0.1;
958
959 leaky_relu(&input, &mut output, alpha);
960
961 assert_relative_eq!(output[0], -0.2, epsilon = 1e-6);
962 assert_relative_eq!(output[1], -0.1, epsilon = 1e-6);
963 assert_relative_eq!(output[2], 0.0, epsilon = 1e-6);
964 assert_relative_eq!(output[3], 1.0, epsilon = 1e-6);
965 assert_relative_eq!(output[4], 2.0, epsilon = 1e-6);
966 }
967
968 #[test]
969 fn test_tanh_activation() {
970 let input = vec![0.0, 1.0, -1.0, 2.0, -2.0];
971 let mut output = vec![0.0; 5];
972
973 tanh_activation(&input, &mut output);
974
975 assert_relative_eq!(output[0], 0.0, epsilon = 1e-3);
976 assert!(output[1] > 0.7 && output[1] < 0.8);
977 assert!(output[2] > -0.8 && output[2] < -0.7);
978 assert!(output[3] > 0.9);
979 assert!(output[4] < -0.9);
980 }
981
982 #[test]
983 fn test_softmax() {
984 let input = vec![1.0, 2.0, 3.0];
985 let mut output = vec![0.0; 3];
986
987 softmax(&input, &mut output);
988
989 let sum: f32 = output.iter().sum();
991 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
992
993 assert!(output[2] > output[1]);
995 assert!(output[1] > output[0]);
996 }
997
998 #[test]
999 fn test_elu() {
1000 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
1001 let mut output = vec![0.0; 5];
1002 let alpha = 1.0;
1003
1004 elu(&input, &mut output, alpha);
1005
1006 assert_relative_eq!(output[2], 0.0, epsilon = 1e-6);
1008 assert_relative_eq!(output[3], 1.0, epsilon = 1e-6);
1009 assert_relative_eq!(output[4], 2.0, epsilon = 1e-6);
1010
1011 assert!(output[0] < 0.0 && output[0] > -alpha); assert!(output[1] < 0.0 && output[1] > output[0]); }
1015
1016 #[test]
1017 fn test_swish() {
1018 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
1019 let mut output = vec![0.0; 5];
1020
1021 swish(&input, &mut output);
1022
1023 assert_relative_eq!(output[2], 0.0, epsilon = 1e-3);
1025
1026 assert!(output[3] > 0.0);
1028 assert!(output[4] > output[3]);
1029
1030 assert!(output[0] < 0.0);
1032 assert!(output[1] < 0.0);
1033 }
1036
1037 #[test]
1038 fn test_gelu() {
1039 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
1040 let mut output = vec![0.0; 5];
1041
1042 gelu(&input, &mut output);
1043
1044 assert_relative_eq!(output[2], 0.0, epsilon = 1e-3);
1046
1047 assert!(output[3] > 0.0);
1049 assert!(output[4] > output[3]);
1050
1051 for &val in &output {
1053 assert!(!val.is_nan());
1054 assert!(val.is_finite());
1055 }
1056 }
1057
1058 #[test]
1059 fn test_relu_derivative() {
1060 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
1061 let mut output = vec![0.0; 5];
1062
1063 relu_derivative(&input, &mut output);
1064
1065 assert_relative_eq!(output[0], 0.0, epsilon = 1e-6); assert_relative_eq!(output[1], 0.0, epsilon = 1e-6); assert_relative_eq!(output[2], 0.0, epsilon = 1e-6); assert_relative_eq!(output[3], 1.0, epsilon = 1e-6); assert_relative_eq!(output[4], 1.0, epsilon = 1e-6); }
1071
1072 #[test]
1073 fn test_sigmoid_derivative() {
1074 let input = vec![0.0, 1.0, -1.0];
1075 let mut output = vec![0.0; 3];
1076
1077 sigmoid_derivative(&input, &mut output);
1078
1079 assert_relative_eq!(output[0], 0.25, epsilon = 1e-3);
1081
1082 for &val in &output {
1084 assert!(val > 0.0);
1085 }
1086 }
1087
1088 #[test]
1089 fn test_tanh_derivative() {
1090 let input = vec![0.0, 1.0, -1.0];
1091 let mut output = vec![0.0; 3];
1092
1093 tanh_derivative(&input, &mut output);
1094
1095 assert_relative_eq!(output[0], 1.0, epsilon = 1e-3);
1097
1098 for &val in &output {
1100 assert!(val > 0.0 && val <= 1.0);
1101 }
1102 }
1103
1104 #[test]
1105 fn test_activation_function_enum() {
1106 let input = vec![1.0, 2.0, 3.0];
1107 let mut output = vec![0.0; 3];
1108
1109 apply_activation_slice(&input, &mut output, ActivationFunction::ReLU, None);
1111 assert_eq!(output, input); apply_activation_slice(&input, &mut output, ActivationFunction::Sigmoid, None);
1114 assert!(output.iter().all(|&x| x > 0.0 && x < 1.0)); apply_activation_slice(&input, &mut output, ActivationFunction::Softmax, None);
1117 let sum: f32 = output.iter().sum();
1118 assert_relative_eq!(sum, 1.0, epsilon = 1e-6); }
1120
1121 #[test]
1122 fn test_ndarray_interface() {
1123 let input_1d = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1124 let output_1d = apply_activation_1d(&input_1d, ActivationFunction::ReLU, None);
1125 assert_eq!(
1126 output_1d
1127 .as_slice()
1128 .expect("slice operation should succeed"),
1129 &[1.0, 2.0, 3.0]
1130 );
1131
1132 let input_2d = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
1133 .expect("shape and data length should match");
1134 let output_2d = apply_activation_2d(&input_2d, ActivationFunction::ReLU, None);
1135 assert_eq!(
1136 output_2d
1137 .as_slice()
1138 .expect("slice operation should succeed"),
1139 &[1.0, 2.0, 3.0, 4.0]
1140 );
1141 }
1142
1143 #[test]
1144 fn test_activation_with_alpha() {
1145 let input = vec![-1.0, 0.0, 1.0];
1146 let mut output = vec![0.0; 3];
1147
1148 apply_activation_slice(
1150 &input,
1151 &mut output,
1152 ActivationFunction::LeakyReLU,
1153 Some(0.2),
1154 );
1155 assert_relative_eq!(output[0], -0.2, epsilon = 1e-6);
1156 assert_relative_eq!(output[1], 0.0, epsilon = 1e-6);
1157 assert_relative_eq!(output[2], 1.0, epsilon = 1e-6);
1158
1159 apply_activation_slice(&input, &mut output, ActivationFunction::ELU, Some(2.0));
1161 assert!(output[0] < 0.0 && output[0] > -2.0); }
1163}