Skip to main content

ailake_vec/
distance.rs

1use ailake_core::{Centroid, VectorMetric};
2use half::f16;
3
4// ── Public API ────────────────────────────────────────────────────────────────
5
6pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
7    #[cfg(target_arch = "x86_64")]
8    {
9        if is_x86_feature_detected!("avx512f") {
10            return unsafe { avx512::dot(a, b) };
11        }
12        if is_x86_feature_detected!("avx2") {
13            return unsafe { avx2::dot(a, b) };
14        }
15    }
16    #[cfg(target_arch = "aarch64")]
17    if std::arch::is_aarch64_feature_detected!("neon") {
18        return unsafe { neon_impl::dot(a, b) };
19    }
20    dot_scalar(a, b)
21}
22
23pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
24    #[cfg(target_arch = "x86_64")]
25    {
26        if is_x86_feature_detected!("avx512f") {
27            return unsafe { avx512::euclidean(a, b) };
28        }
29        if is_x86_feature_detected!("avx2") {
30            return unsafe { avx2::euclidean(a, b) };
31        }
32    }
33    #[cfg(target_arch = "aarch64")]
34    if std::arch::is_aarch64_feature_detected!("neon") {
35        return unsafe { neon_impl::euclidean(a, b) };
36    }
37    euclidean_scalar(a, b)
38}
39
40pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
41    #[cfg(target_arch = "x86_64")]
42    {
43        if is_x86_feature_detected!("avx512f") {
44            return unsafe { avx512::cosine(a, b) };
45        }
46        if is_x86_feature_detected!("avx2") {
47            return unsafe { avx2::cosine(a, b) };
48        }
49    }
50    #[cfg(target_arch = "aarch64")]
51    if std::arch::is_aarch64_feature_detected!("neon") {
52        return unsafe { neon_impl::cosine(a, b) };
53    }
54    cosine_scalar(a, b)
55}
56
57pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
58    match metric {
59        VectorMetric::Cosine => cosine_distance(a, b),
60        VectorMetric::Euclidean => euclidean_distance(a, b),
61        VectorMetric::DotProduct => -dot_product(a, b),
62    }
63}
64
65// ── F16 distance functions ────────────────────────────────────────────────────
66//
67// Query `a` stays F32 (one vector, lives in registers).
68// Database vector `b` is F16 (dequantized inline — no allocation).
69//
70// Fast path: F16C converts 8 F16 values to F32 in one instruction via
71// _mm256_cvtph_ps, then FMA accumulates. Eliminates scalar half::to_f32()
72// loop that dominates HNSW graph traversal on dim=128 vectors.
73
74pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
75    #[cfg(target_arch = "x86_64")]
76    {
77        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
78            return unsafe { avx512::cosine_f16(a, b) };
79        }
80        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
81            return unsafe { avx2_f16c::cosine(a, b) };
82        }
83    }
84    cosine_f16_scalar(a, b)
85}
86
87pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
88    #[cfg(target_arch = "x86_64")]
89    {
90        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
91            return unsafe { avx512::euclidean_f16(a, b) };
92        }
93        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
94            return unsafe { avx2_f16c::euclidean(a, b) };
95        }
96    }
97    euclidean_f16_scalar(a, b)
98}
99
100pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
101    #[cfg(target_arch = "x86_64")]
102    {
103        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
104            return unsafe { avx512::dot_f16(a, b) };
105        }
106        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
107            return unsafe { avx2_f16c::dot(a, b) };
108        }
109    }
110    dot_f16_scalar(a, b)
111}
112
113pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
114    if vectors.is_empty() {
115        return Centroid {
116            values: vec![],
117            radius: 0.0,
118            metric,
119        };
120    }
121    let dim = vectors[0].len();
122    let n = vectors.len() as f32;
123    let centroid: Vec<f32> = (0..dim)
124        .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
125        .collect();
126    let radius = vectors
127        .iter()
128        .map(|v| exact_distance(metric, &centroid, v))
129        .fold(0.0_f32, f32::max);
130    Centroid {
131        values: centroid,
132        radius,
133        metric,
134    }
135}
136
137// ── Scalar fallbacks ──────────────────────────────────────────────────────────
138
139#[inline(always)]
140fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
141    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
142}
143
144#[inline(always)]
145fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
146    a.iter()
147        .zip(b.iter())
148        .map(|(x, y)| (x - y) * (x - y))
149        .sum::<f32>()
150        .sqrt()
151}
152
153#[inline(always)]
154fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
155    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
156    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
157    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
158    if na == 0.0 || nb == 0.0 {
159        return 1.0;
160    }
161    1.0 - dot / (na * nb)
162}
163
164#[inline(always)]
165fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
166    let n = a.len().min(b.len());
167    let mut dot = 0.0f32;
168    let mut norm_a = 0.0f32;
169    let mut norm_b = 0.0f32;
170    for i in 0..n {
171        let ai = a[i];
172        let bi = b[i].to_f32();
173        dot += ai * bi;
174        norm_a += ai * ai;
175        norm_b += bi * bi;
176    }
177    let denom = (norm_a * norm_b).sqrt();
178    if denom < 1e-8 {
179        1.0
180    } else {
181        1.0 - dot / denom
182    }
183}
184
185#[inline(always)]
186fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
187    let n = a.len().min(b.len());
188    let mut sum = 0.0f32;
189    for i in 0..n {
190        let diff = a[i] - b[i].to_f32();
191        sum += diff * diff;
192    }
193    sum.sqrt()
194}
195
196#[inline(always)]
197fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
198    let n = a.len().min(b.len());
199    let mut acc = 0.0f32;
200    for i in 0..n {
201        acc += a[i] * b[i].to_f32();
202    }
203    acc
204}
205
206// ── x86_64 AVX2 + FMA ────────────────────────────────────────────────────────
207//
208// Compiled with target_feature = "avx2,fma". The compiler emits vfmadd231ps
209// instead of separate vmulps + vaddps, cutting inner-loop instruction count
210// by ~33% and reducing latency via fused operations.
211
212#[cfg(target_arch = "x86_64")]
213mod avx2 {
214    use std::arch::x86_64::*;
215
216    #[inline(always)]
217    pub unsafe fn hsum256(v: __m256) -> f32 {
218        let hi = _mm256_extractf128_ps(v, 1);
219        let lo = _mm256_castps256_ps128(v);
220        let s = _mm_add_ps(lo, hi);
221        let shuf = _mm_movehdup_ps(s);
222        let sums = _mm_add_ps(s, shuf);
223        let shuf = _mm_movehl_ps(shuf, sums);
224        _mm_cvtss_f32(_mm_add_ss(sums, shuf))
225    }
226
227    /// dot(a, b) — AVX2+FMA, 2× unrolled (16 f32/iter).
228    #[target_feature(enable = "avx2,fma")]
229    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
230        let n = a.len().min(b.len());
231        let ap = a.as_ptr();
232        let bp = b.as_ptr();
233
234        let mut acc0 = _mm256_setzero_ps();
235        let mut acc1 = _mm256_setzero_ps();
236
237        let chunks16 = n / 16;
238        for i in 0..chunks16 {
239            let base = i * 16;
240            let a0 = _mm256_loadu_ps(ap.add(base));
241            let b0 = _mm256_loadu_ps(bp.add(base));
242            let a1 = _mm256_loadu_ps(ap.add(base + 8));
243            let b1 = _mm256_loadu_ps(bp.add(base + 8));
244            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
245            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
246        }
247
248        let chunks8 = n / 8;
249        if chunks8 > chunks16 * 2 {
250            let base = chunks16 * 16;
251            let a0 = _mm256_loadu_ps(ap.add(base));
252            let b0 = _mm256_loadu_ps(bp.add(base));
253            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
254        }
255
256        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
257        for i in (chunks8 * 8)..n {
258            sum += *ap.add(i) * *bp.add(i);
259        }
260        sum
261    }
262
263    /// ||a - b||² — AVX2+FMA, 2× unrolled.
264    #[target_feature(enable = "avx2,fma")]
265    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
266        let n = a.len().min(b.len());
267        let ap = a.as_ptr();
268        let bp = b.as_ptr();
269
270        let mut acc0 = _mm256_setzero_ps();
271        let mut acc1 = _mm256_setzero_ps();
272
273        let chunks16 = n / 16;
274        for i in 0..chunks16 {
275            let base = i * 16;
276            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
277            let d1 = _mm256_sub_ps(
278                _mm256_loadu_ps(ap.add(base + 8)),
279                _mm256_loadu_ps(bp.add(base + 8)),
280            );
281            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
282            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
283        }
284
285        let chunks8 = n / 8;
286        if chunks8 > chunks16 * 2 {
287            let base = chunks16 * 16;
288            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
289            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
290        }
291
292        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
293        for i in (chunks8 * 8)..n {
294            let d = *ap.add(i) - *bp.add(i);
295            sum += d * d;
296        }
297        sum.sqrt()
298    }
299
300    /// 1 - cos(a, b) — AVX2+FMA, single pass for dot + norms².
301    #[target_feature(enable = "avx2,fma")]
302    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
303        let n = a.len().min(b.len());
304        let ap = a.as_ptr();
305        let bp = b.as_ptr();
306
307        let mut dot_acc = _mm256_setzero_ps();
308        let mut na_acc = _mm256_setzero_ps();
309        let mut nb_acc = _mm256_setzero_ps();
310
311        let chunks8 = n / 8;
312        for i in 0..chunks8 {
313            let base = i * 8;
314            let av = _mm256_loadu_ps(ap.add(base));
315            let bv = _mm256_loadu_ps(bp.add(base));
316            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
317            na_acc = _mm256_fmadd_ps(av, av, na_acc);
318            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
319        }
320
321        let mut dot = hsum256(dot_acc);
322        let mut na2 = hsum256(na_acc);
323        let mut nb2 = hsum256(nb_acc);
324
325        for i in (chunks8 * 8)..n {
326            let ai = *ap.add(i);
327            let bi = *bp.add(i);
328            dot += ai * bi;
329            na2 += ai * ai;
330            nb2 += bi * bi;
331        }
332
333        let na = na2.sqrt();
334        let nb = nb2.sqrt();
335        if na == 0.0 || nb == 0.0 {
336            return 1.0;
337        }
338        1.0 - dot / (na * nb)
339    }
340}
341
342// ── x86_64 AVX2 + F16C — F16 hot path ────────────────────────────────────────
343//
344// _mm256_cvtph_ps converts 8 packed F16 (as __m128i) to 8 F32 in one cycle.
345// Combined with FMA, this replaces 8 scalar half::to_f32() calls per iteration.
346// Critical hot path: every HNSW edge traversal calls one of these functions.
347
348#[cfg(target_arch = "x86_64")]
349mod avx2_f16c {
350    use half::f16;
351    use std::arch::x86_64::*;
352
353    use super::avx2::hsum256;
354
355    /// dot(a_f32, b_f16) — AVX2+F16C+FMA, 16 F16/iter.
356    #[target_feature(enable = "avx2,f16c,fma")]
357    pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
358        let n = a.len().min(b.len());
359        let ap = a.as_ptr();
360        let bp = b.as_ptr() as *const u16;
361
362        let mut acc0 = _mm256_setzero_ps();
363        let mut acc1 = _mm256_setzero_ps();
364
365        let chunks16 = n / 16;
366        for i in 0..chunks16 {
367            let base = i * 16;
368            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
369            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
370            let a0 = _mm256_loadu_ps(ap.add(base));
371            let a1 = _mm256_loadu_ps(ap.add(base + 8));
372            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
373            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
374        }
375
376        let chunks8 = n / 8;
377        if chunks8 > chunks16 * 2 {
378            let base = chunks16 * 16;
379            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
380            let a0 = _mm256_loadu_ps(ap.add(base));
381            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
382        }
383
384        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
385        for i in (chunks8 * 8)..n {
386            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
387        }
388        sum
389    }
390
391    /// ||a_f32 - b_f16||² — AVX2+F16C+FMA, 16 F16/iter.
392    #[target_feature(enable = "avx2,f16c,fma")]
393    pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
394        let n = a.len().min(b.len());
395        let ap = a.as_ptr();
396        let bp = b.as_ptr() as *const u16;
397
398        let mut acc0 = _mm256_setzero_ps();
399        let mut acc1 = _mm256_setzero_ps();
400
401        let chunks16 = n / 16;
402        for i in 0..chunks16 {
403            let base = i * 16;
404            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
405            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
406            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
407            let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
408            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
409            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
410        }
411
412        let chunks8 = n / 8;
413        if chunks8 > chunks16 * 2 {
414            let base = chunks16 * 16;
415            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
416            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
417            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
418        }
419
420        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
421        for i in (chunks8 * 8)..n {
422            let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
423            sum += diff * diff;
424        }
425        sum.sqrt()
426    }
427
428    /// 1 - cos(a_f32, b_f16) — AVX2+F16C+FMA, single pass.
429    #[target_feature(enable = "avx2,f16c,fma")]
430    pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
431        let n = a.len().min(b.len());
432        let ap = a.as_ptr();
433        let bp = b.as_ptr() as *const u16;
434
435        let mut dot_acc = _mm256_setzero_ps();
436        let mut na_acc = _mm256_setzero_ps();
437        let mut nb_acc = _mm256_setzero_ps();
438
439        let chunks8 = n / 8;
440        for i in 0..chunks8 {
441            let base = i * 8;
442            let av = _mm256_loadu_ps(ap.add(base));
443            let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
444            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
445            na_acc = _mm256_fmadd_ps(av, av, na_acc);
446            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
447        }
448
449        let mut dot = hsum256(dot_acc);
450        let mut na2 = hsum256(na_acc);
451        let mut nb2 = hsum256(nb_acc);
452
453        for i in (chunks8 * 8)..n {
454            let ai = *ap.add(i);
455            let bi = f16::from_bits(*bp.add(i)).to_f32();
456            dot += ai * bi;
457            na2 += ai * ai;
458            nb2 += bi * bi;
459        }
460
461        let denom = (na2 * nb2).sqrt();
462        if denom < 1e-8 {
463            1.0
464        } else {
465            1.0 - dot / denom
466        }
467    }
468}
469
470// ── x86_64 AVX-512F — forward compatibility ───────────────────────────────────
471//
472// 16 f32/iter (vs 8 for AVX2). Runtime-detected — skipped on this machine
473// (no avx512f), active on Xeon Scalable, Zen 4+, and Intel Core 12th gen+.
474
475#[cfg(target_arch = "x86_64")]
476mod avx512 {
477    use half::f16;
478    use std::arch::x86_64::*;
479
480    #[inline(always)]
481    unsafe fn hsum512(v: __m512) -> f32 {
482        _mm512_reduce_add_ps(v)
483    }
484
485    #[target_feature(enable = "avx512f,fma")]
486    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
487        let n = a.len().min(b.len());
488        let ap = a.as_ptr();
489        let bp = b.as_ptr();
490        let mut acc = _mm512_setzero_ps();
491        let chunks16 = n / 16;
492        for i in 0..chunks16 {
493            let base = i * 16;
494            acc = _mm512_fmadd_ps(
495                _mm512_loadu_ps(ap.add(base)),
496                _mm512_loadu_ps(bp.add(base)),
497                acc,
498            );
499        }
500        let mut sum = hsum512(acc);
501        for i in (chunks16 * 16)..n {
502            sum += *ap.add(i) * *bp.add(i);
503        }
504        sum
505    }
506
507    #[target_feature(enable = "avx512f,fma")]
508    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
509        let n = a.len().min(b.len());
510        let ap = a.as_ptr();
511        let bp = b.as_ptr();
512        let mut acc = _mm512_setzero_ps();
513        let chunks16 = n / 16;
514        for i in 0..chunks16 {
515            let base = i * 16;
516            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
517            acc = _mm512_fmadd_ps(d, d, acc);
518        }
519        let mut sum = hsum512(acc);
520        for i in (chunks16 * 16)..n {
521            let d = *ap.add(i) - *bp.add(i);
522            sum += d * d;
523        }
524        sum.sqrt()
525    }
526
527    #[target_feature(enable = "avx512f,fma")]
528    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
529        let n = a.len().min(b.len());
530        let ap = a.as_ptr();
531        let bp = b.as_ptr();
532        let mut dot_acc = _mm512_setzero_ps();
533        let mut na_acc = _mm512_setzero_ps();
534        let mut nb_acc = _mm512_setzero_ps();
535        let chunks16 = n / 16;
536        for i in 0..chunks16 {
537            let base = i * 16;
538            let av = _mm512_loadu_ps(ap.add(base));
539            let bv = _mm512_loadu_ps(bp.add(base));
540            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
541            na_acc = _mm512_fmadd_ps(av, av, na_acc);
542            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
543        }
544        let mut dot = hsum512(dot_acc);
545        let mut na2 = hsum512(na_acc);
546        let mut nb2 = hsum512(nb_acc);
547        for i in (chunks16 * 16)..n {
548            let ai = *ap.add(i);
549            let bi = *bp.add(i);
550            dot += ai * bi;
551            na2 += ai * ai;
552            nb2 += bi * bi;
553        }
554        let (na, nb) = (na2.sqrt(), nb2.sqrt());
555        if na == 0.0 || nb == 0.0 {
556            return 1.0;
557        }
558        1.0 - dot / (na * nb)
559    }
560
561    /// dot(a_f32, b_f16) — AVX-512F+F16C+FMA, 16 F16/iter.
562    #[target_feature(enable = "avx512f,f16c,fma")]
563    pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
564        let n = a.len().min(b.len());
565        let ap = a.as_ptr();
566        let bp = b.as_ptr() as *const u16;
567        let mut acc = _mm512_setzero_ps();
568        let chunks16 = n / 16;
569        for i in 0..chunks16 {
570            let base = i * 16;
571            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
572            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
573            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
574            acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
575        }
576        let mut sum = hsum512(acc);
577        for i in (chunks16 * 16)..n {
578            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
579        }
580        sum
581    }
582
583    #[target_feature(enable = "avx512f,f16c,fma")]
584    pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
585        let n = a.len().min(b.len());
586        let ap = a.as_ptr();
587        let bp = b.as_ptr() as *const u16;
588        let mut acc = _mm512_setzero_ps();
589        let chunks16 = n / 16;
590        for i in 0..chunks16 {
591            let base = i * 16;
592            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
593            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
594            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
595            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
596            acc = _mm512_fmadd_ps(d, d, acc);
597        }
598        let mut sum = hsum512(acc);
599        for i in (chunks16 * 16)..n {
600            let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
601            sum += d * d;
602        }
603        sum.sqrt()
604    }
605
606    #[target_feature(enable = "avx512f,f16c,fma")]
607    pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
608        let n = a.len().min(b.len());
609        let ap = a.as_ptr();
610        let bp = b.as_ptr() as *const u16;
611        let mut dot_acc = _mm512_setzero_ps();
612        let mut na_acc = _mm512_setzero_ps();
613        let mut nb_acc = _mm512_setzero_ps();
614        let chunks16 = n / 16;
615        for i in 0..chunks16 {
616            let base = i * 16;
617            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
618            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
619            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
620            let av = _mm512_loadu_ps(ap.add(base));
621            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
622            na_acc = _mm512_fmadd_ps(av, av, na_acc);
623            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
624        }
625        let mut dot = hsum512(dot_acc);
626        let mut na2 = hsum512(na_acc);
627        let mut nb2 = hsum512(nb_acc);
628        for i in (chunks16 * 16)..n {
629            let ai = *ap.add(i);
630            let bi = f16::from_bits(*bp.add(i)).to_f32();
631            dot += ai * bi;
632            na2 += ai * ai;
633            nb2 += bi * bi;
634        }
635        let denom = (na2 * nb2).sqrt();
636        if denom < 1e-8 {
637            1.0
638        } else {
639            1.0 - dot / denom
640        }
641    }
642}
643
644// ── aarch64 NEON ──────────────────────────────────────────────────────────────
645
646#[cfg(target_arch = "aarch64")]
647mod neon_impl {
648    use std::arch::aarch64::*;
649
650    #[target_feature(enable = "neon")]
651    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
652        let n = a.len().min(b.len());
653        let mut acc = vdupq_n_f32(0.0);
654        let chunks = n / 4;
655        for i in 0..chunks {
656            let base = i * 4;
657            let av = vld1q_f32(a.as_ptr().add(base));
658            let bv = vld1q_f32(b.as_ptr().add(base));
659            acc = vmlaq_f32(acc, av, bv);
660        }
661        let mut sum = vaddvq_f32(acc);
662        for i in (chunks * 4)..n {
663            sum += a[i] * b[i];
664        }
665        sum
666    }
667
668    #[target_feature(enable = "neon")]
669    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
670        let n = a.len().min(b.len());
671        let mut acc = vdupq_n_f32(0.0);
672        let chunks = n / 4;
673        for i in 0..chunks {
674            let base = i * 4;
675            let d = vsubq_f32(
676                vld1q_f32(a.as_ptr().add(base)),
677                vld1q_f32(b.as_ptr().add(base)),
678            );
679            acc = vmlaq_f32(acc, d, d);
680        }
681        let mut sum = vaddvq_f32(acc);
682        for i in (chunks * 4)..n {
683            let d = a[i] - b[i];
684            sum += d * d;
685        }
686        sum.sqrt()
687    }
688
689    #[target_feature(enable = "neon")]
690    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
691        let n = a.len().min(b.len());
692        let mut dot_acc = vdupq_n_f32(0.0);
693        let mut na_acc = vdupq_n_f32(0.0);
694        let mut nb_acc = vdupq_n_f32(0.0);
695        let chunks = n / 4;
696        for i in 0..chunks {
697            let base = i * 4;
698            let av = vld1q_f32(a.as_ptr().add(base));
699            let bv = vld1q_f32(b.as_ptr().add(base));
700            dot_acc = vmlaq_f32(dot_acc, av, bv);
701            na_acc = vmlaq_f32(na_acc, av, av);
702            nb_acc = vmlaq_f32(nb_acc, bv, bv);
703        }
704        let mut dot = vaddvq_f32(dot_acc);
705        let mut na2 = vaddvq_f32(na_acc);
706        let mut nb2 = vaddvq_f32(nb_acc);
707        for i in (chunks * 4)..n {
708            dot += a[i] * b[i];
709            na2 += a[i] * a[i];
710            nb2 += b[i] * b[i];
711        }
712        let (na, nb) = (na2.sqrt(), nb2.sqrt());
713        if na == 0.0 || nb == 0.0 {
714            return 1.0;
715        }
716        1.0 - dot / (na * nb)
717    }
718}
719
720// ── Tests ─────────────────────────────────────────────────────────────────────
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725
726    #[test]
727    fn cosine_identical() {
728        let v = vec![1.0f32, 0.0, 0.0];
729        assert!(cosine_distance(&v, &v).abs() < 1e-5);
730    }
731
732    #[test]
733    fn cosine_orthogonal() {
734        assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
735    }
736
737    #[test]
738    fn euclidean_basic() {
739        assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
740    }
741
742    #[test]
743    fn dot_basic() {
744        assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
745    }
746
747    #[test]
748    fn simd_matches_scalar_dim128() {
749        use rand::{rngs::StdRng, Rng, SeedableRng};
750        let mut rng = StdRng::seed_from_u64(99);
751        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
752        let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
753
754        let dot_s = dot_scalar(&a, &b);
755        let euclid_s = euclidean_scalar(&a, &b);
756        let cos_s = cosine_scalar(&a, &b);
757
758        let dot_f = dot_product(&a, &b);
759        let euclid_f = euclidean_distance(&a, &b);
760        let cos_f = cosine_distance(&a, &b);
761
762        assert!(
763            (dot_f - dot_s).abs() < 1e-4,
764            "dot mismatch: {dot_f} vs {dot_s}"
765        );
766        assert!(
767            (euclid_f - euclid_s).abs() < 1e-4,
768            "euclidean mismatch: {euclid_f} vs {euclid_s}"
769        );
770        assert!(
771            (cos_f - cos_s).abs() < 1e-4,
772            "cosine mismatch: {cos_f} vs {cos_s}"
773        );
774    }
775
776    #[test]
777    fn f16_simd_matches_scalar() {
778        use rand::{rngs::StdRng, Rng, SeedableRng};
779        let mut rng = StdRng::seed_from_u64(42);
780        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
781        let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
782        let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
783
784        let dot_s = dot_f16_scalar(&a, &b);
785        let euclid_s = euclidean_f16_scalar(&a, &b);
786        let cos_s = cosine_f16_scalar(&a, &b);
787
788        let dot_f = dot_product_f16(&a, &b);
789        let euclid_f = euclidean_distance_f16(&a, &b);
790        let cos_f = cosine_distance_f16(&a, &b);
791
792        // F16 rounding introduces small error — tolerate 1e-3
793        assert!(
794            (dot_f - dot_s).abs() < 1e-3,
795            "f16 dot mismatch: {dot_f} vs {dot_s}"
796        );
797        assert!(
798            (euclid_f - euclid_s).abs() < 1e-3,
799            "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
800        );
801        assert!(
802            (cos_f - cos_s).abs() < 1e-3,
803            "f16 cosine mismatch: {cos_f} vs {cos_s}"
804        );
805    }
806
807    #[test]
808    fn centroid_single() {
809        let v = vec![vec![1.0f32, 2.0, 3.0]];
810        let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
811        assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
812        assert!(c.radius < 1e-6, "radius={}", c.radius);
813    }
814
815    #[test]
816    fn centroid_two_points() {
817        let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
818        let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
819        assert!((c.values[0] - 1.0).abs() < 1e-6);
820        assert!(c.radius > 0.0);
821    }
822}