1use std::arch::x86_64::*;
7
8#[derive(Debug, Clone, Copy)]
10pub struct SimdFeatures {
11 pub sse2: bool,
13 pub avx: bool,
15 pub avx2: bool,
17 pub avx512f: bool,
19 pub fma: bool,
21}
22
23impl SimdFeatures {
24 #[cfg(target_arch = "x86_64")]
26 pub fn detect() -> Self {
27 Self {
28 sse2: is_x86_feature_detected!("sse2"),
29 avx: is_x86_feature_detected!("avx"),
30 avx2: is_x86_feature_detected!("avx2"),
31 avx512f: is_x86_feature_detected!("avx512f"),
32 fma: is_x86_feature_detected!("fma"),
33 }
34 }
35
36 #[cfg(not(target_arch = "x86_64"))]
37 pub fn detect() -> Self {
38 Self {
39 sse2: false,
40 avx: false,
41 avx2: false,
42 avx512f: false,
43 fma: false,
44 }
45 }
46
47 pub fn best_simd(&self) -> SimdLevel {
49 if self.avx512f {
50 SimdLevel::Avx512
51 } else if self.avx2 {
52 SimdLevel::Avx2
53 } else if self.avx {
54 SimdLevel::Avx
55 } else if self.sse2 {
56 SimdLevel::Sse2
57 } else {
58 SimdLevel::Scalar
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub enum SimdLevel {
66 Scalar = 0,
68 Sse2 = 1,
70 Avx = 2,
72 Avx2 = 3,
74 Avx512 = 4,
76}
77
78impl SimdLevel {
79 pub fn vector_width(&self) -> usize {
81 match self {
82 SimdLevel::Scalar => 1,
83 SimdLevel::Sse2 => 16,
84 SimdLevel::Avx | SimdLevel::Avx2 => 32,
85 SimdLevel::Avx512 => 64,
86 }
87 }
88
89 pub fn f32_lanes(&self) -> usize {
91 self.vector_width() / 4
92 }
93}
94
95#[inline]
105pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
106 assert_eq!(a.len(), b.len(), "Arrays must have equal length");
107
108 let features = SimdFeatures::detect();
109
110 #[cfg(target_arch = "x86_64")]
111 {
112 if features.avx2 && features.fma {
113 unsafe { dot_product_f32_avx2_fma(a, b) }
114 } else if features.avx {
115 unsafe { dot_product_f32_avx(a, b) }
116 } else {
117 dot_product_f32_scalar(a, b)
118 }
119 }
120
121 #[cfg(not(target_arch = "x86_64"))]
122 {
123 dot_product_f32_scalar(a, b)
124 }
125}
126
127#[inline]
129fn dot_product_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
130 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
131}
132
133#[cfg(target_arch = "x86_64")]
135#[target_feature(enable = "avx")]
136#[inline]
137unsafe fn dot_product_f32_avx(a: &[f32], b: &[f32]) -> f32 {
138 unsafe {
139 let len = a.len();
140 let mut sum = _mm256_setzero_ps();
141
142 let chunks = len / 8;
144 for i in 0..chunks {
145 let idx = i * 8;
146 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
147 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
148 sum = _mm256_add_ps(sum, _mm256_mul_ps(va, vb));
149 }
150
151 let mut result = horizontal_sum_avx(sum);
153
154 for i in (chunks * 8)..len {
156 result += a[i] * b[i];
157 }
158
159 result
160 }
161}
162
163#[cfg(target_arch = "x86_64")]
165#[target_feature(enable = "avx2,fma")]
166#[inline]
167unsafe fn dot_product_f32_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
168 unsafe {
169 let len = a.len();
170 let mut sum = _mm256_setzero_ps();
171
172 let chunks = len / 8;
174 for i in 0..chunks {
175 let idx = i * 8;
176 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
177 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
178 sum = _mm256_fmadd_ps(va, vb, sum);
180 }
181
182 let mut result = horizontal_sum_avx(sum);
184
185 for i in (chunks * 8)..len {
187 result += a[i] * b[i];
188 }
189
190 result
191 }
192}
193
194#[cfg(target_arch = "x86_64")]
196#[target_feature(enable = "avx")]
197#[inline]
198unsafe fn horizontal_sum_avx(v: __m256) -> f32 {
199 unsafe {
200 let hi = _mm256_extractf128_ps(v, 1);
202 let lo = _mm256_castps256_ps128(v);
203 let sum128 = _mm_add_ps(hi, lo);
204
205 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
207 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
208
209 _mm_cvtss_f32(sum32)
210 }
211}
212
213#[inline]
218pub fn add_f32(a: &[f32], b: &[f32], result: &mut [f32]) {
219 assert_eq!(a.len(), b.len());
220 assert_eq!(a.len(), result.len());
221
222 let features = SimdFeatures::detect();
223
224 #[cfg(target_arch = "x86_64")]
225 {
226 if features.avx2 {
227 unsafe { add_f32_avx2(a, b, result) }
228 } else {
229 add_f32_scalar(a, b, result)
230 }
231 }
232
233 #[cfg(not(target_arch = "x86_64"))]
234 {
235 add_f32_scalar(a, b, result)
236 }
237}
238
239#[inline]
241fn add_f32_scalar(a: &[f32], b: &[f32], result: &mut [f32]) {
242 for i in 0..a.len() {
243 result[i] = a[i] + b[i];
244 }
245}
246
247#[cfg(target_arch = "x86_64")]
249#[target_feature(enable = "avx2")]
250#[inline]
251unsafe fn add_f32_avx2(a: &[f32], b: &[f32], result: &mut [f32]) {
252 unsafe {
253 let len = a.len();
254 let chunks = len / 8;
255
256 for i in 0..chunks {
258 let idx = i * 8;
259 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
260 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
261 let sum = _mm256_add_ps(va, vb);
262 _mm256_storeu_ps(result.as_mut_ptr().add(idx), sum);
263 }
264
265 for i in (chunks * 8)..len {
267 result[i] = a[i] + b[i];
268 }
269 }
270}
271
272#[inline]
277pub fn relu_f32(input: &[f32], output: &mut [f32]) {
278 assert_eq!(input.len(), output.len());
279
280 let features = SimdFeatures::detect();
281
282 #[cfg(target_arch = "x86_64")]
283 {
284 if features.avx2 {
285 unsafe { relu_f32_avx2(input, output) }
286 } else {
287 relu_f32_scalar(input, output)
288 }
289 }
290
291 #[cfg(not(target_arch = "x86_64"))]
292 {
293 relu_f32_scalar(input, output)
294 }
295}
296
297#[inline]
299fn relu_f32_scalar(input: &[f32], output: &mut [f32]) {
300 for i in 0..input.len() {
301 output[i] = input[i].max(0.0);
302 }
303}
304
305#[cfg(target_arch = "x86_64")]
307#[target_feature(enable = "avx2")]
308#[inline]
309unsafe fn relu_f32_avx2(input: &[f32], output: &mut [f32]) {
310 unsafe {
311 let len = input.len();
312 let chunks = len / 8;
313 let zero = _mm256_setzero_ps();
314
315 for i in 0..chunks {
317 let idx = i * 8;
318 let v = _mm256_loadu_ps(input.as_ptr().add(idx));
319 let relu = _mm256_max_ps(v, zero);
320 _mm256_storeu_ps(output.as_mut_ptr().add(idx), relu);
321 }
322
323 for i in (chunks * 8)..len {
325 output[i] = input[i].max(0.0);
326 }
327 }
328}
329
330pub fn simd_features() -> SimdFeatures {
332 static FEATURES: std::sync::OnceLock<SimdFeatures> = std::sync::OnceLock::new();
333 *FEATURES.get_or_init(SimdFeatures::detect)
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_simd_detection() {
342 let features = SimdFeatures::detect();
343 let level = features.best_simd();
344 println!("Detected SIMD level: {:?}", level);
345 println!("Features: {:?}", features);
346
347 #[cfg(target_arch = "x86_64")]
348 {
349 assert!(features.sse2, "x86_64 always has SSE2");
350 }
351 }
352
353 #[test]
354 fn test_dot_product() {
355 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
356 let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
357
358 let result = dot_product_f32(&a, &b);
359 let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
360
361 assert!((result - expected).abs() < 1e-5);
362 }
363
364 #[test]
365 fn test_add_vectorized() {
366 let a = vec![1.0; 100];
367 let b = vec![2.0; 100];
368 let mut result = vec![0.0; 100];
369
370 add_f32(&a, &b, &mut result);
371
372 for &r in &result {
373 assert!((r - 3.0).abs() < 1e-5);
374 }
375 }
376
377 #[test]
378 fn test_relu() {
379 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
380 let mut output = vec![0.0; 5];
381
382 relu_f32(&input, &mut output);
383
384 let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
385 for (o, e) in output.iter().zip(&expected) {
386 assert!((o - e).abs() < 1e-5);
387 }
388 }
389
390 #[test]
391 fn test_large_dot_product() {
392 let size = 10_000;
393 let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
394 let b: Vec<f32> = (0..size).map(|i| (size - i) as f32).collect();
395
396 let result = dot_product_f32(&a, &b);
397 let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
398
399 let relative_error = ((result - expected) / expected).abs();
401 assert!(relative_error < 1e-4);
402 }
403
404 #[test]
405 fn test_simd_level_comparison() {
406 assert!(SimdLevel::Avx512 > SimdLevel::Avx2);
407 assert!(SimdLevel::Avx2 > SimdLevel::Avx);
408 assert!(SimdLevel::Avx > SimdLevel::Sse2);
409 assert!(SimdLevel::Sse2 > SimdLevel::Scalar);
410 }
411
412 #[test]
413 fn test_vector_widths() {
414 assert_eq!(SimdLevel::Scalar.vector_width(), 1);
415 assert_eq!(SimdLevel::Sse2.vector_width(), 16);
416 assert_eq!(SimdLevel::Avx.vector_width(), 32);
417 assert_eq!(SimdLevel::Avx2.vector_width(), 32);
418 assert_eq!(SimdLevel::Avx512.vector_width(), 64);
419
420 assert_eq!(SimdLevel::Avx2.f32_lanes(), 8);
421 assert_eq!(SimdLevel::Avx512.f32_lanes(), 16);
422 }
423}