1#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
34use std::is_x86_feature_detected;
35
36#[inline]
44fn dot_scalar(a: &[f64], b: &[f64]) -> f64 {
45 let n = a.len().min(b.len());
46 let mut sum = 0.0;
47 for i in 0..n {
48 sum += a[i] * b[i];
49 }
50 sum
51}
52
53#[inline]
58fn mat_vec_scalar(w: &[f64], x: &[f64], _rows: usize, cols: usize, out: &mut [f64]) {
59 for (row, out_i) in out.iter_mut().enumerate() {
60 let start = row * cols;
61 let mut sum = 0.0;
62 for j in 0..cols {
63 sum += w[start + j] * x[j];
64 }
65 *out_i = sum;
66 }
67}
68
69#[inline]
71fn tanh_scalar(input: &[f64], output: &mut [f64]) {
72 for (i, &x) in input.iter().enumerate() {
73 output[i] = crate::math::tanh(x);
74 }
75}
76
77#[inline]
79fn exp_scalar(input: &[f64], output: &mut [f64]) {
80 for (i, &x) in input.iter().enumerate() {
81 output[i] = crate::math::exp(x);
82 }
83}
84
85#[inline]
87fn sigmoid_scalar(input: &[f64], output: &mut [f64]) {
88 for (i, &x) in input.iter().enumerate() {
89 output[i] = crate::math::sigmoid(x);
90 }
91}
92
93#[inline]
95fn silu_scalar(input: &[f64], output: &mut [f64]) {
96 for (i, &x) in input.iter().enumerate() {
97 output[i] = x * crate::math::sigmoid(x);
98 }
99}
100
101#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
106mod avx2 {
107 #[target_feature(enable = "avx2")]
114 pub(super) unsafe fn dot_avx2(a: &[f64], b: &[f64]) -> f64 {
115 #[cfg(target_arch = "x86_64")]
116 use core::arch::x86_64::*;
117
118 let n = a.len().min(b.len());
119 let chunks = n / 4;
120 let remainder = n % 4;
121
122 let a_ptr = a.as_ptr();
123 let b_ptr = b.as_ptr();
124
125 unsafe {
128 let mut acc = _mm256_setzero_pd();
129
130 for i in 0..chunks {
131 let offset = i * 4;
132 let va = _mm256_loadu_pd(a_ptr.add(offset));
133 let vb = _mm256_loadu_pd(b_ptr.add(offset));
134 acc = _mm256_add_pd(acc, _mm256_mul_pd(va, vb));
135 }
136
137 let hi128 = _mm256_extractf128_pd(acc, 1); let lo128 = _mm256_castpd256_pd128(acc); let pair = _mm_add_pd(lo128, hi128); let high64 = _mm_unpackhi_pd(pair, pair); let total = _mm_add_sd(pair, high64); let mut scalar_sum = _mm_cvtsd_f64(total);
144
145 let base = chunks * 4;
147 for i in 0..remainder {
148 scalar_sum += *a_ptr.add(base + i) * *b_ptr.add(base + i);
149 }
150
151 scalar_sum
152 }
153 }
154
155 #[target_feature(enable = "avx2")]
165 pub(super) unsafe fn mat_vec_avx2(
166 w: &[f64],
167 x: &[f64],
168 _rows: usize,
169 cols: usize,
170 out: &mut [f64],
171 ) {
172 for (row, out_i) in out.iter_mut().enumerate() {
173 let row_start = row * cols;
174 unsafe {
177 *out_i = dot_avx2(&w[row_start..row_start + cols], &x[..cols]);
178 }
179 }
180 }
181
182 #[target_feature(enable = "avx2")]
194 pub(super) unsafe fn tanh_avx2(input: &[f64], output: &mut [f64]) {
195 #[cfg(target_arch = "x86_64")]
196 use core::arch::x86_64::*;
197
198 let n = input.len();
199 let chunks = n / 4;
200
201 unsafe {
204 let c15 = _mm256_set1_pd(15.0);
205 let c6 = _mm256_set1_pd(6.0);
206 let pos_sat = _mm256_set1_pd(4.97);
207 let neg_sat = _mm256_set1_pd(-4.97);
208 let one = _mm256_set1_pd(1.0);
209 let neg_one = _mm256_set1_pd(-1.0);
210
211 for i in 0..chunks {
212 let off = i * 4;
213 let x = _mm256_loadu_pd(input.as_ptr().add(off));
214 let x2 = _mm256_mul_pd(x, x);
215
216 let numer = _mm256_mul_pd(x, _mm256_add_pd(c15, x2));
218 let denom = _mm256_add_pd(c15, _mm256_mul_pd(c6, x2));
219 let approx = _mm256_div_pd(numer, denom);
220
221 let clamped = _mm256_min_pd(one, _mm256_max_pd(neg_one, approx));
223
224 let sat_pos = _mm256_cmp_pd(x, pos_sat, _CMP_GT_OQ);
226 let sat_neg = _mm256_cmp_pd(x, neg_sat, _CMP_LT_OQ);
227 let result = _mm256_blendv_pd(clamped, one, sat_pos);
228 let result = _mm256_blendv_pd(result, neg_one, sat_neg);
229
230 _mm256_storeu_pd(output.as_mut_ptr().add(off), result);
231 }
232 }
233
234 for i in (chunks * 4)..n {
236 output[i] = crate::math::tanh(input[i]);
237 }
238 }
239
240 #[target_feature(enable = "avx2")]
250 pub(super) unsafe fn exp_avx2(input: &[f64], output: &mut [f64]) {
251 #[cfg(target_arch = "x86_64")]
252 use core::arch::x86_64::*;
253
254 let n = input.len();
255 let chunks = n / 4;
256
257 unsafe {
258 let ln2 = _mm256_set1_pd(core::f64::consts::LN_2);
259 let log2e = _mm256_set1_pd(core::f64::consts::LOG2_E);
260 let clamp_hi = _mm256_set1_pd(708.0);
261 let clamp_lo = _mm256_set1_pd(-708.0);
262 let one = _mm256_set1_pd(1.0);
263 let half = _mm256_set1_pd(0.5);
264 let c3 = _mm256_set1_pd(1.0 / 6.0);
265 let c4 = _mm256_set1_pd(1.0 / 24.0);
266 let c5 = _mm256_set1_pd(1.0 / 120.0);
267 let bias = _mm256_set1_epi64x(1023);
268
269 for i in 0..chunks {
270 let off = i * 4;
271 let x = _mm256_loadu_pd(input.as_ptr().add(off));
272
273 let x = _mm256_min_pd(clamp_hi, _mm256_max_pd(clamp_lo, x));
275
276 let x_scaled = _mm256_mul_pd(x, log2e);
278 let n_f = _mm256_floor_pd(_mm256_add_pd(x_scaled, half));
279 let r = _mm256_sub_pd(x, _mm256_mul_pd(n_f, ln2));
280
281 let mut p = _mm256_add_pd(c4, _mm256_mul_pd(c5, r));
283 p = _mm256_add_pd(c3, _mm256_mul_pd(p, r));
284 p = _mm256_add_pd(half, _mm256_mul_pd(p, r));
285 p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
286 p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
287
288 let n_i32 = _mm256_cvtpd_epi32(n_f);
290 let n_i64 = _mm256_cvtepi32_epi64(n_i32);
291 let shifted = _mm256_slli_epi64(_mm256_add_epi64(n_i64, bias), 52);
292 let pow2n = _mm256_castsi256_pd(shifted);
293 let result = _mm256_mul_pd(p, pow2n);
294
295 _mm256_storeu_pd(output.as_mut_ptr().add(off), result);
296 }
297 }
298
299 for i in (chunks * 4)..n {
301 output[i] = crate::math::exp(input[i]);
302 }
303 }
304
305 #[target_feature(enable = "avx2")]
314 pub(super) unsafe fn sigmoid_avx2(input: &[f64], output: &mut [f64]) {
315 #[cfg(target_arch = "x86_64")]
316 use core::arch::x86_64::*;
317
318 let n = input.len();
319 let chunks = n / 4;
320
321 unsafe {
322 let ln2 = _mm256_set1_pd(core::f64::consts::LN_2);
323 let log2e = _mm256_set1_pd(core::f64::consts::LOG2_E);
324 let clamp_hi = _mm256_set1_pd(708.0);
325 let clamp_lo = _mm256_set1_pd(-708.0);
326 let one = _mm256_set1_pd(1.0);
327 let half = _mm256_set1_pd(0.5);
328 let c3 = _mm256_set1_pd(1.0 / 6.0);
329 let c4 = _mm256_set1_pd(1.0 / 24.0);
330 let c5 = _mm256_set1_pd(1.0 / 120.0);
331 let bias = _mm256_set1_epi64x(1023);
332 let neg_one = _mm256_set1_pd(-1.0);
333
334 for i in 0..chunks {
335 let off = i * 4;
336 let x = _mm256_loadu_pd(input.as_ptr().add(off));
337
338 let neg_x = _mm256_mul_pd(x, neg_one);
340 let neg_x = _mm256_min_pd(clamp_hi, _mm256_max_pd(clamp_lo, neg_x));
341
342 let x_scaled = _mm256_mul_pd(neg_x, log2e);
344 let n_f = _mm256_floor_pd(_mm256_add_pd(x_scaled, half));
345 let r = _mm256_sub_pd(neg_x, _mm256_mul_pd(n_f, ln2));
346
347 let mut p = _mm256_add_pd(c4, _mm256_mul_pd(c5, r));
349 p = _mm256_add_pd(c3, _mm256_mul_pd(p, r));
350 p = _mm256_add_pd(half, _mm256_mul_pd(p, r));
351 p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
352 p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
353
354 let n_i32 = _mm256_cvtpd_epi32(n_f);
356 let n_i64 = _mm256_cvtepi32_epi64(n_i32);
357 let shifted = _mm256_slli_epi64(_mm256_add_epi64(n_i64, bias), 52);
358 let pow2n = _mm256_castsi256_pd(shifted);
359 let exp_neg_x = _mm256_mul_pd(p, pow2n);
360
361 let result = _mm256_div_pd(one, _mm256_add_pd(one, exp_neg_x));
363
364 _mm256_storeu_pd(output.as_mut_ptr().add(off), result);
365 }
366 }
367
368 for i in (chunks * 4)..n {
370 output[i] = crate::math::sigmoid(input[i]);
371 }
372 }
373
374 #[target_feature(enable = "avx2")]
382 pub(super) unsafe fn silu_avx2(input: &[f64], output: &mut [f64]) {
383 #[cfg(target_arch = "x86_64")]
384 use core::arch::x86_64::*;
385
386 unsafe {
388 sigmoid_avx2(input, output);
389 }
390
391 let n = input.len();
393 let chunks = n / 4;
394 unsafe {
395 for i in 0..chunks {
396 let off = i * 4;
397 let x = _mm256_loadu_pd(input.as_ptr().add(off));
398 let sig = _mm256_loadu_pd(output.as_ptr().add(off));
399 _mm256_storeu_pd(output.as_mut_ptr().add(off), _mm256_mul_pd(x, sig));
400 }
401 }
402 for i in (chunks * 4)..n {
404 output[i] *= input[i];
405 }
406 }
407}
408
409pub fn simd_dot(a: &[f64], b: &[f64]) -> f64 {
431 #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
432 {
433 if is_x86_feature_detected!("avx2") {
434 return unsafe { avx2::dot_avx2(a, b) };
436 }
437 }
438 dot_scalar(a, b)
439}
440
441pub fn simd_mat_vec(w: &[f64], x: &[f64], rows: usize, cols: usize, out: &mut [f64]) {
465 assert!(
466 w.len() >= rows * cols,
467 "simd_mat_vec: w.len()={} < rows*cols={}",
468 w.len(),
469 rows * cols
470 );
471 assert!(
472 out.len() >= rows,
473 "simd_mat_vec: out.len()={} < rows={}",
474 out.len(),
475 rows
476 );
477 assert!(
478 x.len() >= cols,
479 "simd_mat_vec: x.len()={} < cols={}",
480 x.len(),
481 cols
482 );
483
484 #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
485 {
486 if is_x86_feature_detected!("avx2") {
487 unsafe {
489 avx2::mat_vec_avx2(w, x, rows, cols, out);
490 }
491 return;
492 }
493 }
494 mat_vec_scalar(w, x, rows, cols, out);
495}
496
497pub fn simd_tanh(input: &[f64], output: &mut [f64]) {
518 assert!(
519 output.len() >= input.len(),
520 "simd_tanh: output.len()={} < input.len()={}",
521 output.len(),
522 input.len()
523 );
524 #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
525 {
526 if is_x86_feature_detected!("avx2") {
527 unsafe {
529 avx2::tanh_avx2(input, output);
530 }
531 return;
532 }
533 }
534 tanh_scalar(input, output);
535}
536
537pub fn simd_exp(input: &[f64], output: &mut [f64]) {
558 assert!(
559 output.len() >= input.len(),
560 "simd_exp: output.len()={} < input.len()={}",
561 output.len(),
562 input.len()
563 );
564 #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
565 {
566 if is_x86_feature_detected!("avx2") {
567 unsafe {
569 avx2::exp_avx2(input, output);
570 }
571 return;
572 }
573 }
574 exp_scalar(input, output);
575}
576
577pub fn simd_sigmoid(input: &[f64], output: &mut [f64]) {
597 assert!(
598 output.len() >= input.len(),
599 "simd_sigmoid: output.len()={} < input.len()={}",
600 output.len(),
601 input.len()
602 );
603 #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
604 {
605 if is_x86_feature_detected!("avx2") {
606 unsafe {
608 avx2::sigmoid_avx2(input, output);
609 }
610 return;
611 }
612 }
613 sigmoid_scalar(input, output);
614}
615
616pub fn simd_silu(input: &[f64], output: &mut [f64]) {
638 assert!(
639 output.len() >= input.len(),
640 "simd_silu: output.len()={} < input.len()={}",
641 output.len(),
642 input.len()
643 );
644 #[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
645 {
646 if is_x86_feature_detected!("avx2") {
647 unsafe {
649 avx2::silu_avx2(input, output);
650 }
651 return;
652 }
653 }
654 silu_scalar(input, output);
655}
656
657#[cfg(test)]
662mod tests {
663 use super::*;
664 use alloc::vec;
665 use alloc::vec::Vec;
666
667 struct TestRng(u64);
669
670 impl TestRng {
671 fn new(seed: u64) -> Self {
672 Self(seed)
673 }
674
675 fn next_u64(&mut self) -> u64 {
676 let mut x = self.0;
677 x ^= x << 13;
678 x ^= x >> 7;
679 x ^= x << 17;
680 self.0 = x;
681 x
682 }
683
684 fn next_f64(&mut self) -> f64 {
685 (self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64) * 2.0 - 1.0
687 }
688
689 fn fill_vec(&mut self, n: usize) -> Vec<f64> {
690 (0..n).map(|_| self.next_f64()).collect()
691 }
692 }
693
694 #[test]
699 fn dot_empty_returns_zero() {
700 let a: [f64; 0] = [];
701 let b: [f64; 0] = [];
702 assert_eq!(simd_dot(&a, &b), 0.0, "dot of empty slices should be 0");
703 }
704
705 #[test]
706 fn dot_single_element() {
707 let a = [3.0];
708 let b = [4.0];
709 assert!(
710 (simd_dot(&a, &b) - 12.0).abs() < 1e-12,
711 "dot([3], [4]) should be 12, got {}",
712 simd_dot(&a, &b)
713 );
714 }
715
716 #[test]
717 fn dot_known_result() {
718 let a = [1.0, 2.0, 3.0];
719 let b = [4.0, 5.0, 6.0];
720 let result = simd_dot(&a, &b);
721 assert!(
722 (result - 32.0).abs() < 1e-12,
723 "dot([1,2,3], [4,5,6]) should be 32, got {}",
724 result
725 );
726 }
727
728 #[test]
729 fn dot_large_matches_scalar() {
730 let mut rng = TestRng::new(42);
731 let a = rng.fill_vec(1000);
732 let b = rng.fill_vec(1000);
733
734 let simd_result = simd_dot(&a, &b);
735 let scalar_result = dot_scalar(&a, &b);
736
737 assert!(
738 (simd_result - scalar_result).abs() < 1e-9,
739 "1000-element dot: SIMD={} vs scalar={}, diff={}",
740 simd_result,
741 scalar_result,
742 (simd_result - scalar_result).abs()
743 );
744 }
745
746 #[test]
747 fn dot_mismatched_lengths() {
748 let a = [1.0, 2.0, 3.0, 999.0];
750 let b = [4.0, 5.0, 6.0];
751 let result = simd_dot(&a, &b);
752 assert!(
753 (result - 32.0).abs() < 1e-12,
754 "mismatched lengths should use min, expected 32, got {}",
755 result
756 );
757 }
758
759 #[test]
760 fn dot_non_aligned_length() {
761 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
763 let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
764 let result = simd_dot(&a, &b);
765 assert!(
766 (result - 28.0).abs() < 1e-12,
767 "dot of [1..7] with [1..1] should be 28, got {}",
768 result
769 );
770 }
771
772 #[test]
773 fn dot_negative_values() {
774 let a = [-1.0, -2.0, -3.0, -4.0];
775 let b = [4.0, 3.0, 2.0, 1.0];
776 let result = simd_dot(&a, &b);
778 assert!(
779 (result - (-20.0)).abs() < 1e-12,
780 "expected -20, got {}",
781 result
782 );
783 }
784
785 #[test]
786 fn dot_orthogonal_vectors() {
787 let a = [1.0, 0.0, 0.0, 0.0];
788 let b = [0.0, 1.0, 0.0, 0.0];
789 let result = simd_dot(&a, &b);
790 assert!(
791 result.abs() < 1e-12,
792 "orthogonal vectors should have dot=0, got {}",
793 result
794 );
795 }
796
797 #[test]
802 fn mat_vec_identity_like() {
803 let w = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
805 let x = [1.0, 2.0, 3.0];
806 let mut out = [0.0; 3];
807 simd_mat_vec(&w, &x, 3, 3, &mut out);
808 assert!(
809 (out[0] - 1.0).abs() < 1e-12,
810 "identity row 0: expected 1, got {}",
811 out[0]
812 );
813 assert!(
814 (out[1] - 2.0).abs() < 1e-12,
815 "identity row 1: expected 2, got {}",
816 out[1]
817 );
818 assert!(
819 (out[2] - 3.0).abs() < 1e-12,
820 "identity row 2: expected 3, got {}",
821 out[2]
822 );
823 }
824
825 #[test]
826 fn mat_vec_known_result() {
827 let w = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
832 let x = [1.0, 2.0, 3.0];
833 let mut out = [0.0; 2];
834 simd_mat_vec(&w, &x, 2, 3, &mut out);
835 assert!(
836 (out[0] - 14.0).abs() < 1e-12,
837 "row 0: expected 14, got {}",
838 out[0]
839 );
840 assert!(
841 (out[1] - 32.0).abs() < 1e-12,
842 "row 1: expected 32, got {}",
843 out[1]
844 );
845 }
846
847 #[test]
848 fn mat_vec_large_matches_scalar() {
849 let mut rng = TestRng::new(7777);
850 let rows = 100;
851 let cols = 100;
852 let w = rng.fill_vec(rows * cols);
853 let x = rng.fill_vec(cols);
854 let mut out_simd = vec![0.0; rows];
855 let mut out_scalar = vec![0.0; rows];
856
857 simd_mat_vec(&w, &x, rows, cols, &mut out_simd);
858 mat_vec_scalar(&w, &x, rows, cols, &mut out_scalar);
859
860 for i in 0..rows {
861 assert!(
862 (out_simd[i] - out_scalar[i]).abs() < 1e-9,
863 "row {}: SIMD={} vs scalar={}, diff={}",
864 i,
865 out_simd[i],
866 out_scalar[i],
867 (out_simd[i] - out_scalar[i]).abs()
868 );
869 }
870 }
871
872 #[test]
873 fn mat_vec_single_row() {
874 let w = [1.0, 2.0, 3.0, 4.0, 5.0];
876 let x = [2.0, 2.0, 2.0, 2.0, 2.0];
877 let mut out = [0.0; 1];
878 simd_mat_vec(&w, &x, 1, 5, &mut out);
879 assert!(
881 (out[0] - 30.0).abs() < 1e-12,
882 "single-row mat_vec should be dot product, expected 30, got {}",
883 out[0]
884 );
885 }
886
887 #[test]
888 fn mat_vec_single_element() {
889 let w = [7.0];
890 let x = [3.0];
891 let mut out = [0.0; 1];
892 simd_mat_vec(&w, &x, 1, 1, &mut out);
893 assert!(
894 (out[0] - 21.0).abs() < 1e-12,
895 "1x1 mat_vec: 7*3=21, got {}",
896 out[0]
897 );
898 }
899
900 #[test]
905 #[should_panic(expected = "simd_mat_vec: w.len()")]
906 fn mat_vec_panics_w_too_short() {
907 let w = [1.0, 2.0]; let x = [1.0, 2.0, 3.0];
909 let mut out = [0.0; 2];
910 simd_mat_vec(&w, &x, 2, 3, &mut out);
911 }
912
913 #[test]
914 #[should_panic(expected = "simd_mat_vec: out.len()")]
915 fn mat_vec_panics_out_too_short() {
916 let w = [1.0; 6];
917 let x = [1.0; 3];
918 let mut out = [0.0; 1]; simd_mat_vec(&w, &x, 2, 3, &mut out);
920 }
921
922 #[test]
923 #[should_panic(expected = "simd_mat_vec: x.len()")]
924 fn mat_vec_panics_x_too_short() {
925 let w = [1.0; 6];
926 let x = [1.0; 2]; let mut out = [0.0; 2];
928 simd_mat_vec(&w, &x, 2, 3, &mut out);
929 }
930
931 #[cfg(all(target_arch = "x86_64", feature = "std"))]
936 #[test]
937 fn simd_available_on_x86() {
938 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
941 let b = [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
942 let result = simd_dot(&a, &b);
943 assert!(
945 (result - 120.0).abs() < 1e-12,
946 "8-element dot product should be 120, got {}",
947 result
948 );
949
950 }
952
953 #[test]
958 fn tanh_known_values() {
959 let input = [0.0, 1.0, -1.0, 5.0, -5.0, 0.5];
960 let mut output = [0.0; 6];
961 simd_tanh(&input, &mut output);
962 let expected = [0.0, 0.7616, -0.7616, 0.9999, -0.9999, 0.4621];
963 for (i, (&got, &exp)) in output.iter().zip(expected.iter()).enumerate() {
964 assert!(
965 (got - exp).abs() < 0.01,
966 "tanh[{i}]: expected ~{exp}, got {got}"
967 );
968 }
969 }
970
971 #[test]
972 fn tanh_matches_scalar() {
973 let mut rng = TestRng::new(42);
974 let input = rng.fill_vec(100);
975 let mut simd_out = vec![0.0; 100];
976 let mut scalar_out = vec![0.0; 100];
977 simd_tanh(&input, &mut simd_out);
978 for (i, &x) in input.iter().enumerate() {
979 scalar_out[i] = crate::math::tanh(x);
980 }
981 for i in 0..100 {
982 assert!(
983 (simd_out[i] - scalar_out[i]).abs() < 0.01,
984 "tanh[{i}]: SIMD={} vs scalar={}",
985 simd_out[i],
986 scalar_out[i]
987 );
988 }
989 }
990
991 #[test]
992 fn exp_known_values() {
993 let input = [0.0, 1.0, -1.0, 2.0, -2.0];
994 let mut output = [0.0; 5];
995 simd_exp(&input, &mut output);
996 let expected = [
997 1.0,
998 core::f64::consts::E,
999 1.0 / core::f64::consts::E,
1000 core::f64::consts::E * core::f64::consts::E,
1001 1.0 / (core::f64::consts::E * core::f64::consts::E),
1002 ];
1003 for (i, (&got, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1004 let rel = (got - exp).abs() / exp.abs().max(1e-15);
1005 assert!(
1006 rel < 1e-5,
1007 "exp[{i}]: expected {exp}, got {got}, rel_err={rel}"
1008 );
1009 }
1010 }
1011
1012 #[test]
1013 fn exp_matches_scalar() {
1014 let mut rng = TestRng::new(99);
1015 let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 10.0).collect();
1017 let mut simd_out = vec![0.0; 100];
1018 let mut scalar_out = vec![0.0; 100];
1019 simd_exp(&input, &mut simd_out);
1020 for (i, &x) in input.iter().enumerate() {
1021 scalar_out[i] = crate::math::exp(x);
1022 }
1023 for i in 0..100 {
1024 let rel = (simd_out[i] - scalar_out[i]).abs() / scalar_out[i].abs().max(1e-15);
1025 assert!(
1026 rel < 1e-5,
1027 "exp[{i}] (x={}): SIMD={} vs scalar={}, rel_err={}",
1028 input[i],
1029 simd_out[i],
1030 scalar_out[i],
1031 rel
1032 );
1033 }
1034 }
1035
1036 #[test]
1037 fn exp_extreme_values() {
1038 let input = [700.0, -700.0, 0.0, 100.0, -100.0];
1040 let mut output = [0.0; 5];
1041 simd_exp(&input, &mut output);
1042 assert!(output[0].is_finite(), "exp(700) should be finite");
1044 assert!(output[0] > 0.0, "exp(700) should be positive");
1045 assert!(output[1] > 0.0, "exp(-700) should be positive");
1047 assert!(output[1].is_finite(), "exp(-700) should be finite");
1048 assert!((output[2] - 1.0).abs() < 1e-12, "exp(0) should be 1.0");
1050 }
1051
1052 #[test]
1053 fn sigmoid_known_values() {
1054 let input = [0.0, 10.0, -10.0, 1.0];
1055 let mut output = [0.0; 4];
1056 simd_sigmoid(&input, &mut output);
1057 assert!(
1058 (output[0] - 0.5).abs() < 0.01,
1059 "sigmoid(0) should be ~0.5, got {}",
1060 output[0]
1061 );
1062 assert!(
1063 output[1] > 0.99,
1064 "sigmoid(10) should be ~1.0, got {}",
1065 output[1]
1066 );
1067 assert!(
1068 output[2] < 0.01,
1069 "sigmoid(-10) should be ~0.0, got {}",
1070 output[2]
1071 );
1072 }
1073
1074 #[test]
1075 fn sigmoid_matches_scalar() {
1076 let mut rng = TestRng::new(123);
1077 let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 20.0 - 10.0).collect();
1078 let mut simd_out = vec![0.0; 100];
1079 let mut scalar_out = vec![0.0; 100];
1080 simd_sigmoid(&input, &mut simd_out);
1081 for (i, &x) in input.iter().enumerate() {
1082 scalar_out[i] = crate::math::sigmoid(x);
1083 }
1084 for i in 0..100 {
1085 assert!(
1086 (simd_out[i] - scalar_out[i]).abs() < 1e-6,
1087 "sigmoid[{i}] (x={}): SIMD={} vs scalar={}, diff={}",
1088 input[i],
1089 simd_out[i],
1090 scalar_out[i],
1091 (simd_out[i] - scalar_out[i]).abs()
1092 );
1093 }
1094 }
1095
1096 #[test]
1097 fn silu_known_values() {
1098 let input = [0.0, 1.0, -1.0, 3.0];
1099 let mut output = [0.0; 4];
1100 simd_silu(&input, &mut output);
1101 assert!(
1103 output[0].abs() < 0.01,
1104 "silu(0) should be ~0, got {}",
1105 output[0]
1106 );
1107 assert!(
1109 (output[1] - 0.731).abs() < 0.01,
1110 "silu(1) should be ~0.731, got {}",
1111 output[1]
1112 );
1113 }
1114
1115 #[test]
1116 fn silu_matches_scalar() {
1117 let mut rng = TestRng::new(456);
1118 let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 10.0 - 5.0).collect();
1119 let mut simd_out = vec![0.0; 100];
1120 simd_silu(&input, &mut simd_out);
1121 for (i, &x) in input.iter().enumerate() {
1122 let expected = x * crate::math::sigmoid(x);
1123 assert!(
1124 (simd_out[i] - expected).abs() < 1e-6,
1125 "silu[{i}] (x={}): SIMD={} vs scalar={}, diff={}",
1126 x,
1127 simd_out[i],
1128 expected,
1129 (simd_out[i] - expected).abs()
1130 );
1131 }
1132 }
1133
1134 #[test]
1135 fn activations_handle_empty() {
1136 let input: [f64; 0] = [];
1137 let mut output: [f64; 0] = [];
1138 simd_tanh(&input, &mut output);
1139 simd_exp(&input, &mut output);
1140 simd_sigmoid(&input, &mut output);
1141 simd_silu(&input, &mut output);
1142 }
1143}