Skip to main content

ailake_vec/
distance.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use ailake_core::{Centroid, VectorMetric};
3use half::f16;
4
5// ── Public API ────────────────────────────────────────────────────────────────
6
7pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
8    #[cfg(target_arch = "x86_64")]
9    {
10        #[cfg(feature = "avx512")]
11        if is_x86_feature_detected!("avx512f") {
12            return unsafe { avx512::dot(a, b) };
13        }
14        if is_x86_feature_detected!("avx2") {
15            return unsafe { avx2::dot(a, b) };
16        }
17    }
18    #[cfg(target_arch = "aarch64")]
19    if std::arch::is_aarch64_feature_detected!("neon") {
20        return unsafe { neon_impl::dot(a, b) };
21    }
22    dot_scalar(a, b)
23}
24
25pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
26    #[cfg(target_arch = "x86_64")]
27    {
28        #[cfg(feature = "avx512")]
29        if is_x86_feature_detected!("avx512f") {
30            return unsafe { avx512::euclidean(a, b) };
31        }
32        if is_x86_feature_detected!("avx2") {
33            return unsafe { avx2::euclidean(a, b) };
34        }
35    }
36    #[cfg(target_arch = "aarch64")]
37    if std::arch::is_aarch64_feature_detected!("neon") {
38        return unsafe { neon_impl::euclidean(a, b) };
39    }
40    euclidean_scalar(a, b)
41}
42
43pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
44    #[cfg(target_arch = "x86_64")]
45    {
46        #[cfg(feature = "avx512")]
47        if is_x86_feature_detected!("avx512f") {
48            return unsafe { avx512::cosine(a, b) };
49        }
50        if is_x86_feature_detected!("avx2") {
51            return unsafe { avx2::cosine(a, b) };
52        }
53    }
54    #[cfg(target_arch = "aarch64")]
55    if std::arch::is_aarch64_feature_detected!("neon") {
56        return unsafe { neon_impl::cosine(a, b) };
57    }
58    cosine_scalar(a, b)
59}
60
61pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
62    match metric {
63        VectorMetric::Cosine => cosine_distance(a, b),
64        VectorMetric::Euclidean => euclidean_distance(a, b),
65        VectorMetric::DotProduct => -dot_product(a, b),
66    }
67}
68
69// ── F16 distance functions ────────────────────────────────────────────────────
70//
71// Query `a` stays F32 (one vector, lives in registers).
72// Database vector `b` is F16 (dequantized inline — no allocation).
73//
74// Fast path: F16C converts 8 F16 values to F32 in one instruction via
75// _mm256_cvtph_ps, then FMA accumulates. Eliminates scalar half::to_f32()
76// loop that dominates HNSW graph traversal on dim=128 vectors.
77
78pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
79    #[cfg(target_arch = "x86_64")]
80    {
81        #[cfg(feature = "avx512")]
82        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
83            return unsafe { avx512::cosine_f16(a, b) };
84        }
85        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
86            return unsafe { avx2_f16c::cosine(a, b) };
87        }
88    }
89    cosine_f16_scalar(a, b)
90}
91
92pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
93    #[cfg(target_arch = "x86_64")]
94    {
95        #[cfg(feature = "avx512")]
96        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
97            return unsafe { avx512::euclidean_f16(a, b) };
98        }
99        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
100            return unsafe { avx2_f16c::euclidean(a, b) };
101        }
102    }
103    euclidean_f16_scalar(a, b)
104}
105
106pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
107    #[cfg(target_arch = "x86_64")]
108    {
109        #[cfg(feature = "avx512")]
110        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
111            return unsafe { avx512::dot_f16(a, b) };
112        }
113        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
114            return unsafe { avx2_f16c::dot(a, b) };
115        }
116    }
117    dot_f16_scalar(a, b)
118}
119
120pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
121    if vectors.is_empty() {
122        return Centroid {
123            values: vec![],
124            radius: 0.0,
125            metric,
126        };
127    }
128    let dim = vectors[0].len();
129    let n = vectors.len() as f32;
130    let centroid: Vec<f32> = (0..dim)
131        .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
132        .collect();
133    let radius = vectors
134        .iter()
135        .map(|v| exact_distance(metric, &centroid, v))
136        .fold(0.0_f32, f32::max);
137    Centroid {
138        values: centroid,
139        radius,
140        metric,
141    }
142}
143
144// ── Scalar fallbacks ──────────────────────────────────────────────────────────
145
146#[inline(always)]
147fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
148    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
149}
150
151#[inline(always)]
152fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
153    a.iter()
154        .zip(b.iter())
155        .map(|(x, y)| (x - y) * (x - y))
156        .sum::<f32>()
157        .sqrt()
158}
159
160#[inline(always)]
161fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
162    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
163    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
164    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
165    if na == 0.0 || nb == 0.0 {
166        return 1.0;
167    }
168    1.0 - dot / (na * nb)
169}
170
171#[inline(always)]
172fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
173    let n = a.len().min(b.len());
174    let mut dot = 0.0f32;
175    let mut norm_a = 0.0f32;
176    let mut norm_b = 0.0f32;
177    for i in 0..n {
178        let ai = a[i];
179        let bi = b[i].to_f32();
180        dot += ai * bi;
181        norm_a += ai * ai;
182        norm_b += bi * bi;
183    }
184    let denom = (norm_a * norm_b).sqrt();
185    if denom < 1e-8 {
186        1.0
187    } else {
188        1.0 - dot / denom
189    }
190}
191
192#[inline(always)]
193fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
194    let n = a.len().min(b.len());
195    let mut sum = 0.0f32;
196    for i in 0..n {
197        let diff = a[i] - b[i].to_f32();
198        sum += diff * diff;
199    }
200    sum.sqrt()
201}
202
203#[inline(always)]
204fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
205    let n = a.len().min(b.len());
206    let mut acc = 0.0f32;
207    for i in 0..n {
208        acc += a[i] * b[i].to_f32();
209    }
210    acc
211}
212
213// ── x86_64 AVX2 + FMA ────────────────────────────────────────────────────────
214//
215// Compiled with target_feature = "avx2,fma". The compiler emits vfmadd231ps
216// instead of separate vmulps + vaddps, cutting inner-loop instruction count
217// by ~33% and reducing latency via fused operations.
218
219#[cfg(target_arch = "x86_64")]
220mod avx2 {
221    use std::arch::x86_64::*;
222
223    #[inline(always)]
224    pub unsafe fn hsum256(v: __m256) -> f32 {
225        let hi = _mm256_extractf128_ps(v, 1);
226        let lo = _mm256_castps256_ps128(v);
227        let s = _mm_add_ps(lo, hi);
228        let shuf = _mm_movehdup_ps(s);
229        let sums = _mm_add_ps(s, shuf);
230        let shuf = _mm_movehl_ps(shuf, sums);
231        _mm_cvtss_f32(_mm_add_ss(sums, shuf))
232    }
233
234    /// dot(a, b) — AVX2+FMA, 2× unrolled (16 f32/iter).
235    #[target_feature(enable = "avx2,fma")]
236    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
237        let n = a.len().min(b.len());
238        let ap = a.as_ptr();
239        let bp = b.as_ptr();
240
241        let mut acc0 = _mm256_setzero_ps();
242        let mut acc1 = _mm256_setzero_ps();
243
244        let chunks16 = n / 16;
245        for i in 0..chunks16 {
246            let base = i * 16;
247            let a0 = _mm256_loadu_ps(ap.add(base));
248            let b0 = _mm256_loadu_ps(bp.add(base));
249            let a1 = _mm256_loadu_ps(ap.add(base + 8));
250            let b1 = _mm256_loadu_ps(bp.add(base + 8));
251            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
252            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
253        }
254
255        let chunks8 = n / 8;
256        if chunks8 > chunks16 * 2 {
257            let base = chunks16 * 16;
258            let a0 = _mm256_loadu_ps(ap.add(base));
259            let b0 = _mm256_loadu_ps(bp.add(base));
260            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
261        }
262
263        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
264        for i in (chunks8 * 8)..n {
265            sum += *ap.add(i) * *bp.add(i);
266        }
267        sum
268    }
269
270    /// ||a - b||² — AVX2+FMA, 2× unrolled.
271    #[target_feature(enable = "avx2,fma")]
272    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
273        let n = a.len().min(b.len());
274        let ap = a.as_ptr();
275        let bp = b.as_ptr();
276
277        let mut acc0 = _mm256_setzero_ps();
278        let mut acc1 = _mm256_setzero_ps();
279
280        let chunks16 = n / 16;
281        for i in 0..chunks16 {
282            let base = i * 16;
283            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
284            let d1 = _mm256_sub_ps(
285                _mm256_loadu_ps(ap.add(base + 8)),
286                _mm256_loadu_ps(bp.add(base + 8)),
287            );
288            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
289            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
290        }
291
292        let chunks8 = n / 8;
293        if chunks8 > chunks16 * 2 {
294            let base = chunks16 * 16;
295            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
296            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
297        }
298
299        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
300        for i in (chunks8 * 8)..n {
301            let d = *ap.add(i) - *bp.add(i);
302            sum += d * d;
303        }
304        sum.sqrt()
305    }
306
307    /// 1 - cos(a, b) — AVX2+FMA, single pass for dot + norms².
308    #[target_feature(enable = "avx2,fma")]
309    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
310        let n = a.len().min(b.len());
311        let ap = a.as_ptr();
312        let bp = b.as_ptr();
313
314        let mut dot_acc = _mm256_setzero_ps();
315        let mut na_acc = _mm256_setzero_ps();
316        let mut nb_acc = _mm256_setzero_ps();
317
318        let chunks8 = n / 8;
319        for i in 0..chunks8 {
320            let base = i * 8;
321            let av = _mm256_loadu_ps(ap.add(base));
322            let bv = _mm256_loadu_ps(bp.add(base));
323            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
324            na_acc = _mm256_fmadd_ps(av, av, na_acc);
325            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
326        }
327
328        let mut dot = hsum256(dot_acc);
329        let mut na2 = hsum256(na_acc);
330        let mut nb2 = hsum256(nb_acc);
331
332        for i in (chunks8 * 8)..n {
333            let ai = *ap.add(i);
334            let bi = *bp.add(i);
335            dot += ai * bi;
336            na2 += ai * ai;
337            nb2 += bi * bi;
338        }
339
340        let na = na2.sqrt();
341        let nb = nb2.sqrt();
342        if na == 0.0 || nb == 0.0 {
343            return 1.0;
344        }
345        1.0 - dot / (na * nb)
346    }
347}
348
349// ── x86_64 AVX2 + F16C — F16 hot path ────────────────────────────────────────
350//
351// _mm256_cvtph_ps converts 8 packed F16 (as __m128i) to 8 F32 in one cycle.
352// Combined with FMA, this replaces 8 scalar half::to_f32() calls per iteration.
353// Critical hot path: every HNSW edge traversal calls one of these functions.
354
355#[cfg(target_arch = "x86_64")]
356mod avx2_f16c {
357    use half::f16;
358    use std::arch::x86_64::*;
359
360    use super::avx2::hsum256;
361
362    /// dot(a_f32, b_f16) — AVX2+F16C+FMA, 16 F16/iter.
363    #[target_feature(enable = "avx2,f16c,fma")]
364    pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
365        let n = a.len().min(b.len());
366        let ap = a.as_ptr();
367        let bp = b.as_ptr() as *const u16;
368
369        let mut acc0 = _mm256_setzero_ps();
370        let mut acc1 = _mm256_setzero_ps();
371
372        let chunks16 = n / 16;
373        for i in 0..chunks16 {
374            let base = i * 16;
375            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
376            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
377            let a0 = _mm256_loadu_ps(ap.add(base));
378            let a1 = _mm256_loadu_ps(ap.add(base + 8));
379            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
380            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
381        }
382
383        let chunks8 = n / 8;
384        if chunks8 > chunks16 * 2 {
385            let base = chunks16 * 16;
386            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
387            let a0 = _mm256_loadu_ps(ap.add(base));
388            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
389        }
390
391        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
392        for i in (chunks8 * 8)..n {
393            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
394        }
395        sum
396    }
397
398    /// ||a_f32 - b_f16||² — AVX2+F16C+FMA, 16 F16/iter.
399    #[target_feature(enable = "avx2,f16c,fma")]
400    pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
401        let n = a.len().min(b.len());
402        let ap = a.as_ptr();
403        let bp = b.as_ptr() as *const u16;
404
405        let mut acc0 = _mm256_setzero_ps();
406        let mut acc1 = _mm256_setzero_ps();
407
408        let chunks16 = n / 16;
409        for i in 0..chunks16 {
410            let base = i * 16;
411            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
412            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
413            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
414            let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
415            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
416            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
417        }
418
419        let chunks8 = n / 8;
420        if chunks8 > chunks16 * 2 {
421            let base = chunks16 * 16;
422            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
423            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
424            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
425        }
426
427        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
428        for i in (chunks8 * 8)..n {
429            let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
430            sum += diff * diff;
431        }
432        sum.sqrt()
433    }
434
435    /// 1 - cos(a_f32, b_f16) — AVX2+F16C+FMA, single pass.
436    #[target_feature(enable = "avx2,f16c,fma")]
437    pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
438        let n = a.len().min(b.len());
439        let ap = a.as_ptr();
440        let bp = b.as_ptr() as *const u16;
441
442        let mut dot_acc = _mm256_setzero_ps();
443        let mut na_acc = _mm256_setzero_ps();
444        let mut nb_acc = _mm256_setzero_ps();
445
446        let chunks8 = n / 8;
447        for i in 0..chunks8 {
448            let base = i * 8;
449            let av = _mm256_loadu_ps(ap.add(base));
450            let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
451            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
452            na_acc = _mm256_fmadd_ps(av, av, na_acc);
453            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
454        }
455
456        let mut dot = hsum256(dot_acc);
457        let mut na2 = hsum256(na_acc);
458        let mut nb2 = hsum256(nb_acc);
459
460        for i in (chunks8 * 8)..n {
461            let ai = *ap.add(i);
462            let bi = f16::from_bits(*bp.add(i)).to_f32();
463            dot += ai * bi;
464            na2 += ai * ai;
465            nb2 += bi * bi;
466        }
467
468        let denom = (na2 * nb2).sqrt();
469        if denom < 1e-8 {
470            1.0
471        } else {
472            1.0 - dot / denom
473        }
474    }
475}
476
477// ── x86_64 AVX-512F — forward compatibility ───────────────────────────────────
478//
479// 16 f32/iter (vs 8 for AVX2). Runtime-detected — skipped on this machine
480// (no avx512f), active on Xeon Scalable, Zen 4+, and Intel Core 12th gen+.
481// Requires Rust ≥ 1.89 (AVX-512 intrinsics stabilised there). Gated behind
482// the `avx512` feature so the default/manylinux build always succeeds.
483
484#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
485mod avx512 {
486    use half::f16;
487    use std::arch::x86_64::*;
488
489    #[inline(always)]
490    unsafe fn hsum512(v: __m512) -> f32 {
491        // _mm512_reduce_add_ps stabilized Rust 1.89; _mm512_extractf32x8_ps needs avx512dq.
492        // Store all 16 lanes to stack (avx512f), reload as two __m256 (avx), then reduce.
493        let mut buf = [0.0f32; 16];
494        _mm512_storeu_ps(buf.as_mut_ptr(), v);
495        let lo = _mm256_loadu_ps(buf.as_ptr());
496        let hi = _mm256_loadu_ps(buf.as_ptr().add(8));
497        let sum256 = _mm256_add_ps(lo, hi);
498        let hi128 = _mm256_extractf128_ps(sum256, 1);
499        let lo128 = _mm256_castps256_ps128(sum256);
500        let sum128 = _mm_add_ps(lo128, hi128);
501        let shuf = _mm_movehdup_ps(sum128);
502        let sums = _mm_add_ps(sum128, shuf);
503        let shuf2 = _mm_movehl_ps(shuf, sums);
504        _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
505    }
506
507    #[target_feature(enable = "avx512f,fma")]
508    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
509        let n = a.len().min(b.len());
510        let ap = a.as_ptr();
511        let bp = b.as_ptr();
512        let mut acc = _mm512_setzero_ps();
513        let chunks16 = n / 16;
514        for i in 0..chunks16 {
515            let base = i * 16;
516            acc = _mm512_fmadd_ps(
517                _mm512_loadu_ps(ap.add(base)),
518                _mm512_loadu_ps(bp.add(base)),
519                acc,
520            );
521        }
522        let mut sum = hsum512(acc);
523        for i in (chunks16 * 16)..n {
524            sum += *ap.add(i) * *bp.add(i);
525        }
526        sum
527    }
528
529    #[target_feature(enable = "avx512f,fma")]
530    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
531        let n = a.len().min(b.len());
532        let ap = a.as_ptr();
533        let bp = b.as_ptr();
534        let mut acc = _mm512_setzero_ps();
535        let chunks16 = n / 16;
536        for i in 0..chunks16 {
537            let base = i * 16;
538            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
539            acc = _mm512_fmadd_ps(d, d, acc);
540        }
541        let mut sum = hsum512(acc);
542        for i in (chunks16 * 16)..n {
543            let d = *ap.add(i) - *bp.add(i);
544            sum += d * d;
545        }
546        sum.sqrt()
547    }
548
549    #[target_feature(enable = "avx512f,fma")]
550    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
551        let n = a.len().min(b.len());
552        let ap = a.as_ptr();
553        let bp = b.as_ptr();
554        let mut dot_acc = _mm512_setzero_ps();
555        let mut na_acc = _mm512_setzero_ps();
556        let mut nb_acc = _mm512_setzero_ps();
557        let chunks16 = n / 16;
558        for i in 0..chunks16 {
559            let base = i * 16;
560            let av = _mm512_loadu_ps(ap.add(base));
561            let bv = _mm512_loadu_ps(bp.add(base));
562            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
563            na_acc = _mm512_fmadd_ps(av, av, na_acc);
564            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
565        }
566        let mut dot = hsum512(dot_acc);
567        let mut na2 = hsum512(na_acc);
568        let mut nb2 = hsum512(nb_acc);
569        for i in (chunks16 * 16)..n {
570            let ai = *ap.add(i);
571            let bi = *bp.add(i);
572            dot += ai * bi;
573            na2 += ai * ai;
574            nb2 += bi * bi;
575        }
576        let (na, nb) = (na2.sqrt(), nb2.sqrt());
577        if na == 0.0 || nb == 0.0 {
578            return 1.0;
579        }
580        1.0 - dot / (na * nb)
581    }
582
583    /// dot(a_f32, b_f16) — AVX-512F+F16C+FMA, 16 F16/iter.
584    #[target_feature(enable = "avx512f,f16c,fma")]
585    pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
586        let n = a.len().min(b.len());
587        let ap = a.as_ptr();
588        let bp = b.as_ptr() as *const u16;
589        let mut acc = _mm512_setzero_ps();
590        let chunks16 = n / 16;
591        for i in 0..chunks16 {
592            let base = i * 16;
593            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
594            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
595            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
596            acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
597        }
598        let mut sum = hsum512(acc);
599        for i in (chunks16 * 16)..n {
600            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
601        }
602        sum
603    }
604
605    #[target_feature(enable = "avx512f,f16c,fma")]
606    pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
607        let n = a.len().min(b.len());
608        let ap = a.as_ptr();
609        let bp = b.as_ptr() as *const u16;
610        let mut acc = _mm512_setzero_ps();
611        let chunks16 = n / 16;
612        for i in 0..chunks16 {
613            let base = i * 16;
614            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
615            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
616            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
617            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
618            acc = _mm512_fmadd_ps(d, d, acc);
619        }
620        let mut sum = hsum512(acc);
621        for i in (chunks16 * 16)..n {
622            let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
623            sum += d * d;
624        }
625        sum.sqrt()
626    }
627
628    #[target_feature(enable = "avx512f,f16c,fma")]
629    pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
630        let n = a.len().min(b.len());
631        let ap = a.as_ptr();
632        let bp = b.as_ptr() as *const u16;
633        let mut dot_acc = _mm512_setzero_ps();
634        let mut na_acc = _mm512_setzero_ps();
635        let mut nb_acc = _mm512_setzero_ps();
636        let chunks16 = n / 16;
637        for i in 0..chunks16 {
638            let base = i * 16;
639            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
640            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
641            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
642            let av = _mm512_loadu_ps(ap.add(base));
643            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
644            na_acc = _mm512_fmadd_ps(av, av, na_acc);
645            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
646        }
647        let mut dot = hsum512(dot_acc);
648        let mut na2 = hsum512(na_acc);
649        let mut nb2 = hsum512(nb_acc);
650        for i in (chunks16 * 16)..n {
651            let ai = *ap.add(i);
652            let bi = f16::from_bits(*bp.add(i)).to_f32();
653            dot += ai * bi;
654            na2 += ai * ai;
655            nb2 += bi * bi;
656        }
657        let denom = (na2 * nb2).sqrt();
658        if denom < 1e-8 {
659            1.0
660        } else {
661            1.0 - dot / denom
662        }
663    }
664}
665
666// ── aarch64 NEON ──────────────────────────────────────────────────────────────
667
668#[cfg(target_arch = "aarch64")]
669mod neon_impl {
670    use std::arch::aarch64::*;
671
672    #[target_feature(enable = "neon")]
673    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
674        let n = a.len().min(b.len());
675        let mut acc = vdupq_n_f32(0.0);
676        let chunks = n / 4;
677        for i in 0..chunks {
678            let base = i * 4;
679            let av = vld1q_f32(a.as_ptr().add(base));
680            let bv = vld1q_f32(b.as_ptr().add(base));
681            acc = vmlaq_f32(acc, av, bv);
682        }
683        let mut sum = vaddvq_f32(acc);
684        for i in (chunks * 4)..n {
685            sum += a[i] * b[i];
686        }
687        sum
688    }
689
690    #[target_feature(enable = "neon")]
691    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
692        let n = a.len().min(b.len());
693        let mut acc = vdupq_n_f32(0.0);
694        let chunks = n / 4;
695        for i in 0..chunks {
696            let base = i * 4;
697            let d = vsubq_f32(
698                vld1q_f32(a.as_ptr().add(base)),
699                vld1q_f32(b.as_ptr().add(base)),
700            );
701            acc = vmlaq_f32(acc, d, d);
702        }
703        let mut sum = vaddvq_f32(acc);
704        for i in (chunks * 4)..n {
705            let d = a[i] - b[i];
706            sum += d * d;
707        }
708        sum.sqrt()
709    }
710
711    #[target_feature(enable = "neon")]
712    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
713        let n = a.len().min(b.len());
714        let mut dot_acc = vdupq_n_f32(0.0);
715        let mut na_acc = vdupq_n_f32(0.0);
716        let mut nb_acc = vdupq_n_f32(0.0);
717        let chunks = n / 4;
718        for i in 0..chunks {
719            let base = i * 4;
720            let av = vld1q_f32(a.as_ptr().add(base));
721            let bv = vld1q_f32(b.as_ptr().add(base));
722            dot_acc = vmlaq_f32(dot_acc, av, bv);
723            na_acc = vmlaq_f32(na_acc, av, av);
724            nb_acc = vmlaq_f32(nb_acc, bv, bv);
725        }
726        let mut dot = vaddvq_f32(dot_acc);
727        let mut na2 = vaddvq_f32(na_acc);
728        let mut nb2 = vaddvq_f32(nb_acc);
729        for i in (chunks * 4)..n {
730            dot += a[i] * b[i];
731            na2 += a[i] * a[i];
732            nb2 += b[i] * b[i];
733        }
734        let (na, nb) = (na2.sqrt(), nb2.sqrt());
735        if na == 0.0 || nb == 0.0 {
736            return 1.0;
737        }
738        1.0 - dot / (na * nb)
739    }
740}
741
742// ── Tests ─────────────────────────────────────────────────────────────────────
743
744#[cfg(test)]
745mod tests {
746    use super::*;
747
748    #[test]
749    fn cosine_identical() {
750        let v = vec![1.0f32, 0.0, 0.0];
751        assert!(cosine_distance(&v, &v).abs() < 1e-5);
752    }
753
754    #[test]
755    fn cosine_orthogonal() {
756        assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
757    }
758
759    #[test]
760    fn euclidean_basic() {
761        assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
762    }
763
764    #[test]
765    fn dot_basic() {
766        assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
767    }
768
769    #[test]
770    fn simd_matches_scalar_dim128() {
771        use rand::{rngs::StdRng, Rng, SeedableRng};
772        let mut rng = StdRng::seed_from_u64(99);
773        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
774        let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
775
776        let dot_s = dot_scalar(&a, &b);
777        let euclid_s = euclidean_scalar(&a, &b);
778        let cos_s = cosine_scalar(&a, &b);
779
780        let dot_f = dot_product(&a, &b);
781        let euclid_f = euclidean_distance(&a, &b);
782        let cos_f = cosine_distance(&a, &b);
783
784        assert!(
785            (dot_f - dot_s).abs() < 1e-4,
786            "dot mismatch: {dot_f} vs {dot_s}"
787        );
788        assert!(
789            (euclid_f - euclid_s).abs() < 1e-4,
790            "euclidean mismatch: {euclid_f} vs {euclid_s}"
791        );
792        assert!(
793            (cos_f - cos_s).abs() < 1e-4,
794            "cosine mismatch: {cos_f} vs {cos_s}"
795        );
796    }
797
798    #[test]
799    fn f16_simd_matches_scalar() {
800        use rand::{rngs::StdRng, Rng, SeedableRng};
801        let mut rng = StdRng::seed_from_u64(42);
802        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
803        let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
804        let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
805
806        let dot_s = dot_f16_scalar(&a, &b);
807        let euclid_s = euclidean_f16_scalar(&a, &b);
808        let cos_s = cosine_f16_scalar(&a, &b);
809
810        let dot_f = dot_product_f16(&a, &b);
811        let euclid_f = euclidean_distance_f16(&a, &b);
812        let cos_f = cosine_distance_f16(&a, &b);
813
814        // F16 rounding introduces small error — tolerate 1e-3
815        assert!(
816            (dot_f - dot_s).abs() < 1e-3,
817            "f16 dot mismatch: {dot_f} vs {dot_s}"
818        );
819        assert!(
820            (euclid_f - euclid_s).abs() < 1e-3,
821            "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
822        );
823        assert!(
824            (cos_f - cos_s).abs() < 1e-3,
825            "f16 cosine mismatch: {cos_f} vs {cos_s}"
826        );
827    }
828
829    #[test]
830    fn centroid_single() {
831        let v = vec![vec![1.0f32, 2.0, 3.0]];
832        let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
833        assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
834        assert!(c.radius < 1e-6, "radius={}", c.radius);
835    }
836
837    #[test]
838    fn centroid_two_points() {
839        let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
840        let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
841        assert!((c.values[0] - 1.0).abs() < 1e-6);
842        assert!(c.radius > 0.0);
843    }
844}