Skip to main content

ailake_vec/
distance.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use ailake_core::{Centroid, VectorMetric};
3use half::f16;
4
5// ── Public API ────────────────────────────────────────────────────────────────
6
7pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
8    #[cfg(target_arch = "x86_64")]
9    {
10        #[cfg(feature = "avx512")]
11        if is_x86_feature_detected!("avx512f") {
12            return unsafe { avx512::dot(a, b) };
13        }
14        if is_x86_feature_detected!("avx2") {
15            return unsafe { avx2::dot(a, b) };
16        }
17    }
18    #[cfg(target_arch = "aarch64")]
19    if std::arch::is_aarch64_feature_detected!("neon") {
20        return unsafe { neon_impl::dot(a, b) };
21    }
22    dot_scalar(a, b)
23}
24
25pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
26    #[cfg(target_arch = "x86_64")]
27    {
28        #[cfg(feature = "avx512")]
29        if is_x86_feature_detected!("avx512f") {
30            return unsafe { avx512::euclidean(a, b) };
31        }
32        if is_x86_feature_detected!("avx2") {
33            return unsafe { avx2::euclidean(a, b) };
34        }
35    }
36    #[cfg(target_arch = "aarch64")]
37    if std::arch::is_aarch64_feature_detected!("neon") {
38        return unsafe { neon_impl::euclidean(a, b) };
39    }
40    euclidean_scalar(a, b)
41}
42
43pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
44    #[cfg(target_arch = "x86_64")]
45    {
46        #[cfg(feature = "avx512")]
47        if is_x86_feature_detected!("avx512f") {
48            return unsafe { avx512::cosine(a, b) };
49        }
50        if is_x86_feature_detected!("avx2") {
51            return unsafe { avx2::cosine(a, b) };
52        }
53    }
54    #[cfg(target_arch = "aarch64")]
55    if std::arch::is_aarch64_feature_detected!("neon") {
56        return unsafe { neon_impl::cosine(a, b) };
57    }
58    cosine_scalar(a, b)
59}
60
61pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
62    match metric {
63        VectorMetric::Cosine => cosine_distance(a, b),
64        VectorMetric::Euclidean => euclidean_distance(a, b),
65        VectorMetric::DotProduct => -dot_product(a, b),
66        VectorMetric::NormalizedCosine => normalized_cosine_distance(a, b),
67    }
68}
69
70// ── F16 distance functions ────────────────────────────────────────────────────
71//
72// Query `a` stays F32 (one vector, lives in registers).
73// Database vector `b` is F16 (dequantized inline — no allocation).
74//
75// Fast path: F16C converts 8 F16 values to F32 in one instruction via
76// _mm256_cvtph_ps, then FMA accumulates. Eliminates scalar half::to_f32()
77// loop that dominates HNSW graph traversal on dim=128 vectors.
78
79pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
80    #[cfg(target_arch = "x86_64")]
81    {
82        #[cfg(feature = "avx512")]
83        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
84            return unsafe { avx512::cosine_f16(a, b) };
85        }
86        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
87            return unsafe { avx2_f16c::cosine(a, b) };
88        }
89    }
90    cosine_f16_scalar(a, b)
91}
92
93pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
94    #[cfg(target_arch = "x86_64")]
95    {
96        #[cfg(feature = "avx512")]
97        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
98            return unsafe { avx512::euclidean_f16(a, b) };
99        }
100        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
101            return unsafe { avx2_f16c::euclidean(a, b) };
102        }
103    }
104    euclidean_f16_scalar(a, b)
105}
106
107pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
108    #[cfg(target_arch = "x86_64")]
109    {
110        #[cfg(feature = "avx512")]
111        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
112            return unsafe { avx512::dot_f16(a, b) };
113        }
114        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
115            return unsafe { avx2_f16c::dot(a, b) };
116        }
117    }
118    dot_f16_scalar(a, b)
119}
120
121/// Normalize a vector to unit L2 length. Returns a zero vector unchanged.
122pub fn normalize_l2(v: &[f32]) -> Vec<f32> {
123    let norm_sq: f32 = v.iter().map(|x| x * x).sum();
124    if norm_sq < 1e-12 {
125        return v.to_vec();
126    }
127    let inv = 1.0 / norm_sq.sqrt();
128    v.iter().map(|x| x * inv).collect()
129}
130
131/// 1 - dot(a, b) for pre-normalized unit vectors — no sqrt, no norm computation.
132/// Equivalent to cosine distance but ~2× faster in the HNSW hot loop.
133pub fn normalized_cosine_distance(a: &[f32], b: &[f32]) -> f32 {
134    1.0 - dot_product(a, b)
135}
136
137pub fn normalized_cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
138    1.0 - dot_product_f16(a, b)
139}
140
141pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
142    if vectors.is_empty() {
143        return Centroid {
144            values: vec![],
145            radius: 0.0,
146            metric,
147        };
148    }
149    let dim = vectors[0].len();
150    let n = vectors.len() as f32;
151    let centroid: Vec<f32> = (0..dim)
152        .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
153        .collect();
154    let radius = vectors
155        .iter()
156        .map(|v| exact_distance(metric, &centroid, v))
157        .fold(0.0_f32, f32::max);
158    Centroid {
159        values: centroid,
160        radius,
161        metric,
162    }
163}
164
165// ── Scalar fallbacks ──────────────────────────────────────────────────────────
166
167#[inline(always)]
168fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
169    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
170}
171
172#[inline(always)]
173fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
174    a.iter()
175        .zip(b.iter())
176        .map(|(x, y)| (x - y) * (x - y))
177        .sum::<f32>()
178        .sqrt()
179}
180
181#[inline(always)]
182fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
183    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
184    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
185    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
186    if na == 0.0 || nb == 0.0 {
187        return 1.0;
188    }
189    1.0 - dot / (na * nb)
190}
191
192#[inline(always)]
193fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
194    let n = a.len().min(b.len());
195    let mut dot = 0.0f32;
196    let mut norm_a = 0.0f32;
197    let mut norm_b = 0.0f32;
198    for i in 0..n {
199        let ai = a[i];
200        let bi = b[i].to_f32();
201        dot += ai * bi;
202        norm_a += ai * ai;
203        norm_b += bi * bi;
204    }
205    let denom = (norm_a * norm_b).sqrt();
206    if denom < 1e-8 {
207        1.0
208    } else {
209        1.0 - dot / denom
210    }
211}
212
213#[inline(always)]
214fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
215    let n = a.len().min(b.len());
216    let mut sum = 0.0f32;
217    for i in 0..n {
218        let diff = a[i] - b[i].to_f32();
219        sum += diff * diff;
220    }
221    sum.sqrt()
222}
223
224#[inline(always)]
225fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
226    let n = a.len().min(b.len());
227    let mut acc = 0.0f32;
228    for i in 0..n {
229        acc += a[i] * b[i].to_f32();
230    }
231    acc
232}
233
234// ── x86_64 AVX2 + FMA ────────────────────────────────────────────────────────
235//
236// Compiled with target_feature = "avx2,fma". The compiler emits vfmadd231ps
237// instead of separate vmulps + vaddps, cutting inner-loop instruction count
238// by ~33% and reducing latency via fused operations.
239
240#[cfg(target_arch = "x86_64")]
241mod avx2 {
242    use std::arch::x86_64::*;
243
244    #[inline(always)]
245    pub unsafe fn hsum256(v: __m256) -> f32 {
246        let hi = _mm256_extractf128_ps(v, 1);
247        let lo = _mm256_castps256_ps128(v);
248        let s = _mm_add_ps(lo, hi);
249        let shuf = _mm_movehdup_ps(s);
250        let sums = _mm_add_ps(s, shuf);
251        let shuf = _mm_movehl_ps(shuf, sums);
252        _mm_cvtss_f32(_mm_add_ss(sums, shuf))
253    }
254
255    /// dot(a, b) — AVX2+FMA, 2× unrolled (16 f32/iter).
256    #[target_feature(enable = "avx2,fma")]
257    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
258        let n = a.len().min(b.len());
259        let ap = a.as_ptr();
260        let bp = b.as_ptr();
261
262        let mut acc0 = _mm256_setzero_ps();
263        let mut acc1 = _mm256_setzero_ps();
264
265        let chunks16 = n / 16;
266        for i in 0..chunks16 {
267            let base = i * 16;
268            let a0 = _mm256_loadu_ps(ap.add(base));
269            let b0 = _mm256_loadu_ps(bp.add(base));
270            let a1 = _mm256_loadu_ps(ap.add(base + 8));
271            let b1 = _mm256_loadu_ps(bp.add(base + 8));
272            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
273            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
274        }
275
276        let chunks8 = n / 8;
277        if chunks8 > chunks16 * 2 {
278            let base = chunks16 * 16;
279            let a0 = _mm256_loadu_ps(ap.add(base));
280            let b0 = _mm256_loadu_ps(bp.add(base));
281            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
282        }
283
284        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
285        for i in (chunks8 * 8)..n {
286            sum += *ap.add(i) * *bp.add(i);
287        }
288        sum
289    }
290
291    /// ||a - b||² — AVX2+FMA, 2× unrolled.
292    #[target_feature(enable = "avx2,fma")]
293    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
294        let n = a.len().min(b.len());
295        let ap = a.as_ptr();
296        let bp = b.as_ptr();
297
298        let mut acc0 = _mm256_setzero_ps();
299        let mut acc1 = _mm256_setzero_ps();
300
301        let chunks16 = n / 16;
302        for i in 0..chunks16 {
303            let base = i * 16;
304            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
305            let d1 = _mm256_sub_ps(
306                _mm256_loadu_ps(ap.add(base + 8)),
307                _mm256_loadu_ps(bp.add(base + 8)),
308            );
309            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
310            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
311        }
312
313        let chunks8 = n / 8;
314        if chunks8 > chunks16 * 2 {
315            let base = chunks16 * 16;
316            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
317            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
318        }
319
320        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
321        for i in (chunks8 * 8)..n {
322            let d = *ap.add(i) - *bp.add(i);
323            sum += d * d;
324        }
325        sum.sqrt()
326    }
327
328    /// 1 - cos(a, b) — AVX2+FMA, single pass for dot + norms².
329    #[target_feature(enable = "avx2,fma")]
330    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
331        let n = a.len().min(b.len());
332        let ap = a.as_ptr();
333        let bp = b.as_ptr();
334
335        let mut dot_acc = _mm256_setzero_ps();
336        let mut na_acc = _mm256_setzero_ps();
337        let mut nb_acc = _mm256_setzero_ps();
338
339        let chunks8 = n / 8;
340        for i in 0..chunks8 {
341            let base = i * 8;
342            let av = _mm256_loadu_ps(ap.add(base));
343            let bv = _mm256_loadu_ps(bp.add(base));
344            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
345            na_acc = _mm256_fmadd_ps(av, av, na_acc);
346            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
347        }
348
349        let mut dot = hsum256(dot_acc);
350        let mut na2 = hsum256(na_acc);
351        let mut nb2 = hsum256(nb_acc);
352
353        for i in (chunks8 * 8)..n {
354            let ai = *ap.add(i);
355            let bi = *bp.add(i);
356            dot += ai * bi;
357            na2 += ai * ai;
358            nb2 += bi * bi;
359        }
360
361        let na = na2.sqrt();
362        let nb = nb2.sqrt();
363        if na == 0.0 || nb == 0.0 {
364            return 1.0;
365        }
366        1.0 - dot / (na * nb)
367    }
368}
369
370// ── x86_64 AVX2 + F16C — F16 hot path ────────────────────────────────────────
371//
372// _mm256_cvtph_ps converts 8 packed F16 (as __m128i) to 8 F32 in one cycle.
373// Combined with FMA, this replaces 8 scalar half::to_f32() calls per iteration.
374// Critical hot path: every HNSW edge traversal calls one of these functions.
375
376#[cfg(target_arch = "x86_64")]
377mod avx2_f16c {
378    use half::f16;
379    use std::arch::x86_64::*;
380
381    use super::avx2::hsum256;
382
383    /// dot(a_f32, b_f16) — AVX2+F16C+FMA, 16 F16/iter.
384    #[target_feature(enable = "avx2,f16c,fma")]
385    pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
386        let n = a.len().min(b.len());
387        let ap = a.as_ptr();
388        let bp = b.as_ptr() as *const u16;
389
390        let mut acc0 = _mm256_setzero_ps();
391        let mut acc1 = _mm256_setzero_ps();
392
393        let chunks16 = n / 16;
394        for i in 0..chunks16 {
395            let base = i * 16;
396            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
397            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
398            let a0 = _mm256_loadu_ps(ap.add(base));
399            let a1 = _mm256_loadu_ps(ap.add(base + 8));
400            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
401            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
402        }
403
404        let chunks8 = n / 8;
405        if chunks8 > chunks16 * 2 {
406            let base = chunks16 * 16;
407            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
408            let a0 = _mm256_loadu_ps(ap.add(base));
409            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
410        }
411
412        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
413        for i in (chunks8 * 8)..n {
414            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
415        }
416        sum
417    }
418
419    /// ||a_f32 - b_f16||² — AVX2+F16C+FMA, 16 F16/iter.
420    #[target_feature(enable = "avx2,f16c,fma")]
421    pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
422        let n = a.len().min(b.len());
423        let ap = a.as_ptr();
424        let bp = b.as_ptr() as *const u16;
425
426        let mut acc0 = _mm256_setzero_ps();
427        let mut acc1 = _mm256_setzero_ps();
428
429        let chunks16 = n / 16;
430        for i in 0..chunks16 {
431            let base = i * 16;
432            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
433            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
434            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
435            let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
436            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
437            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
438        }
439
440        let chunks8 = n / 8;
441        if chunks8 > chunks16 * 2 {
442            let base = chunks16 * 16;
443            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
444            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
445            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
446        }
447
448        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
449        for i in (chunks8 * 8)..n {
450            let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
451            sum += diff * diff;
452        }
453        sum.sqrt()
454    }
455
456    /// 1 - cos(a_f32, b_f16) — AVX2+F16C+FMA, single pass.
457    #[target_feature(enable = "avx2,f16c,fma")]
458    pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
459        let n = a.len().min(b.len());
460        let ap = a.as_ptr();
461        let bp = b.as_ptr() as *const u16;
462
463        let mut dot_acc = _mm256_setzero_ps();
464        let mut na_acc = _mm256_setzero_ps();
465        let mut nb_acc = _mm256_setzero_ps();
466
467        let chunks8 = n / 8;
468        for i in 0..chunks8 {
469            let base = i * 8;
470            let av = _mm256_loadu_ps(ap.add(base));
471            let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
472            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
473            na_acc = _mm256_fmadd_ps(av, av, na_acc);
474            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
475        }
476
477        let mut dot = hsum256(dot_acc);
478        let mut na2 = hsum256(na_acc);
479        let mut nb2 = hsum256(nb_acc);
480
481        for i in (chunks8 * 8)..n {
482            let ai = *ap.add(i);
483            let bi = f16::from_bits(*bp.add(i)).to_f32();
484            dot += ai * bi;
485            na2 += ai * ai;
486            nb2 += bi * bi;
487        }
488
489        let denom = (na2 * nb2).sqrt();
490        if denom < 1e-8 {
491            1.0
492        } else {
493            1.0 - dot / denom
494        }
495    }
496}
497
498// ── x86_64 AVX-512F — forward compatibility ───────────────────────────────────
499//
500// 16 f32/iter (vs 8 for AVX2). Runtime-detected — skipped on this machine
501// (no avx512f), active on Xeon Scalable, Zen 4+, and Intel Core 12th gen+.
502// Requires Rust ≥ 1.89 (AVX-512 intrinsics stabilised there). Gated behind
503// the `avx512` feature so the default/manylinux build always succeeds.
504
505#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
506mod avx512 {
507    use half::f16;
508    use std::arch::x86_64::*;
509
510    #[inline(always)]
511    unsafe fn hsum512(v: __m512) -> f32 {
512        // _mm512_reduce_add_ps stabilized Rust 1.89; _mm512_extractf32x8_ps needs avx512dq.
513        // Store all 16 lanes to stack (avx512f), reload as two __m256 (avx), then reduce.
514        let mut buf = [0.0f32; 16];
515        _mm512_storeu_ps(buf.as_mut_ptr(), v);
516        let lo = _mm256_loadu_ps(buf.as_ptr());
517        let hi = _mm256_loadu_ps(buf.as_ptr().add(8));
518        let sum256 = _mm256_add_ps(lo, hi);
519        let hi128 = _mm256_extractf128_ps(sum256, 1);
520        let lo128 = _mm256_castps256_ps128(sum256);
521        let sum128 = _mm_add_ps(lo128, hi128);
522        let shuf = _mm_movehdup_ps(sum128);
523        let sums = _mm_add_ps(sum128, shuf);
524        let shuf2 = _mm_movehl_ps(shuf, sums);
525        _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
526    }
527
528    #[target_feature(enable = "avx512f,fma")]
529    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
530        let n = a.len().min(b.len());
531        let ap = a.as_ptr();
532        let bp = b.as_ptr();
533        let mut acc = _mm512_setzero_ps();
534        let chunks16 = n / 16;
535        for i in 0..chunks16 {
536            let base = i * 16;
537            acc = _mm512_fmadd_ps(
538                _mm512_loadu_ps(ap.add(base)),
539                _mm512_loadu_ps(bp.add(base)),
540                acc,
541            );
542        }
543        let mut sum = hsum512(acc);
544        for i in (chunks16 * 16)..n {
545            sum += *ap.add(i) * *bp.add(i);
546        }
547        sum
548    }
549
550    #[target_feature(enable = "avx512f,fma")]
551    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
552        let n = a.len().min(b.len());
553        let ap = a.as_ptr();
554        let bp = b.as_ptr();
555        let mut acc = _mm512_setzero_ps();
556        let chunks16 = n / 16;
557        for i in 0..chunks16 {
558            let base = i * 16;
559            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
560            acc = _mm512_fmadd_ps(d, d, acc);
561        }
562        let mut sum = hsum512(acc);
563        for i in (chunks16 * 16)..n {
564            let d = *ap.add(i) - *bp.add(i);
565            sum += d * d;
566        }
567        sum.sqrt()
568    }
569
570    #[target_feature(enable = "avx512f,fma")]
571    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
572        let n = a.len().min(b.len());
573        let ap = a.as_ptr();
574        let bp = b.as_ptr();
575        let mut dot_acc = _mm512_setzero_ps();
576        let mut na_acc = _mm512_setzero_ps();
577        let mut nb_acc = _mm512_setzero_ps();
578        let chunks16 = n / 16;
579        for i in 0..chunks16 {
580            let base = i * 16;
581            let av = _mm512_loadu_ps(ap.add(base));
582            let bv = _mm512_loadu_ps(bp.add(base));
583            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
584            na_acc = _mm512_fmadd_ps(av, av, na_acc);
585            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
586        }
587        let mut dot = hsum512(dot_acc);
588        let mut na2 = hsum512(na_acc);
589        let mut nb2 = hsum512(nb_acc);
590        for i in (chunks16 * 16)..n {
591            let ai = *ap.add(i);
592            let bi = *bp.add(i);
593            dot += ai * bi;
594            na2 += ai * ai;
595            nb2 += bi * bi;
596        }
597        let (na, nb) = (na2.sqrt(), nb2.sqrt());
598        if na == 0.0 || nb == 0.0 {
599            return 1.0;
600        }
601        1.0 - dot / (na * nb)
602    }
603
604    /// dot(a_f32, b_f16) — AVX-512F+F16C+FMA, 16 F16/iter.
605    #[target_feature(enable = "avx512f,f16c,fma")]
606    pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
607        let n = a.len().min(b.len());
608        let ap = a.as_ptr();
609        let bp = b.as_ptr() as *const u16;
610        let mut acc = _mm512_setzero_ps();
611        let chunks16 = n / 16;
612        for i in 0..chunks16 {
613            let base = i * 16;
614            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
615            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
616            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
617            acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
618        }
619        let mut sum = hsum512(acc);
620        for i in (chunks16 * 16)..n {
621            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
622        }
623        sum
624    }
625
626    #[target_feature(enable = "avx512f,f16c,fma")]
627    pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
628        let n = a.len().min(b.len());
629        let ap = a.as_ptr();
630        let bp = b.as_ptr() as *const u16;
631        let mut acc = _mm512_setzero_ps();
632        let chunks16 = n / 16;
633        for i in 0..chunks16 {
634            let base = i * 16;
635            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
636            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
637            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
638            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
639            acc = _mm512_fmadd_ps(d, d, acc);
640        }
641        let mut sum = hsum512(acc);
642        for i in (chunks16 * 16)..n {
643            let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
644            sum += d * d;
645        }
646        sum.sqrt()
647    }
648
649    #[target_feature(enable = "avx512f,f16c,fma")]
650    pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
651        let n = a.len().min(b.len());
652        let ap = a.as_ptr();
653        let bp = b.as_ptr() as *const u16;
654        let mut dot_acc = _mm512_setzero_ps();
655        let mut na_acc = _mm512_setzero_ps();
656        let mut nb_acc = _mm512_setzero_ps();
657        let chunks16 = n / 16;
658        for i in 0..chunks16 {
659            let base = i * 16;
660            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
661            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
662            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
663            let av = _mm512_loadu_ps(ap.add(base));
664            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
665            na_acc = _mm512_fmadd_ps(av, av, na_acc);
666            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
667        }
668        let mut dot = hsum512(dot_acc);
669        let mut na2 = hsum512(na_acc);
670        let mut nb2 = hsum512(nb_acc);
671        for i in (chunks16 * 16)..n {
672            let ai = *ap.add(i);
673            let bi = f16::from_bits(*bp.add(i)).to_f32();
674            dot += ai * bi;
675            na2 += ai * ai;
676            nb2 += bi * bi;
677        }
678        let denom = (na2 * nb2).sqrt();
679        if denom < 1e-8 {
680            1.0
681        } else {
682            1.0 - dot / denom
683        }
684    }
685}
686
687// ── aarch64 NEON ──────────────────────────────────────────────────────────────
688
689#[cfg(target_arch = "aarch64")]
690mod neon_impl {
691    use std::arch::aarch64::*;
692
693    #[target_feature(enable = "neon")]
694    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
695        let n = a.len().min(b.len());
696        let mut acc = vdupq_n_f32(0.0);
697        let chunks = n / 4;
698        for i in 0..chunks {
699            let base = i * 4;
700            let av = vld1q_f32(a.as_ptr().add(base));
701            let bv = vld1q_f32(b.as_ptr().add(base));
702            acc = vmlaq_f32(acc, av, bv);
703        }
704        let mut sum = vaddvq_f32(acc);
705        for i in (chunks * 4)..n {
706            sum += a[i] * b[i];
707        }
708        sum
709    }
710
711    #[target_feature(enable = "neon")]
712    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
713        let n = a.len().min(b.len());
714        let mut acc = vdupq_n_f32(0.0);
715        let chunks = n / 4;
716        for i in 0..chunks {
717            let base = i * 4;
718            let d = vsubq_f32(
719                vld1q_f32(a.as_ptr().add(base)),
720                vld1q_f32(b.as_ptr().add(base)),
721            );
722            acc = vmlaq_f32(acc, d, d);
723        }
724        let mut sum = vaddvq_f32(acc);
725        for i in (chunks * 4)..n {
726            let d = a[i] - b[i];
727            sum += d * d;
728        }
729        sum.sqrt()
730    }
731
732    #[target_feature(enable = "neon")]
733    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
734        let n = a.len().min(b.len());
735        let mut dot_acc = vdupq_n_f32(0.0);
736        let mut na_acc = vdupq_n_f32(0.0);
737        let mut nb_acc = vdupq_n_f32(0.0);
738        let chunks = n / 4;
739        for i in 0..chunks {
740            let base = i * 4;
741            let av = vld1q_f32(a.as_ptr().add(base));
742            let bv = vld1q_f32(b.as_ptr().add(base));
743            dot_acc = vmlaq_f32(dot_acc, av, bv);
744            na_acc = vmlaq_f32(na_acc, av, av);
745            nb_acc = vmlaq_f32(nb_acc, bv, bv);
746        }
747        let mut dot = vaddvq_f32(dot_acc);
748        let mut na2 = vaddvq_f32(na_acc);
749        let mut nb2 = vaddvq_f32(nb_acc);
750        for i in (chunks * 4)..n {
751            dot += a[i] * b[i];
752            na2 += a[i] * a[i];
753            nb2 += b[i] * b[i];
754        }
755        let (na, nb) = (na2.sqrt(), nb2.sqrt());
756        if na == 0.0 || nb == 0.0 {
757            return 1.0;
758        }
759        1.0 - dot / (na * nb)
760    }
761}
762
763// ── Tests ─────────────────────────────────────────────────────────────────────
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768
769    #[test]
770    fn cosine_identical() {
771        let v = vec![1.0f32, 0.0, 0.0];
772        assert!(cosine_distance(&v, &v).abs() < 1e-5);
773    }
774
775    #[test]
776    fn cosine_orthogonal() {
777        assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
778    }
779
780    #[test]
781    fn euclidean_basic() {
782        assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
783    }
784
785    #[test]
786    fn dot_basic() {
787        assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
788    }
789
790    #[test]
791    fn simd_matches_scalar_dim128() {
792        use rand::{rngs::StdRng, Rng, SeedableRng};
793        let mut rng = StdRng::seed_from_u64(99);
794        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
795        let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
796
797        let dot_s = dot_scalar(&a, &b);
798        let euclid_s = euclidean_scalar(&a, &b);
799        let cos_s = cosine_scalar(&a, &b);
800
801        let dot_f = dot_product(&a, &b);
802        let euclid_f = euclidean_distance(&a, &b);
803        let cos_f = cosine_distance(&a, &b);
804
805        assert!(
806            (dot_f - dot_s).abs() < 1e-4,
807            "dot mismatch: {dot_f} vs {dot_s}"
808        );
809        assert!(
810            (euclid_f - euclid_s).abs() < 1e-4,
811            "euclidean mismatch: {euclid_f} vs {euclid_s}"
812        );
813        assert!(
814            (cos_f - cos_s).abs() < 1e-4,
815            "cosine mismatch: {cos_f} vs {cos_s}"
816        );
817    }
818
819    #[test]
820    fn f16_simd_matches_scalar() {
821        use rand::{rngs::StdRng, Rng, SeedableRng};
822        let mut rng = StdRng::seed_from_u64(42);
823        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
824        let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
825        let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
826
827        let dot_s = dot_f16_scalar(&a, &b);
828        let euclid_s = euclidean_f16_scalar(&a, &b);
829        let cos_s = cosine_f16_scalar(&a, &b);
830
831        let dot_f = dot_product_f16(&a, &b);
832        let euclid_f = euclidean_distance_f16(&a, &b);
833        let cos_f = cosine_distance_f16(&a, &b);
834
835        // F16 rounding introduces small error — tolerate 1e-3
836        assert!(
837            (dot_f - dot_s).abs() < 1e-3,
838            "f16 dot mismatch: {dot_f} vs {dot_s}"
839        );
840        assert!(
841            (euclid_f - euclid_s).abs() < 1e-3,
842            "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
843        );
844        assert!(
845            (cos_f - cos_s).abs() < 1e-3,
846            "f16 cosine mismatch: {cos_f} vs {cos_s}"
847        );
848    }
849
850    #[test]
851    fn normalize_l2_unit() {
852        let v = vec![3.0f32, 4.0];
853        let n = normalize_l2(&v);
854        let norm: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
855        assert!((norm - 1.0).abs() < 1e-6, "norm={norm}");
856        assert!((n[0] - 0.6).abs() < 1e-6);
857        assert!((n[1] - 0.8).abs() < 1e-6);
858    }
859
860    #[test]
861    fn normalized_cosine_matches_cosine_on_unit_vecs() {
862        let a = normalize_l2(&[1.0f32, 1.0, 0.0]);
863        let b = normalize_l2(&[1.0f32, 0.0, 1.0]);
864        let cos = cosine_distance(&a, &b);
865        let ncos = normalized_cosine_distance(&a, &b);
866        assert!((cos - ncos).abs() < 1e-5, "cos={cos} ncos={ncos}");
867    }
868
869    #[test]
870    fn centroid_single() {
871        let v = vec![vec![1.0f32, 2.0, 3.0]];
872        let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
873        assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
874        assert!(c.radius < 1e-6, "radius={}", c.radius);
875    }
876
877    #[test]
878    fn centroid_two_points() {
879        let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
880        let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
881        assert!((c.values[0] - 1.0).abs() < 1e-6);
882        assert!(c.radius > 0.0);
883    }
884}