1#[cfg(target_arch = "x86_64")]
9use std::sync::OnceLock;
10
11use common::DistanceMetric;
12
13#[cfg(target_arch = "x86_64")]
16static AVX2_AVAILABLE: OnceLock<bool> = OnceLock::new();
17
18#[cfg(target_arch = "x86_64")]
19#[inline(always)]
20fn avx2_available() -> bool {
21 *AVX2_AVAILABLE
22 .get_or_init(|| is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma"))
23}
24
25pub fn simd_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
28 match metric {
29 DistanceMetric::Cosine => simd_cosine_similarity(a, b),
30 DistanceMetric::Euclidean => simd_negative_euclidean(a, b),
31 DistanceMetric::DotProduct => simd_dot_product(a, b),
32 }
33}
34
35#[inline]
37pub fn simd_dot_product(a: &[f32], b: &[f32]) -> f32 {
38 #[cfg(target_arch = "x86_64")]
39 {
40 if avx2_available() {
41 return unsafe { avx2_dot_product(a, b) };
42 }
43 }
44
45 #[cfg(target_arch = "aarch64")]
46 {
47 unsafe { neon_dot_product(a, b) }
48 }
49
50 #[cfg(not(target_arch = "aarch64"))]
52 {
53 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
54 }
55}
56
57#[inline]
59pub fn simd_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
60 #[cfg(target_arch = "x86_64")]
61 {
62 if avx2_available() {
63 return unsafe { avx2_cosine_similarity(a, b) };
64 }
65 }
66
67 #[cfg(target_arch = "aarch64")]
68 {
69 unsafe { neon_cosine_similarity(a, b) }
70 }
71
72 #[cfg(not(target_arch = "aarch64"))]
74 {
75 fallback_cosine_similarity(a, b)
76 }
77}
78
79#[inline]
81pub fn simd_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
82 #[cfg(target_arch = "x86_64")]
83 {
84 if avx2_available() {
85 return unsafe { avx2_negative_euclidean(a, b) };
86 }
87 }
88
89 #[cfg(target_arch = "aarch64")]
90 {
91 unsafe { neon_negative_euclidean(a, b) }
92 }
93
94 #[cfg(not(target_arch = "aarch64"))]
96 {
97 let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
98 -sum.sqrt()
99 }
100}
101
102#[inline]
108#[allow(dead_code)]
109fn fallback_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
110 let mut dot = 0.0f32;
111 let mut norm_a = 0.0f32;
112 let mut norm_b = 0.0f32;
113
114 for (x, y) in a.iter().zip(b.iter()) {
115 dot += x * y;
116 norm_a += x * x;
117 norm_b += y * y;
118 }
119
120 let norm_a = norm_a.sqrt();
121 let norm_b = norm_b.sqrt();
122
123 if norm_a == 0.0 || norm_b == 0.0 {
124 return 0.0;
125 }
126
127 dot / (norm_a * norm_b)
128}
129
130#[inline]
137#[cfg(test)]
138fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
139 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
140}
141
142#[inline]
145#[cfg(test)]
146fn scalar_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
147 let mut dot = 0.0f32;
148 let mut norm_a = 0.0f32;
149 let mut norm_b = 0.0f32;
150
151 for (x, y) in a.iter().zip(b.iter()) {
152 dot += x * y;
153 norm_a += x * x;
154 norm_b += y * y;
155 }
156
157 let norm_a = norm_a.sqrt();
158 let norm_b = norm_b.sqrt();
159
160 if norm_a == 0.0 || norm_b == 0.0 {
161 return 0.0;
162 }
163
164 dot / (norm_a * norm_b)
165}
166
167#[inline]
170#[cfg(test)]
171fn scalar_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
172 let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
173 -sum.sqrt()
174}
175
176#[cfg(target_arch = "x86_64")]
181#[target_feature(enable = "avx2", enable = "fma")]
182unsafe fn avx2_dot_product(a: &[f32], b: &[f32]) -> f32 {
183 use std::arch::x86_64::*;
184
185 let n = a.len();
186 let chunks = n / 8;
187 let remainder = n % 8;
188
189 let mut sum = _mm256_setzero_ps();
190
191 let a_ptr = a.as_ptr();
192 let b_ptr = b.as_ptr();
193
194 for i in 0..chunks {
195 let offset = i * 8;
196 let va = _mm256_loadu_ps(a_ptr.add(offset));
197 let vb = _mm256_loadu_ps(b_ptr.add(offset));
198 sum = _mm256_fmadd_ps(va, vb, sum);
199 }
200
201 let mut result = hsum_avx(sum);
203
204 let start = chunks * 8;
206 for i in 0..remainder {
207 result += a[start + i] * b[start + i];
208 }
209
210 result
211}
212
213#[cfg(target_arch = "x86_64")]
214#[target_feature(enable = "avx2", enable = "fma")]
215unsafe fn avx2_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
216 use std::arch::x86_64::*;
217
218 let n = a.len();
219 let chunks = n / 8;
220 let remainder = n % 8;
221
222 let mut dot_sum = _mm256_setzero_ps();
223 let mut norm_a_sum = _mm256_setzero_ps();
224 let mut norm_b_sum = _mm256_setzero_ps();
225
226 let a_ptr = a.as_ptr();
227 let b_ptr = b.as_ptr();
228
229 for i in 0..chunks {
230 let offset = i * 8;
231 let va = _mm256_loadu_ps(a_ptr.add(offset));
232 let vb = _mm256_loadu_ps(b_ptr.add(offset));
233
234 dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
235 norm_a_sum = _mm256_fmadd_ps(va, va, norm_a_sum);
236 norm_b_sum = _mm256_fmadd_ps(vb, vb, norm_b_sum);
237 }
238
239 let mut dot = hsum_avx(dot_sum);
240 let mut norm_a = hsum_avx(norm_a_sum);
241 let mut norm_b = hsum_avx(norm_b_sum);
242
243 let start = chunks * 8;
245 for i in 0..remainder {
246 let x = a[start + i];
247 let y = b[start + i];
248 dot += x * y;
249 norm_a += x * x;
250 norm_b += y * y;
251 }
252
253 let norm_a = norm_a.sqrt();
254 let norm_b = norm_b.sqrt();
255
256 if norm_a == 0.0 || norm_b == 0.0 {
257 return 0.0;
258 }
259
260 dot / (norm_a * norm_b)
261}
262
263#[cfg(target_arch = "x86_64")]
264#[target_feature(enable = "avx2", enable = "fma")]
265unsafe fn avx2_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
266 use std::arch::x86_64::*;
267
268 let n = a.len();
269 let chunks = n / 8;
270 let remainder = n % 8;
271
272 let mut sum = _mm256_setzero_ps();
273
274 let a_ptr = a.as_ptr();
275 let b_ptr = b.as_ptr();
276
277 for i in 0..chunks {
278 let offset = i * 8;
279 let va = _mm256_loadu_ps(a_ptr.add(offset));
280 let vb = _mm256_loadu_ps(b_ptr.add(offset));
281 let diff = _mm256_sub_ps(va, vb);
282 sum = _mm256_fmadd_ps(diff, diff, sum);
283 }
284
285 let mut result = hsum_avx(sum);
286
287 let start = chunks * 8;
289 for i in 0..remainder {
290 let diff = a[start + i] - b[start + i];
291 result += diff * diff;
292 }
293
294 -result.sqrt()
295}
296
297#[cfg(target_arch = "x86_64")]
299#[target_feature(enable = "avx2")]
300#[inline]
301unsafe fn hsum_avx(v: std::arch::x86_64::__m256) -> f32 {
302 use std::arch::x86_64::*;
303
304 let high = _mm256_extractf128_ps(v, 1);
306 let low = _mm256_castps256_ps128(v);
307 let sum128 = _mm_add_ps(high, low);
308
309 let shuf = _mm_movehdup_ps(sum128);
311 let sums = _mm_add_ps(sum128, shuf);
312 let shuf = _mm_movehl_ps(sums, sums);
313 let sums = _mm_add_ss(sums, shuf);
314
315 _mm_cvtss_f32(sums)
316}
317
318#[cfg(target_arch = "aarch64")]
323unsafe fn neon_dot_product(a: &[f32], b: &[f32]) -> f32 {
324 use std::arch::aarch64::*;
325
326 let n = a.len();
327 let chunks = n / 4;
328 let remainder = n % 4;
329
330 let mut sum = vdupq_n_f32(0.0);
331
332 let a_ptr = a.as_ptr();
333 let b_ptr = b.as_ptr();
334
335 for i in 0..chunks {
336 let offset = i * 4;
337 let va = vld1q_f32(a_ptr.add(offset));
338 let vb = vld1q_f32(b_ptr.add(offset));
339 sum = vfmaq_f32(sum, va, vb);
340 }
341
342 let mut result = vaddvq_f32(sum);
343
344 let start = chunks * 4;
346 for i in 0..remainder {
347 result += a[start + i] * b[start + i];
348 }
349
350 result
351}
352
353#[cfg(target_arch = "aarch64")]
354unsafe fn neon_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
355 use std::arch::aarch64::*;
356
357 let n = a.len();
358 let chunks = n / 4;
359 let remainder = n % 4;
360
361 let mut dot_sum = vdupq_n_f32(0.0);
362 let mut norm_a_sum = vdupq_n_f32(0.0);
363 let mut norm_b_sum = vdupq_n_f32(0.0);
364
365 let a_ptr = a.as_ptr();
366 let b_ptr = b.as_ptr();
367
368 for i in 0..chunks {
369 let offset = i * 4;
370 let va = vld1q_f32(a_ptr.add(offset));
371 let vb = vld1q_f32(b_ptr.add(offset));
372
373 dot_sum = vfmaq_f32(dot_sum, va, vb);
374 norm_a_sum = vfmaq_f32(norm_a_sum, va, va);
375 norm_b_sum = vfmaq_f32(norm_b_sum, vb, vb);
376 }
377
378 let mut dot = vaddvq_f32(dot_sum);
379 let mut norm_a = vaddvq_f32(norm_a_sum);
380 let mut norm_b = vaddvq_f32(norm_b_sum);
381
382 let start = chunks * 4;
384 for i in 0..remainder {
385 let x = a[start + i];
386 let y = b[start + i];
387 dot += x * y;
388 norm_a += x * x;
389 norm_b += y * y;
390 }
391
392 let norm_a = norm_a.sqrt();
393 let norm_b = norm_b.sqrt();
394
395 if norm_a == 0.0 || norm_b == 0.0 {
396 return 0.0;
397 }
398
399 dot / (norm_a * norm_b)
400}
401
402#[cfg(target_arch = "aarch64")]
403unsafe fn neon_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
404 use std::arch::aarch64::*;
405
406 let n = a.len();
407 let chunks = n / 4;
408 let remainder = n % 4;
409
410 let mut sum = vdupq_n_f32(0.0);
411
412 let a_ptr = a.as_ptr();
413 let b_ptr = b.as_ptr();
414
415 for i in 0..chunks {
416 let offset = i * 4;
417 let va = vld1q_f32(a_ptr.add(offset));
418 let vb = vld1q_f32(b_ptr.add(offset));
419 let diff = vsubq_f32(va, vb);
420 sum = vfmaq_f32(sum, diff, diff);
421 }
422
423 let mut result = vaddvq_f32(sum);
424
425 let start = chunks * 4;
427 for i in 0..remainder {
428 let diff = a[start + i] - b[start + i];
429 result += diff * diff;
430 }
431
432 -result.sqrt()
433}
434
435#[cfg(test)]
440mod tests {
441 use super::*;
442
443 const EPSILON: f32 = 1e-5;
444
445 fn approx_eq(a: f32, b: f32) -> bool {
446 (a - b).abs() < EPSILON
447 }
448
449 #[test]
450 fn test_simd_dot_product() {
451 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
452 let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
453 let result = simd_dot_product(&a, &b);
455 assert!(approx_eq(result, 36.0), "Expected 36.0, got {}", result);
456 }
457
458 #[test]
459 fn test_simd_dot_product_large() {
460 let a: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
462 let b: Vec<f32> = (0..1024).map(|i| (1024 - i) as f32 * 0.001).collect();
463
464 let simd_result = simd_dot_product(&a, &b);
465 let scalar_result = scalar_dot_product(&a, &b);
466
467 assert!(
469 (simd_result - scalar_result).abs() < 0.01,
470 "SIMD: {}, Scalar: {}",
471 simd_result,
472 scalar_result
473 );
474 }
475
476 #[test]
477 fn test_simd_cosine_identical() {
478 let a = vec![1.0, 0.0, 0.0, 0.0];
479 let result = simd_cosine_similarity(&a, &a);
480 assert!(approx_eq(result, 1.0), "Expected 1.0, got {}", result);
481 }
482
483 #[test]
484 fn test_simd_cosine_orthogonal() {
485 let a = vec![1.0, 0.0, 0.0, 0.0];
486 let b = vec![0.0, 1.0, 0.0, 0.0];
487 let result = simd_cosine_similarity(&a, &b);
488 assert!(approx_eq(result, 0.0), "Expected 0.0, got {}", result);
489 }
490
491 #[test]
492 fn test_simd_cosine_large() {
493 let a: Vec<f32> = (0..1024).map(|i| (i as f32).sin()).collect();
494 let b: Vec<f32> = (0..1024).map(|i| (i as f32).cos()).collect();
495
496 let simd_result = simd_cosine_similarity(&a, &b);
497 let scalar_result = scalar_cosine_similarity(&a, &b);
498
499 assert!(
500 (simd_result - scalar_result).abs() < 1e-4,
501 "SIMD: {}, Scalar: {}",
502 simd_result,
503 scalar_result
504 );
505 }
506
507 #[test]
508 fn test_simd_euclidean_identical() {
509 let a = vec![1.0, 2.0, 3.0, 4.0];
510 let result = simd_negative_euclidean(&a, &a);
511 assert!(approx_eq(result, 0.0), "Expected 0.0, got {}", result);
512 }
513
514 #[test]
515 fn test_simd_euclidean_known() {
516 let a = vec![0.0, 0.0, 0.0, 0.0];
517 let b = vec![3.0, 4.0, 0.0, 0.0];
518 let result = simd_negative_euclidean(&a, &b);
520 assert!(approx_eq(result, -5.0), "Expected -5.0, got {}", result);
521 }
522
523 #[test]
524 fn test_simd_euclidean_large() {
525 let a: Vec<f32> = (0..1024).map(|i| i as f32 * 0.01).collect();
526 let b: Vec<f32> = (0..1024).map(|i| (i + 1) as f32 * 0.01).collect();
527
528 let simd_result = simd_negative_euclidean(&a, &b);
529 let scalar_result = scalar_negative_euclidean(&a, &b);
530
531 assert!(
532 (simd_result - scalar_result).abs() < 1e-3,
533 "SIMD: {}, Scalar: {}",
534 simd_result,
535 scalar_result
536 );
537 }
538
539 #[test]
540 fn test_simd_distance_dispatch() {
541 let a = vec![1.0, 0.0, 0.0, 0.0];
542 let b = vec![1.0, 0.0, 0.0, 0.0];
543
544 assert!(approx_eq(
545 simd_distance(&a, &b, DistanceMetric::Cosine),
546 1.0
547 ));
548 assert!(approx_eq(
549 simd_distance(&a, &b, DistanceMetric::Euclidean),
550 0.0
551 ));
552 assert!(approx_eq(
553 simd_distance(&a, &b, DistanceMetric::DotProduct),
554 1.0
555 ));
556 }
557
558 #[test]
559 fn test_simd_remainder_handling() {
560 for size in [3, 5, 7, 9, 11, 13, 15, 17] {
562 let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
563 let b: Vec<f32> = (0..size).map(|i| (i + 1) as f32).collect();
564
565 let simd_dot = simd_dot_product(&a, &b);
566 let scalar_dot = scalar_dot_product(&a, &b);
567 assert!(
568 approx_eq(simd_dot, scalar_dot),
569 "Size {}: SIMD {} != Scalar {}",
570 size,
571 simd_dot,
572 scalar_dot
573 );
574 }
575 }
576}