1#[cfg(target_arch = "aarch64")]
18use std::arch::is_aarch64_feature_detected;
19#[cfg(target_arch = "x86_64")]
20use std::arch::is_x86_feature_detected;
21
22#[inline]
24pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
25 debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
26
27 #[cfg(target_arch = "aarch64")]
28 {
29 if is_aarch64_feature_detected!("neon") {
30 return unsafe { l2_distance_neon(a, b) };
31 }
32 }
33
34 #[cfg(target_arch = "x86_64")]
35 {
36 if is_x86_feature_detected!("avx2") {
37 return unsafe { l2_distance_avx2(a, b) };
38 }
39 if is_x86_feature_detected!("avx") {
40 return unsafe { l2_distance_avx(a, b) };
41 }
42 if is_x86_feature_detected!("sse") {
43 return unsafe { l2_distance_sse(a, b) };
44 }
45 }
46
47 l2_distance_scalar(a, b)
49}
50
51#[inline]
53pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
54 debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
55
56 #[cfg(target_arch = "aarch64")]
57 {
58 if is_aarch64_feature_detected!("neon") {
59 return unsafe { dot_product_neon(a, b) };
60 }
61 }
62
63 #[cfg(target_arch = "x86_64")]
64 {
65 if is_x86_feature_detected!("avx2") {
66 return unsafe { dot_product_avx2(a, b) };
67 }
68 if is_x86_feature_detected!("avx") {
69 return unsafe { dot_product_avx(a, b) };
70 }
71 if is_x86_feature_detected!("sse") {
72 return unsafe { dot_product_sse(a, b) };
73 }
74 }
75
76 dot_product_scalar(a, b)
78}
79
80#[inline]
82pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
83 debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
84
85 #[cfg(target_arch = "aarch64")]
86 {
87 if is_aarch64_feature_detected!("neon") {
88 return unsafe { cosine_distance_neon(a, b) };
89 }
90 }
91
92 #[cfg(target_arch = "x86_64")]
93 {
94 if is_x86_feature_detected!("avx2") {
95 return unsafe { cosine_distance_avx2(a, b) };
96 }
97 if is_x86_feature_detected!("avx") {
98 return unsafe { cosine_distance_avx(a, b) };
99 }
100 if is_x86_feature_detected!("sse") {
101 return unsafe { cosine_distance_sse(a, b) };
102 }
103 }
104
105 cosine_distance_scalar(a, b)
107}
108
109#[cfg(target_arch = "aarch64")]
114#[target_feature(enable = "neon")]
115unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
116 use std::arch::aarch64::*;
117
118 let len = a.len();
119 let mut sum = vdupq_n_f32(0.0);
120 let mut i = 0;
121
122 while i + 4 <= len {
124 let va = vld1q_f32(a.as_ptr().add(i));
125 let vb = vld1q_f32(b.as_ptr().add(i));
126 let diff = vsubq_f32(va, vb);
127 sum = vfmaq_f32(sum, diff, diff); i += 4;
129 }
130
131 let mut result = vaddvq_f32(sum);
133
134 while i < len {
136 let diff = a[i] - b[i];
137 result += diff * diff;
138 i += 1;
139 }
140
141 result.sqrt()
142}
143
144#[cfg(target_arch = "aarch64")]
145#[target_feature(enable = "neon")]
146unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
147 use std::arch::aarch64::*;
148
149 let len = a.len();
150 let mut sum = vdupq_n_f32(0.0);
151 let mut i = 0;
152
153 while i + 4 <= len {
155 let va = vld1q_f32(a.as_ptr().add(i));
156 let vb = vld1q_f32(b.as_ptr().add(i));
157 sum = vfmaq_f32(sum, va, vb); i += 4;
159 }
160
161 let mut result = vaddvq_f32(sum);
163
164 while i < len {
166 result += a[i] * b[i];
167 i += 1;
168 }
169
170 result
171}
172
173#[cfg(target_arch = "aarch64")]
174#[target_feature(enable = "neon")]
175unsafe fn cosine_distance_neon(a: &[f32], b: &[f32]) -> f32 {
176 use std::arch::aarch64::*;
177
178 let len = a.len();
179 let mut dot = vdupq_n_f32(0.0);
180 let mut norm_a = vdupq_n_f32(0.0);
181 let mut norm_b = vdupq_n_f32(0.0);
182 let mut i = 0;
183
184 while i + 4 <= len {
186 let va = vld1q_f32(a.as_ptr().add(i));
187 let vb = vld1q_f32(b.as_ptr().add(i));
188 dot = vfmaq_f32(dot, va, vb);
189 norm_a = vfmaq_f32(norm_a, va, va);
190 norm_b = vfmaq_f32(norm_b, vb, vb);
191 i += 4;
192 }
193
194 let mut dot_sum = vaddvq_f32(dot);
196 let mut norm_a_sum = vaddvq_f32(norm_a);
197 let mut norm_b_sum = vaddvq_f32(norm_b);
198
199 while i < len {
201 dot_sum += a[i] * b[i];
202 norm_a_sum += a[i] * a[i];
203 norm_b_sum += b[i] * b[i];
204 i += 1;
205 }
206
207 let similarity = dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt());
208 1.0 - similarity
209}
210
211#[cfg(target_arch = "x86_64")]
216#[target_feature(enable = "sse")]
217unsafe fn l2_distance_sse(a: &[f32], b: &[f32]) -> f32 {
218 use std::arch::x86_64::*;
219
220 let len = a.len();
221 let mut sum = _mm_setzero_ps();
222 let mut i = 0;
223
224 while i + 4 <= len {
226 let va = _mm_loadu_ps(a.as_ptr().add(i));
227 let vb = _mm_loadu_ps(b.as_ptr().add(i));
228 let diff = _mm_sub_ps(va, vb);
229 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
230 i += 4;
231 }
232
233 let mut result = horizontal_sum_sse(sum);
235
236 while i < len {
238 let diff = a[i] - b[i];
239 result += diff * diff;
240 i += 1;
241 }
242
243 result.sqrt()
244}
245
246#[cfg(target_arch = "x86_64")]
247#[target_feature(enable = "sse")]
248unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
249 use std::arch::x86_64::*;
250
251 let len = a.len();
252 let mut sum = _mm_setzero_ps();
253 let mut i = 0;
254
255 while i + 4 <= len {
257 let va = _mm_loadu_ps(a.as_ptr().add(i));
258 let vb = _mm_loadu_ps(b.as_ptr().add(i));
259 sum = _mm_add_ps(sum, _mm_mul_ps(va, vb));
260 i += 4;
261 }
262
263 let mut result = horizontal_sum_sse(sum);
265
266 while i < len {
268 result += a[i] * b[i];
269 i += 1;
270 }
271
272 result
273}
274
275#[cfg(target_arch = "x86_64")]
276#[target_feature(enable = "sse")]
277unsafe fn cosine_distance_sse(a: &[f32], b: &[f32]) -> f32 {
278 use std::arch::x86_64::*;
279
280 let len = a.len();
281 let mut dot = _mm_setzero_ps();
282 let mut norm_a = _mm_setzero_ps();
283 let mut norm_b = _mm_setzero_ps();
284 let mut i = 0;
285
286 while i + 4 <= len {
288 let va = _mm_loadu_ps(a.as_ptr().add(i));
289 let vb = _mm_loadu_ps(b.as_ptr().add(i));
290 dot = _mm_add_ps(dot, _mm_mul_ps(va, vb));
291 norm_a = _mm_add_ps(norm_a, _mm_mul_ps(va, va));
292 norm_b = _mm_add_ps(norm_b, _mm_mul_ps(vb, vb));
293 i += 4;
294 }
295
296 let mut dot_sum = horizontal_sum_sse(dot);
298 let mut norm_a_sum = horizontal_sum_sse(norm_a);
299 let mut norm_b_sum = horizontal_sum_sse(norm_b);
300
301 while i < len {
303 dot_sum += a[i] * b[i];
304 norm_a_sum += a[i] * a[i];
305 norm_b_sum += b[i] * b[i];
306 i += 1;
307 }
308
309 let similarity = dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt());
310 1.0 - similarity
311}
312
313#[cfg(target_arch = "x86_64")]
314#[inline]
315unsafe fn horizontal_sum_sse(v: std::arch::x86_64::__m128) -> f32 {
316 use std::arch::x86_64::*;
317
318 let shuf = _mm_movehdup_ps(v);
319 let sums = _mm_add_ps(v, shuf);
320 let shuf = _mm_movehl_ps(shuf, sums);
321 let result = _mm_add_ss(sums, shuf);
322 _mm_cvtss_f32(result)
323}
324
325#[cfg(target_arch = "x86_64")]
330#[target_feature(enable = "avx")]
331unsafe fn l2_distance_avx(a: &[f32], b: &[f32]) -> f32 {
332 use std::arch::x86_64::*;
333
334 let len = a.len();
335 let mut sum = _mm256_setzero_ps();
336 let mut i = 0;
337
338 while i + 8 <= len {
340 let va = _mm256_loadu_ps(a.as_ptr().add(i));
341 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
342 let diff = _mm256_sub_ps(va, vb);
343 sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
344 i += 8;
345 }
346
347 let mut result = horizontal_sum_avx(sum);
349
350 while i < len {
352 let diff = a[i] - b[i];
353 result += diff * diff;
354 i += 1;
355 }
356
357 result.sqrt()
358}
359
360#[cfg(target_arch = "x86_64")]
361#[target_feature(enable = "avx")]
362unsafe fn dot_product_avx(a: &[f32], b: &[f32]) -> f32 {
363 use std::arch::x86_64::*;
364
365 let len = a.len();
366 let mut sum = _mm256_setzero_ps();
367 let mut i = 0;
368
369 while i + 8 <= len {
371 let va = _mm256_loadu_ps(a.as_ptr().add(i));
372 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
373 sum = _mm256_add_ps(sum, _mm256_mul_ps(va, vb));
374 i += 8;
375 }
376
377 let mut result = horizontal_sum_avx(sum);
379
380 while i < len {
382 result += a[i] * b[i];
383 i += 1;
384 }
385
386 result
387}
388
389#[cfg(target_arch = "x86_64")]
390#[target_feature(enable = "avx")]
391unsafe fn cosine_distance_avx(a: &[f32], b: &[f32]) -> f32 {
392 use std::arch::x86_64::*;
393
394 let len = a.len();
395 let mut dot = _mm256_setzero_ps();
396 let mut norm_a = _mm256_setzero_ps();
397 let mut norm_b = _mm256_setzero_ps();
398 let mut i = 0;
399
400 while i + 8 <= len {
402 let va = _mm256_loadu_ps(a.as_ptr().add(i));
403 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
404 dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
405 norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
406 norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
407 i += 8;
408 }
409
410 let mut dot_sum = horizontal_sum_avx(dot);
412 let mut norm_a_sum = horizontal_sum_avx(norm_a);
413 let mut norm_b_sum = horizontal_sum_avx(norm_b);
414
415 while i < len {
417 dot_sum += a[i] * b[i];
418 norm_a_sum += a[i] * a[i];
419 norm_b_sum += b[i] * b[i];
420 i += 1;
421 }
422
423 let similarity = dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt());
424 1.0 - similarity
425}
426
427#[cfg(target_arch = "x86_64")]
428#[inline]
429unsafe fn horizontal_sum_avx(v: std::arch::x86_64::__m256) -> f32 {
430 use std::arch::x86_64::*;
431
432 let hi = _mm256_extractf128_ps(v, 1);
433 let lo = _mm256_castps256_ps128(v);
434 let sum128 = _mm_add_ps(hi, lo);
435 horizontal_sum_sse(sum128)
436}
437
438#[cfg(target_arch = "x86_64")]
443#[target_feature(enable = "avx2")]
444unsafe fn l2_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
445 l2_distance_avx(a, b)
447}
448
449#[cfg(target_arch = "x86_64")]
450#[target_feature(enable = "avx2")]
451unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
452 dot_product_avx(a, b)
454}
455
456#[cfg(target_arch = "x86_64")]
457#[target_feature(enable = "avx2")]
458unsafe fn cosine_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
459 cosine_distance_avx(a, b)
461}
462
463#[inline]
468fn l2_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
469 a.iter()
470 .zip(b.iter())
471 .map(|(x, y)| {
472 let diff = x - y;
473 diff * diff
474 })
475 .sum::<f32>()
476 .sqrt()
477}
478
479#[inline]
480fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
481 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
482}
483
484#[inline]
485fn cosine_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
486 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
487 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
488 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
489 1.0 - (dot / (norm_a * norm_b))
490}
491
492#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_l2_distance() {
502 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
503 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
504
505 let dist = l2_distance(&a, &b);
506 let expected = (8.0_f32).sqrt(); assert!((dist - expected).abs() < 1e-5, "L2 distance mismatch");
509 }
510
511 #[test]
512 fn test_dot_product() {
513 let a = vec![1.0, 2.0, 3.0, 4.0];
514 let b = vec![5.0, 6.0, 7.0, 8.0];
515
516 let dot = dot_product(&a, &b);
517 let expected = 1.0 * 5.0 + 2.0 * 6.0 + 3.0 * 7.0 + 4.0 * 8.0;
518
519 assert!((dot - expected).abs() < 1e-5, "Dot product mismatch");
520 }
521
522 #[test]
523 fn test_cosine_distance() {
524 let a = vec![1.0, 0.0, 0.0, 0.0];
525 let b = vec![1.0, 0.0, 0.0, 0.0];
526
527 let dist = cosine_distance(&a, &b);
528
529 assert!(
531 dist.abs() < 1e-5,
532 "Cosine distance should be 0 for identical vectors"
533 );
534 }
535
536 #[test]
537 fn test_cosine_distance_orthogonal() {
538 let a = vec![1.0, 0.0, 0.0, 0.0];
539 let b = vec![0.0, 1.0, 0.0, 0.0];
540
541 let dist = cosine_distance(&a, &b);
542
543 assert!(
545 (dist - 1.0).abs() < 1e-5,
546 "Cosine distance should be 1 for orthogonal vectors"
547 );
548 }
549
550 #[test]
551 fn test_simd_vs_scalar_l2() {
552 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
553 let b: Vec<f32> = (0..128).map(|i| (i as f32 + 1.0) * 0.1).collect();
554
555 let simd_result = l2_distance(&a, &b);
556 let scalar_result = l2_distance_scalar(&a, &b);
557
558 assert!(
559 (simd_result - scalar_result).abs() < 1e-4,
560 "SIMD and scalar L2 results should match"
561 );
562 }
563
564 #[test]
565 fn test_simd_vs_scalar_dot() {
566 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
567 let b: Vec<f32> = (0..128).map(|i| (i as f32 + 1.0) * 0.1).collect();
568
569 let simd_result = dot_product(&a, &b);
570 let scalar_result = dot_product_scalar(&a, &b);
571
572 assert!(
573 (simd_result - scalar_result).abs() < 1e-3,
574 "SIMD and scalar dot product results should match"
575 );
576 }
577
578 #[test]
579 fn test_simd_vs_scalar_cosine() {
580 let a: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 1.0).collect();
581 let b: Vec<f32> = (0..128).map(|i| ((i as f32 + 1.0) * 0.1) + 1.0).collect();
582
583 let simd_result = cosine_distance(&a, &b);
584 let scalar_result = cosine_distance_scalar(&a, &b);
585
586 assert!(
587 (simd_result - scalar_result).abs() < 1e-4,
588 "SIMD and scalar cosine results should match"
589 );
590 }
591}