Skip to main content

ailake_vec/
distance.rs

1use ailake_core::{Centroid, VectorMetric};
2use half::f16;
3
4// ── Public API ────────────────────────────────────────────────────────────────
5
6pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
7    #[cfg(target_arch = "x86_64")]
8    {
9        #[cfg(feature = "avx512")]
10        if is_x86_feature_detected!("avx512f") {
11            return unsafe { avx512::dot(a, b) };
12        }
13        if is_x86_feature_detected!("avx2") {
14            return unsafe { avx2::dot(a, b) };
15        }
16    }
17    #[cfg(target_arch = "aarch64")]
18    if std::arch::is_aarch64_feature_detected!("neon") {
19        return unsafe { neon_impl::dot(a, b) };
20    }
21    dot_scalar(a, b)
22}
23
24pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
25    #[cfg(target_arch = "x86_64")]
26    {
27        #[cfg(feature = "avx512")]
28        if is_x86_feature_detected!("avx512f") {
29            return unsafe { avx512::euclidean(a, b) };
30        }
31        if is_x86_feature_detected!("avx2") {
32            return unsafe { avx2::euclidean(a, b) };
33        }
34    }
35    #[cfg(target_arch = "aarch64")]
36    if std::arch::is_aarch64_feature_detected!("neon") {
37        return unsafe { neon_impl::euclidean(a, b) };
38    }
39    euclidean_scalar(a, b)
40}
41
42pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
43    #[cfg(target_arch = "x86_64")]
44    {
45        #[cfg(feature = "avx512")]
46        if is_x86_feature_detected!("avx512f") {
47            return unsafe { avx512::cosine(a, b) };
48        }
49        if is_x86_feature_detected!("avx2") {
50            return unsafe { avx2::cosine(a, b) };
51        }
52    }
53    #[cfg(target_arch = "aarch64")]
54    if std::arch::is_aarch64_feature_detected!("neon") {
55        return unsafe { neon_impl::cosine(a, b) };
56    }
57    cosine_scalar(a, b)
58}
59
60pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
61    match metric {
62        VectorMetric::Cosine => cosine_distance(a, b),
63        VectorMetric::Euclidean => euclidean_distance(a, b),
64        VectorMetric::DotProduct => -dot_product(a, b),
65    }
66}
67
68// ── F16 distance functions ────────────────────────────────────────────────────
69//
70// Query `a` stays F32 (one vector, lives in registers).
71// Database vector `b` is F16 (dequantized inline — no allocation).
72//
73// Fast path: F16C converts 8 F16 values to F32 in one instruction via
74// _mm256_cvtph_ps, then FMA accumulates. Eliminates scalar half::to_f32()
75// loop that dominates HNSW graph traversal on dim=128 vectors.
76
77pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
78    #[cfg(target_arch = "x86_64")]
79    {
80        #[cfg(feature = "avx512")]
81        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
82            return unsafe { avx512::cosine_f16(a, b) };
83        }
84        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
85            return unsafe { avx2_f16c::cosine(a, b) };
86        }
87    }
88    cosine_f16_scalar(a, b)
89}
90
91pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
92    #[cfg(target_arch = "x86_64")]
93    {
94        #[cfg(feature = "avx512")]
95        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
96            return unsafe { avx512::euclidean_f16(a, b) };
97        }
98        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
99            return unsafe { avx2_f16c::euclidean(a, b) };
100        }
101    }
102    euclidean_f16_scalar(a, b)
103}
104
105pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
106    #[cfg(target_arch = "x86_64")]
107    {
108        #[cfg(feature = "avx512")]
109        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
110            return unsafe { avx512::dot_f16(a, b) };
111        }
112        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
113            return unsafe { avx2_f16c::dot(a, b) };
114        }
115    }
116    dot_f16_scalar(a, b)
117}
118
119pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
120    if vectors.is_empty() {
121        return Centroid {
122            values: vec![],
123            radius: 0.0,
124            metric,
125        };
126    }
127    let dim = vectors[0].len();
128    let n = vectors.len() as f32;
129    let centroid: Vec<f32> = (0..dim)
130        .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
131        .collect();
132    let radius = vectors
133        .iter()
134        .map(|v| exact_distance(metric, &centroid, v))
135        .fold(0.0_f32, f32::max);
136    Centroid {
137        values: centroid,
138        radius,
139        metric,
140    }
141}
142
143// ── Scalar fallbacks ──────────────────────────────────────────────────────────
144
145#[inline(always)]
146fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
147    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
148}
149
150#[inline(always)]
151fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
152    a.iter()
153        .zip(b.iter())
154        .map(|(x, y)| (x - y) * (x - y))
155        .sum::<f32>()
156        .sqrt()
157}
158
159#[inline(always)]
160fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
161    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
162    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
163    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
164    if na == 0.0 || nb == 0.0 {
165        return 1.0;
166    }
167    1.0 - dot / (na * nb)
168}
169
170#[inline(always)]
171fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
172    let n = a.len().min(b.len());
173    let mut dot = 0.0f32;
174    let mut norm_a = 0.0f32;
175    let mut norm_b = 0.0f32;
176    for i in 0..n {
177        let ai = a[i];
178        let bi = b[i].to_f32();
179        dot += ai * bi;
180        norm_a += ai * ai;
181        norm_b += bi * bi;
182    }
183    let denom = (norm_a * norm_b).sqrt();
184    if denom < 1e-8 {
185        1.0
186    } else {
187        1.0 - dot / denom
188    }
189}
190
191#[inline(always)]
192fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
193    let n = a.len().min(b.len());
194    let mut sum = 0.0f32;
195    for i in 0..n {
196        let diff = a[i] - b[i].to_f32();
197        sum += diff * diff;
198    }
199    sum.sqrt()
200}
201
202#[inline(always)]
203fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
204    let n = a.len().min(b.len());
205    let mut acc = 0.0f32;
206    for i in 0..n {
207        acc += a[i] * b[i].to_f32();
208    }
209    acc
210}
211
212// ── x86_64 AVX2 + FMA ────────────────────────────────────────────────────────
213//
214// Compiled with target_feature = "avx2,fma". The compiler emits vfmadd231ps
215// instead of separate vmulps + vaddps, cutting inner-loop instruction count
216// by ~33% and reducing latency via fused operations.
217
218#[cfg(target_arch = "x86_64")]
219mod avx2 {
220    use std::arch::x86_64::*;
221
222    #[inline(always)]
223    pub unsafe fn hsum256(v: __m256) -> f32 {
224        let hi = _mm256_extractf128_ps(v, 1);
225        let lo = _mm256_castps256_ps128(v);
226        let s = _mm_add_ps(lo, hi);
227        let shuf = _mm_movehdup_ps(s);
228        let sums = _mm_add_ps(s, shuf);
229        let shuf = _mm_movehl_ps(shuf, sums);
230        _mm_cvtss_f32(_mm_add_ss(sums, shuf))
231    }
232
233    /// dot(a, b) — AVX2+FMA, 2× unrolled (16 f32/iter).
234    #[target_feature(enable = "avx2,fma")]
235    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
236        let n = a.len().min(b.len());
237        let ap = a.as_ptr();
238        let bp = b.as_ptr();
239
240        let mut acc0 = _mm256_setzero_ps();
241        let mut acc1 = _mm256_setzero_ps();
242
243        let chunks16 = n / 16;
244        for i in 0..chunks16 {
245            let base = i * 16;
246            let a0 = _mm256_loadu_ps(ap.add(base));
247            let b0 = _mm256_loadu_ps(bp.add(base));
248            let a1 = _mm256_loadu_ps(ap.add(base + 8));
249            let b1 = _mm256_loadu_ps(bp.add(base + 8));
250            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
251            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
252        }
253
254        let chunks8 = n / 8;
255        if chunks8 > chunks16 * 2 {
256            let base = chunks16 * 16;
257            let a0 = _mm256_loadu_ps(ap.add(base));
258            let b0 = _mm256_loadu_ps(bp.add(base));
259            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
260        }
261
262        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
263        for i in (chunks8 * 8)..n {
264            sum += *ap.add(i) * *bp.add(i);
265        }
266        sum
267    }
268
269    /// ||a - b||² — AVX2+FMA, 2× unrolled.
270    #[target_feature(enable = "avx2,fma")]
271    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
272        let n = a.len().min(b.len());
273        let ap = a.as_ptr();
274        let bp = b.as_ptr();
275
276        let mut acc0 = _mm256_setzero_ps();
277        let mut acc1 = _mm256_setzero_ps();
278
279        let chunks16 = n / 16;
280        for i in 0..chunks16 {
281            let base = i * 16;
282            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
283            let d1 = _mm256_sub_ps(
284                _mm256_loadu_ps(ap.add(base + 8)),
285                _mm256_loadu_ps(bp.add(base + 8)),
286            );
287            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
288            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
289        }
290
291        let chunks8 = n / 8;
292        if chunks8 > chunks16 * 2 {
293            let base = chunks16 * 16;
294            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
295            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
296        }
297
298        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
299        for i in (chunks8 * 8)..n {
300            let d = *ap.add(i) - *bp.add(i);
301            sum += d * d;
302        }
303        sum.sqrt()
304    }
305
306    /// 1 - cos(a, b) — AVX2+FMA, single pass for dot + norms².
307    #[target_feature(enable = "avx2,fma")]
308    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
309        let n = a.len().min(b.len());
310        let ap = a.as_ptr();
311        let bp = b.as_ptr();
312
313        let mut dot_acc = _mm256_setzero_ps();
314        let mut na_acc = _mm256_setzero_ps();
315        let mut nb_acc = _mm256_setzero_ps();
316
317        let chunks8 = n / 8;
318        for i in 0..chunks8 {
319            let base = i * 8;
320            let av = _mm256_loadu_ps(ap.add(base));
321            let bv = _mm256_loadu_ps(bp.add(base));
322            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
323            na_acc = _mm256_fmadd_ps(av, av, na_acc);
324            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
325        }
326
327        let mut dot = hsum256(dot_acc);
328        let mut na2 = hsum256(na_acc);
329        let mut nb2 = hsum256(nb_acc);
330
331        for i in (chunks8 * 8)..n {
332            let ai = *ap.add(i);
333            let bi = *bp.add(i);
334            dot += ai * bi;
335            na2 += ai * ai;
336            nb2 += bi * bi;
337        }
338
339        let na = na2.sqrt();
340        let nb = nb2.sqrt();
341        if na == 0.0 || nb == 0.0 {
342            return 1.0;
343        }
344        1.0 - dot / (na * nb)
345    }
346}
347
348// ── x86_64 AVX2 + F16C — F16 hot path ────────────────────────────────────────
349//
350// _mm256_cvtph_ps converts 8 packed F16 (as __m128i) to 8 F32 in one cycle.
351// Combined with FMA, this replaces 8 scalar half::to_f32() calls per iteration.
352// Critical hot path: every HNSW edge traversal calls one of these functions.
353
354#[cfg(target_arch = "x86_64")]
355mod avx2_f16c {
356    use half::f16;
357    use std::arch::x86_64::*;
358
359    use super::avx2::hsum256;
360
361    /// dot(a_f32, b_f16) — AVX2+F16C+FMA, 16 F16/iter.
362    #[target_feature(enable = "avx2,f16c,fma")]
363    pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
364        let n = a.len().min(b.len());
365        let ap = a.as_ptr();
366        let bp = b.as_ptr() as *const u16;
367
368        let mut acc0 = _mm256_setzero_ps();
369        let mut acc1 = _mm256_setzero_ps();
370
371        let chunks16 = n / 16;
372        for i in 0..chunks16 {
373            let base = i * 16;
374            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
375            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
376            let a0 = _mm256_loadu_ps(ap.add(base));
377            let a1 = _mm256_loadu_ps(ap.add(base + 8));
378            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
379            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
380        }
381
382        let chunks8 = n / 8;
383        if chunks8 > chunks16 * 2 {
384            let base = chunks16 * 16;
385            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
386            let a0 = _mm256_loadu_ps(ap.add(base));
387            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
388        }
389
390        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
391        for i in (chunks8 * 8)..n {
392            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
393        }
394        sum
395    }
396
397    /// ||a_f32 - b_f16||² — AVX2+F16C+FMA, 16 F16/iter.
398    #[target_feature(enable = "avx2,f16c,fma")]
399    pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
400        let n = a.len().min(b.len());
401        let ap = a.as_ptr();
402        let bp = b.as_ptr() as *const u16;
403
404        let mut acc0 = _mm256_setzero_ps();
405        let mut acc1 = _mm256_setzero_ps();
406
407        let chunks16 = n / 16;
408        for i in 0..chunks16 {
409            let base = i * 16;
410            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
411            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
412            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
413            let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
414            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
415            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
416        }
417
418        let chunks8 = n / 8;
419        if chunks8 > chunks16 * 2 {
420            let base = chunks16 * 16;
421            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
422            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
423            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
424        }
425
426        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
427        for i in (chunks8 * 8)..n {
428            let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
429            sum += diff * diff;
430        }
431        sum.sqrt()
432    }
433
434    /// 1 - cos(a_f32, b_f16) — AVX2+F16C+FMA, single pass.
435    #[target_feature(enable = "avx2,f16c,fma")]
436    pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
437        let n = a.len().min(b.len());
438        let ap = a.as_ptr();
439        let bp = b.as_ptr() as *const u16;
440
441        let mut dot_acc = _mm256_setzero_ps();
442        let mut na_acc = _mm256_setzero_ps();
443        let mut nb_acc = _mm256_setzero_ps();
444
445        let chunks8 = n / 8;
446        for i in 0..chunks8 {
447            let base = i * 8;
448            let av = _mm256_loadu_ps(ap.add(base));
449            let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
450            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
451            na_acc = _mm256_fmadd_ps(av, av, na_acc);
452            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
453        }
454
455        let mut dot = hsum256(dot_acc);
456        let mut na2 = hsum256(na_acc);
457        let mut nb2 = hsum256(nb_acc);
458
459        for i in (chunks8 * 8)..n {
460            let ai = *ap.add(i);
461            let bi = f16::from_bits(*bp.add(i)).to_f32();
462            dot += ai * bi;
463            na2 += ai * ai;
464            nb2 += bi * bi;
465        }
466
467        let denom = (na2 * nb2).sqrt();
468        if denom < 1e-8 {
469            1.0
470        } else {
471            1.0 - dot / denom
472        }
473    }
474}
475
476// ── x86_64 AVX-512F — forward compatibility ───────────────────────────────────
477//
478// 16 f32/iter (vs 8 for AVX2). Runtime-detected — skipped on this machine
479// (no avx512f), active on Xeon Scalable, Zen 4+, and Intel Core 12th gen+.
480// Requires Rust ≥ 1.89 (AVX-512 intrinsics stabilised there). Gated behind
481// the `avx512` feature so the default/manylinux build always succeeds.
482
483#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
484mod avx512 {
485    use half::f16;
486    use std::arch::x86_64::*;
487
488    #[inline(always)]
489    unsafe fn hsum512(v: __m512) -> f32 {
490        // _mm512_reduce_add_ps stabilized Rust 1.89; _mm512_extractf32x8_ps needs avx512dq.
491        // Store all 16 lanes to stack (avx512f), reload as two __m256 (avx), then reduce.
492        let mut buf = [0.0f32; 16];
493        _mm512_storeu_ps(buf.as_mut_ptr(), v);
494        let lo = _mm256_loadu_ps(buf.as_ptr());
495        let hi = _mm256_loadu_ps(buf.as_ptr().add(8));
496        let sum256 = _mm256_add_ps(lo, hi);
497        let hi128 = _mm256_extractf128_ps(sum256, 1);
498        let lo128 = _mm256_castps256_ps128(sum256);
499        let sum128 = _mm_add_ps(lo128, hi128);
500        let shuf = _mm_movehdup_ps(sum128);
501        let sums = _mm_add_ps(sum128, shuf);
502        let shuf2 = _mm_movehl_ps(shuf, sums);
503        _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
504    }
505
506    #[target_feature(enable = "avx512f,fma")]
507    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
508        let n = a.len().min(b.len());
509        let ap = a.as_ptr();
510        let bp = b.as_ptr();
511        let mut acc = _mm512_setzero_ps();
512        let chunks16 = n / 16;
513        for i in 0..chunks16 {
514            let base = i * 16;
515            acc = _mm512_fmadd_ps(
516                _mm512_loadu_ps(ap.add(base)),
517                _mm512_loadu_ps(bp.add(base)),
518                acc,
519            );
520        }
521        let mut sum = hsum512(acc);
522        for i in (chunks16 * 16)..n {
523            sum += *ap.add(i) * *bp.add(i);
524        }
525        sum
526    }
527
528    #[target_feature(enable = "avx512f,fma")]
529    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
530        let n = a.len().min(b.len());
531        let ap = a.as_ptr();
532        let bp = b.as_ptr();
533        let mut acc = _mm512_setzero_ps();
534        let chunks16 = n / 16;
535        for i in 0..chunks16 {
536            let base = i * 16;
537            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
538            acc = _mm512_fmadd_ps(d, d, acc);
539        }
540        let mut sum = hsum512(acc);
541        for i in (chunks16 * 16)..n {
542            let d = *ap.add(i) - *bp.add(i);
543            sum += d * d;
544        }
545        sum.sqrt()
546    }
547
548    #[target_feature(enable = "avx512f,fma")]
549    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
550        let n = a.len().min(b.len());
551        let ap = a.as_ptr();
552        let bp = b.as_ptr();
553        let mut dot_acc = _mm512_setzero_ps();
554        let mut na_acc = _mm512_setzero_ps();
555        let mut nb_acc = _mm512_setzero_ps();
556        let chunks16 = n / 16;
557        for i in 0..chunks16 {
558            let base = i * 16;
559            let av = _mm512_loadu_ps(ap.add(base));
560            let bv = _mm512_loadu_ps(bp.add(base));
561            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
562            na_acc = _mm512_fmadd_ps(av, av, na_acc);
563            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
564        }
565        let mut dot = hsum512(dot_acc);
566        let mut na2 = hsum512(na_acc);
567        let mut nb2 = hsum512(nb_acc);
568        for i in (chunks16 * 16)..n {
569            let ai = *ap.add(i);
570            let bi = *bp.add(i);
571            dot += ai * bi;
572            na2 += ai * ai;
573            nb2 += bi * bi;
574        }
575        let (na, nb) = (na2.sqrt(), nb2.sqrt());
576        if na == 0.0 || nb == 0.0 {
577            return 1.0;
578        }
579        1.0 - dot / (na * nb)
580    }
581
582    /// dot(a_f32, b_f16) — AVX-512F+F16C+FMA, 16 F16/iter.
583    #[target_feature(enable = "avx512f,f16c,fma")]
584    pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
585        let n = a.len().min(b.len());
586        let ap = a.as_ptr();
587        let bp = b.as_ptr() as *const u16;
588        let mut acc = _mm512_setzero_ps();
589        let chunks16 = n / 16;
590        for i in 0..chunks16 {
591            let base = i * 16;
592            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
593            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
594            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
595            acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
596        }
597        let mut sum = hsum512(acc);
598        for i in (chunks16 * 16)..n {
599            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
600        }
601        sum
602    }
603
604    #[target_feature(enable = "avx512f,f16c,fma")]
605    pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
606        let n = a.len().min(b.len());
607        let ap = a.as_ptr();
608        let bp = b.as_ptr() as *const u16;
609        let mut acc = _mm512_setzero_ps();
610        let chunks16 = n / 16;
611        for i in 0..chunks16 {
612            let base = i * 16;
613            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
614            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
615            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
616            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
617            acc = _mm512_fmadd_ps(d, d, acc);
618        }
619        let mut sum = hsum512(acc);
620        for i in (chunks16 * 16)..n {
621            let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
622            sum += d * d;
623        }
624        sum.sqrt()
625    }
626
627    #[target_feature(enable = "avx512f,f16c,fma")]
628    pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
629        let n = a.len().min(b.len());
630        let ap = a.as_ptr();
631        let bp = b.as_ptr() as *const u16;
632        let mut dot_acc = _mm512_setzero_ps();
633        let mut na_acc = _mm512_setzero_ps();
634        let mut nb_acc = _mm512_setzero_ps();
635        let chunks16 = n / 16;
636        for i in 0..chunks16 {
637            let base = i * 16;
638            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
639            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
640            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
641            let av = _mm512_loadu_ps(ap.add(base));
642            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
643            na_acc = _mm512_fmadd_ps(av, av, na_acc);
644            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
645        }
646        let mut dot = hsum512(dot_acc);
647        let mut na2 = hsum512(na_acc);
648        let mut nb2 = hsum512(nb_acc);
649        for i in (chunks16 * 16)..n {
650            let ai = *ap.add(i);
651            let bi = f16::from_bits(*bp.add(i)).to_f32();
652            dot += ai * bi;
653            na2 += ai * ai;
654            nb2 += bi * bi;
655        }
656        let denom = (na2 * nb2).sqrt();
657        if denom < 1e-8 {
658            1.0
659        } else {
660            1.0 - dot / denom
661        }
662    }
663}
664
665// ── aarch64 NEON ──────────────────────────────────────────────────────────────
666
667#[cfg(target_arch = "aarch64")]
668mod neon_impl {
669    use std::arch::aarch64::*;
670
671    #[target_feature(enable = "neon")]
672    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
673        let n = a.len().min(b.len());
674        let mut acc = vdupq_n_f32(0.0);
675        let chunks = n / 4;
676        for i in 0..chunks {
677            let base = i * 4;
678            let av = vld1q_f32(a.as_ptr().add(base));
679            let bv = vld1q_f32(b.as_ptr().add(base));
680            acc = vmlaq_f32(acc, av, bv);
681        }
682        let mut sum = vaddvq_f32(acc);
683        for i in (chunks * 4)..n {
684            sum += a[i] * b[i];
685        }
686        sum
687    }
688
689    #[target_feature(enable = "neon")]
690    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
691        let n = a.len().min(b.len());
692        let mut acc = vdupq_n_f32(0.0);
693        let chunks = n / 4;
694        for i in 0..chunks {
695            let base = i * 4;
696            let d = vsubq_f32(
697                vld1q_f32(a.as_ptr().add(base)),
698                vld1q_f32(b.as_ptr().add(base)),
699            );
700            acc = vmlaq_f32(acc, d, d);
701        }
702        let mut sum = vaddvq_f32(acc);
703        for i in (chunks * 4)..n {
704            let d = a[i] - b[i];
705            sum += d * d;
706        }
707        sum.sqrt()
708    }
709
710    #[target_feature(enable = "neon")]
711    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
712        let n = a.len().min(b.len());
713        let mut dot_acc = vdupq_n_f32(0.0);
714        let mut na_acc = vdupq_n_f32(0.0);
715        let mut nb_acc = vdupq_n_f32(0.0);
716        let chunks = n / 4;
717        for i in 0..chunks {
718            let base = i * 4;
719            let av = vld1q_f32(a.as_ptr().add(base));
720            let bv = vld1q_f32(b.as_ptr().add(base));
721            dot_acc = vmlaq_f32(dot_acc, av, bv);
722            na_acc = vmlaq_f32(na_acc, av, av);
723            nb_acc = vmlaq_f32(nb_acc, bv, bv);
724        }
725        let mut dot = vaddvq_f32(dot_acc);
726        let mut na2 = vaddvq_f32(na_acc);
727        let mut nb2 = vaddvq_f32(nb_acc);
728        for i in (chunks * 4)..n {
729            dot += a[i] * b[i];
730            na2 += a[i] * a[i];
731            nb2 += b[i] * b[i];
732        }
733        let (na, nb) = (na2.sqrt(), nb2.sqrt());
734        if na == 0.0 || nb == 0.0 {
735            return 1.0;
736        }
737        1.0 - dot / (na * nb)
738    }
739}
740
741// ── Tests ─────────────────────────────────────────────────────────────────────
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746
747    #[test]
748    fn cosine_identical() {
749        let v = vec![1.0f32, 0.0, 0.0];
750        assert!(cosine_distance(&v, &v).abs() < 1e-5);
751    }
752
753    #[test]
754    fn cosine_orthogonal() {
755        assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
756    }
757
758    #[test]
759    fn euclidean_basic() {
760        assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
761    }
762
763    #[test]
764    fn dot_basic() {
765        assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
766    }
767
768    #[test]
769    fn simd_matches_scalar_dim128() {
770        use rand::{rngs::StdRng, Rng, SeedableRng};
771        let mut rng = StdRng::seed_from_u64(99);
772        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
773        let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
774
775        let dot_s = dot_scalar(&a, &b);
776        let euclid_s = euclidean_scalar(&a, &b);
777        let cos_s = cosine_scalar(&a, &b);
778
779        let dot_f = dot_product(&a, &b);
780        let euclid_f = euclidean_distance(&a, &b);
781        let cos_f = cosine_distance(&a, &b);
782
783        assert!(
784            (dot_f - dot_s).abs() < 1e-4,
785            "dot mismatch: {dot_f} vs {dot_s}"
786        );
787        assert!(
788            (euclid_f - euclid_s).abs() < 1e-4,
789            "euclidean mismatch: {euclid_f} vs {euclid_s}"
790        );
791        assert!(
792            (cos_f - cos_s).abs() < 1e-4,
793            "cosine mismatch: {cos_f} vs {cos_s}"
794        );
795    }
796
797    #[test]
798    fn f16_simd_matches_scalar() {
799        use rand::{rngs::StdRng, Rng, SeedableRng};
800        let mut rng = StdRng::seed_from_u64(42);
801        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
802        let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
803        let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
804
805        let dot_s = dot_f16_scalar(&a, &b);
806        let euclid_s = euclidean_f16_scalar(&a, &b);
807        let cos_s = cosine_f16_scalar(&a, &b);
808
809        let dot_f = dot_product_f16(&a, &b);
810        let euclid_f = euclidean_distance_f16(&a, &b);
811        let cos_f = cosine_distance_f16(&a, &b);
812
813        // F16 rounding introduces small error — tolerate 1e-3
814        assert!(
815            (dot_f - dot_s).abs() < 1e-3,
816            "f16 dot mismatch: {dot_f} vs {dot_s}"
817        );
818        assert!(
819            (euclid_f - euclid_s).abs() < 1e-3,
820            "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
821        );
822        assert!(
823            (cos_f - cos_s).abs() < 1e-3,
824            "f16 cosine mismatch: {cos_f} vs {cos_s}"
825        );
826    }
827
828    #[test]
829    fn centroid_single() {
830        let v = vec![vec![1.0f32, 2.0, 3.0]];
831        let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
832        assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
833        assert!(c.radius < 1e-6, "radius={}", c.radius);
834    }
835
836    #[test]
837    fn centroid_two_points() {
838        let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
839        let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
840        assert!((c.values[0] - 1.0).abs() < 1e-6);
841        assert!(c.radius > 0.0);
842    }
843}