Skip to main content

ailake_vec/
distance.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use ailake_core::{Centroid, VectorMetric};
3use half::f16;
4
5// ── Public API ────────────────────────────────────────────────────────────────
6
7pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
8    debug_assert_eq!(
9        a.len(),
10        b.len(),
11        "dot_product: dimension mismatch {} vs {}",
12        a.len(),
13        b.len()
14    );
15    #[cfg(target_arch = "x86_64")]
16    {
17        #[cfg(feature = "avx512")]
18        if is_x86_feature_detected!("avx512f") {
19            return unsafe { avx512::dot(a, b) };
20        }
21        if is_x86_feature_detected!("avx2") {
22            return unsafe { avx2::dot(a, b) };
23        }
24    }
25    #[cfg(target_arch = "aarch64")]
26    if std::arch::is_aarch64_feature_detected!("neon") {
27        return unsafe { neon_impl::dot(a, b) };
28    }
29    dot_scalar(a, b)
30}
31
32pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
33    debug_assert_eq!(
34        a.len(),
35        b.len(),
36        "euclidean_distance: dimension mismatch {} vs {}",
37        a.len(),
38        b.len()
39    );
40    #[cfg(target_arch = "x86_64")]
41    {
42        #[cfg(feature = "avx512")]
43        if is_x86_feature_detected!("avx512f") {
44            return unsafe { avx512::euclidean(a, b) };
45        }
46        if is_x86_feature_detected!("avx2") {
47            return unsafe { avx2::euclidean(a, b) };
48        }
49    }
50    #[cfg(target_arch = "aarch64")]
51    if std::arch::is_aarch64_feature_detected!("neon") {
52        return unsafe { neon_impl::euclidean(a, b) };
53    }
54    euclidean_scalar(a, b)
55}
56
57pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
58    debug_assert_eq!(
59        a.len(),
60        b.len(),
61        "cosine_distance: dimension mismatch {} vs {}",
62        a.len(),
63        b.len()
64    );
65    #[cfg(target_arch = "x86_64")]
66    {
67        #[cfg(feature = "avx512")]
68        if is_x86_feature_detected!("avx512f") {
69            return unsafe { avx512::cosine(a, b) };
70        }
71        if is_x86_feature_detected!("avx2") {
72            return unsafe { avx2::cosine(a, b) };
73        }
74    }
75    #[cfg(target_arch = "aarch64")]
76    if std::arch::is_aarch64_feature_detected!("neon") {
77        return unsafe { neon_impl::cosine(a, b) };
78    }
79    cosine_scalar(a, b)
80}
81
82pub fn exact_distance(metric: VectorMetric, a: &[f32], b: &[f32]) -> f32 {
83    match metric {
84        VectorMetric::Cosine => cosine_distance(a, b),
85        VectorMetric::Euclidean => euclidean_distance(a, b),
86        VectorMetric::DotProduct => -dot_product(a, b),
87        VectorMetric::NormalizedCosine => normalized_cosine_distance(a, b),
88    }
89}
90
91// ── F16 distance functions ────────────────────────────────────────────────────
92//
93// Query `a` stays F32 (one vector, lives in registers).
94// Database vector `b` is F16 (dequantized inline — no allocation).
95//
96// Fast path: F16C converts 8 F16 values to F32 in one instruction via
97// _mm256_cvtph_ps, then FMA accumulates. Eliminates scalar half::to_f32()
98// loop that dominates HNSW graph traversal on dim=128 vectors.
99
100pub fn cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
101    debug_assert_eq!(
102        a.len(),
103        b.len(),
104        "cosine_distance_f16: dimension mismatch {} vs {}",
105        a.len(),
106        b.len()
107    );
108    #[cfg(target_arch = "x86_64")]
109    {
110        #[cfg(feature = "avx512")]
111        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
112            return unsafe { avx512::cosine_f16(a, b) };
113        }
114        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
115            return unsafe { avx2_f16c::cosine(a, b) };
116        }
117    }
118    cosine_f16_scalar(a, b)
119}
120
121pub fn euclidean_distance_f16(a: &[f32], b: &[f16]) -> f32 {
122    debug_assert_eq!(
123        a.len(),
124        b.len(),
125        "euclidean_distance_f16: dimension mismatch {} vs {}",
126        a.len(),
127        b.len()
128    );
129    #[cfg(target_arch = "x86_64")]
130    {
131        #[cfg(feature = "avx512")]
132        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
133            return unsafe { avx512::euclidean_f16(a, b) };
134        }
135        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
136            return unsafe { avx2_f16c::euclidean(a, b) };
137        }
138    }
139    euclidean_f16_scalar(a, b)
140}
141
142pub fn dot_product_f16(a: &[f32], b: &[f16]) -> f32 {
143    debug_assert_eq!(
144        a.len(),
145        b.len(),
146        "dot_product_f16: dimension mismatch {} vs {}",
147        a.len(),
148        b.len()
149    );
150    #[cfg(target_arch = "x86_64")]
151    {
152        #[cfg(feature = "avx512")]
153        if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") {
154            return unsafe { avx512::dot_f16(a, b) };
155        }
156        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("f16c") {
157            return unsafe { avx2_f16c::dot(a, b) };
158        }
159    }
160    dot_f16_scalar(a, b)
161}
162
163/// Normalize a vector to unit L2 length. Returns a zero vector unchanged.
164pub fn normalize_l2(v: &[f32]) -> Vec<f32> {
165    let norm_sq: f32 = v.iter().map(|x| x * x).sum();
166    if norm_sq < 1e-12 {
167        return v.to_vec();
168    }
169    let inv = 1.0 / norm_sq.sqrt();
170    v.iter().map(|x| x * inv).collect()
171}
172
173/// 1 - dot(a, b) for pre-normalized unit vectors — no sqrt, no norm computation.
174/// Equivalent to cosine distance but ~2× faster in the HNSW hot loop.
175pub fn normalized_cosine_distance(a: &[f32], b: &[f32]) -> f32 {
176    1.0 - dot_product(a, b)
177}
178
179pub fn normalized_cosine_distance_f16(a: &[f32], b: &[f16]) -> f32 {
180    1.0 - dot_product_f16(a, b)
181}
182
183pub fn compute_centroid_and_radius(vectors: &[Vec<f32>], metric: VectorMetric) -> Centroid {
184    if vectors.is_empty() {
185        return Centroid {
186            values: vec![],
187            radius: 0.0,
188            metric,
189        };
190    }
191    let dim = vectors[0].len();
192    let n = vectors.len() as f32;
193    let centroid: Vec<f32> = (0..dim)
194        .map(|i| vectors.iter().map(|v| v[i]).sum::<f32>() / n)
195        .collect();
196    let radius = vectors
197        .iter()
198        .map(|v| exact_distance(metric, &centroid, v))
199        .fold(0.0_f32, f32::max);
200    Centroid {
201        values: centroid,
202        radius,
203        metric,
204    }
205}
206
207// ── Scalar fallbacks ──────────────────────────────────────────────────────────
208
209#[inline(always)]
210fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
211    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
212}
213
214#[inline(always)]
215fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
216    a.iter()
217        .zip(b.iter())
218        .map(|(x, y)| (x - y) * (x - y))
219        .sum::<f32>()
220        .sqrt()
221}
222
223#[inline(always)]
224fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
225    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
226    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
227    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
228    if na == 0.0 || nb == 0.0 {
229        return 1.0;
230    }
231    1.0 - dot / (na * nb)
232}
233
234#[inline(always)]
235fn cosine_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
236    let n = a.len().min(b.len());
237    let mut dot = 0.0f32;
238    let mut norm_a = 0.0f32;
239    let mut norm_b = 0.0f32;
240    for i in 0..n {
241        let ai = a[i];
242        let bi = b[i].to_f32();
243        dot += ai * bi;
244        norm_a += ai * ai;
245        norm_b += bi * bi;
246    }
247    let denom = (norm_a * norm_b).sqrt();
248    if denom < 1e-8 {
249        1.0
250    } else {
251        1.0 - dot / denom
252    }
253}
254
255#[inline(always)]
256fn euclidean_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
257    let n = a.len().min(b.len());
258    let mut sum = 0.0f32;
259    for i in 0..n {
260        let diff = a[i] - b[i].to_f32();
261        sum += diff * diff;
262    }
263    sum.sqrt()
264}
265
266#[inline(always)]
267fn dot_f16_scalar(a: &[f32], b: &[f16]) -> f32 {
268    let n = a.len().min(b.len());
269    let mut acc = 0.0f32;
270    for i in 0..n {
271        acc += a[i] * b[i].to_f32();
272    }
273    acc
274}
275
276// ── x86_64 AVX2 + FMA ────────────────────────────────────────────────────────
277//
278// Compiled with target_feature = "avx2,fma". The compiler emits vfmadd231ps
279// instead of separate vmulps + vaddps, cutting inner-loop instruction count
280// by ~33% and reducing latency via fused operations.
281
282#[cfg(target_arch = "x86_64")]
283mod avx2 {
284    use std::arch::x86_64::*;
285
286    #[inline(always)]
287    pub unsafe fn hsum256(v: __m256) -> f32 {
288        let hi = _mm256_extractf128_ps(v, 1);
289        let lo = _mm256_castps256_ps128(v);
290        let s = _mm_add_ps(lo, hi);
291        let shuf = _mm_movehdup_ps(s);
292        let sums = _mm_add_ps(s, shuf);
293        let shuf = _mm_movehl_ps(shuf, sums);
294        _mm_cvtss_f32(_mm_add_ss(sums, shuf))
295    }
296
297    /// dot(a, b) — AVX2+FMA, 2× unrolled (16 f32/iter).
298    #[target_feature(enable = "avx2,fma")]
299    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
300        let n = a.len().min(b.len());
301        let ap = a.as_ptr();
302        let bp = b.as_ptr();
303
304        let mut acc0 = _mm256_setzero_ps();
305        let mut acc1 = _mm256_setzero_ps();
306
307        let chunks16 = n / 16;
308        for i in 0..chunks16 {
309            let base = i * 16;
310            let a0 = _mm256_loadu_ps(ap.add(base));
311            let b0 = _mm256_loadu_ps(bp.add(base));
312            let a1 = _mm256_loadu_ps(ap.add(base + 8));
313            let b1 = _mm256_loadu_ps(bp.add(base + 8));
314            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
315            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
316        }
317
318        let chunks8 = n / 8;
319        if chunks8 > chunks16 * 2 {
320            let base = chunks16 * 16;
321            let a0 = _mm256_loadu_ps(ap.add(base));
322            let b0 = _mm256_loadu_ps(bp.add(base));
323            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
324        }
325
326        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
327        for i in (chunks8 * 8)..n {
328            sum += *ap.add(i) * *bp.add(i);
329        }
330        sum
331    }
332
333    /// ||a - b||² — AVX2+FMA, 2× unrolled.
334    #[target_feature(enable = "avx2,fma")]
335    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
336        let n = a.len().min(b.len());
337        let ap = a.as_ptr();
338        let bp = b.as_ptr();
339
340        let mut acc0 = _mm256_setzero_ps();
341        let mut acc1 = _mm256_setzero_ps();
342
343        let chunks16 = n / 16;
344        for i in 0..chunks16 {
345            let base = i * 16;
346            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
347            let d1 = _mm256_sub_ps(
348                _mm256_loadu_ps(ap.add(base + 8)),
349                _mm256_loadu_ps(bp.add(base + 8)),
350            );
351            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
352            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
353        }
354
355        let chunks8 = n / 8;
356        if chunks8 > chunks16 * 2 {
357            let base = chunks16 * 16;
358            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), _mm256_loadu_ps(bp.add(base)));
359            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
360        }
361
362        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
363        for i in (chunks8 * 8)..n {
364            let d = *ap.add(i) - *bp.add(i);
365            sum += d * d;
366        }
367        sum.sqrt()
368    }
369
370    /// 1 - cos(a, b) — AVX2+FMA, single pass for dot + norms².
371    #[target_feature(enable = "avx2,fma")]
372    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
373        let n = a.len().min(b.len());
374        let ap = a.as_ptr();
375        let bp = b.as_ptr();
376
377        let mut dot_acc = _mm256_setzero_ps();
378        let mut na_acc = _mm256_setzero_ps();
379        let mut nb_acc = _mm256_setzero_ps();
380
381        let chunks8 = n / 8;
382        for i in 0..chunks8 {
383            let base = i * 8;
384            let av = _mm256_loadu_ps(ap.add(base));
385            let bv = _mm256_loadu_ps(bp.add(base));
386            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
387            na_acc = _mm256_fmadd_ps(av, av, na_acc);
388            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
389        }
390
391        let mut dot = hsum256(dot_acc);
392        let mut na2 = hsum256(na_acc);
393        let mut nb2 = hsum256(nb_acc);
394
395        for i in (chunks8 * 8)..n {
396            let ai = *ap.add(i);
397            let bi = *bp.add(i);
398            dot += ai * bi;
399            na2 += ai * ai;
400            nb2 += bi * bi;
401        }
402
403        let na = na2.sqrt();
404        let nb = nb2.sqrt();
405        if na == 0.0 || nb == 0.0 {
406            return 1.0;
407        }
408        1.0 - dot / (na * nb)
409    }
410}
411
412// ── x86_64 AVX2 + F16C — F16 hot path ────────────────────────────────────────
413//
414// _mm256_cvtph_ps converts 8 packed F16 (as __m128i) to 8 F32 in one cycle.
415// Combined with FMA, this replaces 8 scalar half::to_f32() calls per iteration.
416// Critical hot path: every HNSW edge traversal calls one of these functions.
417
418#[cfg(target_arch = "x86_64")]
419mod avx2_f16c {
420    use half::f16;
421    use std::arch::x86_64::*;
422
423    use super::avx2::hsum256;
424
425    /// dot(a_f32, b_f16) — AVX2+F16C+FMA, 16 F16/iter.
426    #[target_feature(enable = "avx2,f16c,fma")]
427    pub unsafe fn dot(a: &[f32], b: &[f16]) -> f32 {
428        let n = a.len().min(b.len());
429        let ap = a.as_ptr();
430        let bp = b.as_ptr() as *const u16;
431
432        let mut acc0 = _mm256_setzero_ps();
433        let mut acc1 = _mm256_setzero_ps();
434
435        let chunks16 = n / 16;
436        for i in 0..chunks16 {
437            let base = i * 16;
438            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
439            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
440            let a0 = _mm256_loadu_ps(ap.add(base));
441            let a1 = _mm256_loadu_ps(ap.add(base + 8));
442            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
443            acc1 = _mm256_fmadd_ps(a1, b1, acc1);
444        }
445
446        let chunks8 = n / 8;
447        if chunks8 > chunks16 * 2 {
448            let base = chunks16 * 16;
449            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
450            let a0 = _mm256_loadu_ps(ap.add(base));
451            acc0 = _mm256_fmadd_ps(a0, b0, acc0);
452        }
453
454        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
455        for i in (chunks8 * 8)..n {
456            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
457        }
458        sum
459    }
460
461    /// ||a_f32 - b_f16||² — AVX2+F16C+FMA, 16 F16/iter.
462    #[target_feature(enable = "avx2,f16c,fma")]
463    pub unsafe fn euclidean(a: &[f32], b: &[f16]) -> f32 {
464        let n = a.len().min(b.len());
465        let ap = a.as_ptr();
466        let bp = b.as_ptr() as *const u16;
467
468        let mut acc0 = _mm256_setzero_ps();
469        let mut acc1 = _mm256_setzero_ps();
470
471        let chunks16 = n / 16;
472        for i in 0..chunks16 {
473            let base = i * 16;
474            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
475            let b1 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
476            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
477            let d1 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base + 8)), b1);
478            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
479            acc1 = _mm256_fmadd_ps(d1, d1, acc1);
480        }
481
482        let chunks8 = n / 8;
483        if chunks8 > chunks16 * 2 {
484            let base = chunks16 * 16;
485            let b0 = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
486            let d0 = _mm256_sub_ps(_mm256_loadu_ps(ap.add(base)), b0);
487            acc0 = _mm256_fmadd_ps(d0, d0, acc0);
488        }
489
490        let mut sum = hsum256(_mm256_add_ps(acc0, acc1));
491        for i in (chunks8 * 8)..n {
492            let diff = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
493            sum += diff * diff;
494        }
495        sum.sqrt()
496    }
497
498    /// 1 - cos(a_f32, b_f16) — AVX2+F16C+FMA, single pass.
499    #[target_feature(enable = "avx2,f16c,fma")]
500    pub unsafe fn cosine(a: &[f32], b: &[f16]) -> f32 {
501        let n = a.len().min(b.len());
502        let ap = a.as_ptr();
503        let bp = b.as_ptr() as *const u16;
504
505        let mut dot_acc = _mm256_setzero_ps();
506        let mut na_acc = _mm256_setzero_ps();
507        let mut nb_acc = _mm256_setzero_ps();
508
509        let chunks8 = n / 8;
510        for i in 0..chunks8 {
511            let base = i * 8;
512            let av = _mm256_loadu_ps(ap.add(base));
513            let bv = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
514            dot_acc = _mm256_fmadd_ps(av, bv, dot_acc);
515            na_acc = _mm256_fmadd_ps(av, av, na_acc);
516            nb_acc = _mm256_fmadd_ps(bv, bv, nb_acc);
517        }
518
519        let mut dot = hsum256(dot_acc);
520        let mut na2 = hsum256(na_acc);
521        let mut nb2 = hsum256(nb_acc);
522
523        for i in (chunks8 * 8)..n {
524            let ai = *ap.add(i);
525            let bi = f16::from_bits(*bp.add(i)).to_f32();
526            dot += ai * bi;
527            na2 += ai * ai;
528            nb2 += bi * bi;
529        }
530
531        let denom = (na2 * nb2).sqrt();
532        if denom < 1e-8 {
533            1.0
534        } else {
535            1.0 - dot / denom
536        }
537    }
538}
539
540// ── x86_64 AVX-512F — forward compatibility ───────────────────────────────────
541//
542// 16 f32/iter (vs 8 for AVX2). Runtime-detected — skipped on this machine
543// (no avx512f), active on Xeon Scalable, Zen 4+, and Intel Core 12th gen+.
544// Requires Rust ≥ 1.89 (AVX-512 intrinsics stabilised there). Gated behind
545// the `avx512` feature so the default/manylinux build always succeeds.
546
547#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
548mod avx512 {
549    use half::f16;
550    use std::arch::x86_64::*;
551
552    #[inline(always)]
553    unsafe fn hsum512(v: __m512) -> f32 {
554        // _mm512_reduce_add_ps stabilized Rust 1.89; _mm512_extractf32x8_ps needs avx512dq.
555        // Store all 16 lanes to stack (avx512f), reload as two __m256 (avx), then reduce.
556        let mut buf = [0.0f32; 16];
557        _mm512_storeu_ps(buf.as_mut_ptr(), v);
558        let lo = _mm256_loadu_ps(buf.as_ptr());
559        let hi = _mm256_loadu_ps(buf.as_ptr().add(8));
560        let sum256 = _mm256_add_ps(lo, hi);
561        let hi128 = _mm256_extractf128_ps(sum256, 1);
562        let lo128 = _mm256_castps256_ps128(sum256);
563        let sum128 = _mm_add_ps(lo128, hi128);
564        let shuf = _mm_movehdup_ps(sum128);
565        let sums = _mm_add_ps(sum128, shuf);
566        let shuf2 = _mm_movehl_ps(shuf, sums);
567        _mm_cvtss_f32(_mm_add_ss(sums, shuf2))
568    }
569
570    #[target_feature(enable = "avx512f,fma")]
571    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
572        let n = a.len().min(b.len());
573        let ap = a.as_ptr();
574        let bp = b.as_ptr();
575        let mut acc = _mm512_setzero_ps();
576        let chunks16 = n / 16;
577        for i in 0..chunks16 {
578            let base = i * 16;
579            acc = _mm512_fmadd_ps(
580                _mm512_loadu_ps(ap.add(base)),
581                _mm512_loadu_ps(bp.add(base)),
582                acc,
583            );
584        }
585        let mut sum = hsum512(acc);
586        for i in (chunks16 * 16)..n {
587            sum += *ap.add(i) * *bp.add(i);
588        }
589        sum
590    }
591
592    #[target_feature(enable = "avx512f,fma")]
593    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
594        let n = a.len().min(b.len());
595        let ap = a.as_ptr();
596        let bp = b.as_ptr();
597        let mut acc = _mm512_setzero_ps();
598        let chunks16 = n / 16;
599        for i in 0..chunks16 {
600            let base = i * 16;
601            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), _mm512_loadu_ps(bp.add(base)));
602            acc = _mm512_fmadd_ps(d, d, acc);
603        }
604        let mut sum = hsum512(acc);
605        for i in (chunks16 * 16)..n {
606            let d = *ap.add(i) - *bp.add(i);
607            sum += d * d;
608        }
609        sum.sqrt()
610    }
611
612    #[target_feature(enable = "avx512f,fma")]
613    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
614        let n = a.len().min(b.len());
615        let ap = a.as_ptr();
616        let bp = b.as_ptr();
617        let mut dot_acc = _mm512_setzero_ps();
618        let mut na_acc = _mm512_setzero_ps();
619        let mut nb_acc = _mm512_setzero_ps();
620        let chunks16 = n / 16;
621        for i in 0..chunks16 {
622            let base = i * 16;
623            let av = _mm512_loadu_ps(ap.add(base));
624            let bv = _mm512_loadu_ps(bp.add(base));
625            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
626            na_acc = _mm512_fmadd_ps(av, av, na_acc);
627            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
628        }
629        let mut dot = hsum512(dot_acc);
630        let mut na2 = hsum512(na_acc);
631        let mut nb2 = hsum512(nb_acc);
632        for i in (chunks16 * 16)..n {
633            let ai = *ap.add(i);
634            let bi = *bp.add(i);
635            dot += ai * bi;
636            na2 += ai * ai;
637            nb2 += bi * bi;
638        }
639        let (na, nb) = (na2.sqrt(), nb2.sqrt());
640        if na == 0.0 || nb == 0.0 {
641            return 1.0;
642        }
643        1.0 - dot / (na * nb)
644    }
645
646    /// dot(a_f32, b_f16) — AVX-512F+F16C+FMA, 16 F16/iter.
647    #[target_feature(enable = "avx512f,f16c,fma")]
648    pub unsafe fn dot_f16(a: &[f32], b: &[f16]) -> f32 {
649        let n = a.len().min(b.len());
650        let ap = a.as_ptr();
651        let bp = b.as_ptr() as *const u16;
652        let mut acc = _mm512_setzero_ps();
653        let chunks16 = n / 16;
654        for i in 0..chunks16 {
655            let base = i * 16;
656            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
657            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
658            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
659            acc = _mm512_fmadd_ps(_mm512_loadu_ps(ap.add(base)), bv, acc);
660        }
661        let mut sum = hsum512(acc);
662        for i in (chunks16 * 16)..n {
663            sum += *ap.add(i) * f16::from_bits(*bp.add(i)).to_f32();
664        }
665        sum
666    }
667
668    #[target_feature(enable = "avx512f,f16c,fma")]
669    pub unsafe fn euclidean_f16(a: &[f32], b: &[f16]) -> f32 {
670        let n = a.len().min(b.len());
671        let ap = a.as_ptr();
672        let bp = b.as_ptr() as *const u16;
673        let mut acc = _mm512_setzero_ps();
674        let chunks16 = n / 16;
675        for i in 0..chunks16 {
676            let base = i * 16;
677            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
678            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
679            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
680            let d = _mm512_sub_ps(_mm512_loadu_ps(ap.add(base)), bv);
681            acc = _mm512_fmadd_ps(d, d, acc);
682        }
683        let mut sum = hsum512(acc);
684        for i in (chunks16 * 16)..n {
685            let d = *ap.add(i) - f16::from_bits(*bp.add(i)).to_f32();
686            sum += d * d;
687        }
688        sum.sqrt()
689    }
690
691    #[target_feature(enable = "avx512f,f16c,fma")]
692    pub unsafe fn cosine_f16(a: &[f32], b: &[f16]) -> f32 {
693        let n = a.len().min(b.len());
694        let ap = a.as_ptr();
695        let bp = b.as_ptr() as *const u16;
696        let mut dot_acc = _mm512_setzero_ps();
697        let mut na_acc = _mm512_setzero_ps();
698        let mut nb_acc = _mm512_setzero_ps();
699        let chunks16 = n / 16;
700        for i in 0..chunks16 {
701            let base = i * 16;
702            let b_lo = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base) as *const __m128i));
703            let b_hi = _mm256_cvtph_ps(_mm_loadu_si128(bp.add(base + 8) as *const __m128i));
704            let bv = _mm512_insertf32x8(_mm512_castps256_ps512(b_lo), b_hi, 1);
705            let av = _mm512_loadu_ps(ap.add(base));
706            dot_acc = _mm512_fmadd_ps(av, bv, dot_acc);
707            na_acc = _mm512_fmadd_ps(av, av, na_acc);
708            nb_acc = _mm512_fmadd_ps(bv, bv, nb_acc);
709        }
710        let mut dot = hsum512(dot_acc);
711        let mut na2 = hsum512(na_acc);
712        let mut nb2 = hsum512(nb_acc);
713        for i in (chunks16 * 16)..n {
714            let ai = *ap.add(i);
715            let bi = f16::from_bits(*bp.add(i)).to_f32();
716            dot += ai * bi;
717            na2 += ai * ai;
718            nb2 += bi * bi;
719        }
720        let denom = (na2 * nb2).sqrt();
721        if denom < 1e-8 {
722            1.0
723        } else {
724            1.0 - dot / denom
725        }
726    }
727}
728
729// ── aarch64 NEON ──────────────────────────────────────────────────────────────
730
731#[cfg(target_arch = "aarch64")]
732mod neon_impl {
733    use std::arch::aarch64::*;
734
735    #[target_feature(enable = "neon")]
736    pub unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
737        let n = a.len().min(b.len());
738        let mut acc = vdupq_n_f32(0.0);
739        let chunks = n / 4;
740        for i in 0..chunks {
741            let base = i * 4;
742            let av = vld1q_f32(a.as_ptr().add(base));
743            let bv = vld1q_f32(b.as_ptr().add(base));
744            acc = vmlaq_f32(acc, av, bv);
745        }
746        let mut sum = vaddvq_f32(acc);
747        for i in (chunks * 4)..n {
748            sum += a[i] * b[i];
749        }
750        sum
751    }
752
753    #[target_feature(enable = "neon")]
754    pub unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
755        let n = a.len().min(b.len());
756        let mut acc = vdupq_n_f32(0.0);
757        let chunks = n / 4;
758        for i in 0..chunks {
759            let base = i * 4;
760            let d = vsubq_f32(
761                vld1q_f32(a.as_ptr().add(base)),
762                vld1q_f32(b.as_ptr().add(base)),
763            );
764            acc = vmlaq_f32(acc, d, d);
765        }
766        let mut sum = vaddvq_f32(acc);
767        for i in (chunks * 4)..n {
768            let d = a[i] - b[i];
769            sum += d * d;
770        }
771        sum.sqrt()
772    }
773
774    #[target_feature(enable = "neon")]
775    pub unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
776        let n = a.len().min(b.len());
777        let mut dot_acc = vdupq_n_f32(0.0);
778        let mut na_acc = vdupq_n_f32(0.0);
779        let mut nb_acc = vdupq_n_f32(0.0);
780        let chunks = n / 4;
781        for i in 0..chunks {
782            let base = i * 4;
783            let av = vld1q_f32(a.as_ptr().add(base));
784            let bv = vld1q_f32(b.as_ptr().add(base));
785            dot_acc = vmlaq_f32(dot_acc, av, bv);
786            na_acc = vmlaq_f32(na_acc, av, av);
787            nb_acc = vmlaq_f32(nb_acc, bv, bv);
788        }
789        let mut dot = vaddvq_f32(dot_acc);
790        let mut na2 = vaddvq_f32(na_acc);
791        let mut nb2 = vaddvq_f32(nb_acc);
792        for i in (chunks * 4)..n {
793            dot += a[i] * b[i];
794            na2 += a[i] * a[i];
795            nb2 += b[i] * b[i];
796        }
797        let (na, nb) = (na2.sqrt(), nb2.sqrt());
798        if na == 0.0 || nb == 0.0 {
799            return 1.0;
800        }
801        1.0 - dot / (na * nb)
802    }
803}
804
805// ── Tests ─────────────────────────────────────────────────────────────────────
806
807#[cfg(test)]
808mod tests {
809    use super::*;
810
811    #[test]
812    fn cosine_identical() {
813        let v = vec![1.0f32, 0.0, 0.0];
814        assert!(cosine_distance(&v, &v).abs() < 1e-5);
815    }
816
817    #[test]
818    fn cosine_orthogonal() {
819        assert!((cosine_distance(&[1.0f32, 0.0], &[0.0f32, 1.0]) - 1.0).abs() < 1e-5);
820    }
821
822    #[test]
823    fn euclidean_basic() {
824        assert!((euclidean_distance(&[0.0f32, 0.0], &[3.0f32, 4.0]) - 5.0).abs() < 1e-5);
825    }
826
827    #[test]
828    fn dot_basic() {
829        assert!((dot_product(&[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0]) - 32.0).abs() < 1e-5);
830    }
831
832    #[test]
833    fn simd_matches_scalar_dim128() {
834        use rand::{rngs::StdRng, Rng, SeedableRng};
835        let mut rng = StdRng::seed_from_u64(99);
836        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
837        let b: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
838
839        let dot_s = dot_scalar(&a, &b);
840        let euclid_s = euclidean_scalar(&a, &b);
841        let cos_s = cosine_scalar(&a, &b);
842
843        let dot_f = dot_product(&a, &b);
844        let euclid_f = euclidean_distance(&a, &b);
845        let cos_f = cosine_distance(&a, &b);
846
847        assert!(
848            (dot_f - dot_s).abs() < 1e-4,
849            "dot mismatch: {dot_f} vs {dot_s}"
850        );
851        assert!(
852            (euclid_f - euclid_s).abs() < 1e-4,
853            "euclidean mismatch: {euclid_f} vs {euclid_s}"
854        );
855        assert!(
856            (cos_f - cos_s).abs() < 1e-4,
857            "cosine mismatch: {cos_f} vs {cos_s}"
858        );
859    }
860
861    #[test]
862    fn f16_simd_matches_scalar() {
863        use rand::{rngs::StdRng, Rng, SeedableRng};
864        let mut rng = StdRng::seed_from_u64(42);
865        let a: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
866        let b_f32: Vec<f32> = (0..128).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
867        let b: Vec<f16> = b_f32.iter().map(|&x| f16::from_f32(x)).collect();
868
869        let dot_s = dot_f16_scalar(&a, &b);
870        let euclid_s = euclidean_f16_scalar(&a, &b);
871        let cos_s = cosine_f16_scalar(&a, &b);
872
873        let dot_f = dot_product_f16(&a, &b);
874        let euclid_f = euclidean_distance_f16(&a, &b);
875        let cos_f = cosine_distance_f16(&a, &b);
876
877        // F16 rounding introduces small error — tolerate 1e-3
878        assert!(
879            (dot_f - dot_s).abs() < 1e-3,
880            "f16 dot mismatch: {dot_f} vs {dot_s}"
881        );
882        assert!(
883            (euclid_f - euclid_s).abs() < 1e-3,
884            "f16 euclidean mismatch: {euclid_f} vs {euclid_s}"
885        );
886        assert!(
887            (cos_f - cos_s).abs() < 1e-3,
888            "f16 cosine mismatch: {cos_f} vs {cos_s}"
889        );
890    }
891
892    #[test]
893    fn normalize_l2_unit() {
894        let v = vec![3.0f32, 4.0];
895        let n = normalize_l2(&v);
896        let norm: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
897        assert!((norm - 1.0).abs() < 1e-6, "norm={norm}");
898        assert!((n[0] - 0.6).abs() < 1e-6);
899        assert!((n[1] - 0.8).abs() < 1e-6);
900    }
901
902    #[test]
903    fn normalized_cosine_matches_cosine_on_unit_vecs() {
904        let a = normalize_l2(&[1.0f32, 1.0, 0.0]);
905        let b = normalize_l2(&[1.0f32, 0.0, 1.0]);
906        let cos = cosine_distance(&a, &b);
907        let ncos = normalized_cosine_distance(&a, &b);
908        assert!((cos - ncos).abs() < 1e-5, "cos={cos} ncos={ncos}");
909    }
910
911    #[test]
912    fn centroid_single() {
913        let v = vec![vec![1.0f32, 2.0, 3.0]];
914        let c = compute_centroid_and_radius(&v, VectorMetric::Cosine);
915        assert_eq!(c.values, vec![1.0, 2.0, 3.0]);
916        assert!(c.radius < 1e-6, "radius={}", c.radius);
917    }
918
919    #[test]
920    fn centroid_two_points() {
921        let vs = vec![vec![0.0f32, 0.0], vec![2.0f32, 2.0]];
922        let c = compute_centroid_and_radius(&vs, VectorMetric::Euclidean);
923        assert!((c.values[0] - 1.0).abs() < 1e-6);
924        assert!(c.radius > 0.0);
925    }
926}