1use ailake_core::{Centroid, VectorMetric};
3use half::f16;
4
5pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
8 #[cfg(target_arch = "x86_64")]
9 {
10 #[cfg(feature = "avx512")]
11 if is_x86_feature_detected!("avx512f") {
12 return unsafe { avx512::dot(a, b) };
13 }
14 if is_x86_feature_detected!("avx2") {
15 return unsafe { avx2::dot(a, b) };
16 }
17 }
18 #[cfg(target_arch = "aarch64")]
19 if std::arch::is_aarch64_feature_detected!("neon") {
20 return unsafe { neon_impl::dot(a, b) };
21 }
22 dot_scalar(a, b)
23}
24
25pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
26 #[cfg(target_arch = "x86_64")]
27 {
28 #[cfg(feature = "avx512")]
29 if is_x86_feature_detected!("avx512f") {
30 return unsafe { avx512::euclidean(a, b) };
31 }
32 if is_x86_feature_detected!("avx2") {
33 return unsafe { avx2::euclidean(a, b) };
34 }
35 }
36 #[cfg(target_arch = "aarch64")]
37 if std::arch::is_aarch64_feature_detected!("neon") {
38 return unsafe { neon_impl::euclidean(a, b) };
39 }
40 euclidean_scalar(a, b)
41}
42
43pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
44 #[cfg(target_arch = "x86_64")]
45 {
46 #[cfg(feature = "avx512")]
47 if is_x86_feature_detected!("avx512f") {
48 return unsafe { avx512::cosine(a, b) };
49 }
50 if is_x86_feature_detected!("avx2") {
51 return unsafe { avx2::cosine(a, b) };
52 }
53 }
54 #[cfg(target_arch = "aarch64")]
55 if std::arch::is_aarch64_feature_detected!("neon") {
56 return unsafe { neon_impl::cosine(a, b) };
57 }
58 cosine_scalar(a, b)
59}
60
61pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
62 match metric {
63 VectorMetric::Cosine => cosine_distance(a, b),
64 VectorMetric::Euclidean => euclidean_distance(a, b),
65 VectorMetric::DotProduct => -dot_product(a, b),
66 VectorMetric::NormalizedCosine => normalized_cosine_distance(a, b),
67 }
68}
69
70pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
80 #[cfg(target_arch = "x86_64")]
81 {
82 #[cfg(feature = "avx512")]
83 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
84 return unsafe { avx512::cosine_f16(a, b) };
85 }
86 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
87 return unsafe { avx2_f16c::cosine(a, b) };
88 }
89 }
90 cosine_f16_scalar(a, b)
91}
92
93pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
94 #[cfg(target_arch = "x86_64")]
95 {
96 #[cfg(feature = "avx512")]
97 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
98 return unsafe { avx512::euclidean_f16(a, b) };
99 }
100 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
101 return unsafe { avx2_f16c::euclidean(a, b) };
102 }
103 }
104 euclidean_f16_scalar(a, b)
105}
106
107pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
108 #[cfg(target_arch = "x86_64")]
109 {
110 #[cfg(feature = "avx512")]
111 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
112 return unsafe { avx512::dot_f16(a, b) };
113 }
114 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
115 return unsafe { avx2_f16c::dot(a, b) };
116 }
117 }
118 dot_f16_scalar(a, b)
119}
120
121pub fn normalize_l2(v: &[f32]) -> Vec<f32> {
123 let norm_sq: f32 = v.iter().map(|x| x * x).sum();
124 if norm_sq < 1e-12 {
125 return v.to_vec();
126 }
127 let inv = 1.0 / norm_sq.sqrt();
128 v.iter().map(|x| x * inv).collect()
129}
130
131pub fn normalized_cosine_distance(a: &[f32], b: &[f32]) -> f32 {
134 1.0 - dot_product(a, b)
135}
136
137pub fn normalized_cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
138 1.0 - dot_product_f16(a, b)
139}
140
141pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
142 if vectors.is_empty() {
143 return Centroid {
144 values: vec![],
145 radius: 0.0,
146 metric,
147 };
148 }
149 let dim = vectors[0].len();
150 let n = vectors.len() as f32;
151 let centroid: Vec<f32> = (0..dim)
152 .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
153 .collect();
154 let radius = vectors
155 .iter()
156 .map(|v| exact_distance(metric, ¢roid, v))
157 .fold(0.0_f32, f32::max);
158 Centroid {
159 values: centroid,
160 radius,
161 metric,
162 }
163}
164
165#[inline(always)]
168fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
169 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
170}
171
172#[inline(always)]
173fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
174 a.iter()
175 .zip(b.iter())
176 .map(|(x, y)| (x - y) * (x - y))
177 .sum::<f32>()
178 .sqrt()
179}
180
181#[inline(always)]
182fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
183 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
184 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
185 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
186 if na == 0.0 || nb == 0.0 {
187 return 1.0;
188 }
189 1.0 - dot / (na * nb)
190}
191
192#[inline(always)]
193fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
194 let n = a.len().min(b.len());
195 let mut dot = 0.0f32;
196 let mut norm_a = 0.0f32;
197 let mut norm_b = 0.0f32;
198 for i in 0..n {
199 let ai = a[i];
200 let bi = b[i].to_f32();
201 dot += ai * bi;
202 norm_a += ai * ai;
203 norm_b += bi * bi;
204 }
205 let denom = (norm_a * norm_b).sqrt();
206 if denom < 1e-8 {
207 1.0
208 } else {
209 1.0 - dot / denom
210 }
211}
212
213#[inline(always)]
214fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
215 let n = a.len().min(b.len());
216 let mut sum = 0.0f32;
217 for i in 0..n {
218 let diff = a[i] - b[i].to_f32();
219 sum += diff * diff;
220 }
221 sum.sqrt()
222}
223
224#[inline(always)]
225fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
226 let n = a.len().min(b.len());
227 let mut acc = 0.0f32;
228 for i in 0..n {
229 acc += a[i] * b[i].to_f32();
230 }
231 acc
232}
233
234#[cfg(target_arch = "x86_64")]
241mod avx2 {
242 use std::arch::x86_64::*;
243
244 #[inline(always)]
245 pub unsafe fn hsum256(v: __m256) -> f32 {
246 let hi = _mm256_extractf128_ps(v, 1);
247 let lo = _mm256_castps256_ps128(v);
248 let s = _mm_add_ps(lo, hi);
249 let shuf = _mm_movehdup_ps(s);
250 let sums = _mm_add_ps(s, shuf);
251 let shuf = _mm_movehl_ps(shuf, sums);
252 _mm_cvtss_f32(_mm_add_ss(sums, shuf))
253 }
254
255 #[target_feature(enable = "avx2,fma")]
257 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
258 let n = a.len().min(b.len());
259 let ap = a.as_ptr();
260 let bp = b.as_ptr();
261
262 let mut acc0 = _mm256_setzero_ps();
263 let mut acc1 = _mm256_setzero_ps();
264
265 let chunks16 = n / 16;
266 for i in 0..chunks16 {
267 let base = i * 16;
268 let a0 = _mm256_loadu_ps(ap.add(base));
269 let b0 = _mm256_loadu_ps(bp.add(base));
270 let a1 = _mm256_loadu_ps(ap.add(base + 8));
271 let b1 = _mm256_loadu_ps(bp.add(base + 8));
272 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
273 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
274 }
275
276 let chunks8 = n / 8;
277 if chunks8 > chunks16 * 2 {
278 let base = chunks16 * 16;
279 let a0 = _mm256_loadu_ps(ap.add(base));
280 let b0 = _mm256_loadu_ps(bp.add(base));
281 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
282 }
283
284 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
285 for i in (chunks8 * 8)..n {
286 sum += *ap.add(i) * *bp.add(i);
287 }
288 sum
289 }
290
291 #[target_feature(enable = "avx2,fma")]
293 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
294 let n = a.len().min(b.len());
295 let ap = a.as_ptr();
296 let bp = b.as_ptr();
297
298 let mut acc0 = _mm256_setzero_ps();
299 let mut acc1 = _mm256_setzero_ps();
300
301 let chunks16 = n / 16;
302 for i in 0..chunks16 {
303 let base = i * 16;
304 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
305 let d1 = _mm256_sub_ps(
306 _mm256_loadu_ps(ap.add(base + 8)),
307 _mm256_loadu_ps(bp.add(base + 8)),
308 );
309 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
310 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
311 }
312
313 let chunks8 = n / 8;
314 if chunks8 > chunks16 * 2 {
315 let base = chunks16 * 16;
316 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
317 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
318 }
319
320 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
321 for i in (chunks8 * 8)..n {
322 let d = *ap.add(i) - *bp.add(i);
323 sum += d * d;
324 }
325 sum.sqrt()
326 }
327
328 #[target_feature(enable = "avx2,fma")]
330 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
331 let n = a.len().min(b.len());
332 let ap = a.as_ptr();
333 let bp = b.as_ptr();
334
335 let mut dot_acc = _mm256_setzero_ps();
336 let mut na_acc = _mm256_setzero_ps();
337 let mut nb_acc = _mm256_setzero_ps();
338
339 let chunks8 = n / 8;
340 for i in 0..chunks8 {
341 let base = i * 8;
342 let av = _mm256_loadu_ps(ap.add(base));
343 let bv = _mm256_loadu_ps(bp.add(base));
344 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
345 na_acc = _mm256_fmadd_ps(av, av, na_acc);
346 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
347 }
348
349 let mut dot = hsum256(dot_acc);
350 let mut na2 = hsum256(na_acc);
351 let mut nb2 = hsum256(nb_acc);
352
353 for i in (chunks8 * 8)..n {
354 let ai = *ap.add(i);
355 let bi = *bp.add(i);
356 dot += ai * bi;
357 na2 += ai * ai;
358 nb2 += bi * bi;
359 }
360
361 let na = na2.sqrt();
362 let nb = nb2.sqrt();
363 if na == 0.0 || nb == 0.0 {
364 return 1.0;
365 }
366 1.0 - dot / (na * nb)
367 }
368}
369
370#[cfg(target_arch = "x86_64")]
377mod avx2_f16c {
378 use half::f16;
379 use std::arch::x86_64::*;
380
381 use super::avx2::hsum256;
382
383 #[target_feature(enable = "avx2,f16c,fma")]
385 pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
386 let n = a.len().min(b.len());
387 let ap = a.as_ptr();
388 let bp = b.as_ptr() as *const u16;
389
390 let mut acc0 = _mm256_setzero_ps();
391 let mut acc1 = _mm256_setzero_ps();
392
393 let chunks16 = n / 16;
394 for i in 0..chunks16 {
395 let base = i * 16;
396 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
397 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
398 let a0 = _mm256_loadu_ps(ap.add(base));
399 let a1 = _mm256_loadu_ps(ap.add(base + 8));
400 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
401 acc1 = _mm256_fmadd_ps(a1, b1, acc1);
402 }
403
404 let chunks8 = n / 8;
405 if chunks8 > chunks16 * 2 {
406 let base = chunks16 * 16;
407 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
408 let a0 = _mm256_loadu_ps(ap.add(base));
409 acc0 = _mm256_fmadd_ps(a0, b0, acc0);
410 }
411
412 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
413 for i in (chunks8 * 8)..n {
414 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
415 }
416 sum
417 }
418
419 #[target_feature(enable = "avx2,f16c,fma")]
421 pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
422 let n = a.len().min(b.len());
423 let ap = a.as_ptr();
424 let bp = b.as_ptr() as *const u16;
425
426 let mut acc0 = _mm256_setzero_ps();
427 let mut acc1 = _mm256_setzero_ps();
428
429 let chunks16 = n / 16;
430 for i in 0..chunks16 {
431 let base = i * 16;
432 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
433 let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
434 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
435 let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
436 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
437 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
438 }
439
440 let chunks8 = n / 8;
441 if chunks8 > chunks16 * 2 {
442 let base = chunks16 * 16;
443 let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
444 let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
445 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
446 }
447
448 let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
449 for i in (chunks8 * 8)..n {
450 let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
451 sum += diff * diff;
452 }
453 sum.sqrt()
454 }
455
456 #[target_feature(enable = "avx2,f16c,fma")]
458 pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
459 let n = a.len().min(b.len());
460 let ap = a.as_ptr();
461 let bp = b.as_ptr() as *const u16;
462
463 let mut dot_acc = _mm256_setzero_ps();
464 let mut na_acc = _mm256_setzero_ps();
465 let mut nb_acc = _mm256_setzero_ps();
466
467 let chunks8 = n / 8;
468 for i in 0..chunks8 {
469 let base = i * 8;
470 let av = _mm256_loadu_ps(ap.add(base));
471 let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
472 dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
473 na_acc = _mm256_fmadd_ps(av, av, na_acc);
474 nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
475 }
476
477 let mut dot = hsum256(dot_acc);
478 let mut na2 = hsum256(na_acc);
479 let mut nb2 = hsum256(nb_acc);
480
481 for i in (chunks8 * 8)..n {
482 let ai = *ap.add(i);
483 let bi = f16::from_bits(*bp.add(i)).to_f32();
484 dot += ai * bi;
485 na2 += ai * ai;
486 nb2 += bi * bi;
487 }
488
489 let denom = (na2 * nb2).sqrt();
490 if denom < 1e-8 {
491 1.0
492 } else {
493 1.0 - dot / denom
494 }
495 }
496}
497
498#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
506mod avx512 {
507 use half::f16;
508 use std::arch::x86_64::*;
509
510 #[inline(always)]
511 unsafe fn hsum512(v: __m512) -> f32 {
512 let mut buf = [0.0f32; 16];
515 _mm512_storeu_ps(buf.as_mut_ptr(), v);
516 let lo = _mm256_loadu_ps(buf.as_ptr());
517 let hi = _mm256_loadu_ps(buf.as_ptr().add(8));
518 let sum256 = _mm256_add_ps(lo, hi);
519 let hi128 = _mm256_extractf128_ps(sum256, 1);
520 let lo128 = _mm256_castps256_ps128(sum256);
521 let sum128 = _mm_add_ps(lo128, hi128);
522 let shuf = _mm_movehdup_ps(sum128);
523 let sums = _mm_add_ps(sum128, shuf);
524 let shuf2 = _mm_movehl_ps(shuf, sums);
525 _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
526 }
527
528 #[target_feature(enable = "avx512f,fma")]
529 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
530 let n = a.len().min(b.len());
531 let ap = a.as_ptr();
532 let bp = b.as_ptr();
533 let mut acc = _mm512_setzero_ps();
534 let chunks16 = n / 16;
535 for i in 0..chunks16 {
536 let base = i * 16;
537 acc = _mm512_fmadd_ps(
538 _mm512_loadu_ps(ap.add(base)),
539 _mm512_loadu_ps(bp.add(base)),
540 acc,
541 );
542 }
543 let mut sum = hsum512(acc);
544 for i in (chunks16 * 16)..n {
545 sum += *ap.add(i) * *bp.add(i);
546 }
547 sum
548 }
549
550 #[target_feature(enable = "avx512f,fma")]
551 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
552 let n = a.len().min(b.len());
553 let ap = a.as_ptr();
554 let bp = b.as_ptr();
555 let mut acc = _mm512_setzero_ps();
556 let chunks16 = n / 16;
557 for i in 0..chunks16 {
558 let base = i * 16;
559 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
560 acc = _mm512_fmadd_ps(d, d, acc);
561 }
562 let mut sum = hsum512(acc);
563 for i in (chunks16 * 16)..n {
564 let d = *ap.add(i) - *bp.add(i);
565 sum += d * d;
566 }
567 sum.sqrt()
568 }
569
570 #[target_feature(enable = "avx512f,fma")]
571 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
572 let n = a.len().min(b.len());
573 let ap = a.as_ptr();
574 let bp = b.as_ptr();
575 let mut dot_acc = _mm512_setzero_ps();
576 let mut na_acc = _mm512_setzero_ps();
577 let mut nb_acc = _mm512_setzero_ps();
578 let chunks16 = n / 16;
579 for i in 0..chunks16 {
580 let base = i * 16;
581 let av = _mm512_loadu_ps(ap.add(base));
582 let bv = _mm512_loadu_ps(bp.add(base));
583 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
584 na_acc = _mm512_fmadd_ps(av, av, na_acc);
585 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
586 }
587 let mut dot = hsum512(dot_acc);
588 let mut na2 = hsum512(na_acc);
589 let mut nb2 = hsum512(nb_acc);
590 for i in (chunks16 * 16)..n {
591 let ai = *ap.add(i);
592 let bi = *bp.add(i);
593 dot += ai * bi;
594 na2 += ai * ai;
595 nb2 += bi * bi;
596 }
597 let (na, nb) = (na2.sqrt(), nb2.sqrt());
598 if na == 0.0 || nb == 0.0 {
599 return 1.0;
600 }
601 1.0 - dot / (na * nb)
602 }
603
604 #[target_feature(enable = "avx512f,f16c,fma")]
606 pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
607 let n = a.len().min(b.len());
608 let ap = a.as_ptr();
609 let bp = b.as_ptr() as *const u16;
610 let mut acc = _mm512_setzero_ps();
611 let chunks16 = n / 16;
612 for i in 0..chunks16 {
613 let base = i * 16;
614 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
615 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
616 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
617 acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
618 }
619 let mut sum = hsum512(acc);
620 for i in (chunks16 * 16)..n {
621 sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
622 }
623 sum
624 }
625
626 #[target_feature(enable = "avx512f,f16c,fma")]
627 pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
628 let n = a.len().min(b.len());
629 let ap = a.as_ptr();
630 let bp = b.as_ptr() as *const u16;
631 let mut acc = _mm512_setzero_ps();
632 let chunks16 = n / 16;
633 for i in 0..chunks16 {
634 let base = i * 16;
635 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
636 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
637 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
638 let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
639 acc = _mm512_fmadd_ps(d, d, acc);
640 }
641 let mut sum = hsum512(acc);
642 for i in (chunks16 * 16)..n {
643 let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
644 sum += d * d;
645 }
646 sum.sqrt()
647 }
648
649 #[target_feature(enable = "avx512f,f16c,fma")]
650 pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
651 let n = a.len().min(b.len());
652 let ap = a.as_ptr();
653 let bp = b.as_ptr() as *const u16;
654 let mut dot_acc = _mm512_setzero_ps();
655 let mut na_acc = _mm512_setzero_ps();
656 let mut nb_acc = _mm512_setzero_ps();
657 let chunks16 = n / 16;
658 for i in 0..chunks16 {
659 let base = i * 16;
660 let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
661 let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
662 let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
663 let av = _mm512_loadu_ps(ap.add(base));
664 dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
665 na_acc = _mm512_fmadd_ps(av, av, na_acc);
666 nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
667 }
668 let mut dot = hsum512(dot_acc);
669 let mut na2 = hsum512(na_acc);
670 let mut nb2 = hsum512(nb_acc);
671 for i in (chunks16 * 16)..n {
672 let ai = *ap.add(i);
673 let bi = f16::from_bits(*bp.add(i)).to_f32();
674 dot += ai * bi;
675 na2 += ai * ai;
676 nb2 += bi * bi;
677 }
678 let denom = (na2 * nb2).sqrt();
679 if denom < 1e-8 {
680 1.0
681 } else {
682 1.0 - dot / denom
683 }
684 }
685}
686
687#[cfg(target_arch = "aarch64")]
690mod neon_impl {
691 use std::arch::aarch64::*;
692
693 #[target_feature(enable = "neon")]
694 pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
695 let n = a.len().min(b.len());
696 let mut acc = vdupq_n_f32(0.0);
697 let chunks = n / 4;
698 for i in 0..chunks {
699 let base = i * 4;
700 let av = vld1q_f32(a.as_ptr().add(base));
701 let bv = vld1q_f32(b.as_ptr().add(base));
702 acc = vmlaq_f32(acc, av, bv);
703 }
704 let mut sum = vaddvq_f32(acc);
705 for i in (chunks * 4)..n {
706 sum += a[i] * b[i];
707 }
708 sum
709 }
710
711 #[target_feature(enable = "neon")]
712 pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
713 let n = a.len().min(b.len());
714 let mut acc = vdupq_n_f32(0.0);
715 let chunks = n / 4;
716 for i in 0..chunks {
717 let base = i * 4;
718 let d = vsubq_f32(
719 vld1q_f32(a.as_ptr().add(base)),
720 vld1q_f32(b.as_ptr().add(base)),
721 );
722 acc = vmlaq_f32(acc, d, d);
723 }
724 let mut sum = vaddvq_f32(acc);
725 for i in (chunks * 4)..n {
726 let d = a[i] - b[i];
727 sum += d * d;
728 }
729 sum.sqrt()
730 }
731
732 #[target_feature(enable = "neon")]
733 pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
734 let n = a.len().min(b.len());
735 let mut dot_acc = vdupq_n_f32(0.0);
736 let mut na_acc = vdupq_n_f32(0.0);
737 let mut nb_acc = vdupq_n_f32(0.0);
738 let chunks = n / 4;
739 for i in 0..chunks {
740 let base = i * 4;
741 let av = vld1q_f32(a.as_ptr().add(base));
742 let bv = vld1q_f32(b.as_ptr().add(base));
743 dot_acc = vmlaq_f32(dot_acc, av, bv);
744 na_acc = vmlaq_f32(na_acc, av, av);
745 nb_acc = vmlaq_f32(nb_acc, bv, bv);
746 }
747 let mut dot = vaddvq_f32(dot_acc);
748 let mut na2 = vaddvq_f32(na_acc);
749 let mut nb2 = vaddvq_f32(nb_acc);
750 for i in (chunks * 4)..n {
751 dot += a[i] * b[i];
752 na2 += a[i] * a[i];
753 nb2 += b[i] * b[i];
754 }
755 let (na, nb) = (na2.sqrt(), nb2.sqrt());
756 if na == 0.0 || nb == 0.0 {
757 return 1.0;
758 }
759 1.0 - dot / (na * nb)
760 }
761}
762
763#[cfg(test)]
766mod tests {
767 use super::*;
768
769 #[test]
770 fn cosine_identical() {
771 let v = vec![1.0f32, 0.0, 0.0];
772 assert!(cosine_distance(&v, &v).abs() < 1e-5);
773 }
774
775 #[test]
776 fn cosine_orthogonal() {
777 assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
778 }
779
780 #[test]
781 fn euclidean_basic() {
782 assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
783 }
784
785 #[test]
786 fn dot_basic() {
787 assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
788 }
789
790 #[test]
791 fn simd_matches_scalar_dim128() {
792 use rand::{rngs::StdRng, Rng, SeedableRng};
793 let mut rng = StdRng::seed_from_u64(99);
794 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
795 let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
796
797 let dot_s = dot_scalar(&a, &b);
798 let euclid_s = euclidean_scalar(&a, &b);
799 let cos_s = cosine_scalar(&a, &b);
800
801 let dot_f = dot_product(&a, &b);
802 let euclid_f = euclidean_distance(&a, &b);
803 let cos_f = cosine_distance(&a, &b);
804
805 assert!(
806 (dot_f - dot_s).abs() < 1e-4,
807 "dot mismatch: {dot_f} vs {dot_s}"
808 );
809 assert!(
810 (euclid_f - euclid_s).abs() < 1e-4,
811 "euclidean mismatch: {euclid_f} vs {euclid_s}"
812 );
813 assert!(
814 (cos_f - cos_s).abs() < 1e-4,
815 "cosine mismatch: {cos_f} vs {cos_s}"
816 );
817 }
818
819 #[test]
820 fn f16_simd_matches_scalar() {
821 use rand::{rngs::StdRng, Rng, SeedableRng};
822 let mut rng = StdRng::seed_from_u64(42);
823 let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
824 let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
825 let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
826
827 let dot_s = dot_f16_scalar(&a, &b);
828 let euclid_s = euclidean_f16_scalar(&a, &b);
829 let cos_s = cosine_f16_scalar(&a, &b);
830
831 let dot_f = dot_product_f16(&a, &b);
832 let euclid_f = euclidean_distance_f16(&a, &b);
833 let cos_f = cosine_distance_f16(&a, &b);
834
835 assert!(
837 (dot_f - dot_s).abs() < 1e-3,
838 "f16 dot mismatch: {dot_f} vs {dot_s}"
839 );
840 assert!(
841 (euclid_f - euclid_s).abs() < 1e-3,
842 "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
843 );
844 assert!(
845 (cos_f - cos_s).abs() < 1e-3,
846 "f16 cosine mismatch: {cos_f} vs {cos_s}"
847 );
848 }
849
850 #[test]
851 fn normalize_l2_unit() {
852 let v = vec![3.0f32, 4.0];
853 let n = normalize_l2(&v);
854 let norm: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
855 assert!((norm - 1.0).abs() < 1e-6, "norm={norm}");
856 assert!((n[0] - 0.6).abs() < 1e-6);
857 assert!((n[1] - 0.8).abs() < 1e-6);
858 }
859
860 #[test]
861 fn normalized_cosine_matches_cosine_on_unit_vecs() {
862 let a = normalize_l2(&[1.0f32, 1.0, 0.0]);
863 let b = normalize_l2(&[1.0f32, 0.0, 1.0]);
864 let cos = cosine_distance(&a, &b);
865 let ncos = normalized_cosine_distance(&a, &b);
866 assert!((cos - ncos).abs() < 1e-5, "cos={cos} ncos={ncos}");
867 }
868
869 #[test]
870 fn centroid_single() {
871 let v = vec![vec![1.0f32, 2.0, 3.0]];
872 let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
873 assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
874 assert!(c.radius < 1e-6, "radius={}", c.radius);
875 }
876
877 #[test]
878 fn centroid_two_points() {
879 let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
880 let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
881 assert!((c.values[0] - 1.0).abs() < 1e-6);
882 assert!(c.radius > 0.0);
883 }
884}