1#![allow(dead_code)]
2
3use std::sync::OnceLock;
4use wide::{f32x4, f32x8};
5
6#[derive(Clone, Copy, Debug)]
12pub enum SimdLevel {
13 Scalar,
15 Sse,
17 Avx2,
19 Avx512,
21}
22
23static SIMD_LEVEL: OnceLock<SimdLevel> = OnceLock::new();
24
25pub fn detect_simd_level() -> SimdLevel {
27 *SIMD_LEVEL.get_or_init(|| {
28 #[cfg(target_arch = "x86_64")]
29 {
30 if is_x86_feature_detected!("avx512f") {
31 return SimdLevel::Avx512;
32 }
33 if is_x86_feature_detected!("avx2") {
34 return SimdLevel::Avx2;
35 }
36 if is_x86_feature_detected!("sse4.1") {
37 return SimdLevel::Sse;
38 }
39 return SimdLevel::Scalar;
40 }
41
42 #[cfg(target_arch = "aarch64")]
43 {
44 SimdLevel::Sse
46 }
47
48 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
49 {
50 SimdLevel::Scalar
51 }
52 })
53}
54
55#[inline(always)]
74fn dot_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
75 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
76}
77
78#[inline(always)]
89fn dot_f32_sse(a: &[f32], b: &[f32]) -> f32 {
90 let len = a.len();
91 let chunks = len / 4;
92 let mut acc = f32x4::ZERO;
93
94 unsafe {
95 let a_ptr = a.as_ptr();
96 let b_ptr = b.as_ptr();
97
98 for i in 0..chunks {
99 let offset = i * 4;
100 let va = f32x4::from(*(a_ptr.add(offset) as *const [f32; 4]));
101 let vb = f32x4::from(*(b_ptr.add(offset) as *const [f32; 4]));
102 acc += va * vb;
103 }
104 }
105
106 let mut sum = acc.reduce_add();
107 for i in (chunks * 4)..len {
108 sum += a[i] * b[i];
109 }
110 sum
111}
112
113#[inline(always)]
124fn dot_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
125 let len = a.len();
126 let chunks = len / 8;
127 let mut acc = f32x8::ZERO;
128
129 unsafe {
130 let a_ptr = a.as_ptr();
131 let b_ptr = b.as_ptr();
132
133 for i in 0..chunks {
134 let offset = i * 8;
135 let va = f32x8::from(*(a_ptr.add(offset) as *const [f32; 8]));
136 let vb = f32x8::from(*(b_ptr.add(offset) as *const [f32; 8]));
137 acc += va * vb;
138 }
139 }
140
141 let mut sum = acc.reduce_add();
142 for i in (chunks * 8)..len {
143 sum += a[i] * b[i];
144 }
145 sum
146}
147
148#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
159#[inline(always)]
160fn dot_f32_avx512(a: &[f32], b: &[f32]) -> f32 {
161 use std::arch::x86_64::*;
162
163 let len = a.len();
164 let chunks = len / 16;
165
166 unsafe {
167 let mut acc = _mm512_setzero_ps();
168
169 for i in 0..chunks {
170 let va = _mm512_loadu_ps(a.as_ptr().add(i * 16));
171 let vb = _mm512_loadu_ps(b.as_ptr().add(i * 16));
172 acc = _mm512_fmadd_ps(va, vb, acc);
173 }
174
175 let mut sum = _mm512_reduce_add_ps(acc);
176 for i in (chunks * 16)..len {
177 sum += a[i] * b[i];
178 }
179 sum
180 }
181}
182
183#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
194#[inline(always)]
195fn dot_f32_avx512(a: &[f32], b: &[f32]) -> f32 {
196 dot_f32_avx2(a, b)
197}
198
199#[inline]
214pub fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
215 match detect_simd_level() {
216 SimdLevel::Avx512 => dot_f32_avx512(a, b),
217 SimdLevel::Avx2 => dot_f32_avx2(a, b),
218 SimdLevel::Sse => dot_f32_sse(a, b),
219 SimdLevel::Scalar => dot_f32_scalar(a, b),
220 }
221}
222
223#[inline(always)]
235fn saxpy_f32_scalar(dst: &mut [f32], source: &[f32], scale: f32) {
236 for i in 0..dst.len() {
237 dst[i] += scale * source[i];
238 }
239}
240
241#[inline(always)]
249fn saxpy_f32_sse(dst: &mut [f32], source: &[f32], scale: f32) {
250 let len = dst.len();
251 let chunks = len / 4;
252 let scale_vec = f32x4::splat(scale);
253
254 unsafe {
255 let dst_ptr = dst.as_mut_ptr();
256 let src_ptr = source.as_ptr();
257
258 for i in 0..chunks {
259 let offset = i * 4;
260 let vdst = f32x4::from(*(dst_ptr.add(offset) as *const [f32; 4]));
261 let vsrc = f32x4::from(*(src_ptr.add(offset) as *const [f32; 4]));
262 let result = vdst + scale_vec * vsrc;
263 *(dst_ptr.add(offset) as *mut [f32; 4]) = result.into();
264 }
265 }
266
267 for i in (chunks * 4)..len {
268 dst[i] += scale * source[i];
269 }
270}
271
272#[inline(always)]
280fn saxpy_f32_avx2(dst: &mut [f32], source: &[f32], scale: f32) {
281 let len = dst.len();
282 let chunks = len / 8;
283 let scale_vec = f32x8::splat(scale);
284
285 unsafe {
286 let dst_ptr = dst.as_mut_ptr();
287 let src_ptr = source.as_ptr();
288
289 for i in 0..chunks {
290 let offset = i * 8;
291 let vdst = f32x8::from(*(dst_ptr.add(offset) as *const [f32; 8]));
292 let vsrc = f32x8::from(*(src_ptr.add(offset) as *const [f32; 8]));
293 let result = vdst + scale_vec * vsrc;
294 *(dst_ptr.add(offset) as *mut [f32; 8]) = result.into();
295 }
296 }
297
298 for i in (chunks * 8)..len {
299 dst[i] += scale * source[i];
300 }
301}
302
303#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
311#[inline(always)]
312fn saxpy_f32_avx512(dst: &mut [f32], source: &[f32], scale: f32) {
313 use std::arch::x86_64::*;
314 let len = dst.len();
315 let chunks = len / 16;
316
317 unsafe {
318 let scale_vec = _mm512_set1_ps(scale);
319 let dst_ptr = dst.as_mut_ptr();
320 let src_ptr = source.as_ptr();
321
322 for i in 0..chunks {
323 let offset = i * 16;
324 let vdst = _mm512_loadu_ps(dst_ptr.add(offset));
325 let vsrc = _mm512_loadu_ps(src_ptr.add(offset));
326 let result = _mm512_fmadd_ps(scale_vec, vsrc, vdst);
327 _mm512_storeu_ps(dst_ptr.add(offset), result);
328 }
329 }
330
331 for i in (chunks * 16)..len {
332 dst[i] += scale * source[i];
333 }
334}
335
336#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
344#[inline(always)]
345fn saxpy_f32_avx512(dst: &mut [f32], source: &[f32], scale: f32) {
346 saxpy_f32_avx2(dst, source, scale)
347}
348
349#[inline]
361pub fn saxpy_simd(dst: &mut [f32], source: &[f32], scale: f32) {
362 match detect_simd_level() {
363 SimdLevel::Avx512 => saxpy_f32_avx512(dst, source, scale),
364 SimdLevel::Avx2 => saxpy_f32_avx2(dst, source, scale),
365 SimdLevel::Sse => saxpy_f32_sse(dst, source, scale),
366 SimdLevel::Scalar => saxpy_f32_scalar(dst, source, scale),
367 }
368}
369
370#[inline(always)]
384fn norm_l2_f32_scalar(a: &[f32]) -> f32 {
385 a.iter().map(|&x| x * x).sum::<f32>().sqrt()
386}
387
388#[inline(always)]
398fn norm_l2_f32_sse(a: &[f32]) -> f32 {
399 let len = a.len();
400 let chunks = len / 4;
401 let mut acc = f32x4::ZERO;
402
403 unsafe {
404 let a_ptr = a.as_ptr();
405 for i in 0..chunks {
406 let offset = i * 4;
407 let va = f32x4::from(*(a_ptr.add(offset) as *const [f32; 4]));
408 acc += va * va;
409 }
410 }
411
412 let mut sum = acc.reduce_add();
413 for i in (chunks * 4)..len {
414 sum += a[i] * a[i];
415 }
416 sum.sqrt()
417}
418
419#[inline(always)]
429fn norm_l2_f32_avx2(a: &[f32]) -> f32 {
430 let len = a.len();
431 let chunks = len / 8;
432 let mut acc = f32x8::ZERO;
433
434 unsafe {
435 let a_ptr = a.as_ptr();
436 for i in 0..chunks {
437 let offset = i * 8;
438 let va = f32x8::from(*(a_ptr.add(offset) as *const [f32; 8]));
439 acc += va * va;
440 }
441 }
442
443 let mut sum = acc.reduce_add();
444 for i in (chunks * 8)..len {
445 sum += a[i] * a[i];
446 }
447 sum.sqrt()
448}
449
450#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
460#[inline(always)]
461fn norm_l2_f32_avx512(a: &[f32]) -> f32 {
462 use std::arch::x86_64::*;
463 let len = a.len();
464 let chunks = len / 16;
465
466 unsafe {
467 let mut acc = _mm512_setzero_ps();
468 for i in 0..chunks {
469 let va = _mm512_loadu_ps(a.as_ptr().add(i * 16));
470 acc = _mm512_fmadd_ps(va, va, acc);
471 }
472
473 let mut sum = _mm512_reduce_add_ps(acc);
474 for i in (chunks * 16)..len {
475 sum += a[i] * a[i];
476 }
477 sum.sqrt()
478 }
479}
480
481#[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
491#[inline(always)]
492fn norm_l2_f32_avx512(a: &[f32]) -> f32 {
493 norm_l2_f32_avx2(a)
494}
495
496#[inline]
502pub fn norm_l2_simd(a: &[f32]) -> f32 {
503 match detect_simd_level() {
504 SimdLevel::Avx512 => norm_l2_f32_avx512(a),
505 SimdLevel::Avx2 => norm_l2_f32_avx2(a),
506 SimdLevel::Sse => norm_l2_f32_sse(a),
507 SimdLevel::Scalar => norm_l2_f32_scalar(a),
508 }
509}
510
511#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_dot_product_basic() {
521 let a = vec![1.0, 2.0, 3.0, 4.0];
522 let b = vec![2.0, 3.0, 4.0, 5.0];
523 let expected = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0;
524
525 let result = dot_simd(&a, &b);
526 assert!((result - expected).abs() < 1e-6);
527 }
528
529 #[test]
530 fn test_dot_product_zero() {
531 let a = vec![1.0, 2.0, 3.0, 4.0];
532 let b = vec![0.0, 0.0, 0.0, 0.0];
533
534 let result = dot_simd(&a, &b);
535 assert_eq!(result, 0.0);
536 }
537
538 #[test]
539 fn test_dot_product_various_sizes() {
540 for size in [1, 3, 4, 7, 8, 15, 16, 17, 31, 32, 100, 128, 256] {
542 let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
543 let b: Vec<f32> = (0..size).map(|i| (i + 1) as f32).collect();
544
545 let expected = dot_f32_scalar(&a, &b);
546 let result = dot_simd(&a, &b);
547
548 assert!(
549 (result - expected).abs() < 1e-4,
550 "Failed for size {}: expected {}, got {}",
551 size,
552 expected,
553 result
554 );
555 }
556 }
557
558 #[test]
559 fn test_dot_product_negative() {
560 let a = vec![1.0, -2.0, 3.0, -4.0];
561 let b = vec![-1.0, 2.0, -3.0, 4.0];
562 let expected = -1.0 + -2.0 * 2.0 + 3.0 * -3.0 + -4.0 * 4.0;
563
564 let result = dot_simd(&a, &b);
565 assert!((result - expected).abs() < 1e-6);
566 }
567
568 #[test]
569 fn test_dot_product_all_implementations() {
570 let a: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
571 let b: Vec<f32> = (0..100).map(|i| (i + 1) as f32 * 0.2).collect();
572
573 let scalar_result = dot_f32_scalar(&a, &b);
574 let sse_result = dot_f32_sse(&a, &b);
575 let avx2_result = dot_f32_avx2(&a, &b);
576 let avx512_result = dot_f32_avx512(&a, &b);
577
578 assert!((scalar_result - sse_result).abs() < 1e-2);
579 assert!((scalar_result - avx2_result).abs() < 1e-2);
580 assert!((scalar_result - avx512_result).abs() < 1e-2);
581 }
582
583 #[test]
584 fn test_saxpy_basic() {
585 let mut dst = vec![1.0, 2.0, 3.0, 4.0];
586 let source = vec![2.0, 3.0, 4.0, 5.0];
587 let scale = 2.0;
588
589 saxpy_simd(&mut dst, &source, scale);
590
591 let expected = [5.0, 8.0, 11.0, 14.0];
592 for (d, e) in dst.iter().zip(expected.iter()) {
593 assert!((d - e).abs() < 1e-6);
594 }
595 }
596
597 #[test]
598 fn test_saxpy_zero_scale() {
599 let mut dst = vec![1.0, 2.0, 3.0, 4.0];
600 let source = vec![2.0, 3.0, 4.0, 5.0];
601 let expected = dst.clone();
602
603 saxpy_simd(&mut dst, &source, 0.0);
604
605 assert_eq!(dst, expected);
606 }
607
608 #[test]
609 fn test_saxpy_negative_scale() {
610 let mut dst = vec![10.0, 20.0, 30.0, 40.0];
611 let source = vec![1.0, 2.0, 3.0, 4.0];
612 let scale = -2.0;
613
614 saxpy_simd(&mut dst, &source, scale);
615
616 let expected = [8.0, 16.0, 24.0, 32.0];
617 for (d, e) in dst.iter().zip(expected.iter()) {
618 assert!((d - e).abs() < 1e-6);
619 }
620 }
621
622 #[test]
623 fn test_saxpy_various_sizes() {
624 for size in [1, 3, 4, 7, 8, 15, 16, 17, 31, 32, 100, 128, 256] {
625 let mut dst: Vec<f32> = (0..size).map(|i| i as f32).collect();
626 let source: Vec<f32> = (0..size).map(|i| (i + 1) as f32).collect();
627 let scale = 0.5;
628
629 let mut expected = dst.clone();
630 saxpy_f32_scalar(&mut expected, &source, scale);
631
632 saxpy_simd(&mut dst, &source, scale);
633
634 for (i, (d, e)) in dst.iter().zip(expected.iter()).enumerate() {
635 assert!(
636 (d - e).abs() < 1e-4,
637 "Failed at index {} for size {}: expected {}, got {}",
638 i,
639 size,
640 e,
641 d
642 );
643 }
644 }
645 }
646
647 #[test]
648 fn test_saxpy_all_implementations() {
649 let size = 100;
650 let source: Vec<f32> = (0..size).map(|i| i as f32 * 0.1).collect();
651 let scale = 1.5;
652
653 let mut dst_scalar: Vec<f32> = (0..size).map(|i| i as f32).collect();
654 let mut dst_sse = dst_scalar.clone();
655 let mut dst_avx2 = dst_scalar.clone();
656 let mut dst_avx512 = dst_scalar.clone();
657
658 saxpy_f32_scalar(&mut dst_scalar, &source, scale);
659 saxpy_f32_sse(&mut dst_sse, &source, scale);
660 saxpy_f32_avx2(&mut dst_avx2, &source, scale);
661 saxpy_f32_avx512(&mut dst_avx512, &source, scale);
662
663 for i in 0..size {
664 assert!((dst_scalar[i] - dst_sse[i]).abs() < 1e-4);
665 assert!((dst_scalar[i] - dst_avx2[i]).abs() < 1e-4);
666 assert!((dst_scalar[i] - dst_avx512[i]).abs() < 1e-4);
667 }
668 }
669
670 #[test]
671 fn test_saxpy_inplace() {
672 let mut dst = vec![1.0, 2.0, 3.0, 4.0];
673 let source = dst.clone();
674 let scale = 1.0;
675
676 saxpy_simd(&mut dst, &source, scale);
677
678 let expected = [2.0, 4.0, 6.0, 8.0];
679 for (d, e) in dst.iter().zip(expected.iter()) {
680 assert!((d - e).abs() < 1e-6);
681 }
682 }
683
684 #[test]
685 fn test_simd_level_detection() {
686 let level = detect_simd_level();
687 match level {
689 SimdLevel::Scalar | SimdLevel::Sse | SimdLevel::Avx2 | SimdLevel::Avx512 => {}
690 }
691 }
692
693 #[test]
694 fn test_dot_product_large() {
695 let size = 10_000;
696 let a: Vec<f32> = (0..size).map(|i| (i % 100) as f32).collect();
697 let b: Vec<f32> = (0..size).map(|i| ((i + 50) % 100) as f32).collect();
698
699 let expected = dot_f32_scalar(&a, &b);
700 let result = dot_simd(&a, &b);
701
702 let rel_error = (result - expected).abs() / expected.abs();
704 assert!(
705 rel_error < 1e-4,
706 "result: {}, expected: {}, rel_error: {}",
707 result,
708 expected,
709 rel_error
710 );
711 }
712
713 #[test]
714 fn test_saxpy_large() {
715 let size = 10_000;
716 let mut dst: Vec<f32> = (0..size).map(|i| (i % 100) as f32).collect();
717 let source: Vec<f32> = (0..size).map(|i| ((i + 50) % 100) as f32).collect();
718 let scale = 0.75;
719
720 let mut expected = dst.clone();
721 saxpy_f32_scalar(&mut expected, &source, scale);
722
723 saxpy_simd(&mut dst, &source, scale);
724
725 for (d, e) in dst.iter().zip(expected.iter()) {
726 assert!((d - e).abs() < 1e-4);
727 }
728 }
729
730 #[test]
731 fn test_norm_basic() {
732 let a = vec![3.0, 4.0];
733 let result = norm_l2_simd(&a);
734 assert!((result - 5.0).abs() < 1e-6);
735 }
736
737 #[test]
738 fn test_norm_zero() {
739 let a = vec![0.0, 0.0, 0.0];
740 let result = norm_l2_simd(&a);
741 assert_eq!(result, 0.0);
742 }
743
744 #[test]
745 fn test_norm_various_sizes() {
746 for size in [1, 3, 4, 7, 8, 15, 16, 17, 31, 32, 100] {
747 let a: Vec<f32> = (0..size).map(|i| i as f32 * 0.1).collect();
748
749 let expected = norm_l2_f32_scalar(&a);
750 let result = norm_l2_simd(&a);
751
752 assert!(
753 (result - expected).abs() < 1e-3,
754 "Failed for size {}: expected {}, got {}",
755 size,
756 expected,
757 result
758 );
759 }
760 }
761
762 #[test]
763 fn test_norm_all_implementations() {
764 let a: Vec<f32> = (0..100).map(|i| i as f32 * 0.01).collect();
765
766 let scalar = norm_l2_f32_scalar(&a);
767 let sse = norm_l2_f32_sse(&a);
768 let avx2 = norm_l2_f32_avx2(&a);
769 let avx512 = norm_l2_f32_avx512(&a);
770
771 assert!((scalar - sse).abs() < 1e-3);
772 assert!((scalar - avx2).abs() < 1e-3);
773 assert!((scalar - avx512).abs() < 1e-3);
774 }
775}