1use ailake_core::{Centroid, VectorMetric};
3use half::f16;
4
5pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
8 #[cfg(target_arch = "x86_64")]
9 {
10 #[cfg(feature = "avx512")]
11 if is_x86_feature_detected!("avx512f") {
12 return unsafe { avx512::dot(a, b) };
13 }
14 if is_x86_feature_detected!("avx2") {
15 return unsafe { avx2::dot(a, b) };
16 }
17 }
18 #[cfg(target_arch = "aarch64")]
19 if std::arch::is_aarch64_feature_detected!("neon") {
20 return unsafe { neon_impl::dot(a, b) };
21 }
22 dot_scalar(a, b)
23}
24
25pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
26 #[cfg(target_arch = "x86_64")]
27 {
28 #[cfg(feature = "avx512")]
29 if is_x86_feature_detected!("avx512f") {
30 return unsafe { avx512::euclidean(a, b) };
31 }
32 if is_x86_feature_detected!("avx2") {
33 return unsafe { avx2::euclidean(a, b) };
34 }
35 }
36 #[cfg(target_arch = "aarch64")]
37 if std::arch::is_aarch64_feature_detected!("neon") {
38 return unsafe { neon_impl::euclidean(a, b) };
39 }
40 euclidean_scalar(a, b)
41}
42
43pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
44 #[cfg(target_arch = "x86_64")]
45 {
46 #[cfg(feature = "avx512")]
47 if is_x86_feature_detected!("avx512f") {
48 return unsafe { avx512::cosine(a, b) };
49 }
50 if is_x86_feature_detected!("avx2") {
51 return unsafe { avx2::cosine(a, b) };
52 }
53 }
54 #[cfg(target_arch = "aarch64")]
55 if std::arch::is_aarch64_feature_detected!("neon") {
56 return unsafe { neon_impl::cosine(a, b) };
57 }
58 cosine_scalar(a, b)
59}
60
61pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
62 match metric {
63 VectorMetric::Cosine => cosine_distance(a, b),
64 VectorMetric::Euclidean => euclidean_distance(a, b),
65 VectorMetric::DotProduct => -dot_product(a, b),
66 }
67}
68
69pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
79 #[cfg(target_arch = "x86_64")]
80 {
81 #[cfg(feature = "avx512")]
82 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
83 return unsafe { avx512::cosine_f16(a, b) };
84 }
85 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
86 return unsafe { avx2_f16c::cosine(a, b) };
87 }
88 }
89 cosine_f16_scalar(a, b)
90}
91
92pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
93 #[cfg(target_arch = "x86_64")]
94 {
95 #[cfg(feature = "avx512")]
96 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
97 return unsafe { avx512::euclidean_f16(a, b) };
98 }
99 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
100 return unsafe { avx2_f16c::euclidean(a, b) };
101 }
102 }
103 euclidean_f16_scalar(a, b)
104}
105
106pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
107 #[cfg(target_arch = "x86_64")]
108 {
109 #[cfg(feature = "avx512")]
110 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
111 return unsafe { avx512::dot_f16(a, b) };
112 }
113 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
114 return unsafe { avx2_f16c::dot(a, b) };
115 }
116 }
117 dot_f16_scalar(a, b)
118}
119
120pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
121 if vectors.is_empty() {
122 return Centroid {
123 values: vec![],
124 radius: 0.0,
125 metric,
126 };
127 }
128 let dim = vectors[0].len();
129 let n = vectors.len() as f32;
130 let centroid: Vec<f32> = (0..dim)
131 .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
132 .collect();
133 let radius = vectors
134 .iter()
135 .map(|v| exact_distance(metric, ¢roid, v))
136 .fold(0.0_f32, f32::max);
137 Centroid {
138 values: centroid,
139 radius,
140 metric,
141 }
142}
143
144#[inline(always)]
147fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
148 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
149}
150
151#[inline(always)]
152fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
153 a.iter()
154 .zip(b.iter())
155 .map(|(x, y)| (x - y) * (x - y))
156 .sum::<f32>()
157 .sqrt()
158}
159
160#[inline(always)]
161fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
162 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
163 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
164 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
165 if na == 0.0 || nb == 0.0 {
166 return 1.0;
167 }
168 1.0 - dot / (na * nb)
169}
170
171#[inline(always)]
172fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
173 let n = a.len().min(b.len());
174 let mut dot = 0.0f32;
175 let mut norm_a = 0.0f32;
176 let mut norm_b = 0.0f32;
177 for i in 0..n {
178 let ai = a[i];
179 let bi = b[i].to_f32();
180 dot += ai * bi;
181 norm_a += ai * ai;
182 norm_b += bi * bi;
183 }
184 let denom = (norm_a * norm_b).sqrt();
185 if denom < 1e-8 {
186 1.0
187 } else {
188 1.0 - dot / denom
189 }
190}
191
192#[inline(always)]
193fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
194 let n = a.len().min(b.len());
195 let mut sum = 0.0f32;
196 for i in 0..n {
197 let diff = a[i] - b[i].to_f32();
198 sum += diff * diff;
199 }
200 sum.sqrt()
201}
202
203#[inline(always)]
204fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
205 let n = a.len().min(b.len());
206 let mut acc = 0.0f32;
207 for i in 0..n {
208 acc += a[i] * b[i].to_f32();
209 }
210 acc
211}
212
213#[cfg(target_arch = "x86_64")]
220mod avx2 {
221 use std::arch::x86_64::*;
222
223 #[inline(always)]
224 pub unsafe fn hsum256(v: __m256) -> f32 {
225 let hi = _mm256_extractf128_ps(v, 1);
226 let lo = _mm256_castps256_ps128(v);
227 let s = _mm_add_ps(lo, hi);
228 let shuf = _mm_movehdup_ps(s);
229 let sums = _mm_add_ps(s, shuf);
230 let shuf = _mm_movehl_ps(shuf, sums);
231 _mm_cvtss_f32(_mm_add_ss(sums, shuf))
232 }
233
234 #[target_feature(enable = "avx2,fma")]
236 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
237 let n = a.len().min(b.len());
238 let ap = a.as_ptr();
239 let bp = b.as_ptr();
240
241 let mut acc0 = _mm256_setzero_ps();
242 let mut acc1 = _mm256_setzero_ps();
243
244 let chunks16 = n / 16;
245 for i in 0..chunks16 {
246 let base = i * 16;
247 let a0 = _mm256_loadu_ps(ap.add(base));
248 let b0 = _mm256_loadu_ps(bp.add(base));
249 let a1 = _mm256_loadu_ps(ap.add(base + 8));
250 let b1 = _mm256_loadu_ps(bp.add(base + 8));
251 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
252 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
253 }
254
255 let chunks8 = n / 8;
256 if chunks8 > chunks16 * 2 {
257 let base = chunks16 * 16;
258 let a0 = _mm256_loadu_ps(ap.add(base));
259 let b0 = _mm256_loadu_ps(bp.add(base));
260 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
261 }
262
263 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
264 for i in (chunks8 * 8)..n {
265 sum += *ap.add(i) * *bp.add(i);
266 }
267 sum
268 }
269
270 #[target_feature(enable = "avx2,fma")]
272 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
273 let n = a.len().min(b.len());
274 let ap = a.as_ptr();
275 let bp = b.as_ptr();
276
277 let mut acc0 = _mm256_setzero_ps();
278 let mut acc1 = _mm256_setzero_ps();
279
280 let chunks16 = n / 16;
281 for i in 0..chunks16 {
282 let base = i * 16;
283 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
284 let d1 = _mm256_sub_ps(
285 _mm256_loadu_ps(ap.add(base + 8)),
286 _mm256_loadu_ps(bp.add(base + 8)),
287 );
288 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
289 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
290 }
291
292 let chunks8 = n / 8;
293 if chunks8 > chunks16 * 2 {
294 let base = chunks16 * 16;
295 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
296 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
297 }
298
299 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
300 for i in (chunks8 * 8)..n {
301 let d = *ap.add(i) - *bp.add(i);
302 sum += d * d;
303 }
304 sum.sqrt()
305 }
306
307 #[target_feature(enable = "avx2,fma")]
309 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
310 let n = a.len().min(b.len());
311 let ap = a.as_ptr();
312 let bp = b.as_ptr();
313
314 let mut dot_acc = _mm256_setzero_ps();
315 let mut na_acc = _mm256_setzero_ps();
316 let mut nb_acc = _mm256_setzero_ps();
317
318 let chunks8 = n / 8;
319 for i in 0..chunks8 {
320 let base = i * 8;
321 let av = _mm256_loadu_ps(ap.add(base));
322 let bv = _mm256_loadu_ps(bp.add(base));
323 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
324 na_acc = _mm256_fmadd_ps(av, av, na_acc);
325 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
326 }
327
328 let mut dot = hsum256(dot_acc);
329 let mut na2 = hsum256(na_acc);
330 let mut nb2 = hsum256(nb_acc);
331
332 for i in (chunks8 * 8)..n {
333 let ai = *ap.add(i);
334 let bi = *bp.add(i);
335 dot += ai * bi;
336 na2 += ai * ai;
337 nb2 += bi * bi;
338 }
339
340 let na = na2.sqrt();
341 let nb = nb2.sqrt();
342 if na == 0.0 || nb == 0.0 {
343 return 1.0;
344 }
345 1.0 - dot / (na * nb)
346 }
347}
348
349#[cfg(target_arch = "x86_64")]
356mod avx2_f16c {
357 use half::f16;
358 use std::arch::x86_64::*;
359
360 use super::avx2::hsum256;
361
362 #[target_feature(enable = "avx2,f16c,fma")]
364 pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
365 let n = a.len().min(b.len());
366 let ap = a.as_ptr();
367 let bp = b.as_ptr() as *const u16;
368
369 let mut acc0 = _mm256_setzero_ps();
370 let mut acc1 = _mm256_setzero_ps();
371
372 let chunks16 = n / 16;
373 for i in 0..chunks16 {
374 let base = i * 16;
375 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
376 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
377 let a0 = _mm256_loadu_ps(ap.add(base));
378 let a1 = _mm256_loadu_ps(ap.add(base + 8));
379 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
380 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
381 }
382
383 let chunks8 = n / 8;
384 if chunks8 > chunks16 * 2 {
385 let base = chunks16 * 16;
386 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
387 let a0 = _mm256_loadu_ps(ap.add(base));
388 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
389 }
390
391 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
392 for i in (chunks8 * 8)..n {
393 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
394 }
395 sum
396 }
397
398 #[target_feature(enable = "avx2,f16c,fma")]
400 pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
401 let n = a.len().min(b.len());
402 let ap = a.as_ptr();
403 let bp = b.as_ptr() as *const u16;
404
405 let mut acc0 = _mm256_setzero_ps();
406 let mut acc1 = _mm256_setzero_ps();
407
408 let chunks16 = n / 16;
409 for i in 0..chunks16 {
410 let base = i * 16;
411 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
412 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
413 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
414 let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
415 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
416 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
417 }
418
419 let chunks8 = n / 8;
420 if chunks8 > chunks16 * 2 {
421 let base = chunks16 * 16;
422 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
423 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
424 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
425 }
426
427 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
428 for i in (chunks8 * 8)..n {
429 let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
430 sum += diff * diff;
431 }
432 sum.sqrt()
433 }
434
435 #[target_feature(enable = "avx2,f16c,fma")]
437 pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
438 let n = a.len().min(b.len());
439 let ap = a.as_ptr();
440 let bp = b.as_ptr() as *const u16;
441
442 let mut dot_acc = _mm256_setzero_ps();
443 let mut na_acc = _mm256_setzero_ps();
444 let mut nb_acc = _mm256_setzero_ps();
445
446 let chunks8 = n / 8;
447 for i in 0..chunks8 {
448 let base = i * 8;
449 let av = _mm256_loadu_ps(ap.add(base));
450 let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
451 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
452 na_acc = _mm256_fmadd_ps(av, av, na_acc);
453 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
454 }
455
456 let mut dot = hsum256(dot_acc);
457 let mut na2 = hsum256(na_acc);
458 let mut nb2 = hsum256(nb_acc);
459
460 for i in (chunks8 * 8)..n {
461 let ai = *ap.add(i);
462 let bi = f16::from_bits(*bp.add(i)).to_f32();
463 dot += ai * bi;
464 na2 += ai * ai;
465 nb2 += bi * bi;
466 }
467
468 let denom = (na2 * nb2).sqrt();
469 if denom < 1e-8 {
470 1.0
471 } else {
472 1.0 - dot / denom
473 }
474 }
475}
476
477#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
485mod avx512 {
486 use half::f16;
487 use std::arch::x86_64::*;
488
489 #[inline(always)]
490 unsafe fn hsum512(v: __m512) -> f32 {
491 let mut buf = [0.0f32; 16];
494 _mm512_storeu_ps(buf.as_mut_ptr(), v);
495 let lo = _mm256_loadu_ps(buf.as_ptr());
496 let hi = _mm256_loadu_ps(buf.as_ptr().add(8));
497 let sum256 = _mm256_add_ps(lo, hi);
498 let hi128 = _mm256_extractf128_ps(sum256, 1);
499 let lo128 = _mm256_castps256_ps128(sum256);
500 let sum128 = _mm_add_ps(lo128, hi128);
501 let shuf = _mm_movehdup_ps(sum128);
502 let sums = _mm_add_ps(sum128, shuf);
503 let shuf2 = _mm_movehl_ps(shuf, sums);
504 _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
505 }
506
507 #[target_feature(enable = "avx512f,fma")]
508 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
509 let n = a.len().min(b.len());
510 let ap = a.as_ptr();
511 let bp = b.as_ptr();
512 let mut acc = _mm512_setzero_ps();
513 let chunks16 = n / 16;
514 for i in 0..chunks16 {
515 let base = i * 16;
516 acc = _mm512_fmadd_ps(
517 _mm512_loadu_ps(ap.add(base)),
518 _mm512_loadu_ps(bp.add(base)),
519 acc,
520 );
521 }
522 let mut sum = hsum512(acc);
523 for i in (chunks16 * 16)..n {
524 sum += *ap.add(i) * *bp.add(i);
525 }
526 sum
527 }
528
529 #[target_feature(enable = "avx512f,fma")]
530 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
531 let n = a.len().min(b.len());
532 let ap = a.as_ptr();
533 let bp = b.as_ptr();
534 let mut acc = _mm512_setzero_ps();
535 let chunks16 = n / 16;
536 for i in 0..chunks16 {
537 let base = i * 16;
538 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
539 acc = _mm512_fmadd_ps(d, d, acc);
540 }
541 let mut sum = hsum512(acc);
542 for i in (chunks16 * 16)..n {
543 let d = *ap.add(i) - *bp.add(i);
544 sum += d * d;
545 }
546 sum.sqrt()
547 }
548
549 #[target_feature(enable = "avx512f,fma")]
550 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
551 let n = a.len().min(b.len());
552 let ap = a.as_ptr();
553 let bp = b.as_ptr();
554 let mut dot_acc = _mm512_setzero_ps();
555 let mut na_acc = _mm512_setzero_ps();
556 let mut nb_acc = _mm512_setzero_ps();
557 let chunks16 = n / 16;
558 for i in 0..chunks16 {
559 let base = i * 16;
560 let av = _mm512_loadu_ps(ap.add(base));
561 let bv = _mm512_loadu_ps(bp.add(base));
562 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
563 na_acc = _mm512_fmadd_ps(av, av, na_acc);
564 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
565 }
566 let mut dot = hsum512(dot_acc);
567 let mut na2 = hsum512(na_acc);
568 let mut nb2 = hsum512(nb_acc);
569 for i in (chunks16 * 16)..n {
570 let ai = *ap.add(i);
571 let bi = *bp.add(i);
572 dot += ai * bi;
573 na2 += ai * ai;
574 nb2 += bi * bi;
575 }
576 let (na, nb) = (na2.sqrt(), nb2.sqrt());
577 if na == 0.0 || nb == 0.0 {
578 return 1.0;
579 }
580 1.0 - dot / (na * nb)
581 }
582
583 #[target_feature(enable = "avx512f,f16c,fma")]
585 pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
586 let n = a.len().min(b.len());
587 let ap = a.as_ptr();
588 let bp = b.as_ptr() as *const u16;
589 let mut acc = _mm512_setzero_ps();
590 let chunks16 = n / 16;
591 for i in 0..chunks16 {
592 let base = i * 16;
593 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
594 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
595 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
596 acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
597 }
598 let mut sum = hsum512(acc);
599 for i in (chunks16 * 16)..n {
600 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
601 }
602 sum
603 }
604
605 #[target_feature(enable = "avx512f,f16c,fma")]
606 pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
607 let n = a.len().min(b.len());
608 let ap = a.as_ptr();
609 let bp = b.as_ptr() as *const u16;
610 let mut acc = _mm512_setzero_ps();
611 let chunks16 = n / 16;
612 for i in 0..chunks16 {
613 let base = i * 16;
614 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
615 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
616 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
617 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
618 acc = _mm512_fmadd_ps(d, d, acc);
619 }
620 let mut sum = hsum512(acc);
621 for i in (chunks16 * 16)..n {
622 let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
623 sum += d * d;
624 }
625 sum.sqrt()
626 }
627
628 #[target_feature(enable = "avx512f,f16c,fma")]
629 pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
630 let n = a.len().min(b.len());
631 let ap = a.as_ptr();
632 let bp = b.as_ptr() as *const u16;
633 let mut dot_acc = _mm512_setzero_ps();
634 let mut na_acc = _mm512_setzero_ps();
635 let mut nb_acc = _mm512_setzero_ps();
636 let chunks16 = n / 16;
637 for i in 0..chunks16 {
638 let base = i * 16;
639 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
640 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
641 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
642 let av = _mm512_loadu_ps(ap.add(base));
643 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
644 na_acc = _mm512_fmadd_ps(av, av, na_acc);
645 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
646 }
647 let mut dot = hsum512(dot_acc);
648 let mut na2 = hsum512(na_acc);
649 let mut nb2 = hsum512(nb_acc);
650 for i in (chunks16 * 16)..n {
651 let ai = *ap.add(i);
652 let bi = f16::from_bits(*bp.add(i)).to_f32();
653 dot += ai * bi;
654 na2 += ai * ai;
655 nb2 += bi * bi;
656 }
657 let denom = (na2 * nb2).sqrt();
658 if denom < 1e-8 {
659 1.0
660 } else {
661 1.0 - dot / denom
662 }
663 }
664}
665
666#[cfg(target_arch = "aarch64")]
669mod neon_impl {
670 use std::arch::aarch64::*;
671
672 #[target_feature(enable = "neon")]
673 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
674 let n = a.len().min(b.len());
675 let mut acc = vdupq_n_f32(0.0);
676 let chunks = n / 4;
677 for i in 0..chunks {
678 let base = i * 4;
679 let av = vld1q_f32(a.as_ptr().add(base));
680 let bv = vld1q_f32(b.as_ptr().add(base));
681 acc = vmlaq_f32(acc, av, bv);
682 }
683 let mut sum = vaddvq_f32(acc);
684 for i in (chunks * 4)..n {
685 sum += a[i] * b[i];
686 }
687 sum
688 }
689
690 #[target_feature(enable = "neon")]
691 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
692 let n = a.len().min(b.len());
693 let mut acc = vdupq_n_f32(0.0);
694 let chunks = n / 4;
695 for i in 0..chunks {
696 let base = i * 4;
697 let d = vsubq_f32(
698 vld1q_f32(a.as_ptr().add(base)),
699 vld1q_f32(b.as_ptr().add(base)),
700 );
701 acc = vmlaq_f32(acc, d, d);
702 }
703 let mut sum = vaddvq_f32(acc);
704 for i in (chunks * 4)..n {
705 let d = a[i] - b[i];
706 sum += d * d;
707 }
708 sum.sqrt()
709 }
710
711 #[target_feature(enable = "neon")]
712 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
713 let n = a.len().min(b.len());
714 let mut dot_acc = vdupq_n_f32(0.0);
715 let mut na_acc = vdupq_n_f32(0.0);
716 let mut nb_acc = vdupq_n_f32(0.0);
717 let chunks = n / 4;
718 for i in 0..chunks {
719 let base = i * 4;
720 let av = vld1q_f32(a.as_ptr().add(base));
721 let bv = vld1q_f32(b.as_ptr().add(base));
722 dot_acc = vmlaq_f32(dot_acc, av, bv);
723 na_acc = vmlaq_f32(na_acc, av, av);
724 nb_acc = vmlaq_f32(nb_acc, bv, bv);
725 }
726 let mut dot = vaddvq_f32(dot_acc);
727 let mut na2 = vaddvq_f32(na_acc);
728 let mut nb2 = vaddvq_f32(nb_acc);
729 for i in (chunks * 4)..n {
730 dot += a[i] * b[i];
731 na2 += a[i] * a[i];
732 nb2 += b[i] * b[i];
733 }
734 let (na, nb) = (na2.sqrt(), nb2.sqrt());
735 if na == 0.0 || nb == 0.0 {
736 return 1.0;
737 }
738 1.0 - dot / (na * nb)
739 }
740}
741
742#[cfg(test)]
745mod tests {
746 use super::*;
747
748 #[test]
749 fn cosine_identical() {
750 let v = vec![1.0f32, 0.0, 0.0];
751 assert!(cosine_distance(&v, &v).abs() < 1e-5);
752 }
753
754 #[test]
755 fn cosine_orthogonal() {
756 assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
757 }
758
759 #[test]
760 fn euclidean_basic() {
761 assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
762 }
763
764 #[test]
765 fn dot_basic() {
766 assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
767 }
768
769 #[test]
770 fn simd_matches_scalar_dim128() {
771 use rand::{rngs::StdRng, Rng, SeedableRng};
772 let mut rng = StdRng::seed_from_u64(99);
773 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
774 let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
775
776 let dot_s = dot_scalar(&a, &b);
777 let euclid_s = euclidean_scalar(&a, &b);
778 let cos_s = cosine_scalar(&a, &b);
779
780 let dot_f = dot_product(&a, &b);
781 let euclid_f = euclidean_distance(&a, &b);
782 let cos_f = cosine_distance(&a, &b);
783
784 assert!(
785 (dot_f - dot_s).abs() < 1e-4,
786 "dot mismatch: {dot_f} vs {dot_s}"
787 );
788 assert!(
789 (euclid_f - euclid_s).abs() < 1e-4,
790 "euclidean mismatch: {euclid_f} vs {euclid_s}"
791 );
792 assert!(
793 (cos_f - cos_s).abs() < 1e-4,
794 "cosine mismatch: {cos_f} vs {cos_s}"
795 );
796 }
797
798 #[test]
799 fn f16_simd_matches_scalar() {
800 use rand::{rngs::StdRng, Rng, SeedableRng};
801 let mut rng = StdRng::seed_from_u64(42);
802 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
803 let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
804 let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
805
806 let dot_s = dot_f16_scalar(&a, &b);
807 let euclid_s = euclidean_f16_scalar(&a, &b);
808 let cos_s = cosine_f16_scalar(&a, &b);
809
810 let dot_f = dot_product_f16(&a, &b);
811 let euclid_f = euclidean_distance_f16(&a, &b);
812 let cos_f = cosine_distance_f16(&a, &b);
813
814 assert!(
816 (dot_f - dot_s).abs() < 1e-3,
817 "f16 dot mismatch: {dot_f} vs {dot_s}"
818 );
819 assert!(
820 (euclid_f - euclid_s).abs() < 1e-3,
821 "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
822 );
823 assert!(
824 (cos_f - cos_s).abs() < 1e-3,
825 "f16 cosine mismatch: {cos_f} vs {cos_s}"
826 );
827 }
828
829 #[test]
830 fn centroid_single() {
831 let v = vec![vec![1.0f32, 2.0, 3.0]];
832 let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
833 assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
834 assert!(c.radius < 1e-6, "radius={}", c.radius);
835 }
836
837 #[test]
838 fn centroid_two_points() {
839 let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
840 let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
841 assert!((c.values[0] - 1.0).abs() < 1e-6);
842 assert!(c.radius > 0.0);
843 }
844}