1use super::distance::DistanceMetric;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SimdLevel {
21 Scalar,
23 Sse,
25 Avx,
27 AvxFma,
29}
30
31impl SimdLevel {
32 #[inline]
34 pub fn detect() -> Self {
35 #[cfg(target_arch = "x86_64")]
36 {
37 if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
38 SimdLevel::AvxFma
39 } else if is_x86_feature_detected!("avx") {
40 SimdLevel::Avx
41 } else if is_x86_feature_detected!("sse") {
42 SimdLevel::Sse
43 } else {
44 SimdLevel::Scalar
45 }
46 }
47 #[cfg(not(target_arch = "x86_64"))]
48 {
49 SimdLevel::Scalar
50 }
51 }
52}
53
54static SIMD_LEVEL: std::sync::OnceLock<SimdLevel> = std::sync::OnceLock::new();
56
57#[inline]
59pub fn simd_level() -> SimdLevel {
60 *SIMD_LEVEL.get_or_init(SimdLevel::detect)
61}
62
63#[inline]
69pub fn l2_squared_simd(a: &[f32], b: &[f32]) -> f32 {
70 assert_eq!(a.len(), b.len(), "Vector dimensions must match");
71
72 match simd_level() {
73 #[cfg(target_arch = "x86_64")]
74 SimdLevel::AvxFma => unsafe { l2_squared_avx_fma(a, b) },
75 #[cfg(target_arch = "x86_64")]
76 SimdLevel::Avx => unsafe { l2_squared_avx(a, b) },
77 #[cfg(target_arch = "x86_64")]
78 SimdLevel::Sse => unsafe { l2_squared_sse(a, b) },
79 _ => l2_squared_scalar(a, b),
80 }
81}
82
83#[inline]
85fn l2_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
86 let mut sum = 0.0f32;
87 for i in 0..a.len() {
88 let d = a[i] - b[i];
89 sum += d * d;
90 }
91 sum
92}
93
94#[cfg(target_arch = "x86_64")]
95#[target_feature(enable = "sse")]
96unsafe fn l2_squared_sse(a: &[f32], b: &[f32]) -> f32 {
97 use std::arch::x86_64::*;
98
99 let len = a.len();
100 let mut sum = _mm_setzero_ps(); let chunks = len / 4;
104 for i in 0..chunks {
105 let idx = i * 4;
106 let va = _mm_loadu_ps(a.as_ptr().add(idx));
107 let vb = _mm_loadu_ps(b.as_ptr().add(idx));
108 let diff = _mm_sub_ps(va, vb);
109 let sq = _mm_mul_ps(diff, diff);
110 sum = _mm_add_ps(sum, sq);
111 }
112
113 let mut result = horizontal_sum_sse(sum);
115
116 for i in (chunks * 4)..len {
118 let d = a[i] - b[i];
119 result += d * d;
120 }
121
122 result
123}
124
125#[cfg(target_arch = "x86_64")]
126#[target_feature(enable = "sse")]
127#[inline]
128unsafe fn horizontal_sum_sse(v: std::arch::x86_64::__m128) -> f32 {
129 use std::arch::x86_64::*;
130
131 let shuf = _mm_movehdup_ps(v); let sums = _mm_add_ps(v, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let sums2 = _mm_add_ss(sums, shuf2); _mm_cvtss_f32(sums2)
138}
139
140#[cfg(target_arch = "x86_64")]
141#[target_feature(enable = "avx")]
142unsafe fn l2_squared_avx(a: &[f32], b: &[f32]) -> f32 {
143 use std::arch::x86_64::*;
144
145 let len = a.len();
146 let mut sum = _mm256_setzero_ps(); let chunks = len / 8;
150 for i in 0..chunks {
151 let idx = i * 8;
152 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
153 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
154 let diff = _mm256_sub_ps(va, vb);
155 let sq = _mm256_mul_ps(diff, diff);
156 sum = _mm256_add_ps(sum, sq);
157 }
158
159 let mut result = horizontal_sum_avx(sum);
161
162 for i in (chunks * 8)..len {
164 let d = a[i] - b[i];
165 result += d * d;
166 }
167
168 result
169}
170
171#[cfg(target_arch = "x86_64")]
172#[target_feature(enable = "avx")]
173#[inline]
174unsafe fn horizontal_sum_avx(v: std::arch::x86_64::__m256) -> f32 {
175 use std::arch::x86_64::*;
176
177 let high = _mm256_extractf128_ps(v, 1);
179 let low = _mm256_castps256_ps128(v);
180 let sum128 = _mm_add_ps(high, low);
181
182 horizontal_sum_sse(sum128)
184}
185
186#[cfg(target_arch = "x86_64")]
187#[target_feature(enable = "avx", enable = "fma")]
188unsafe fn l2_squared_avx_fma(a: &[f32], b: &[f32]) -> f32 {
189 use std::arch::x86_64::*;
190
191 let len = a.len();
192 let mut sum = _mm256_setzero_ps();
193
194 let chunks = len / 8;
196 for i in 0..chunks {
197 let idx = i * 8;
198 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
199 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
200 let diff = _mm256_sub_ps(va, vb);
201 sum = _mm256_fmadd_ps(diff, diff, sum);
203 }
204
205 let mut result = horizontal_sum_avx(sum);
206
207 for i in (chunks * 8)..len {
209 let d = a[i] - b[i];
210 result += d * d;
211 }
212
213 result
214}
215
216#[inline]
222pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
223 assert_eq!(a.len(), b.len(), "Vector dimensions must match");
224
225 match simd_level() {
226 #[cfg(target_arch = "x86_64")]
227 SimdLevel::AvxFma => unsafe { dot_product_avx_fma(a, b) },
228 #[cfg(target_arch = "x86_64")]
229 SimdLevel::Avx => unsafe { dot_product_avx(a, b) },
230 #[cfg(target_arch = "x86_64")]
231 SimdLevel::Sse => unsafe { dot_product_sse(a, b) },
232 _ => dot_product_scalar(a, b),
233 }
234}
235
236#[inline]
237fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
238 let mut sum = 0.0f32;
239 for i in 0..a.len() {
240 sum += a[i] * b[i];
241 }
242 sum
243}
244
245#[cfg(target_arch = "x86_64")]
246#[target_feature(enable = "sse")]
247unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
248 use std::arch::x86_64::*;
249
250 let len = a.len();
251 let mut sum = _mm_setzero_ps();
252
253 let chunks = len / 4;
254 for i in 0..chunks {
255 let idx = i * 4;
256 let va = _mm_loadu_ps(a.as_ptr().add(idx));
257 let vb = _mm_loadu_ps(b.as_ptr().add(idx));
258 let prod = _mm_mul_ps(va, vb);
259 sum = _mm_add_ps(sum, prod);
260 }
261
262 let mut result = horizontal_sum_sse(sum);
263
264 for i in (chunks * 4)..len {
265 result += a[i] * b[i];
266 }
267
268 result
269}
270
271#[cfg(target_arch = "x86_64")]
272#[target_feature(enable = "avx")]
273unsafe fn dot_product_avx(a: &[f32], b: &[f32]) -> f32 {
274 use std::arch::x86_64::*;
275
276 let len = a.len();
277 let mut sum = _mm256_setzero_ps();
278
279 let chunks = len / 8;
280 for i in 0..chunks {
281 let idx = i * 8;
282 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
283 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
284 let prod = _mm256_mul_ps(va, vb);
285 sum = _mm256_add_ps(sum, prod);
286 }
287
288 let mut result = horizontal_sum_avx(sum);
289
290 for i in (chunks * 8)..len {
291 result += a[i] * b[i];
292 }
293
294 result
295}
296
297#[cfg(target_arch = "x86_64")]
298#[target_feature(enable = "avx", enable = "fma")]
299unsafe fn dot_product_avx_fma(a: &[f32], b: &[f32]) -> f32 {
300 use std::arch::x86_64::*;
301
302 let len = a.len();
303 let mut sum = _mm256_setzero_ps();
304
305 let chunks = len / 8;
306 for i in 0..chunks {
307 let idx = i * 8;
308 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
309 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
310 sum = _mm256_fmadd_ps(va, vb, sum);
312 }
313
314 let mut result = horizontal_sum_avx(sum);
315
316 for i in (chunks * 8)..len {
317 result += a[i] * b[i];
318 }
319
320 result
321}
322
323#[inline]
329pub fn l2_norm_simd(v: &[f32]) -> f32 {
330 dot_product_simd(v, v).sqrt()
331}
332
333#[inline]
339pub fn cosine_distance_simd(a: &[f32], b: &[f32]) -> f32 {
340 let dot = dot_product_simd(a, b);
341 let norm_a = l2_norm_simd(a);
342 let norm_b = l2_norm_simd(b);
343
344 if norm_a == 0.0 || norm_b == 0.0 {
345 return 1.0;
346 }
347
348 let similarity = (dot / (norm_a * norm_b)).clamp(-1.0, 1.0);
349 1.0 - similarity
350}
351
352#[inline]
354pub fn inner_product_distance_simd(a: &[f32], b: &[f32]) -> f32 {
355 -dot_product_simd(a, b)
356}
357
358#[inline]
364pub fn distance_simd(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
365 match metric {
366 DistanceMetric::L2 => l2_squared_simd(a, b),
367 DistanceMetric::Cosine => cosine_distance_simd(a, b),
368 DistanceMetric::InnerProduct => inner_product_distance_simd(a, b),
369 }
370}
371
372pub fn batch_distances(
379 query: &[f32],
380 targets: &[Vec<f32>],
381 metric: DistanceMetric,
382 top_k: usize,
383) -> Vec<(usize, f32)> {
384 let mut results: Vec<(usize, f32)> = targets
385 .iter()
386 .enumerate()
387 .map(|(i, target)| (i, distance_simd(query, target, metric)))
388 .collect();
389
390 if top_k < results.len() {
392 results.select_nth_unstable_by(top_k, |a, b| {
393 a.1.partial_cmp(&b.1)
394 .unwrap_or(std::cmp::Ordering::Equal)
395 .then_with(|| a.0.cmp(&b.0))
396 });
397 results.truncate(top_k);
398 }
399
400 results.sort_by(|a, b| {
401 a.1.partial_cmp(&b.1)
402 .unwrap_or(std::cmp::Ordering::Equal)
403 .then_with(|| a.0.cmp(&b.0))
404 });
405 results
406}
407
408#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_simd_level_detection() {
418 let level = simd_level();
419 println!("Detected SIMD level: {:?}", level);
420 assert!(matches!(
422 level,
423 SimdLevel::Scalar | SimdLevel::Sse | SimdLevel::Avx | SimdLevel::AvxFma
424 ));
425 }
426
427 #[test]
428 fn test_l2_squared_simd_identical() {
429 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
430 let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
431 assert!((l2_squared_simd(&a, &b) - 0.0).abs() < 1e-6);
432 }
433
434 #[test]
435 fn test_l2_squared_simd_simple() {
436 let a = vec![0.0, 0.0, 0.0, 0.0];
437 let b = vec![1.0, 0.0, 0.0, 0.0];
438 assert!((l2_squared_simd(&a, &b) - 1.0).abs() < 1e-6);
439 }
440
441 #[test]
442 fn test_l2_squared_simd_vs_scalar() {
443 let a: Vec<f32> = (0..256).map(|i| i as f32 * 0.1).collect();
444 let b: Vec<f32> = (0..256).map(|i| (i + 1) as f32 * 0.1).collect();
445
446 let simd_result = l2_squared_simd(&a, &b);
447 let scalar_result = l2_squared_scalar(&a, &b);
448
449 assert!(
450 (simd_result - scalar_result).abs() < 1e-3,
451 "SIMD: {}, Scalar: {}",
452 simd_result,
453 scalar_result
454 );
455 }
456
457 #[test]
458 fn test_dot_product_simd() {
459 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
460 let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
461 let result = dot_product_simd(&a, &b);
462 assert!((result - 36.0).abs() < 1e-6); }
464
465 #[test]
466 fn test_dot_product_simd_vs_scalar() {
467 let a: Vec<f32> = (0..256).map(|i| i as f32 * 0.1).collect();
468 let b: Vec<f32> = (0..256).map(|i| (i + 1) as f32 * 0.1).collect();
469
470 let simd_result = dot_product_simd(&a, &b);
471 let scalar_result = dot_product_scalar(&a, &b);
472
473 assert!(
474 (simd_result - scalar_result).abs() < 1.0, "SIMD: {}, Scalar: {}",
476 simd_result,
477 scalar_result
478 );
479 }
480
481 #[test]
482 fn test_l2_norm_simd() {
483 let v = vec![3.0, 4.0, 0.0, 0.0];
484 assert!((l2_norm_simd(&v) - 5.0).abs() < 1e-6);
485 }
486
487 #[test]
488 fn test_cosine_distance_simd_identical() {
489 let a = vec![1.0, 0.0, 0.0, 0.0];
490 let b = vec![1.0, 0.0, 0.0, 0.0];
491 assert!((cosine_distance_simd(&a, &b) - 0.0).abs() < 1e-6);
492 }
493
494 #[test]
495 fn test_cosine_distance_simd_orthogonal() {
496 let a = vec![1.0, 0.0, 0.0, 0.0];
497 let b = vec![0.0, 1.0, 0.0, 0.0];
498 assert!((cosine_distance_simd(&a, &b) - 1.0).abs() < 1e-6);
499 }
500
501 #[test]
502 fn test_batch_distances() {
503 let query = vec![0.0, 0.0, 0.0, 0.0];
504 let targets = vec![
505 vec![1.0, 0.0, 0.0, 0.0], vec![2.0, 0.0, 0.0, 0.0], vec![0.5, 0.0, 0.0, 0.0], ];
509
510 let results = batch_distances(&query, &targets, DistanceMetric::L2, 2);
511 assert_eq!(results.len(), 2);
512 assert_eq!(results[0].0, 2); assert_eq!(results[1].0, 0); }
515
516 #[test]
517 fn test_odd_length_vectors() {
518 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let b = vec![5.0, 4.0, 3.0, 2.0, 1.0];
521
522 let simd_result = l2_squared_simd(&a, &b);
523 let expected = 16.0 + 4.0 + 0.0 + 4.0 + 16.0; assert!((simd_result - expected).abs() < 1e-6);
525 }
526
527 #[test]
528 fn test_large_vectors() {
529 let a: Vec<f32> = (0..1536).map(|i| (i as f32).sin()).collect();
531 let b: Vec<f32> = (0..1536).map(|i| (i as f32).cos()).collect();
532
533 let simd_result = l2_squared_simd(&a, &b);
534 let scalar_result = l2_squared_scalar(&a, &b);
535
536 assert!(
537 (simd_result - scalar_result).abs() / scalar_result.abs() < 1e-5,
538 "Relative error too large: SIMD={}, Scalar={}",
539 simd_result,
540 scalar_result
541 );
542 }
543}