Skip to main content

luci/vector/
mod.rs

1//! Vector search: HNSW graph construction, search, and kNN queries.
2//!
3//! See [[hierarchical-navigable-small-world]] and [[architecture-overview|milestone-4]].
4
5use crate::core::LuciError;
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::{vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32};
8
9pub mod global;
10pub mod hnsw;
11pub mod quantize;
12pub mod query;
13
14#[cfg(test)]
15mod distance_tests;
16
17/// Distance metric for vector similarity.
18///
19/// The `#[repr(u8)]` and explicit discriminants pin the on-disk byte
20/// encoding used by HNSW and quantized-vector segment blobs. Adding a
21/// new variant requires picking the next unused discriminant and
22/// updating [`Self::from_byte`] in the same change. See
23/// [[code-must-not-lie]].
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25#[repr(u8)]
26pub enum DistanceMetric {
27    /// Cosine distance: 1 - cosine_similarity. Lower = more similar.
28    Cosine = 0,
29    /// Negative dot product. Lower = more similar (for pre-normalized vectors).
30    DotProduct = 1,
31    /// Euclidean (L2) distance. Lower = more similar.
32    L2 = 2,
33}
34
35impl DistanceMetric {
36    /// Decode a metric byte written by `metric as u8`. Panics on unknown
37    /// bytes — these mean the segment is corrupted or was written by a
38    /// newer Luci version with an unfamiliar metric. Silently mapping
39    /// unknown bytes to a default (e.g., L2) would produce wrong recall
40    /// and scoring without any signal to the caller.
41    pub fn from_byte(byte: u8) -> Self {
42        match byte {
43            0 => Self::Cosine,
44            1 => Self::DotProduct,
45            2 => Self::L2,
46            other => panic!(
47                "unknown distance metric byte {other}: segment is corrupted \
48                 or was written by a newer version of Luci"
49            ),
50        }
51    }
52}
53
54/// Compute distance between two vectors.
55pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
56    debug_assert_eq!(a.len(), b.len());
57    match metric {
58        DistanceMetric::Cosine => cosine_distance_normalized(a, b),
59        DistanceMetric::DotProduct => -dot_product(a, b),
60        DistanceMetric::L2 => l2_distance(a, b),
61    }
62}
63
64/// f32 dot product.
65///
66/// On aarch64 dispatches to an explicit NEON kernel with 4 parallel
67/// FMA accumulators × 4 lanes (16-element block) + 4-element middle
68/// + scalar tail. Rust's strict float-associativity forbids the
69/// auto-vectorizer from emitting the parallel-accumulator pattern
70/// that hnswlib's C scalar relies on; only explicit intrinsics close
71/// the gap. See [[optimization-vector-distance-kernel-trait]] §"Phase 1.x"
72/// and [[vector-bench-glove100-global-hnsw]] §"SIMD kernel gap".
73///
74/// Other architectures use the iterator chain, which the LLVM
75/// auto-vectorizer handles adequately on x86_64 with AVX2.
76fn dot_product(a: &[f32], b: &[f32]) -> f32 {
77    debug_assert_eq!(a.len(), b.len());
78    #[cfg(target_arch = "aarch64")]
79    {
80        // SAFETY: NEON is part of the AArch64 baseline ISA; always available.
81        unsafe { dot_product_neon(a, b) }
82    }
83    #[cfg(not(target_arch = "aarch64"))]
84    {
85        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
86    }
87}
88
89#[cfg(target_arch = "aarch64")]
90#[target_feature(enable = "neon")]
91unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
92    let n = a.len();
93    let a_ptr = a.as_ptr();
94    let b_ptr = b.as_ptr();
95
96    // SAFETY: caller guarantees NEON via target_feature; pointer arithmetic
97    // stays in bounds because every load is gated by `i + N <= n`.
98    unsafe {
99        let mut acc0 = vdupq_n_f32(0.0);
100        let mut acc1 = vdupq_n_f32(0.0);
101        let mut acc2 = vdupq_n_f32(0.0);
102        let mut acc3 = vdupq_n_f32(0.0);
103
104        let mut i = 0;
105        while i + 16 <= n {
106            let a0 = vld1q_f32(a_ptr.add(i));
107            let a1 = vld1q_f32(a_ptr.add(i + 4));
108            let a2 = vld1q_f32(a_ptr.add(i + 8));
109            let a3 = vld1q_f32(a_ptr.add(i + 12));
110            let b0 = vld1q_f32(b_ptr.add(i));
111            let b1 = vld1q_f32(b_ptr.add(i + 4));
112            let b2 = vld1q_f32(b_ptr.add(i + 8));
113            let b3 = vld1q_f32(b_ptr.add(i + 12));
114            acc0 = vfmaq_f32(acc0, a0, b0);
115            acc1 = vfmaq_f32(acc1, a1, b1);
116            acc2 = vfmaq_f32(acc2, a2, b2);
117            acc3 = vfmaq_f32(acc3, a3, b3);
118            i += 16;
119        }
120        while i + 4 <= n {
121            let av = vld1q_f32(a_ptr.add(i));
122            let bv = vld1q_f32(b_ptr.add(i));
123            acc0 = vfmaq_f32(acc0, av, bv);
124            i += 4;
125        }
126        let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
127        let mut sum = vaddvq_f32(acc);
128        while i < n {
129            sum += *a_ptr.add(i) * *b_ptr.add(i);
130            i += 1;
131        }
132        sum
133    }
134}
135
136/// Cosine distance over pre-normalized vectors.
137///
138/// Both inputs must be unit length. Builder invariant for
139/// [`DistanceMetric::Cosine`] (see [[optimize-cosine-norm-precompute]]).
140/// L2 / DotProduct kernels are unchanged.
141fn cosine_distance_normalized(a: &[f32], b: &[f32]) -> f32 {
142    // cosine_similarity = dot(a, b)  (in [-1, 1])
143    // cosine_distance   = 1 - dot(a, b)
144    1.0 - dot_product(a, b)
145}
146
147/// Normalize a vector to unit length in place.
148///
149/// Idempotent on already-unit-length inputs (skips the multiply when the
150/// squared norm is within 1e-4 of 1.0 — picked to absorb f32 rounding from
151/// vectors normalized in f64 elsewhere and downcast). Returns an error for
152/// zero, subnormal-collapse-to-zero, NaN, or infinite squared norms; the
153/// builder propagates the error so a cosine index never silently embeds a
154/// degenerate vector.
155///
156/// See [[optimize-cosine-norm-precompute]] and [[code-must-not-lie]].
157pub fn normalize_in_place(v: &mut [f32]) -> Result<(), LuciError> {
158    let norm_sq: f32 = v.iter().map(|x| x * x).sum();
159    if !norm_sq.is_finite() || norm_sq == 0.0 {
160        return Err(LuciError::InvalidQuery(
161            "zero-length / non-finite vector not supported with cosine \
162             metric — use metric: dot_product to bypass normalization"
163                .into(),
164        ));
165    }
166    if (norm_sq - 1.0).abs() < 1e-4 {
167        return Ok(());
168    }
169    let inv = 1.0 / norm_sq.sqrt();
170    for x in v.iter_mut() {
171        *x *= inv;
172    }
173    Ok(())
174}
175
176/// f32 L2 (Euclidean) distance.
177///
178/// Same shape and rationale as [`dot_product`] — explicit NEON on
179/// aarch64 with 4 parallel FMA accumulators on `(a-b)²`, iterator
180/// chain elsewhere. See [[optimization-vector-distance-kernel-trait]]
181/// §"Phase 1.x".
182fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
183    debug_assert_eq!(a.len(), b.len());
184    #[cfg(target_arch = "aarch64")]
185    {
186        // SAFETY: NEON is part of the AArch64 baseline ISA; always available.
187        unsafe { l2_distance_neon(a, b) }
188    }
189    #[cfg(not(target_arch = "aarch64"))]
190    {
191        a.iter()
192            .zip(b.iter())
193            .map(|(x, y)| (x - y) * (x - y))
194            .sum::<f32>()
195            .sqrt()
196    }
197}
198
199#[cfg(target_arch = "aarch64")]
200#[target_feature(enable = "neon")]
201unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
202    let n = a.len();
203    let a_ptr = a.as_ptr();
204    let b_ptr = b.as_ptr();
205
206    // SAFETY: caller guarantees NEON via target_feature; pointer arithmetic
207    // stays in bounds because every load is gated by `i + N <= n`.
208    unsafe {
209        let mut acc0 = vdupq_n_f32(0.0);
210        let mut acc1 = vdupq_n_f32(0.0);
211        let mut acc2 = vdupq_n_f32(0.0);
212        let mut acc3 = vdupq_n_f32(0.0);
213
214        let mut i = 0;
215        while i + 16 <= n {
216            let a0 = vld1q_f32(a_ptr.add(i));
217            let a1 = vld1q_f32(a_ptr.add(i + 4));
218            let a2 = vld1q_f32(a_ptr.add(i + 8));
219            let a3 = vld1q_f32(a_ptr.add(i + 12));
220            let b0 = vld1q_f32(b_ptr.add(i));
221            let b1 = vld1q_f32(b_ptr.add(i + 4));
222            let b2 = vld1q_f32(b_ptr.add(i + 8));
223            let b3 = vld1q_f32(b_ptr.add(i + 12));
224            let d0 = vsubq_f32(a0, b0);
225            let d1 = vsubq_f32(a1, b1);
226            let d2 = vsubq_f32(a2, b2);
227            let d3 = vsubq_f32(a3, b3);
228            acc0 = vfmaq_f32(acc0, d0, d0);
229            acc1 = vfmaq_f32(acc1, d1, d1);
230            acc2 = vfmaq_f32(acc2, d2, d2);
231            acc3 = vfmaq_f32(acc3, d3, d3);
232            i += 16;
233        }
234        while i + 4 <= n {
235            let av = vld1q_f32(a_ptr.add(i));
236            let bv = vld1q_f32(b_ptr.add(i));
237            let d = vsubq_f32(av, bv);
238            acc0 = vfmaq_f32(acc0, d, d);
239            i += 4;
240        }
241        let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
242        let mut sum = vaddvq_f32(acc);
243        while i < n {
244            let d = *a_ptr.add(i) - *b_ptr.add(i);
245            sum += d * d;
246            i += 1;
247        }
248        sum.sqrt()
249    }
250}
251
252/// Convert a raw distance value to a score.
253///
254/// Uses metric-specific formulas matching Lucene's `VectorSimilarityFunction`:
255/// - Cosine: `max((1 + cos_sim) / 2, 0)` where cos_sim = 1 - distance
256/// - L2: `1 / (1 + distance²)` — inherently in (0, 1]
257/// - DotProduct: `max((1 + dot) / 2, 0)` where dot = -distance
258///
259/// Lucene floors at 0 but does NOT ceil at 1 — scores above 1 are possible
260/// for DotProduct when the normalization contract is violated (unnormalized
261/// vectors). This matches Lucene's `VectorUtil.normalizeToUnitInterval()`.
262///
263/// See [[feature-knn-query-type#2b]].
264pub fn distance_to_score(raw_distance: f32, metric: DistanceMetric) -> f32 {
265    match metric {
266        DistanceMetric::Cosine => {
267            // cosine_distance = 1 - cos_sim, so cos_sim = 1 - distance
268            // Lucene: max((1 + cos_sim) / 2, 0) = max((2 - distance) / 2, 0)
269            ((2.0 - raw_distance) / 2.0).max(0.0)
270        }
271        DistanceMetric::L2 => {
272            // Lucene: 1 / (1 + squaredDistance)
273            // Inherently in (0, 1] — no clamping needed.
274            1.0 / (1.0 + raw_distance * raw_distance)
275        }
276        DistanceMetric::DotProduct => {
277            // distance = -dot_product, so dot = -distance
278            // Lucene: max((1 + dot) / 2, 0) = max((1 - distance) / 2, 0)
279            // Floor at 0, no ceiling — scores > 1 possible for unnormalized vectors.
280            ((1.0 - raw_distance) / 2.0).max(0.0)
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn cosine_identical() {
291        // Kernel post-fix requires unit-length inputs. Normalize before
292        // comparing to honor the builder invariant.
293        let mut v = vec![1.0, 2.0, 3.0];
294        normalize_in_place(&mut v).unwrap();
295        let d = distance(&v, &v, DistanceMetric::Cosine);
296        assert!(
297            d.abs() < 1e-5,
298            "identical vectors should have cosine distance ~0, got {d}"
299        );
300    }
301
302    #[test]
303    fn cosine_orthogonal() {
304        let a = vec![1.0, 0.0];
305        let b = vec![0.0, 1.0];
306        let d = distance(&a, &b, DistanceMetric::Cosine);
307        assert!(
308            (d - 1.0).abs() < 1e-5,
309            "orthogonal vectors should have cosine distance ~1, got {d}"
310        );
311    }
312
313    #[test]
314    fn cosine_opposite() {
315        let a = vec![1.0, 0.0];
316        let b = vec![-1.0, 0.0];
317        let d = distance(&a, &b, DistanceMetric::Cosine);
318        assert!(
319            (d - 2.0).abs() < 1e-5,
320            "opposite vectors should have cosine distance ~2, got {d}"
321        );
322    }
323
324    #[test]
325    fn dot_product_metric() {
326        let a = vec![1.0, 2.0];
327        let b = vec![3.0, 4.0];
328        let d = distance(&a, &b, DistanceMetric::DotProduct);
329        // dot = 1*3 + 2*4 = 11, negated = -11
330        assert_eq!(d, -11.0);
331    }
332
333    #[test]
334    fn l2_distance_metric() {
335        let a = vec![0.0, 0.0];
336        let b = vec![3.0, 4.0];
337        let d = distance(&a, &b, DistanceMetric::L2);
338        assert!((d - 5.0).abs() < 1e-5, "L2 distance should be 5.0, got {d}");
339    }
340
341    #[test]
342    fn l2_identical() {
343        let v = vec![1.0, 2.0, 3.0];
344        let d = distance(&v, &v, DistanceMetric::L2);
345        assert!(d.abs() < 1e-5);
346    }
347
348    // (Pre-fix `zero_vector_cosine` test removed: the kernel now requires
349    // unit-length inputs, so zero/raw vectors no longer flow into
350    // `distance(..., Cosine)`. Zero-vector rejection is now tested via
351    // `normalize_in_place_zero_errors` and the builder/bulk tests in
352    // hnsw.rs.)
353
354    #[test]
355    fn unit_vectors() {
356        let a = vec![1.0, 0.0, 0.0];
357        let b = vec![0.0, 1.0, 0.0];
358        let d_cos = distance(&a, &b, DistanceMetric::Cosine);
359        let d_l2 = distance(&a, &b, DistanceMetric::L2);
360        assert!((d_cos - 1.0).abs() < 1e-5);
361        assert!((d_l2 - std::f32::consts::SQRT_2).abs() < 1e-5);
362    }
363
364    // --- distance_to_score tests ---
365
366    #[test]
367    fn cosine_score_identical() {
368        // cos_sim = 1 → distance = 0 → score = (2-0)/2 = 1.0
369        let s = distance_to_score(0.0, DistanceMetric::Cosine);
370        assert!(
371            (s - 1.0).abs() < 1e-5,
372            "identical vectors: score={s}, expected 1.0"
373        );
374    }
375
376    #[test]
377    fn cosine_score_orthogonal() {
378        // cos_sim = 0 → distance = 1 → score = (2-1)/2 = 0.5
379        let s = distance_to_score(1.0, DistanceMetric::Cosine);
380        assert!(
381            (s - 0.5).abs() < 1e-5,
382            "orthogonal vectors: score={s}, expected 0.5"
383        );
384    }
385
386    #[test]
387    fn cosine_score_opposite() {
388        // cos_sim = -1 → distance = 2 → score = (2-2)/2 = 0.0
389        let s = distance_to_score(2.0, DistanceMetric::Cosine);
390        assert!(s.abs() < 1e-5, "opposite vectors: score={s}, expected 0.0");
391    }
392
393    #[test]
394    fn l2_score_identical() {
395        // distance = 0 → score = 1/(1+0) = 1.0
396        let s = distance_to_score(0.0, DistanceMetric::L2);
397        assert!((s - 1.0).abs() < 1e-5, "identical: score={s}, expected 1.0");
398    }
399
400    #[test]
401    fn l2_score_unit_distance() {
402        // distance = 1 → score = 1/(1+1) = 0.5
403        let s = distance_to_score(1.0, DistanceMetric::L2);
404        assert!(
405            (s - 0.5).abs() < 1e-5,
406            "unit distance: score={s}, expected 0.5"
407        );
408    }
409
410    #[test]
411    fn l2_score_far() {
412        // distance = 2 → score = 1/(1+4) = 0.2
413        let s = distance_to_score(2.0, DistanceMetric::L2);
414        assert!((s - 0.2).abs() < 1e-5, "far: score={s}, expected 0.2");
415    }
416
417    #[test]
418    fn dot_product_score_high_similarity() {
419        // dot = 1.0 → distance = -1.0 → score = (1-(-1))/2 = 1.0
420        let s = distance_to_score(-1.0, DistanceMetric::DotProduct);
421        assert!((s - 1.0).abs() < 1e-5, "high sim: score={s}, expected 1.0");
422    }
423
424    #[test]
425    fn dot_product_score_zero() {
426        // dot = 0 → distance = 0 → score = (1-0)/2 = 0.5
427        let s = distance_to_score(0.0, DistanceMetric::DotProduct);
428        assert!((s - 0.5).abs() < 1e-5, "zero dot: score={s}, expected 0.5");
429    }
430
431    #[test]
432    fn dot_product_score_negative() {
433        // dot = -1 → distance = 1 → score = (1-1)/2 = 0.0
434        let s = distance_to_score(1.0, DistanceMetric::DotProduct);
435        assert!(s.abs() < 1e-5, "negative dot: score={s}, expected 0.0");
436    }
437
438    #[test]
439    fn all_scores_non_negative() {
440        // Verify all metrics produce non-negative scores (Lucene floors at 0).
441        // Scores > 1 are allowed for DotProduct with unnormalized vectors.
442        for dist in [0.0, 0.5, 1.0, 2.0, 5.0, 10.0] {
443            for metric in [
444                DistanceMetric::Cosine,
445                DistanceMetric::L2,
446                DistanceMetric::DotProduct,
447            ] {
448                let s = distance_to_score(dist, metric);
449                assert!(
450                    s >= 0.0,
451                    "score should be non-negative: metric={metric:?}, dist={dist}, score={s}"
452                );
453            }
454        }
455    }
456
457    #[test]
458    fn l2_scores_bounded_unit() {
459        // L2 scores are always in (0, 1] by formula
460        for dist in [0.0, 0.1, 1.0, 10.0, 100.0] {
461            let s = distance_to_score(dist, DistanceMetric::L2);
462            assert!(
463                s > 0.0 && s <= 1.0,
464                "L2 score out of (0,1]: dist={dist}, score={s}"
465            );
466        }
467    }
468
469    #[test]
470    fn dot_product_unnormalized_can_exceed_one() {
471        // distance = -dot, so distance = -2 means dot = 2
472        // score = (1 + 2) / 2 = 1.5 — valid for unnormalized vectors
473        let s = distance_to_score(-2.0, DistanceMetric::DotProduct);
474        assert!(
475            s > 1.0,
476            "unnormalized dot product should produce score > 1: {s}"
477        );
478    }
479
480    #[test]
481    fn from_byte_round_trips_known_metrics() {
482        for metric in [
483            DistanceMetric::Cosine,
484            DistanceMetric::DotProduct,
485            DistanceMetric::L2,
486        ] {
487            let byte = metric as u8;
488            assert_eq!(DistanceMetric::from_byte(byte), metric);
489        }
490    }
491
492    #[test]
493    fn from_byte_discriminants_are_pinned() {
494        // The on-disk encoding depends on these exact discriminants.
495        // Changing them silently is a forward-compat break.
496        assert_eq!(DistanceMetric::Cosine as u8, 0);
497        assert_eq!(DistanceMetric::DotProduct as u8, 1);
498        assert_eq!(DistanceMetric::L2 as u8, 2);
499    }
500
501    #[test]
502    #[should_panic(expected = "unknown distance metric byte 3")]
503    fn from_byte_panics_on_unknown_metric() {
504        // Forward-version mismatch: a newer Luci wrote a metric we don't
505        // know. Must NOT silently fall back to L2 (or anything else).
506        let _ = DistanceMetric::from_byte(3);
507    }
508
509    #[test]
510    #[should_panic(expected = "unknown distance metric byte 255")]
511    fn from_byte_panics_on_garbage() {
512        // Segment corruption or a wild pointer landing on a metric byte.
513        let _ = DistanceMetric::from_byte(255);
514    }
515
516    // --- normalize_in_place tests ---
517
518    #[test]
519    fn normalize_in_place_unit_length() {
520        let mut v = vec![3.0_f32, 4.0];
521        normalize_in_place(&mut v).unwrap();
522        let norm = (v[0] * v[0] + v[1] * v[1]).sqrt();
523        assert!((norm - 1.0).abs() < 1e-6, "norm after normalize: {norm}");
524        assert!((v[0] - 0.6).abs() < 1e-6 && (v[1] - 0.8).abs() < 1e-6);
525    }
526
527    #[test]
528    fn normalize_in_place_idempotent_on_unit_input() {
529        let mut v = vec![0.6_f32, 0.8];
530        let before = v.clone();
531        normalize_in_place(&mut v).unwrap();
532        // Idempotent short-circuit: input already unit length, no drift.
533        for (a, b) in v.iter().zip(before.iter()) {
534            assert_eq!(a, b);
535        }
536    }
537
538    #[test]
539    fn normalize_in_place_zero_errors() {
540        let mut v = vec![0.0_f32, 0.0, 0.0];
541        let err = normalize_in_place(&mut v).unwrap_err();
542        let msg = format!("{err}");
543        assert!(
544            msg.contains("zero-length / non-finite vector"),
545            "unexpected message: {msg}",
546        );
547    }
548
549    #[test]
550    fn normalize_in_place_subnormal_errors() {
551        // f32::MIN_POSITIVE * 1e-2 lives in the subnormal range. Squaring
552        // it underflows to 0 in f32, so the zero-norm guard fires.
553        let mut v = vec![f32::MIN_POSITIVE * 1e-2; 3];
554        let err = normalize_in_place(&mut v).unwrap_err();
555        assert!(format!("{err}").contains("zero-length / non-finite vector"));
556    }
557
558    #[test]
559    fn normalize_in_place_overflow_errors() {
560        // 1e20² = 1e40, well past f32 max (~3.4e38). norm_sq → +Inf.
561        let mut v = vec![1e20_f32; 3];
562        let err = normalize_in_place(&mut v).unwrap_err();
563        assert!(format!("{err}").contains("zero-length / non-finite vector"));
564    }
565
566    #[test]
567    fn normalize_in_place_nan_errors() {
568        let mut v = vec![1.0_f32, f32::NAN, 2.0];
569        let err = normalize_in_place(&mut v).unwrap_err();
570        assert!(format!("{err}").contains("zero-length / non-finite vector"));
571    }
572
573    #[test]
574    fn cosine_score_unchanged_after_normalize() {
575        // Pre-fix cosine kernel computed `1 - dot/(norm_a*norm_b)` in f32.
576        // Post-fix computes `1 - dot(a/norm_a, b/norm_b)` in f32 after
577        // explicit normalization. Both must produce the same score within
578        // floating-point tolerance.
579        let cases: &[(Vec<f32>, Vec<f32>)] = &[
580            (vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]),
581            (vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]),
582            (vec![0.1; 100], vec![0.2; 100]),
583        ];
584        for (a_raw, b_raw) in cases {
585            // Pre-fix formula using f64 accumulators as oracle.
586            let dot64: f64 = a_raw
587                .iter()
588                .zip(b_raw.iter())
589                .map(|(x, y)| (*x as f64) * (*y as f64))
590                .sum();
591            let na64: f64 = a_raw
592                .iter()
593                .map(|x| (*x as f64).powi(2))
594                .sum::<f64>()
595                .sqrt();
596            let nb64: f64 = b_raw
597                .iter()
598                .map(|x| (*x as f64).powi(2))
599                .sum::<f64>()
600                .sqrt();
601            let oracle_dist = 1.0 - dot64 / (na64 * nb64);
602            let oracle_score = ((2.0 - oracle_dist) / 2.0).max(0.0);
603
604            let mut a = a_raw.clone();
605            let mut b = b_raw.clone();
606            normalize_in_place(&mut a).unwrap();
607            normalize_in_place(&mut b).unwrap();
608            let d = distance(&a, &b, DistanceMetric::Cosine);
609            let s = distance_to_score(d, DistanceMetric::Cosine);
610            assert!(
611                ((s as f64) - oracle_score).abs() < 1e-3,
612                "score drift > 1e-3: post={s}, oracle={oracle_score}",
613            );
614        }
615    }
616
617    #[test]
618    fn cosine_distance_orthogonal_after_normalize() {
619        let mut a = vec![3.0, 0.0];
620        let mut b = vec![0.0, 7.0];
621        normalize_in_place(&mut a).unwrap();
622        normalize_in_place(&mut b).unwrap();
623        let d = distance(&a, &b, DistanceMetric::Cosine);
624        assert!((d - 1.0).abs() < 1e-6, "orthogonal cosine distance: {d}");
625    }
626
627    #[test]
628    fn cosine_distance_identical_after_normalize() {
629        let mut a = vec![1.0, 2.0, 3.0];
630        let mut b = a.clone();
631        normalize_in_place(&mut a).unwrap();
632        normalize_in_place(&mut b).unwrap();
633        let d = distance(&a, &b, DistanceMetric::Cosine);
634        assert!(d.abs() < 1e-6, "identical cosine distance: {d}");
635    }
636}