1#[cfg(all(target_arch = "x86_64", feature = "std"))]
21use std::is_x86_feature_detected;
22
23#[inline]
31fn dot_scalar(a: &[f64], b: &[f64]) -> f64 {
32 let n = a.len().min(b.len());
33 let mut sum = 0.0;
34 for i in 0..n {
35 sum += a[i] * b[i];
36 }
37 sum
38}
39
40#[inline]
45fn mat_vec_scalar(w: &[f64], x: &[f64], _rows: usize, cols: usize, out: &mut [f64]) {
46 for (row, out_i) in out.iter_mut().enumerate() {
47 let start = row * cols;
48 let mut sum = 0.0;
49 for j in 0..cols {
50 sum += w[start + j] * x[j];
51 }
52 *out_i = sum;
53 }
54}
55
56#[cfg(all(target_arch = "x86_64", feature = "std"))]
61mod avx2 {
62 #[target_feature(enable = "avx2")]
69 pub(super) unsafe fn dot_avx2(a: &[f64], b: &[f64]) -> f64 {
70 #[cfg(target_arch = "x86_64")]
71 use core::arch::x86_64::*;
72
73 let n = a.len().min(b.len());
74 let chunks = n / 4;
75 let remainder = n % 4;
76
77 let a_ptr = a.as_ptr();
78 let b_ptr = b.as_ptr();
79
80 unsafe {
83 let mut acc = _mm256_setzero_pd();
84
85 for i in 0..chunks {
86 let offset = i * 4;
87 let va = _mm256_loadu_pd(a_ptr.add(offset));
88 let vb = _mm256_loadu_pd(b_ptr.add(offset));
89 acc = _mm256_add_pd(acc, _mm256_mul_pd(va, vb));
90 }
91
92 let hi128 = _mm256_extractf128_pd(acc, 1); let lo128 = _mm256_castpd256_pd128(acc); let pair = _mm_add_pd(lo128, hi128); let high64 = _mm_unpackhi_pd(pair, pair); let total = _mm_add_sd(pair, high64); let mut scalar_sum = _mm_cvtsd_f64(total);
99
100 let base = chunks * 4;
102 for i in 0..remainder {
103 scalar_sum += *a_ptr.add(base + i) * *b_ptr.add(base + i);
104 }
105
106 scalar_sum
107 }
108 }
109
110 #[target_feature(enable = "avx2")]
120 pub(super) unsafe fn mat_vec_avx2(
121 w: &[f64],
122 x: &[f64],
123 _rows: usize,
124 cols: usize,
125 out: &mut [f64],
126 ) {
127 for (row, out_i) in out.iter_mut().enumerate() {
128 let row_start = row * cols;
129 unsafe {
132 *out_i = dot_avx2(&w[row_start..row_start + cols], &x[..cols]);
133 }
134 }
135 }
136}
137
138pub fn simd_dot(a: &[f64], b: &[f64]) -> f64 {
160 #[cfg(all(target_arch = "x86_64", feature = "std"))]
161 {
162 if is_x86_feature_detected!("avx2") {
163 return unsafe { avx2::dot_avx2(a, b) };
165 }
166 }
167 dot_scalar(a, b)
168}
169
170pub fn simd_mat_vec(w: &[f64], x: &[f64], rows: usize, cols: usize, out: &mut [f64]) {
194 assert!(
195 w.len() >= rows * cols,
196 "simd_mat_vec: w.len()={} < rows*cols={}",
197 w.len(),
198 rows * cols
199 );
200 assert!(
201 out.len() >= rows,
202 "simd_mat_vec: out.len()={} < rows={}",
203 out.len(),
204 rows
205 );
206 assert!(
207 x.len() >= cols,
208 "simd_mat_vec: x.len()={} < cols={}",
209 x.len(),
210 cols
211 );
212
213 #[cfg(all(target_arch = "x86_64", feature = "std"))]
214 {
215 if is_x86_feature_detected!("avx2") {
216 unsafe {
218 avx2::mat_vec_avx2(w, x, rows, cols, out);
219 }
220 return;
221 }
222 }
223 mat_vec_scalar(w, x, rows, cols, out);
224}
225
226#[cfg(test)]
231mod tests {
232 use super::*;
233 use alloc::vec;
234 use alloc::vec::Vec;
235
236 struct TestRng(u64);
238
239 impl TestRng {
240 fn new(seed: u64) -> Self {
241 Self(seed)
242 }
243
244 fn next_u64(&mut self) -> u64 {
245 let mut x = self.0;
246 x ^= x << 13;
247 x ^= x >> 7;
248 x ^= x << 17;
249 self.0 = x;
250 x
251 }
252
253 fn next_f64(&mut self) -> f64 {
254 (self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64) * 2.0 - 1.0
256 }
257
258 fn fill_vec(&mut self, n: usize) -> Vec<f64> {
259 (0..n).map(|_| self.next_f64()).collect()
260 }
261 }
262
263 #[test]
268 fn dot_empty_returns_zero() {
269 let a: [f64; 0] = [];
270 let b: [f64; 0] = [];
271 assert_eq!(simd_dot(&a, &b), 0.0, "dot of empty slices should be 0");
272 }
273
274 #[test]
275 fn dot_single_element() {
276 let a = [3.0];
277 let b = [4.0];
278 assert!(
279 (simd_dot(&a, &b) - 12.0).abs() < 1e-12,
280 "dot([3], [4]) should be 12, got {}",
281 simd_dot(&a, &b)
282 );
283 }
284
285 #[test]
286 fn dot_known_result() {
287 let a = [1.0, 2.0, 3.0];
288 let b = [4.0, 5.0, 6.0];
289 let result = simd_dot(&a, &b);
290 assert!(
291 (result - 32.0).abs() < 1e-12,
292 "dot([1,2,3], [4,5,6]) should be 32, got {}",
293 result
294 );
295 }
296
297 #[test]
298 fn dot_large_matches_scalar() {
299 let mut rng = TestRng::new(42);
300 let a = rng.fill_vec(1000);
301 let b = rng.fill_vec(1000);
302
303 let simd_result = simd_dot(&a, &b);
304 let scalar_result = dot_scalar(&a, &b);
305
306 assert!(
307 (simd_result - scalar_result).abs() < 1e-9,
308 "1000-element dot: SIMD={} vs scalar={}, diff={}",
309 simd_result,
310 scalar_result,
311 (simd_result - scalar_result).abs()
312 );
313 }
314
315 #[test]
316 fn dot_mismatched_lengths() {
317 let a = [1.0, 2.0, 3.0, 999.0];
319 let b = [4.0, 5.0, 6.0];
320 let result = simd_dot(&a, &b);
321 assert!(
322 (result - 32.0).abs() < 1e-12,
323 "mismatched lengths should use min, expected 32, got {}",
324 result
325 );
326 }
327
328 #[test]
329 fn dot_non_aligned_length() {
330 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
332 let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
333 let result = simd_dot(&a, &b);
334 assert!(
335 (result - 28.0).abs() < 1e-12,
336 "dot of [1..7] with [1..1] should be 28, got {}",
337 result
338 );
339 }
340
341 #[test]
342 fn dot_negative_values() {
343 let a = [-1.0, -2.0, -3.0, -4.0];
344 let b = [4.0, 3.0, 2.0, 1.0];
345 let result = simd_dot(&a, &b);
347 assert!(
348 (result - (-20.0)).abs() < 1e-12,
349 "expected -20, got {}",
350 result
351 );
352 }
353
354 #[test]
355 fn dot_orthogonal_vectors() {
356 let a = [1.0, 0.0, 0.0, 0.0];
357 let b = [0.0, 1.0, 0.0, 0.0];
358 let result = simd_dot(&a, &b);
359 assert!(
360 result.abs() < 1e-12,
361 "orthogonal vectors should have dot=0, got {}",
362 result
363 );
364 }
365
366 #[test]
371 fn mat_vec_identity_like() {
372 let w = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
374 let x = [1.0, 2.0, 3.0];
375 let mut out = [0.0; 3];
376 simd_mat_vec(&w, &x, 3, 3, &mut out);
377 assert!(
378 (out[0] - 1.0).abs() < 1e-12,
379 "identity row 0: expected 1, got {}",
380 out[0]
381 );
382 assert!(
383 (out[1] - 2.0).abs() < 1e-12,
384 "identity row 1: expected 2, got {}",
385 out[1]
386 );
387 assert!(
388 (out[2] - 3.0).abs() < 1e-12,
389 "identity row 2: expected 3, got {}",
390 out[2]
391 );
392 }
393
394 #[test]
395 fn mat_vec_known_result() {
396 let w = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
401 let x = [1.0, 2.0, 3.0];
402 let mut out = [0.0; 2];
403 simd_mat_vec(&w, &x, 2, 3, &mut out);
404 assert!(
405 (out[0] - 14.0).abs() < 1e-12,
406 "row 0: expected 14, got {}",
407 out[0]
408 );
409 assert!(
410 (out[1] - 32.0).abs() < 1e-12,
411 "row 1: expected 32, got {}",
412 out[1]
413 );
414 }
415
416 #[test]
417 fn mat_vec_large_matches_scalar() {
418 let mut rng = TestRng::new(7777);
419 let rows = 100;
420 let cols = 100;
421 let w = rng.fill_vec(rows * cols);
422 let x = rng.fill_vec(cols);
423 let mut out_simd = vec![0.0; rows];
424 let mut out_scalar = vec![0.0; rows];
425
426 simd_mat_vec(&w, &x, rows, cols, &mut out_simd);
427 mat_vec_scalar(&w, &x, rows, cols, &mut out_scalar);
428
429 for i in 0..rows {
430 assert!(
431 (out_simd[i] - out_scalar[i]).abs() < 1e-9,
432 "row {}: SIMD={} vs scalar={}, diff={}",
433 i,
434 out_simd[i],
435 out_scalar[i],
436 (out_simd[i] - out_scalar[i]).abs()
437 );
438 }
439 }
440
441 #[test]
442 fn mat_vec_single_row() {
443 let w = [1.0, 2.0, 3.0, 4.0, 5.0];
445 let x = [2.0, 2.0, 2.0, 2.0, 2.0];
446 let mut out = [0.0; 1];
447 simd_mat_vec(&w, &x, 1, 5, &mut out);
448 assert!(
450 (out[0] - 30.0).abs() < 1e-12,
451 "single-row mat_vec should be dot product, expected 30, got {}",
452 out[0]
453 );
454 }
455
456 #[test]
457 fn mat_vec_single_element() {
458 let w = [7.0];
459 let x = [3.0];
460 let mut out = [0.0; 1];
461 simd_mat_vec(&w, &x, 1, 1, &mut out);
462 assert!(
463 (out[0] - 21.0).abs() < 1e-12,
464 "1x1 mat_vec: 7*3=21, got {}",
465 out[0]
466 );
467 }
468
469 #[test]
474 #[should_panic(expected = "simd_mat_vec: w.len()")]
475 fn mat_vec_panics_w_too_short() {
476 let w = [1.0, 2.0]; let x = [1.0, 2.0, 3.0];
478 let mut out = [0.0; 2];
479 simd_mat_vec(&w, &x, 2, 3, &mut out);
480 }
481
482 #[test]
483 #[should_panic(expected = "simd_mat_vec: out.len()")]
484 fn mat_vec_panics_out_too_short() {
485 let w = [1.0; 6];
486 let x = [1.0; 3];
487 let mut out = [0.0; 1]; simd_mat_vec(&w, &x, 2, 3, &mut out);
489 }
490
491 #[test]
492 #[should_panic(expected = "simd_mat_vec: x.len()")]
493 fn mat_vec_panics_x_too_short() {
494 let w = [1.0; 6];
495 let x = [1.0; 2]; let mut out = [0.0; 2];
497 simd_mat_vec(&w, &x, 2, 3, &mut out);
498 }
499
500 #[cfg(all(target_arch = "x86_64", feature = "std"))]
505 #[test]
506 fn simd_available_on_x86() {
507 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
510 let b = [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
511 let result = simd_dot(&a, &b);
512 assert!(
514 (result - 120.0).abs() < 1e-12,
515 "8-element dot product should be 120, got {}",
516 result
517 );
518
519 if is_x86_feature_detected!("avx2") {
521 }
523 }
524}