1use ailake_core::{Centroid, VectorMetric};
3use half::f16;
4
5pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
8 debug_assert_eq!(
9 a.len(),
10 b.len(),
11 "dot_product: dimension mismatch {} vs {}",
12 a.len(),
13 b.len()
14 );
15 #[cfg(target_arch = "x86_64")]
16 {
17 #[cfg(feature = "avx512")]
18 if is_x86_feature_detected!("avx512f") {
19 return unsafe { avx512::dot(a, b) };
20 }
21 if is_x86_feature_detected!("avx2") {
22 return unsafe { avx2::dot(a, b) };
23 }
24 }
25 #[cfg(target_arch = "aarch64")]
26 if std::arch::is_aarch64_feature_detected!("neon") {
27 return unsafe { neon_impl::dot(a, b) };
28 }
29 dot_scalar(a, b)
30}
31
32pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
33 debug_assert_eq!(
34 a.len(),
35 b.len(),
36 "euclidean_distance: dimension mismatch {} vs {}",
37 a.len(),
38 b.len()
39 );
40 #[cfg(target_arch = "x86_64")]
41 {
42 #[cfg(feature = "avx512")]
43 if is_x86_feature_detected!("avx512f") {
44 return unsafe { avx512::euclidean(a, b) };
45 }
46 if is_x86_feature_detected!("avx2") {
47 return unsafe { avx2::euclidean(a, b) };
48 }
49 }
50 #[cfg(target_arch = "aarch64")]
51 if std::arch::is_aarch64_feature_detected!("neon") {
52 return unsafe { neon_impl::euclidean(a, b) };
53 }
54 euclidean_scalar(a, b)
55}
56
57pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
58 debug_assert_eq!(
59 a.len(),
60 b.len(),
61 "cosine_distance: dimension mismatch {} vs {}",
62 a.len(),
63 b.len()
64 );
65 #[cfg(target_arch = "x86_64")]
66 {
67 #[cfg(feature = "avx512")]
68 if is_x86_feature_detected!("avx512f") {
69 return unsafe { avx512::cosine(a, b) };
70 }
71 if is_x86_feature_detected!("avx2") {
72 return unsafe { avx2::cosine(a, b) };
73 }
74 }
75 #[cfg(target_arch = "aarch64")]
76 if std::arch::is_aarch64_feature_detected!("neon") {
77 return unsafe { neon_impl::cosine(a, b) };
78 }
79 cosine_scalar(a, b)
80}
81
82pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
83 match metric {
84 VectorMetric::Cosine => cosine_distance(a, b),
85 VectorMetric::Euclidean => euclidean_distance(a, b),
86 VectorMetric::DotProduct => -dot_product(a, b),
87 VectorMetric::NormalizedCosine => normalized_cosine_distance(a, b),
88 }
89}
90
91pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
101 debug_assert_eq!(
102 a.len(),
103 b.len(),
104 "cosine_distance_f16: dimension mismatch {} vs {}",
105 a.len(),
106 b.len()
107 );
108 #[cfg(target_arch = "x86_64")]
109 {
110 #[cfg(feature = "avx512")]
111 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
112 return unsafe { avx512::cosine_f16(a, b) };
113 }
114 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
115 return unsafe { avx2_f16c::cosine(a, b) };
116 }
117 }
118 cosine_f16_scalar(a, b)
119}
120
121pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
122 debug_assert_eq!(
123 a.len(),
124 b.len(),
125 "euclidean_distance_f16: dimension mismatch {} vs {}",
126 a.len(),
127 b.len()
128 );
129 #[cfg(target_arch = "x86_64")]
130 {
131 #[cfg(feature = "avx512")]
132 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
133 return unsafe { avx512::euclidean_f16(a, b) };
134 }
135 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
136 return unsafe { avx2_f16c::euclidean(a, b) };
137 }
138 }
139 euclidean_f16_scalar(a, b)
140}
141
142pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
143 debug_assert_eq!(
144 a.len(),
145 b.len(),
146 "dot_product_f16: dimension mismatch {} vs {}",
147 a.len(),
148 b.len()
149 );
150 #[cfg(target_arch = "x86_64")]
151 {
152 #[cfg(feature = "avx512")]
153 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
154 return unsafe { avx512::dot_f16(a, b) };
155 }
156 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
157 return unsafe { avx2_f16c::dot(a, b) };
158 }
159 }
160 dot_f16_scalar(a, b)
161}
162
163pub fn normalize_l2(v: &[f32]) -> Vec<f32> {
165 let norm_sq: f32 = v.iter().map(|x| x * x).sum();
166 if norm_sq < 1e-12 {
167 return v.to_vec();
168 }
169 let inv = 1.0 / norm_sq.sqrt();
170 v.iter().map(|x| x * inv).collect()
171}
172
173pub fn normalized_cosine_distance(a: &[f32], b: &[f32]) -> f32 {
176 1.0 - dot_product(a, b)
177}
178
179pub fn normalized_cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
180 1.0 - dot_product_f16(a, b)
181}
182
183pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
184 if vectors.is_empty() {
185 return Centroid {
186 values: vec![],
187 radius: 0.0,
188 metric,
189 };
190 }
191 let dim = vectors[0].len();
192 let n = vectors.len() as f32;
193 let centroid: Vec<f32> = (0..dim)
194 .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
195 .collect();
196 let radius = vectors
197 .iter()
198 .map(|v| exact_distance(metric, ¢roid, v))
199 .fold(0.0_f32, f32::max);
200 Centroid {
201 values: centroid,
202 radius,
203 metric,
204 }
205}
206
207#[inline(always)]
210fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
211 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
212}
213
214#[inline(always)]
215fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
216 a.iter()
217 .zip(b.iter())
218 .map(|(x, y)| (x - y) * (x - y))
219 .sum::<f32>()
220 .sqrt()
221}
222
223#[inline(always)]
224fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
225 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
226 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
227 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
228 if na == 0.0 || nb == 0.0 {
229 return 1.0;
230 }
231 1.0 - dot / (na * nb)
232}
233
234#[inline(always)]
235fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
236 let n = a.len().min(b.len());
237 let mut dot = 0.0f32;
238 let mut norm_a = 0.0f32;
239 let mut norm_b = 0.0f32;
240 for i in 0..n {
241 let ai = a[i];
242 let bi = b[i].to_f32();
243 dot += ai * bi;
244 norm_a += ai * ai;
245 norm_b += bi * bi;
246 }
247 let denom = (norm_a * norm_b).sqrt();
248 if denom < 1e-8 {
249 1.0
250 } else {
251 1.0 - dot / denom
252 }
253}
254
255#[inline(always)]
256fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
257 let n = a.len().min(b.len());
258 let mut sum = 0.0f32;
259 for i in 0..n {
260 let diff = a[i] - b[i].to_f32();
261 sum += diff * diff;
262 }
263 sum.sqrt()
264}
265
266#[inline(always)]
267fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
268 let n = a.len().min(b.len());
269 let mut acc = 0.0f32;
270 for i in 0..n {
271 acc += a[i] * b[i].to_f32();
272 }
273 acc
274}
275
276#[cfg(target_arch = "x86_64")]
283mod avx2 {
284 use std::arch::x86_64::*;
285
286 #[inline(always)]
287 pub unsafe fn hsum256(v: __m256) -> f32 {
288 let hi = _mm256_extractf128_ps(v, 1);
289 let lo = _mm256_castps256_ps128(v);
290 let s = _mm_add_ps(lo, hi);
291 let shuf = _mm_movehdup_ps(s);
292 let sums = _mm_add_ps(s, shuf);
293 let shuf = _mm_movehl_ps(shuf, sums);
294 _mm_cvtss_f32(_mm_add_ss(sums, shuf))
295 }
296
297 #[target_feature(enable = "avx2,fma")]
299 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
300 let n = a.len().min(b.len());
301 let ap = a.as_ptr();
302 let bp = b.as_ptr();
303
304 let mut acc0 = _mm256_setzero_ps();
305 let mut acc1 = _mm256_setzero_ps();
306
307 let chunks16 = n / 16;
308 for i in 0..chunks16 {
309 let base = i * 16;
310 let a0 = _mm256_loadu_ps(ap.add(base));
311 let b0 = _mm256_loadu_ps(bp.add(base));
312 let a1 = _mm256_loadu_ps(ap.add(base + 8));
313 let b1 = _mm256_loadu_ps(bp.add(base + 8));
314 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
315 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
316 }
317
318 let chunks8 = n / 8;
319 if chunks8 > chunks16 * 2 {
320 let base = chunks16 * 16;
321 let a0 = _mm256_loadu_ps(ap.add(base));
322 let b0 = _mm256_loadu_ps(bp.add(base));
323 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
324 }
325
326 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
327 for i in (chunks8 * 8)..n {
328 sum += *ap.add(i) * *bp.add(i);
329 }
330 sum
331 }
332
333 #[target_feature(enable = "avx2,fma")]
335 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
336 let n = a.len().min(b.len());
337 let ap = a.as_ptr();
338 let bp = b.as_ptr();
339
340 let mut acc0 = _mm256_setzero_ps();
341 let mut acc1 = _mm256_setzero_ps();
342
343 let chunks16 = n / 16;
344 for i in 0..chunks16 {
345 let base = i * 16;
346 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
347 let d1 = _mm256_sub_ps(
348 _mm256_loadu_ps(ap.add(base + 8)),
349 _mm256_loadu_ps(bp.add(base + 8)),
350 );
351 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
352 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
353 }
354
355 let chunks8 = n / 8;
356 if chunks8 > chunks16 * 2 {
357 let base = chunks16 * 16;
358 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
359 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
360 }
361
362 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
363 for i in (chunks8 * 8)..n {
364 let d = *ap.add(i) - *bp.add(i);
365 sum += d * d;
366 }
367 sum.sqrt()
368 }
369
370 #[target_feature(enable = "avx2,fma")]
372 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
373 let n = a.len().min(b.len());
374 let ap = a.as_ptr();
375 let bp = b.as_ptr();
376
377 let mut dot_acc = _mm256_setzero_ps();
378 let mut na_acc = _mm256_setzero_ps();
379 let mut nb_acc = _mm256_setzero_ps();
380
381 let chunks8 = n / 8;
382 for i in 0..chunks8 {
383 let base = i * 8;
384 let av = _mm256_loadu_ps(ap.add(base));
385 let bv = _mm256_loadu_ps(bp.add(base));
386 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
387 na_acc = _mm256_fmadd_ps(av, av, na_acc);
388 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
389 }
390
391 let mut dot = hsum256(dot_acc);
392 let mut na2 = hsum256(na_acc);
393 let mut nb2 = hsum256(nb_acc);
394
395 for i in (chunks8 * 8)..n {
396 let ai = *ap.add(i);
397 let bi = *bp.add(i);
398 dot += ai * bi;
399 na2 += ai * ai;
400 nb2 += bi * bi;
401 }
402
403 let na = na2.sqrt();
404 let nb = nb2.sqrt();
405 if na == 0.0 || nb == 0.0 {
406 return 1.0;
407 }
408 1.0 - dot / (na * nb)
409 }
410}
411
412#[cfg(target_arch = "x86_64")]
419mod avx2_f16c {
420 use half::f16;
421 use std::arch::x86_64::*;
422
423 use super::avx2::hsum256;
424
425 #[target_feature(enable = "avx2,f16c,fma")]
427 pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
428 let n = a.len().min(b.len());
429 let ap = a.as_ptr();
430 let bp = b.as_ptr() as *const u16;
431
432 let mut acc0 = _mm256_setzero_ps();
433 let mut acc1 = _mm256_setzero_ps();
434
435 let chunks16 = n / 16;
436 for i in 0..chunks16 {
437 let base = i * 16;
438 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
439 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
440 let a0 = _mm256_loadu_ps(ap.add(base));
441 let a1 = _mm256_loadu_ps(ap.add(base + 8));
442 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
443 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
444 }
445
446 let chunks8 = n / 8;
447 if chunks8 > chunks16 * 2 {
448 let base = chunks16 * 16;
449 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
450 let a0 = _mm256_loadu_ps(ap.add(base));
451 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
452 }
453
454 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
455 for i in (chunks8 * 8)..n {
456 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
457 }
458 sum
459 }
460
461 #[target_feature(enable = "avx2,f16c,fma")]
463 pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
464 let n = a.len().min(b.len());
465 let ap = a.as_ptr();
466 let bp = b.as_ptr() as *const u16;
467
468 let mut acc0 = _mm256_setzero_ps();
469 let mut acc1 = _mm256_setzero_ps();
470
471 let chunks16 = n / 16;
472 for i in 0..chunks16 {
473 let base = i * 16;
474 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
475 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
476 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
477 let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
478 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
479 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
480 }
481
482 let chunks8 = n / 8;
483 if chunks8 > chunks16 * 2 {
484 let base = chunks16 * 16;
485 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
486 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
487 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
488 }
489
490 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
491 for i in (chunks8 * 8)..n {
492 let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
493 sum += diff * diff;
494 }
495 sum.sqrt()
496 }
497
498 #[target_feature(enable = "avx2,f16c,fma")]
500 pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
501 let n = a.len().min(b.len());
502 let ap = a.as_ptr();
503 let bp = b.as_ptr() as *const u16;
504
505 let mut dot_acc = _mm256_setzero_ps();
506 let mut na_acc = _mm256_setzero_ps();
507 let mut nb_acc = _mm256_setzero_ps();
508
509 let chunks8 = n / 8;
510 for i in 0..chunks8 {
511 let base = i * 8;
512 let av = _mm256_loadu_ps(ap.add(base));
513 let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
514 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
515 na_acc = _mm256_fmadd_ps(av, av, na_acc);
516 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
517 }
518
519 let mut dot = hsum256(dot_acc);
520 let mut na2 = hsum256(na_acc);
521 let mut nb2 = hsum256(nb_acc);
522
523 for i in (chunks8 * 8)..n {
524 let ai = *ap.add(i);
525 let bi = f16::from_bits(*bp.add(i)).to_f32();
526 dot += ai * bi;
527 na2 += ai * ai;
528 nb2 += bi * bi;
529 }
530
531 let denom = (na2 * nb2).sqrt();
532 if denom < 1e-8 {
533 1.0
534 } else {
535 1.0 - dot / denom
536 }
537 }
538}
539
540#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
548mod avx512 {
549 use half::f16;
550 use std::arch::x86_64::*;
551
552 #[inline(always)]
553 unsafe fn hsum512(v: __m512) -> f32 {
554 let mut buf = [0.0f32; 16];
557 _mm512_storeu_ps(buf.as_mut_ptr(), v);
558 let lo = _mm256_loadu_ps(buf.as_ptr());
559 let hi = _mm256_loadu_ps(buf.as_ptr().add(8));
560 let sum256 = _mm256_add_ps(lo, hi);
561 let hi128 = _mm256_extractf128_ps(sum256, 1);
562 let lo128 = _mm256_castps256_ps128(sum256);
563 let sum128 = _mm_add_ps(lo128, hi128);
564 let shuf = _mm_movehdup_ps(sum128);
565 let sums = _mm_add_ps(sum128, shuf);
566 let shuf2 = _mm_movehl_ps(shuf, sums);
567 _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
568 }
569
570 #[target_feature(enable = "avx512f,fma")]
571 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
572 let n = a.len().min(b.len());
573 let ap = a.as_ptr();
574 let bp = b.as_ptr();
575 let mut acc = _mm512_setzero_ps();
576 let chunks16 = n / 16;
577 for i in 0..chunks16 {
578 let base = i * 16;
579 acc = _mm512_fmadd_ps(
580 _mm512_loadu_ps(ap.add(base)),
581 _mm512_loadu_ps(bp.add(base)),
582 acc,
583 );
584 }
585 let mut sum = hsum512(acc);
586 for i in (chunks16 * 16)..n {
587 sum += *ap.add(i) * *bp.add(i);
588 }
589 sum
590 }
591
592 #[target_feature(enable = "avx512f,fma")]
593 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
594 let n = a.len().min(b.len());
595 let ap = a.as_ptr();
596 let bp = b.as_ptr();
597 let mut acc = _mm512_setzero_ps();
598 let chunks16 = n / 16;
599 for i in 0..chunks16 {
600 let base = i * 16;
601 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
602 acc = _mm512_fmadd_ps(d, d, acc);
603 }
604 let mut sum = hsum512(acc);
605 for i in (chunks16 * 16)..n {
606 let d = *ap.add(i) - *bp.add(i);
607 sum += d * d;
608 }
609 sum.sqrt()
610 }
611
612 #[target_feature(enable = "avx512f,fma")]
613 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
614 let n = a.len().min(b.len());
615 let ap = a.as_ptr();
616 let bp = b.as_ptr();
617 let mut dot_acc = _mm512_setzero_ps();
618 let mut na_acc = _mm512_setzero_ps();
619 let mut nb_acc = _mm512_setzero_ps();
620 let chunks16 = n / 16;
621 for i in 0..chunks16 {
622 let base = i * 16;
623 let av = _mm512_loadu_ps(ap.add(base));
624 let bv = _mm512_loadu_ps(bp.add(base));
625 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
626 na_acc = _mm512_fmadd_ps(av, av, na_acc);
627 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
628 }
629 let mut dot = hsum512(dot_acc);
630 let mut na2 = hsum512(na_acc);
631 let mut nb2 = hsum512(nb_acc);
632 for i in (chunks16 * 16)..n {
633 let ai = *ap.add(i);
634 let bi = *bp.add(i);
635 dot += ai * bi;
636 na2 += ai * ai;
637 nb2 += bi * bi;
638 }
639 let (na, nb) = (na2.sqrt(), nb2.sqrt());
640 if na == 0.0 || nb == 0.0 {
641 return 1.0;
642 }
643 1.0 - dot / (na * nb)
644 }
645
646 #[target_feature(enable = "avx512f,f16c,fma")]
648 pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
649 let n = a.len().min(b.len());
650 let ap = a.as_ptr();
651 let bp = b.as_ptr() as *const u16;
652 let mut acc = _mm512_setzero_ps();
653 let chunks16 = n / 16;
654 for i in 0..chunks16 {
655 let base = i * 16;
656 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
657 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
658 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
659 acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
660 }
661 let mut sum = hsum512(acc);
662 for i in (chunks16 * 16)..n {
663 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
664 }
665 sum
666 }
667
668 #[target_feature(enable = "avx512f,f16c,fma")]
669 pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
670 let n = a.len().min(b.len());
671 let ap = a.as_ptr();
672 let bp = b.as_ptr() as *const u16;
673 let mut acc = _mm512_setzero_ps();
674 let chunks16 = n / 16;
675 for i in 0..chunks16 {
676 let base = i * 16;
677 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
678 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
679 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
680 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
681 acc = _mm512_fmadd_ps(d, d, acc);
682 }
683 let mut sum = hsum512(acc);
684 for i in (chunks16 * 16)..n {
685 let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
686 sum += d * d;
687 }
688 sum.sqrt()
689 }
690
691 #[target_feature(enable = "avx512f,f16c,fma")]
692 pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
693 let n = a.len().min(b.len());
694 let ap = a.as_ptr();
695 let bp = b.as_ptr() as *const u16;
696 let mut dot_acc = _mm512_setzero_ps();
697 let mut na_acc = _mm512_setzero_ps();
698 let mut nb_acc = _mm512_setzero_ps();
699 let chunks16 = n / 16;
700 for i in 0..chunks16 {
701 let base = i * 16;
702 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
703 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
704 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
705 let av = _mm512_loadu_ps(ap.add(base));
706 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
707 na_acc = _mm512_fmadd_ps(av, av, na_acc);
708 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
709 }
710 let mut dot = hsum512(dot_acc);
711 let mut na2 = hsum512(na_acc);
712 let mut nb2 = hsum512(nb_acc);
713 for i in (chunks16 * 16)..n {
714 let ai = *ap.add(i);
715 let bi = f16::from_bits(*bp.add(i)).to_f32();
716 dot += ai * bi;
717 na2 += ai * ai;
718 nb2 += bi * bi;
719 }
720 let denom = (na2 * nb2).sqrt();
721 if denom < 1e-8 {
722 1.0
723 } else {
724 1.0 - dot / denom
725 }
726 }
727}
728
729#[cfg(target_arch = "aarch64")]
732mod neon_impl {
733 use std::arch::aarch64::*;
734
735 #[target_feature(enable = "neon")]
736 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
737 let n = a.len().min(b.len());
738 let mut acc = vdupq_n_f32(0.0);
739 let chunks = n / 4;
740 for i in 0..chunks {
741 let base = i * 4;
742 let av = vld1q_f32(a.as_ptr().add(base));
743 let bv = vld1q_f32(b.as_ptr().add(base));
744 acc = vmlaq_f32(acc, av, bv);
745 }
746 let mut sum = vaddvq_f32(acc);
747 for i in (chunks * 4)..n {
748 sum += a[i] * b[i];
749 }
750 sum
751 }
752
753 #[target_feature(enable = "neon")]
754 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
755 let n = a.len().min(b.len());
756 let mut acc = vdupq_n_f32(0.0);
757 let chunks = n / 4;
758 for i in 0..chunks {
759 let base = i * 4;
760 let d = vsubq_f32(
761 vld1q_f32(a.as_ptr().add(base)),
762 vld1q_f32(b.as_ptr().add(base)),
763 );
764 acc = vmlaq_f32(acc, d, d);
765 }
766 let mut sum = vaddvq_f32(acc);
767 for i in (chunks * 4)..n {
768 let d = a[i] - b[i];
769 sum += d * d;
770 }
771 sum.sqrt()
772 }
773
774 #[target_feature(enable = "neon")]
775 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
776 let n = a.len().min(b.len());
777 let mut dot_acc = vdupq_n_f32(0.0);
778 let mut na_acc = vdupq_n_f32(0.0);
779 let mut nb_acc = vdupq_n_f32(0.0);
780 let chunks = n / 4;
781 for i in 0..chunks {
782 let base = i * 4;
783 let av = vld1q_f32(a.as_ptr().add(base));
784 let bv = vld1q_f32(b.as_ptr().add(base));
785 dot_acc = vmlaq_f32(dot_acc, av, bv);
786 na_acc = vmlaq_f32(na_acc, av, av);
787 nb_acc = vmlaq_f32(nb_acc, bv, bv);
788 }
789 let mut dot = vaddvq_f32(dot_acc);
790 let mut na2 = vaddvq_f32(na_acc);
791 let mut nb2 = vaddvq_f32(nb_acc);
792 for i in (chunks * 4)..n {
793 dot += a[i] * b[i];
794 na2 += a[i] * a[i];
795 nb2 += b[i] * b[i];
796 }
797 let (na, nb) = (na2.sqrt(), nb2.sqrt());
798 if na == 0.0 || nb == 0.0 {
799 return 1.0;
800 }
801 1.0 - dot / (na * nb)
802 }
803}
804
805#[cfg(test)]
808mod tests {
809 use super::*;
810
811 #[test]
812 fn cosine_identical() {
813 let v = vec![1.0f32, 0.0, 0.0];
814 assert!(cosine_distance(&v, &v).abs() < 1e-5);
815 }
816
817 #[test]
818 fn cosine_orthogonal() {
819 assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
820 }
821
822 #[test]
823 fn euclidean_basic() {
824 assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
825 }
826
827 #[test]
828 fn dot_basic() {
829 assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
830 }
831
832 #[test]
833 fn simd_matches_scalar_dim128() {
834 use rand::{rngs::StdRng, Rng, SeedableRng};
835 let mut rng = StdRng::seed_from_u64(99);
836 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
837 let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
838
839 let dot_s = dot_scalar(&a, &b);
840 let euclid_s = euclidean_scalar(&a, &b);
841 let cos_s = cosine_scalar(&a, &b);
842
843 let dot_f = dot_product(&a, &b);
844 let euclid_f = euclidean_distance(&a, &b);
845 let cos_f = cosine_distance(&a, &b);
846
847 assert!(
848 (dot_f - dot_s).abs() < 1e-4,
849 "dot mismatch: {dot_f} vs {dot_s}"
850 );
851 assert!(
852 (euclid_f - euclid_s).abs() < 1e-4,
853 "euclidean mismatch: {euclid_f} vs {euclid_s}"
854 );
855 assert!(
856 (cos_f - cos_s).abs() < 1e-4,
857 "cosine mismatch: {cos_f} vs {cos_s}"
858 );
859 }
860
861 #[test]
862 fn f16_simd_matches_scalar() {
863 use rand::{rngs::StdRng, Rng, SeedableRng};
864 let mut rng = StdRng::seed_from_u64(42);
865 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
866 let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
867 let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
868
869 let dot_s = dot_f16_scalar(&a, &b);
870 let euclid_s = euclidean_f16_scalar(&a, &b);
871 let cos_s = cosine_f16_scalar(&a, &b);
872
873 let dot_f = dot_product_f16(&a, &b);
874 let euclid_f = euclidean_distance_f16(&a, &b);
875 let cos_f = cosine_distance_f16(&a, &b);
876
877 assert!(
879 (dot_f - dot_s).abs() < 1e-3,
880 "f16 dot mismatch: {dot_f} vs {dot_s}"
881 );
882 assert!(
883 (euclid_f - euclid_s).abs() < 1e-3,
884 "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
885 );
886 assert!(
887 (cos_f - cos_s).abs() < 1e-3,
888 "f16 cosine mismatch: {cos_f} vs {cos_s}"
889 );
890 }
891
892 #[test]
893 fn normalize_l2_unit() {
894 let v = vec![3.0f32, 4.0];
895 let n = normalize_l2(&v);
896 let norm: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
897 assert!((norm - 1.0).abs() < 1e-6, "norm={norm}");
898 assert!((n[0] - 0.6).abs() < 1e-6);
899 assert!((n[1] - 0.8).abs() < 1e-6);
900 }
901
902 #[test]
903 fn normalized_cosine_matches_cosine_on_unit_vecs() {
904 let a = normalize_l2(&[1.0f32, 1.0, 0.0]);
905 let b = normalize_l2(&[1.0f32, 0.0, 1.0]);
906 let cos = cosine_distance(&a, &b);
907 let ncos = normalized_cosine_distance(&a, &b);
908 assert!((cos - ncos).abs() < 1e-5, "cos={cos} ncos={ncos}");
909 }
910
911 #[test]
912 fn centroid_single() {
913 let v = vec![vec![1.0f32, 2.0, 3.0]];
914 let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
915 assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
916 assert!(c.radius < 1e-6, "radius={}", c.radius);
917 }
918
919 #[test]
920 fn centroid_two_points() {
921 let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
922 let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
923 assert!((c.values[0] - 1.0).abs() < 1e-6);
924 assert!(c.radius > 0.0);
925 }
926}