Skip to main content

nvs_core/
simd.rs

1#[inline]
2pub fn dot(a: &[f32], b: &[f32]) -> f32 {
3    debug_assert_eq!(a.len(), b.len());
4    #[cfg(target_arch = "x86_64")]
5    {
6        if x86_avx2_fma::get() {
7            unsafe {
8                return dot_avx2_fma(a, b);
9            };
10        } else if x86_avx2::get() {
11            unsafe {
12                return dot_avx2(a, b);
13            };
14        } else if x86_sse2::get() {
15            unsafe {
16                return dot_sse2(a, b);
17            };
18        }
19        return dot_scalar(a, b);
20    }
21    #[cfg(target_arch = "aarch64")]
22    {
23        unsafe {
24            return dot_neon(a, b);
25        };
26    }
27    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
28    {
29        return dot_scalar(a, b);
30    }
31}
32
33#[inline]
34#[allow(dead_code)]
35fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
36    a.iter().zip(b).map(|(x, y)| x * y).sum()
37}
38
39// x86 feature detectors
40#[cfg(target_arch = "x86_64")]
41cpufeatures::new!(x86_avx2_fma, "avx2", "fma");
42#[cfg(target_arch = "x86_64")]
43cpufeatures::new!(x86_avx2, "avx2");
44#[cfg(target_arch = "x86_64")]
45cpufeatures::new!(x86_sse2, "sse2");
46
47#[cfg(target_arch = "x86_64")]
48#[target_feature(enable = "avx2")]
49unsafe fn dot_avx2(a: &[f32], b: &[f32]) -> f32 {
50    use core::arch::x86_64::*;
51    let mut i = 0usize;
52    let mut acc = _mm256_setzero_ps();
53    while i + 8 <= a.len() {
54        let va = _mm256_loadu_ps(a.as_ptr().add(i));
55        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
56        let prod = _mm256_mul_ps(va, vb);
57        acc = _mm256_add_ps(acc, prod);
58        i += 8;
59    }
60    let mut tmp = [0f32; 8];
61    _mm256_storeu_ps(tmp.as_mut_ptr(), acc);
62    let mut sum: f32 = tmp.iter().sum();
63    while i < a.len() {
64        sum += *a.get_unchecked(i) * *b.get_unchecked(i);
65        i += 1;
66    }
67    sum
68}
69
70#[cfg(target_arch = "x86_64")]
71#[target_feature(enable = "avx2,fma")]
72unsafe fn dot_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
73    use core::arch::x86_64::*;
74    let mut i = 0usize;
75    let mut acc = _mm256_setzero_ps();
76    while i + 8 <= a.len() {
77        let va = _mm256_loadu_ps(a.as_ptr().add(i));
78        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
79        acc = _mm256_fmadd_ps(va, vb, acc);
80        i += 8;
81    }
82    let mut tmp = [0f32; 8];
83    _mm256_storeu_ps(tmp.as_mut_ptr(), acc);
84    let mut sum: f32 = tmp.iter().sum();
85    while i < a.len() {
86        sum += *a.get_unchecked(i) * *b.get_unchecked(i);
87        i += 1;
88    }
89    sum
90}
91
92#[cfg(target_arch = "x86_64")]
93#[target_feature(enable = "sse2")]
94unsafe fn dot_sse2(a: &[f32], b: &[f32]) -> f32 {
95    use core::arch::x86_64::*;
96    let mut i = 0usize;
97    let mut acc = _mm_setzero_ps();
98    while i + 4 <= a.len() {
99        let va = _mm_loadu_ps(a.as_ptr().add(i));
100        let vb = _mm_loadu_ps(b.as_ptr().add(i));
101        let prod = _mm_mul_ps(va, vb);
102        acc = _mm_add_ps(acc, prod);
103        i += 4;
104    }
105    let mut tmp = [0f32; 4];
106    _mm_storeu_ps(tmp.as_mut_ptr(), acc);
107    let mut sum: f32 = tmp.iter().sum();
108    while i < a.len() {
109        sum += *a.get_unchecked(i) * *b.get_unchecked(i);
110        i += 1;
111    }
112    sum
113}
114
115#[cfg(target_arch = "aarch64")]
116#[target_feature(enable = "neon")]
117unsafe fn dot_neon(a: &[f32], b: &[f32]) -> f32 {
118    use core::arch::aarch64::*;
119    let mut i = 0usize;
120    let mut acc = vdupq_n_f32(0.0);
121    while i + 4 <= a.len() {
122        let va = vld1q_f32(a.as_ptr().add(i));
123        let vb = vld1q_f32(b.as_ptr().add(i));
124        acc = vfmaq_f32(acc, va, vb); // FMA if available; on some CPUs this maps to mul+add
125        i += 4;
126    }
127    let mut sum: f32 = vaddvq_f32(acc);
128    while i < a.len() {
129        sum += *a.get_unchecked(i) * *b.get_unchecked(i);
130        i += 1;
131    }
132    sum
133}