1use ailake_core::{Centroid, VectorMetric};
2use half::f16;
3
4pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
7 #[cfg(target_arch = "x86_64")]
8 {
9 #[cfg(feature = "avx512")]
10 if is_x86_feature_detected!("avx512f") {
11 return unsafe { avx512::dot(a, b) };
12 }
13 if is_x86_feature_detected!("avx2") {
14 return unsafe { avx2::dot(a, b) };
15 }
16 }
17 #[cfg(target_arch = "aarch64")]
18 if std::arch::is_aarch64_feature_detected!("neon") {
19 return unsafe { neon_impl::dot(a, b) };
20 }
21 dot_scalar(a, b)
22}
23
24pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
25 #[cfg(target_arch = "x86_64")]
26 {
27 #[cfg(feature = "avx512")]
28 if is_x86_feature_detected!("avx512f") {
29 return unsafe { avx512::euclidean(a, b) };
30 }
31 if is_x86_feature_detected!("avx2") {
32 return unsafe { avx2::euclidean(a, b) };
33 }
34 }
35 #[cfg(target_arch = "aarch64")]
36 if std::arch::is_aarch64_feature_detected!("neon") {
37 return unsafe { neon_impl::euclidean(a, b) };
38 }
39 euclidean_scalar(a, b)
40}
41
42pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
43 #[cfg(target_arch = "x86_64")]
44 {
45 #[cfg(feature = "avx512")]
46 if is_x86_feature_detected!("avx512f") {
47 return unsafe { avx512::cosine(a, b) };
48 }
49 if is_x86_feature_detected!("avx2") {
50 return unsafe { avx2::cosine(a, b) };
51 }
52 }
53 #[cfg(target_arch = "aarch64")]
54 if std::arch::is_aarch64_feature_detected!("neon") {
55 return unsafe { neon_impl::cosine(a, b) };
56 }
57 cosine_scalar(a, b)
58}
59
60pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
61 match metric {
62 VectorMetric::Cosine => cosine_distance(a, b),
63 VectorMetric::Euclidean => euclidean_distance(a, b),
64 VectorMetric::DotProduct => -dot_product(a, b),
65 }
66}
67
68pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
78 #[cfg(target_arch = "x86_64")]
79 {
80 #[cfg(feature = "avx512")]
81 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
82 return unsafe { avx512::cosine_f16(a, b) };
83 }
84 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
85 return unsafe { avx2_f16c::cosine(a, b) };
86 }
87 }
88 cosine_f16_scalar(a, b)
89}
90
91pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
92 #[cfg(target_arch = "x86_64")]
93 {
94 #[cfg(feature = "avx512")]
95 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
96 return unsafe { avx512::euclidean_f16(a, b) };
97 }
98 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
99 return unsafe { avx2_f16c::euclidean(a, b) };
100 }
101 }
102 euclidean_f16_scalar(a, b)
103}
104
105pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
106 #[cfg(target_arch = "x86_64")]
107 {
108 #[cfg(feature = "avx512")]
109 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
110 return unsafe { avx512::dot_f16(a, b) };
111 }
112 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
113 return unsafe { avx2_f16c::dot(a, b) };
114 }
115 }
116 dot_f16_scalar(a, b)
117}
118
119pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
120 if vectors.is_empty() {
121 return Centroid {
122 values: vec![],
123 radius: 0.0,
124 metric,
125 };
126 }
127 let dim = vectors[0].len();
128 let n = vectors.len() as f32;
129 let centroid: Vec<f32> = (0..dim)
130 .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
131 .collect();
132 let radius = vectors
133 .iter()
134 .map(|v| exact_distance(metric, ¢roid, v))
135 .fold(0.0_f32, f32::max);
136 Centroid {
137 values: centroid,
138 radius,
139 metric,
140 }
141}
142
143#[inline(always)]
146fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
147 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
148}
149
150#[inline(always)]
151fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
152 a.iter()
153 .zip(b.iter())
154 .map(|(x, y)| (x - y) * (x - y))
155 .sum::<f32>()
156 .sqrt()
157}
158
159#[inline(always)]
160fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
161 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
162 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
163 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
164 if na == 0.0 || nb == 0.0 {
165 return 1.0;
166 }
167 1.0 - dot / (na * nb)
168}
169
170#[inline(always)]
171fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
172 let n = a.len().min(b.len());
173 let mut dot = 0.0f32;
174 let mut norm_a = 0.0f32;
175 let mut norm_b = 0.0f32;
176 for i in 0..n {
177 let ai = a[i];
178 let bi = b[i].to_f32();
179 dot += ai * bi;
180 norm_a += ai * ai;
181 norm_b += bi * bi;
182 }
183 let denom = (norm_a * norm_b).sqrt();
184 if denom < 1e-8 {
185 1.0
186 } else {
187 1.0 - dot / denom
188 }
189}
190
191#[inline(always)]
192fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
193 let n = a.len().min(b.len());
194 let mut sum = 0.0f32;
195 for i in 0..n {
196 let diff = a[i] - b[i].to_f32();
197 sum += diff * diff;
198 }
199 sum.sqrt()
200}
201
202#[inline(always)]
203fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
204 let n = a.len().min(b.len());
205 let mut acc = 0.0f32;
206 for i in 0..n {
207 acc += a[i] * b[i].to_f32();
208 }
209 acc
210}
211
212#[cfg(target_arch = "x86_64")]
219mod avx2 {
220 use std::arch::x86_64::*;
221
222 #[inline(always)]
223 pub unsafe fn hsum256(v: __m256) -> f32 {
224 let hi = _mm256_extractf128_ps(v, 1);
225 let lo = _mm256_castps256_ps128(v);
226 let s = _mm_add_ps(lo, hi);
227 let shuf = _mm_movehdup_ps(s);
228 let sums = _mm_add_ps(s, shuf);
229 let shuf = _mm_movehl_ps(shuf, sums);
230 _mm_cvtss_f32(_mm_add_ss(sums, shuf))
231 }
232
233 #[target_feature(enable = "avx2,fma")]
235 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
236 let n = a.len().min(b.len());
237 let ap = a.as_ptr();
238 let bp = b.as_ptr();
239
240 let mut acc0 = _mm256_setzero_ps();
241 let mut acc1 = _mm256_setzero_ps();
242
243 let chunks16 = n / 16;
244 for i in 0..chunks16 {
245 let base = i * 16;
246 let a0 = _mm256_loadu_ps(ap.add(base));
247 let b0 = _mm256_loadu_ps(bp.add(base));
248 let a1 = _mm256_loadu_ps(ap.add(base + 8));
249 let b1 = _mm256_loadu_ps(bp.add(base + 8));
250 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
251 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
252 }
253
254 let chunks8 = n / 8;
255 if chunks8 > chunks16 * 2 {
256 let base = chunks16 * 16;
257 let a0 = _mm256_loadu_ps(ap.add(base));
258 let b0 = _mm256_loadu_ps(bp.add(base));
259 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
260 }
261
262 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
263 for i in (chunks8 * 8)..n {
264 sum += *ap.add(i) * *bp.add(i);
265 }
266 sum
267 }
268
269 #[target_feature(enable = "avx2,fma")]
271 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
272 let n = a.len().min(b.len());
273 let ap = a.as_ptr();
274 let bp = b.as_ptr();
275
276 let mut acc0 = _mm256_setzero_ps();
277 let mut acc1 = _mm256_setzero_ps();
278
279 let chunks16 = n / 16;
280 for i in 0..chunks16 {
281 let base = i * 16;
282 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
283 let d1 = _mm256_sub_ps(
284 _mm256_loadu_ps(ap.add(base + 8)),
285 _mm256_loadu_ps(bp.add(base + 8)),
286 );
287 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
288 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
289 }
290
291 let chunks8 = n / 8;
292 if chunks8 > chunks16 * 2 {
293 let base = chunks16 * 16;
294 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
295 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
296 }
297
298 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
299 for i in (chunks8 * 8)..n {
300 let d = *ap.add(i) - *bp.add(i);
301 sum += d * d;
302 }
303 sum.sqrt()
304 }
305
306 #[target_feature(enable = "avx2,fma")]
308 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
309 let n = a.len().min(b.len());
310 let ap = a.as_ptr();
311 let bp = b.as_ptr();
312
313 let mut dot_acc = _mm256_setzero_ps();
314 let mut na_acc = _mm256_setzero_ps();
315 let mut nb_acc = _mm256_setzero_ps();
316
317 let chunks8 = n / 8;
318 for i in 0..chunks8 {
319 let base = i * 8;
320 let av = _mm256_loadu_ps(ap.add(base));
321 let bv = _mm256_loadu_ps(bp.add(base));
322 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
323 na_acc = _mm256_fmadd_ps(av, av, na_acc);
324 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
325 }
326
327 let mut dot = hsum256(dot_acc);
328 let mut na2 = hsum256(na_acc);
329 let mut nb2 = hsum256(nb_acc);
330
331 for i in (chunks8 * 8)..n {
332 let ai = *ap.add(i);
333 let bi = *bp.add(i);
334 dot += ai * bi;
335 na2 += ai * ai;
336 nb2 += bi * bi;
337 }
338
339 let na = na2.sqrt();
340 let nb = nb2.sqrt();
341 if na == 0.0 || nb == 0.0 {
342 return 1.0;
343 }
344 1.0 - dot / (na * nb)
345 }
346}
347
348#[cfg(target_arch = "x86_64")]
355mod avx2_f16c {
356 use half::f16;
357 use std::arch::x86_64::*;
358
359 use super::avx2::hsum256;
360
361 #[target_feature(enable = "avx2,f16c,fma")]
363 pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
364 let n = a.len().min(b.len());
365 let ap = a.as_ptr();
366 let bp = b.as_ptr() as *const u16;
367
368 let mut acc0 = _mm256_setzero_ps();
369 let mut acc1 = _mm256_setzero_ps();
370
371 let chunks16 = n / 16;
372 for i in 0..chunks16 {
373 let base = i * 16;
374 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
375 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
376 let a0 = _mm256_loadu_ps(ap.add(base));
377 let a1 = _mm256_loadu_ps(ap.add(base + 8));
378 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
379 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
380 }
381
382 let chunks8 = n / 8;
383 if chunks8 > chunks16 * 2 {
384 let base = chunks16 * 16;
385 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
386 let a0 = _mm256_loadu_ps(ap.add(base));
387 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
388 }
389
390 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
391 for i in (chunks8 * 8)..n {
392 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
393 }
394 sum
395 }
396
397 #[target_feature(enable = "avx2,f16c,fma")]
399 pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
400 let n = a.len().min(b.len());
401 let ap = a.as_ptr();
402 let bp = b.as_ptr() as *const u16;
403
404 let mut acc0 = _mm256_setzero_ps();
405 let mut acc1 = _mm256_setzero_ps();
406
407 let chunks16 = n / 16;
408 for i in 0..chunks16 {
409 let base = i * 16;
410 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
411 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
412 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
413 let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
414 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
415 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
416 }
417
418 let chunks8 = n / 8;
419 if chunks8 > chunks16 * 2 {
420 let base = chunks16 * 16;
421 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
422 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
423 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
424 }
425
426 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
427 for i in (chunks8 * 8)..n {
428 let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
429 sum += diff * diff;
430 }
431 sum.sqrt()
432 }
433
434 #[target_feature(enable = "avx2,f16c,fma")]
436 pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
437 let n = a.len().min(b.len());
438 let ap = a.as_ptr();
439 let bp = b.as_ptr() as *const u16;
440
441 let mut dot_acc = _mm256_setzero_ps();
442 let mut na_acc = _mm256_setzero_ps();
443 let mut nb_acc = _mm256_setzero_ps();
444
445 let chunks8 = n / 8;
446 for i in 0..chunks8 {
447 let base = i * 8;
448 let av = _mm256_loadu_ps(ap.add(base));
449 let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
450 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
451 na_acc = _mm256_fmadd_ps(av, av, na_acc);
452 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
453 }
454
455 let mut dot = hsum256(dot_acc);
456 let mut na2 = hsum256(na_acc);
457 let mut nb2 = hsum256(nb_acc);
458
459 for i in (chunks8 * 8)..n {
460 let ai = *ap.add(i);
461 let bi = f16::from_bits(*bp.add(i)).to_f32();
462 dot += ai * bi;
463 na2 += ai * ai;
464 nb2 += bi * bi;
465 }
466
467 let denom = (na2 * nb2).sqrt();
468 if denom < 1e-8 {
469 1.0
470 } else {
471 1.0 - dot / denom
472 }
473 }
474}
475
476#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
484mod avx512 {
485 use half::f16;
486 use std::arch::x86_64::*;
487
488 #[inline(always)]
489 unsafe fn hsum512(v: __m512) -> f32 {
490 let mut buf = [0.0f32; 16];
493 _mm512_storeu_ps(buf.as_mut_ptr(), v);
494 let lo = _mm256_loadu_ps(buf.as_ptr());
495 let hi = _mm256_loadu_ps(buf.as_ptr().add(8));
496 let sum256 = _mm256_add_ps(lo, hi);
497 let hi128 = _mm256_extractf128_ps(sum256, 1);
498 let lo128 = _mm256_castps256_ps128(sum256);
499 let sum128 = _mm_add_ps(lo128, hi128);
500 let shuf = _mm_movehdup_ps(sum128);
501 let sums = _mm_add_ps(sum128, shuf);
502 let shuf2 = _mm_movehl_ps(shuf, sums);
503 _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
504 }
505
506 #[target_feature(enable = "avx512f,fma")]
507 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
508 let n = a.len().min(b.len());
509 let ap = a.as_ptr();
510 let bp = b.as_ptr();
511 let mut acc = _mm512_setzero_ps();
512 let chunks16 = n / 16;
513 for i in 0..chunks16 {
514 let base = i * 16;
515 acc = _mm512_fmadd_ps(
516 _mm512_loadu_ps(ap.add(base)),
517 _mm512_loadu_ps(bp.add(base)),
518 acc,
519 );
520 }
521 let mut sum = hsum512(acc);
522 for i in (chunks16 * 16)..n {
523 sum += *ap.add(i) * *bp.add(i);
524 }
525 sum
526 }
527
528 #[target_feature(enable = "avx512f,fma")]
529 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
530 let n = a.len().min(b.len());
531 let ap = a.as_ptr();
532 let bp = b.as_ptr();
533 let mut acc = _mm512_setzero_ps();
534 let chunks16 = n / 16;
535 for i in 0..chunks16 {
536 let base = i * 16;
537 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
538 acc = _mm512_fmadd_ps(d, d, acc);
539 }
540 let mut sum = hsum512(acc);
541 for i in (chunks16 * 16)..n {
542 let d = *ap.add(i) - *bp.add(i);
543 sum += d * d;
544 }
545 sum.sqrt()
546 }
547
548 #[target_feature(enable = "avx512f,fma")]
549 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
550 let n = a.len().min(b.len());
551 let ap = a.as_ptr();
552 let bp = b.as_ptr();
553 let mut dot_acc = _mm512_setzero_ps();
554 let mut na_acc = _mm512_setzero_ps();
555 let mut nb_acc = _mm512_setzero_ps();
556 let chunks16 = n / 16;
557 for i in 0..chunks16 {
558 let base = i * 16;
559 let av = _mm512_loadu_ps(ap.add(base));
560 let bv = _mm512_loadu_ps(bp.add(base));
561 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
562 na_acc = _mm512_fmadd_ps(av, av, na_acc);
563 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
564 }
565 let mut dot = hsum512(dot_acc);
566 let mut na2 = hsum512(na_acc);
567 let mut nb2 = hsum512(nb_acc);
568 for i in (chunks16 * 16)..n {
569 let ai = *ap.add(i);
570 let bi = *bp.add(i);
571 dot += ai * bi;
572 na2 += ai * ai;
573 nb2 += bi * bi;
574 }
575 let (na, nb) = (na2.sqrt(), nb2.sqrt());
576 if na == 0.0 || nb == 0.0 {
577 return 1.0;
578 }
579 1.0 - dot / (na * nb)
580 }
581
582 #[target_feature(enable = "avx512f,f16c,fma")]
584 pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
585 let n = a.len().min(b.len());
586 let ap = a.as_ptr();
587 let bp = b.as_ptr() as *const u16;
588 let mut acc = _mm512_setzero_ps();
589 let chunks16 = n / 16;
590 for i in 0..chunks16 {
591 let base = i * 16;
592 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
593 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
594 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
595 acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
596 }
597 let mut sum = hsum512(acc);
598 for i in (chunks16 * 16)..n {
599 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
600 }
601 sum
602 }
603
604 #[target_feature(enable = "avx512f,f16c,fma")]
605 pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
606 let n = a.len().min(b.len());
607 let ap = a.as_ptr();
608 let bp = b.as_ptr() as *const u16;
609 let mut acc = _mm512_setzero_ps();
610 let chunks16 = n / 16;
611 for i in 0..chunks16 {
612 let base = i * 16;
613 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
614 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
615 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
616 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
617 acc = _mm512_fmadd_ps(d, d, acc);
618 }
619 let mut sum = hsum512(acc);
620 for i in (chunks16 * 16)..n {
621 let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
622 sum += d * d;
623 }
624 sum.sqrt()
625 }
626
627 #[target_feature(enable = "avx512f,f16c,fma")]
628 pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
629 let n = a.len().min(b.len());
630 let ap = a.as_ptr();
631 let bp = b.as_ptr() as *const u16;
632 let mut dot_acc = _mm512_setzero_ps();
633 let mut na_acc = _mm512_setzero_ps();
634 let mut nb_acc = _mm512_setzero_ps();
635 let chunks16 = n / 16;
636 for i in 0..chunks16 {
637 let base = i * 16;
638 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
639 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
640 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
641 let av = _mm512_loadu_ps(ap.add(base));
642 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
643 na_acc = _mm512_fmadd_ps(av, av, na_acc);
644 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
645 }
646 let mut dot = hsum512(dot_acc);
647 let mut na2 = hsum512(na_acc);
648 let mut nb2 = hsum512(nb_acc);
649 for i in (chunks16 * 16)..n {
650 let ai = *ap.add(i);
651 let bi = f16::from_bits(*bp.add(i)).to_f32();
652 dot += ai * bi;
653 na2 += ai * ai;
654 nb2 += bi * bi;
655 }
656 let denom = (na2 * nb2).sqrt();
657 if denom < 1e-8 {
658 1.0
659 } else {
660 1.0 - dot / denom
661 }
662 }
663}
664
665#[cfg(target_arch = "aarch64")]
668mod neon_impl {
669 use std::arch::aarch64::*;
670
671 #[target_feature(enable = "neon")]
672 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
673 let n = a.len().min(b.len());
674 let mut acc = vdupq_n_f32(0.0);
675 let chunks = n / 4;
676 for i in 0..chunks {
677 let base = i * 4;
678 let av = vld1q_f32(a.as_ptr().add(base));
679 let bv = vld1q_f32(b.as_ptr().add(base));
680 acc = vmlaq_f32(acc, av, bv);
681 }
682 let mut sum = vaddvq_f32(acc);
683 for i in (chunks * 4)..n {
684 sum += a[i] * b[i];
685 }
686 sum
687 }
688
689 #[target_feature(enable = "neon")]
690 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
691 let n = a.len().min(b.len());
692 let mut acc = vdupq_n_f32(0.0);
693 let chunks = n / 4;
694 for i in 0..chunks {
695 let base = i * 4;
696 let d = vsubq_f32(
697 vld1q_f32(a.as_ptr().add(base)),
698 vld1q_f32(b.as_ptr().add(base)),
699 );
700 acc = vmlaq_f32(acc, d, d);
701 }
702 let mut sum = vaddvq_f32(acc);
703 for i in (chunks * 4)..n {
704 let d = a[i] - b[i];
705 sum += d * d;
706 }
707 sum.sqrt()
708 }
709
710 #[target_feature(enable = "neon")]
711 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
712 let n = a.len().min(b.len());
713 let mut dot_acc = vdupq_n_f32(0.0);
714 let mut na_acc = vdupq_n_f32(0.0);
715 let mut nb_acc = vdupq_n_f32(0.0);
716 let chunks = n / 4;
717 for i in 0..chunks {
718 let base = i * 4;
719 let av = vld1q_f32(a.as_ptr().add(base));
720 let bv = vld1q_f32(b.as_ptr().add(base));
721 dot_acc = vmlaq_f32(dot_acc, av, bv);
722 na_acc = vmlaq_f32(na_acc, av, av);
723 nb_acc = vmlaq_f32(nb_acc, bv, bv);
724 }
725 let mut dot = vaddvq_f32(dot_acc);
726 let mut na2 = vaddvq_f32(na_acc);
727 let mut nb2 = vaddvq_f32(nb_acc);
728 for i in (chunks * 4)..n {
729 dot += a[i] * b[i];
730 na2 += a[i] * a[i];
731 nb2 += b[i] * b[i];
732 }
733 let (na, nb) = (na2.sqrt(), nb2.sqrt());
734 if na == 0.0 || nb == 0.0 {
735 return 1.0;
736 }
737 1.0 - dot / (na * nb)
738 }
739}
740
741#[cfg(test)]
744mod tests {
745 use super::*;
746
747 #[test]
748 fn cosine_identical() {
749 let v = vec![1.0f32, 0.0, 0.0];
750 assert!(cosine_distance(&v, &v).abs() < 1e-5);
751 }
752
753 #[test]
754 fn cosine_orthogonal() {
755 assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
756 }
757
758 #[test]
759 fn euclidean_basic() {
760 assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
761 }
762
763 #[test]
764 fn dot_basic() {
765 assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
766 }
767
768 #[test]
769 fn simd_matches_scalar_dim128() {
770 use rand::{rngs::StdRng, Rng, SeedableRng};
771 let mut rng = StdRng::seed_from_u64(99);
772 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
773 let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
774
775 let dot_s = dot_scalar(&a, &b);
776 let euclid_s = euclidean_scalar(&a, &b);
777 let cos_s = cosine_scalar(&a, &b);
778
779 let dot_f = dot_product(&a, &b);
780 let euclid_f = euclidean_distance(&a, &b);
781 let cos_f = cosine_distance(&a, &b);
782
783 assert!(
784 (dot_f - dot_s).abs() < 1e-4,
785 "dot mismatch: {dot_f} vs {dot_s}"
786 );
787 assert!(
788 (euclid_f - euclid_s).abs() < 1e-4,
789 "euclidean mismatch: {euclid_f} vs {euclid_s}"
790 );
791 assert!(
792 (cos_f - cos_s).abs() < 1e-4,
793 "cosine mismatch: {cos_f} vs {cos_s}"
794 );
795 }
796
797 #[test]
798 fn f16_simd_matches_scalar() {
799 use rand::{rngs::StdRng, Rng, SeedableRng};
800 let mut rng = StdRng::seed_from_u64(42);
801 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
802 let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
803 let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
804
805 let dot_s = dot_f16_scalar(&a, &b);
806 let euclid_s = euclidean_f16_scalar(&a, &b);
807 let cos_s = cosine_f16_scalar(&a, &b);
808
809 let dot_f = dot_product_f16(&a, &b);
810 let euclid_f = euclidean_distance_f16(&a, &b);
811 let cos_f = cosine_distance_f16(&a, &b);
812
813 assert!(
815 (dot_f - dot_s).abs() < 1e-3,
816 "f16 dot mismatch: {dot_f} vs {dot_s}"
817 );
818 assert!(
819 (euclid_f - euclid_s).abs() < 1e-3,
820 "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
821 );
822 assert!(
823 (cos_f - cos_s).abs() < 1e-3,
824 "f16 cosine mismatch: {cos_f} vs {cos_s}"
825 );
826 }
827
828 #[test]
829 fn centroid_single() {
830 let v = vec![vec![1.0f32, 2.0, 3.0]];
831 let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
832 assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
833 assert!(c.radius < 1e-6, "radius={}", c.radius);
834 }
835
836 #[test]
837 fn centroid_two_points() {
838 let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
839 let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
840 assert!((c.values[0] - 1.0).abs() < 1e-6);
841 assert!(c.radius > 0.0);
842 }
843}