1use ailake_core::{Centroid, VectorMetric};
2use half::f16;
3
4pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
7 #[cfg(target_arch = "x86_64")]
8 {
9 if is_x86_feature_detected!("avx512f") {
10 return unsafe { avx512::dot(a, b) };
11 }
12 if is_x86_feature_detected!("avx2") {
13 return unsafe { avx2::dot(a, b) };
14 }
15 }
16 #[cfg(target_arch = "aarch64")]
17 if std::arch::is_aarch64_feature_detected!("neon") {
18 return unsafe { neon_impl::dot(a, b) };
19 }
20 dot_scalar(a, b)
21}
22
23pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
24 #[cfg(target_arch = "x86_64")]
25 {
26 if is_x86_feature_detected!("avx512f") {
27 return unsafe { avx512::euclidean(a, b) };
28 }
29 if is_x86_feature_detected!("avx2") {
30 return unsafe { avx2::euclidean(a, b) };
31 }
32 }
33 #[cfg(target_arch = "aarch64")]
34 if std::arch::is_aarch64_feature_detected!("neon") {
35 return unsafe { neon_impl::euclidean(a, b) };
36 }
37 euclidean_scalar(a, b)
38}
39
40pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
41 #[cfg(target_arch = "x86_64")]
42 {
43 if is_x86_feature_detected!("avx512f") {
44 return unsafe { avx512::cosine(a, b) };
45 }
46 if is_x86_feature_detected!("avx2") {
47 return unsafe { avx2::cosine(a, b) };
48 }
49 }
50 #[cfg(target_arch = "aarch64")]
51 if std::arch::is_aarch64_feature_detected!("neon") {
52 return unsafe { neon_impl::cosine(a, b) };
53 }
54 cosine_scalar(a, b)
55}
56
57pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
58 match metric {
59 VectorMetric::Cosine => cosine_distance(a, b),
60 VectorMetric::Euclidean => euclidean_distance(a, b),
61 VectorMetric::DotProduct => -dot_product(a, b),
62 }
63}
64
65pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
75 #[cfg(target_arch = "x86_64")]
76 {
77 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
78 return unsafe { avx512::cosine_f16(a, b) };
79 }
80 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
81 return unsafe { avx2_f16c::cosine(a, b) };
82 }
83 }
84 cosine_f16_scalar(a, b)
85}
86
87pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
88 #[cfg(target_arch = "x86_64")]
89 {
90 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
91 return unsafe { avx512::euclidean_f16(a, b) };
92 }
93 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
94 return unsafe { avx2_f16c::euclidean(a, b) };
95 }
96 }
97 euclidean_f16_scalar(a, b)
98}
99
100pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
101 #[cfg(target_arch = "x86_64")]
102 {
103 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
104 return unsafe { avx512::dot_f16(a, b) };
105 }
106 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
107 return unsafe { avx2_f16c::dot(a, b) };
108 }
109 }
110 dot_f16_scalar(a, b)
111}
112
113pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
114 if vectors.is_empty() {
115 return Centroid {
116 values: vec![],
117 radius: 0.0,
118 metric,
119 };
120 }
121 let dim = vectors[0].len();
122 let n = vectors.len() as f32;
123 let centroid: Vec<f32> = (0..dim)
124 .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
125 .collect();
126 let radius = vectors
127 .iter()
128 .map(|v| exact_distance(metric, ¢roid, v))
129 .fold(0.0_f32, f32::max);
130 Centroid {
131 values: centroid,
132 radius,
133 metric,
134 }
135}
136
137#[inline(always)]
140fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
141 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
142}
143
144#[inline(always)]
145fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
146 a.iter()
147 .zip(b.iter())
148 .map(|(x, y)| (x - y) * (x - y))
149 .sum::<f32>()
150 .sqrt()
151}
152
153#[inline(always)]
154fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
155 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
156 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
157 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
158 if na == 0.0 || nb == 0.0 {
159 return 1.0;
160 }
161 1.0 - dot / (na * nb)
162}
163
164#[inline(always)]
165fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
166 let n = a.len().min(b.len());
167 let mut dot = 0.0f32;
168 let mut norm_a = 0.0f32;
169 let mut norm_b = 0.0f32;
170 for i in 0..n {
171 let ai = a[i];
172 let bi = b[i].to_f32();
173 dot += ai * bi;
174 norm_a += ai * ai;
175 norm_b += bi * bi;
176 }
177 let denom = (norm_a * norm_b).sqrt();
178 if denom < 1e-8 {
179 1.0
180 } else {
181 1.0 - dot / denom
182 }
183}
184
185#[inline(always)]
186fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
187 let n = a.len().min(b.len());
188 let mut sum = 0.0f32;
189 for i in 0..n {
190 let diff = a[i] - b[i].to_f32();
191 sum += diff * diff;
192 }
193 sum.sqrt()
194}
195
196#[inline(always)]
197fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
198 let n = a.len().min(b.len());
199 let mut acc = 0.0f32;
200 for i in 0..n {
201 acc += a[i] * b[i].to_f32();
202 }
203 acc
204}
205
206#[cfg(target_arch = "x86_64")]
213mod avx2 {
214 use std::arch::x86_64::*;
215
216 #[inline(always)]
217 pub unsafe fn hsum256(v: __m256) -> f32 {
218 let hi = _mm256_extractf128_ps(v, 1);
219 let lo = _mm256_castps256_ps128(v);
220 let s = _mm_add_ps(lo, hi);
221 let shuf = _mm_movehdup_ps(s);
222 let sums = _mm_add_ps(s, shuf);
223 let shuf = _mm_movehl_ps(shuf, sums);
224 _mm_cvtss_f32(_mm_add_ss(sums, shuf))
225 }
226
227 #[target_feature(enable = "avx2,fma")]
229 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
230 let n = a.len().min(b.len());
231 let ap = a.as_ptr();
232 let bp = b.as_ptr();
233
234 let mut acc0 = _mm256_setzero_ps();
235 let mut acc1 = _mm256_setzero_ps();
236
237 let chunks16 = n / 16;
238 for i in 0..chunks16 {
239 let base = i * 16;
240 let a0 = _mm256_loadu_ps(ap.add(base));
241 let b0 = _mm256_loadu_ps(bp.add(base));
242 let a1 = _mm256_loadu_ps(ap.add(base + 8));
243 let b1 = _mm256_loadu_ps(bp.add(base + 8));
244 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
245 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
246 }
247
248 let chunks8 = n / 8;
249 if chunks8 > chunks16 * 2 {
250 let base = chunks16 * 16;
251 let a0 = _mm256_loadu_ps(ap.add(base));
252 let b0 = _mm256_loadu_ps(bp.add(base));
253 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
254 }
255
256 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
257 for i in (chunks8 * 8)..n {
258 sum += *ap.add(i) * *bp.add(i);
259 }
260 sum
261 }
262
263 #[target_feature(enable = "avx2,fma")]
265 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
266 let n = a.len().min(b.len());
267 let ap = a.as_ptr();
268 let bp = b.as_ptr();
269
270 let mut acc0 = _mm256_setzero_ps();
271 let mut acc1 = _mm256_setzero_ps();
272
273 let chunks16 = n / 16;
274 for i in 0..chunks16 {
275 let base = i * 16;
276 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
277 let d1 = _mm256_sub_ps(
278 _mm256_loadu_ps(ap.add(base + 8)),
279 _mm256_loadu_ps(bp.add(base + 8)),
280 );
281 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
282 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
283 }
284
285 let chunks8 = n / 8;
286 if chunks8 > chunks16 * 2 {
287 let base = chunks16 * 16;
288 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
289 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
290 }
291
292 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
293 for i in (chunks8 * 8)..n {
294 let d = *ap.add(i) - *bp.add(i);
295 sum += d * d;
296 }
297 sum.sqrt()
298 }
299
300 #[target_feature(enable = "avx2,fma")]
302 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
303 let n = a.len().min(b.len());
304 let ap = a.as_ptr();
305 let bp = b.as_ptr();
306
307 let mut dot_acc = _mm256_setzero_ps();
308 let mut na_acc = _mm256_setzero_ps();
309 let mut nb_acc = _mm256_setzero_ps();
310
311 let chunks8 = n / 8;
312 for i in 0..chunks8 {
313 let base = i * 8;
314 let av = _mm256_loadu_ps(ap.add(base));
315 let bv = _mm256_loadu_ps(bp.add(base));
316 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
317 na_acc = _mm256_fmadd_ps(av, av, na_acc);
318 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
319 }
320
321 let mut dot = hsum256(dot_acc);
322 let mut na2 = hsum256(na_acc);
323 let mut nb2 = hsum256(nb_acc);
324
325 for i in (chunks8 * 8)..n {
326 let ai = *ap.add(i);
327 let bi = *bp.add(i);
328 dot += ai * bi;
329 na2 += ai * ai;
330 nb2 += bi * bi;
331 }
332
333 let na = na2.sqrt();
334 let nb = nb2.sqrt();
335 if na == 0.0 || nb == 0.0 {
336 return 1.0;
337 }
338 1.0 - dot / (na * nb)
339 }
340}
341
342#[cfg(target_arch = "x86_64")]
349mod avx2_f16c {
350 use half::f16;
351 use std::arch::x86_64::*;
352
353 use super::avx2::hsum256;
354
355 #[target_feature(enable = "avx2,f16c,fma")]
357 pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
358 let n = a.len().min(b.len());
359 let ap = a.as_ptr();
360 let bp = b.as_ptr() as *const u16;
361
362 let mut acc0 = _mm256_setzero_ps();
363 let mut acc1 = _mm256_setzero_ps();
364
365 let chunks16 = n / 16;
366 for i in 0..chunks16 {
367 let base = i * 16;
368 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
369 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
370 let a0 = _mm256_loadu_ps(ap.add(base));
371 let a1 = _mm256_loadu_ps(ap.add(base + 8));
372 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
373 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
374 }
375
376 let chunks8 = n / 8;
377 if chunks8 > chunks16 * 2 {
378 let base = chunks16 * 16;
379 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
380 let a0 = _mm256_loadu_ps(ap.add(base));
381 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
382 }
383
384 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
385 for i in (chunks8 * 8)..n {
386 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
387 }
388 sum
389 }
390
391 #[target_feature(enable = "avx2,f16c,fma")]
393 pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
394 let n = a.len().min(b.len());
395 let ap = a.as_ptr();
396 let bp = b.as_ptr() as *const u16;
397
398 let mut acc0 = _mm256_setzero_ps();
399 let mut acc1 = _mm256_setzero_ps();
400
401 let chunks16 = n / 16;
402 for i in 0..chunks16 {
403 let base = i * 16;
404 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
405 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
406 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
407 let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
408 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
409 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
410 }
411
412 let chunks8 = n / 8;
413 if chunks8 > chunks16 * 2 {
414 let base = chunks16 * 16;
415 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
416 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
417 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
418 }
419
420 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
421 for i in (chunks8 * 8)..n {
422 let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
423 sum += diff * diff;
424 }
425 sum.sqrt()
426 }
427
428 #[target_feature(enable = "avx2,f16c,fma")]
430 pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
431 let n = a.len().min(b.len());
432 let ap = a.as_ptr();
433 let bp = b.as_ptr() as *const u16;
434
435 let mut dot_acc = _mm256_setzero_ps();
436 let mut na_acc = _mm256_setzero_ps();
437 let mut nb_acc = _mm256_setzero_ps();
438
439 let chunks8 = n / 8;
440 for i in 0..chunks8 {
441 let base = i * 8;
442 let av = _mm256_loadu_ps(ap.add(base));
443 let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
444 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
445 na_acc = _mm256_fmadd_ps(av, av, na_acc);
446 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
447 }
448
449 let mut dot = hsum256(dot_acc);
450 let mut na2 = hsum256(na_acc);
451 let mut nb2 = hsum256(nb_acc);
452
453 for i in (chunks8 * 8)..n {
454 let ai = *ap.add(i);
455 let bi = f16::from_bits(*bp.add(i)).to_f32();
456 dot += ai * bi;
457 na2 += ai * ai;
458 nb2 += bi * bi;
459 }
460
461 let denom = (na2 * nb2).sqrt();
462 if denom < 1e-8 {
463 1.0
464 } else {
465 1.0 - dot / denom
466 }
467 }
468}
469
470#[cfg(target_arch = "x86_64")]
476mod avx512 {
477 use half::f16;
478 use std::arch::x86_64::*;
479
480 #[inline(always)]
481 unsafe fn hsum512(v: __m512) -> f32 {
482 _mm512_reduce_add_ps(v)
483 }
484
485 #[target_feature(enable = "avx512f,fma")]
486 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
487 let n = a.len().min(b.len());
488 let ap = a.as_ptr();
489 let bp = b.as_ptr();
490 let mut acc = _mm512_setzero_ps();
491 let chunks16 = n / 16;
492 for i in 0..chunks16 {
493 let base = i * 16;
494 acc = _mm512_fmadd_ps(
495 _mm512_loadu_ps(ap.add(base)),
496 _mm512_loadu_ps(bp.add(base)),
497 acc,
498 );
499 }
500 let mut sum = hsum512(acc);
501 for i in (chunks16 * 16)..n {
502 sum += *ap.add(i) * *bp.add(i);
503 }
504 sum
505 }
506
507 #[target_feature(enable = "avx512f,fma")]
508 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
509 let n = a.len().min(b.len());
510 let ap = a.as_ptr();
511 let bp = b.as_ptr();
512 let mut acc = _mm512_setzero_ps();
513 let chunks16 = n / 16;
514 for i in 0..chunks16 {
515 let base = i * 16;
516 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
517 acc = _mm512_fmadd_ps(d, d, acc);
518 }
519 let mut sum = hsum512(acc);
520 for i in (chunks16 * 16)..n {
521 let d = *ap.add(i) - *bp.add(i);
522 sum += d * d;
523 }
524 sum.sqrt()
525 }
526
527 #[target_feature(enable = "avx512f,fma")]
528 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
529 let n = a.len().min(b.len());
530 let ap = a.as_ptr();
531 let bp = b.as_ptr();
532 let mut dot_acc = _mm512_setzero_ps();
533 let mut na_acc = _mm512_setzero_ps();
534 let mut nb_acc = _mm512_setzero_ps();
535 let chunks16 = n / 16;
536 for i in 0..chunks16 {
537 let base = i * 16;
538 let av = _mm512_loadu_ps(ap.add(base));
539 let bv = _mm512_loadu_ps(bp.add(base));
540 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
541 na_acc = _mm512_fmadd_ps(av, av, na_acc);
542 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
543 }
544 let mut dot = hsum512(dot_acc);
545 let mut na2 = hsum512(na_acc);
546 let mut nb2 = hsum512(nb_acc);
547 for i in (chunks16 * 16)..n {
548 let ai = *ap.add(i);
549 let bi = *bp.add(i);
550 dot += ai * bi;
551 na2 += ai * ai;
552 nb2 += bi * bi;
553 }
554 let (na, nb) = (na2.sqrt(), nb2.sqrt());
555 if na == 0.0 || nb == 0.0 {
556 return 1.0;
557 }
558 1.0 - dot / (na * nb)
559 }
560
561 #[target_feature(enable = "avx512f,f16c,fma")]
563 pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
564 let n = a.len().min(b.len());
565 let ap = a.as_ptr();
566 let bp = b.as_ptr() as *const u16;
567 let mut acc = _mm512_setzero_ps();
568 let chunks16 = n / 16;
569 for i in 0..chunks16 {
570 let base = i * 16;
571 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
572 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
573 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
574 acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
575 }
576 let mut sum = hsum512(acc);
577 for i in (chunks16 * 16)..n {
578 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
579 }
580 sum
581 }
582
583 #[target_feature(enable = "avx512f,f16c,fma")]
584 pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
585 let n = a.len().min(b.len());
586 let ap = a.as_ptr();
587 let bp = b.as_ptr() as *const u16;
588 let mut acc = _mm512_setzero_ps();
589 let chunks16 = n / 16;
590 for i in 0..chunks16 {
591 let base = i * 16;
592 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
593 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
594 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
595 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
596 acc = _mm512_fmadd_ps(d, d, acc);
597 }
598 let mut sum = hsum512(acc);
599 for i in (chunks16 * 16)..n {
600 let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
601 sum += d * d;
602 }
603 sum.sqrt()
604 }
605
606 #[target_feature(enable = "avx512f,f16c,fma")]
607 pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
608 let n = a.len().min(b.len());
609 let ap = a.as_ptr();
610 let bp = b.as_ptr() as *const u16;
611 let mut dot_acc = _mm512_setzero_ps();
612 let mut na_acc = _mm512_setzero_ps();
613 let mut nb_acc = _mm512_setzero_ps();
614 let chunks16 = n / 16;
615 for i in 0..chunks16 {
616 let base = i * 16;
617 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
618 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
619 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
620 let av = _mm512_loadu_ps(ap.add(base));
621 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
622 na_acc = _mm512_fmadd_ps(av, av, na_acc);
623 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
624 }
625 let mut dot = hsum512(dot_acc);
626 let mut na2 = hsum512(na_acc);
627 let mut nb2 = hsum512(nb_acc);
628 for i in (chunks16 * 16)..n {
629 let ai = *ap.add(i);
630 let bi = f16::from_bits(*bp.add(i)).to_f32();
631 dot += ai * bi;
632 na2 += ai * ai;
633 nb2 += bi * bi;
634 }
635 let denom = (na2 * nb2).sqrt();
636 if denom < 1e-8 {
637 1.0
638 } else {
639 1.0 - dot / denom
640 }
641 }
642}
643
644#[cfg(target_arch = "aarch64")]
647mod neon_impl {
648 use std::arch::aarch64::*;
649
650 #[target_feature(enable = "neon")]
651 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
652 let n = a.len().min(b.len());
653 let mut acc = vdupq_n_f32(0.0);
654 let chunks = n / 4;
655 for i in 0..chunks {
656 let base = i * 4;
657 let av = vld1q_f32(a.as_ptr().add(base));
658 let bv = vld1q_f32(b.as_ptr().add(base));
659 acc = vmlaq_f32(acc, av, bv);
660 }
661 let mut sum = vaddvq_f32(acc);
662 for i in (chunks * 4)..n {
663 sum += a[i] * b[i];
664 }
665 sum
666 }
667
668 #[target_feature(enable = "neon")]
669 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
670 let n = a.len().min(b.len());
671 let mut acc = vdupq_n_f32(0.0);
672 let chunks = n / 4;
673 for i in 0..chunks {
674 let base = i * 4;
675 let d = vsubq_f32(
676 vld1q_f32(a.as_ptr().add(base)),
677 vld1q_f32(b.as_ptr().add(base)),
678 );
679 acc = vmlaq_f32(acc, d, d);
680 }
681 let mut sum = vaddvq_f32(acc);
682 for i in (chunks * 4)..n {
683 let d = a[i] - b[i];
684 sum += d * d;
685 }
686 sum.sqrt()
687 }
688
689 #[target_feature(enable = "neon")]
690 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
691 let n = a.len().min(b.len());
692 let mut dot_acc = vdupq_n_f32(0.0);
693 let mut na_acc = vdupq_n_f32(0.0);
694 let mut nb_acc = vdupq_n_f32(0.0);
695 let chunks = n / 4;
696 for i in 0..chunks {
697 let base = i * 4;
698 let av = vld1q_f32(a.as_ptr().add(base));
699 let bv = vld1q_f32(b.as_ptr().add(base));
700 dot_acc = vmlaq_f32(dot_acc, av, bv);
701 na_acc = vmlaq_f32(na_acc, av, av);
702 nb_acc = vmlaq_f32(nb_acc, bv, bv);
703 }
704 let mut dot = vaddvq_f32(dot_acc);
705 let mut na2 = vaddvq_f32(na_acc);
706 let mut nb2 = vaddvq_f32(nb_acc);
707 for i in (chunks * 4)..n {
708 dot += a[i] * b[i];
709 na2 += a[i] * a[i];
710 nb2 += b[i] * b[i];
711 }
712 let (na, nb) = (na2.sqrt(), nb2.sqrt());
713 if na == 0.0 || nb == 0.0 {
714 return 1.0;
715 }
716 1.0 - dot / (na * nb)
717 }
718}
719
720#[cfg(test)]
723mod tests {
724 use super::*;
725
726 #[test]
727 fn cosine_identical() {
728 let v = vec![1.0f32, 0.0, 0.0];
729 assert!(cosine_distance(&v, &v).abs() < 1e-5);
730 }
731
732 #[test]
733 fn cosine_orthogonal() {
734 assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
735 }
736
737 #[test]
738 fn euclidean_basic() {
739 assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
740 }
741
742 #[test]
743 fn dot_basic() {
744 assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
745 }
746
747 #[test]
748 fn simd_matches_scalar_dim128() {
749 use rand::{rngs::StdRng, Rng, SeedableRng};
750 let mut rng = StdRng::seed_from_u64(99);
751 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
752 let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
753
754 let dot_s = dot_scalar(&a, &b);
755 let euclid_s = euclidean_scalar(&a, &b);
756 let cos_s = cosine_scalar(&a, &b);
757
758 let dot_f = dot_product(&a, &b);
759 let euclid_f = euclidean_distance(&a, &b);
760 let cos_f = cosine_distance(&a, &b);
761
762 assert!(
763 (dot_f - dot_s).abs() < 1e-4,
764 "dot mismatch: {dot_f} vs {dot_s}"
765 );
766 assert!(
767 (euclid_f - euclid_s).abs() < 1e-4,
768 "euclidean mismatch: {euclid_f} vs {euclid_s}"
769 );
770 assert!(
771 (cos_f - cos_s).abs() < 1e-4,
772 "cosine mismatch: {cos_f} vs {cos_s}"
773 );
774 }
775
776 #[test]
777 fn f16_simd_matches_scalar() {
778 use rand::{rngs::StdRng, Rng, SeedableRng};
779 let mut rng = StdRng::seed_from_u64(42);
780 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
781 let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
782 let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
783
784 let dot_s = dot_f16_scalar(&a, &b);
785 let euclid_s = euclidean_f16_scalar(&a, &b);
786 let cos_s = cosine_f16_scalar(&a, &b);
787
788 let dot_f = dot_product_f16(&a, &b);
789 let euclid_f = euclidean_distance_f16(&a, &b);
790 let cos_f = cosine_distance_f16(&a, &b);
791
792 assert!(
794 (dot_f - dot_s).abs() < 1e-3,
795 "f16 dot mismatch: {dot_f} vs {dot_s}"
796 );
797 assert!(
798 (euclid_f - euclid_s).abs() < 1e-3,
799 "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
800 );
801 assert!(
802 (cos_f - cos_s).abs() < 1e-3,
803 "f16 cosine mismatch: {cos_f} vs {cos_s}"
804 );
805 }
806
807 #[test]
808 fn centroid_single() {
809 let v = vec![vec![1.0f32, 2.0, 3.0]];
810 let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
811 assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
812 assert!(c.radius < 1e-6, "radius={}", c.radius);
813 }
814
815 #[test]
816 fn centroid_two_points() {
817 let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
818 let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
819 assert!((c.values[0] - 1.0).abs() < 1e-6);
820 assert!(c.radius > 0.0);
821 }
822}