Skip to main content

quiver_simd/
lib.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2//! SIMD distance kernels for Quiver — cosine, squared-L2, and inner product over
3//! `f32` and `i8`, plus Hamming distance over packed-bit (`u64`) vectors, with
4//! runtime CPU-feature dispatch and a scalar fallback.
5//!
6//! Each public function selects the best available implementation once per call
7//! (`is_x86_feature_detected!` results are cached by `std`) and always has a
8//! correct scalar fallback. The SIMD paths are differential-tested against the
9//! scalar reference. Design: `docs/index/distance-kernels.md`, ADR-0009.
10
11mod scalar;
12
13#[cfg(target_arch = "x86_64")]
14mod avx2;
15
16/// A supported distance / similarity metric over dense vectors.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19pub enum Metric {
20    /// Inner product — higher is more similar.
21    Dot,
22    /// Cosine similarity in `[-1, 1]` — higher is more similar.
23    Cosine,
24    /// Squared Euclidean distance — lower is more similar.
25    L2,
26}
27
28/// Inner product (dot product) of two equal-length `f32` vectors.
29///
30/// # Panics
31/// Panics if `a.len() != b.len()`.
32#[inline]
33#[must_use]
34pub fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
35    assert_eq!(a.len(), b.len(), "vectors must have equal length");
36    #[cfg(target_arch = "x86_64")]
37    {
38        if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
39            // SAFETY: AVX and FMA were just confirmed present.
40            return unsafe { avx2::dot_f32(a, b) };
41        }
42    }
43    scalar::dot_f32(a, b)
44}
45
46/// Squared Euclidean distance of two equal-length `f32` vectors.
47///
48/// # Panics
49/// Panics if `a.len() != b.len()`.
50#[inline]
51#[must_use]
52pub fn l2_sq_f32(a: &[f32], b: &[f32]) -> f32 {
53    assert_eq!(a.len(), b.len(), "vectors must have equal length");
54    #[cfg(target_arch = "x86_64")]
55    {
56        if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
57            // SAFETY: AVX and FMA were just confirmed present.
58            return unsafe { avx2::l2_sq_f32(a, b) };
59        }
60    }
61    scalar::l2_sq_f32(a, b)
62}
63
64/// Cosine similarity (in `[-1, 1]`) of two equal-length `f32` vectors.
65///
66/// Returns `0.0` if either vector has zero magnitude.
67///
68/// # Panics
69/// Panics if `a.len() != b.len()`.
70#[inline]
71#[must_use]
72pub fn cosine_f32(a: &[f32], b: &[f32]) -> f32 {
73    assert_eq!(a.len(), b.len(), "vectors must have equal length");
74    #[cfg(target_arch = "x86_64")]
75    {
76        if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
77            // SAFETY: AVX and FMA were just confirmed present.
78            return unsafe { avx2::cosine_f32(a, b) };
79        }
80    }
81    scalar::cosine_f32(a, b)
82}
83
84/// Inner product of two equal-length `i8` vectors, accumulated in `i32`.
85///
86/// # Panics
87/// Panics if `a.len() != b.len()`.
88#[inline]
89#[must_use]
90pub fn dot_i8(a: &[i8], b: &[i8]) -> i32 {
91    assert_eq!(a.len(), b.len(), "vectors must have equal length");
92    #[cfg(target_arch = "x86_64")]
93    {
94        if is_x86_feature_detected!("avx2") {
95            // SAFETY: AVX2 was just confirmed present.
96            return unsafe { avx2::dot_i8(a, b) };
97        }
98    }
99    scalar::dot_i8(a, b)
100}
101
102/// Squared Euclidean distance of two equal-length `i8` vectors, in `i32`.
103///
104/// # Panics
105/// Panics if `a.len() != b.len()`.
106#[inline]
107#[must_use]
108pub fn l2_sq_i8(a: &[i8], b: &[i8]) -> i32 {
109    assert_eq!(a.len(), b.len(), "vectors must have equal length");
110    #[cfg(target_arch = "x86_64")]
111    {
112        if is_x86_feature_detected!("avx2") {
113            // SAFETY: AVX2 was just confirmed present.
114            return unsafe { avx2::l2_sq_i8(a, b) };
115        }
116    }
117    scalar::l2_sq_i8(a, b)
118}
119
120/// Hamming distance of two equal-length packed-bit vectors: the number of
121/// differing bits, `popcount(a XOR b)`, over `u64` words.
122///
123/// This is the fast pre-filter for binary-quantized search (ADR-0008): pack each
124/// vector's sign bits into `u64` words, rank candidates by Hamming distance, then
125/// re-rank the shortlist with an exact full-precision metric.
126///
127/// # Panics
128/// Panics if `a.len() != b.len()`.
129#[inline]
130#[must_use]
131pub fn hamming_u64(a: &[u64], b: &[u64]) -> u32 {
132    assert_eq!(a.len(), b.len(), "vectors must have equal length");
133    #[cfg(target_arch = "x86_64")]
134    {
135        if is_x86_feature_detected!("avx2") {
136            // SAFETY: AVX2 was just confirmed present.
137            return unsafe { avx2::hamming_u64(a, b) };
138        }
139    }
140    scalar::hamming_u64(a, b)
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    /// Tiny deterministic xorshift PRNG so tests need no external dependency.
148    struct Rng(u64);
149    impl Rng {
150        fn new(seed: u64) -> Self {
151            Self(seed | 1)
152        }
153        fn next_u64(&mut self) -> u64 {
154            let mut x = self.0;
155            x ^= x << 13;
156            x ^= x >> 7;
157            x ^= x << 17;
158            self.0 = x;
159            x
160        }
161        /// A value in `[-1, 1)`, from 24 random bits.
162        fn f32(&mut self) -> f32 {
163            let bits = (self.next_u64() >> 40) as u32;
164            (bits as f32 / 16_777_216.0) * 2.0 - 1.0
165        }
166        fn i8(&mut self) -> i8 {
167            (self.next_u64() >> 56) as i8
168        }
169    }
170
171    const F32_DIMS: &[usize] = &[0, 1, 7, 8, 9, 16, 31, 128, 769];
172    const I8_DIMS: &[usize] = &[0, 1, 15, 16, 17, 31, 128, 769];
173    // Word counts including non-multiples of 4 to exercise the AVX2 tail.
174    const U64_WORDS: &[usize] = &[0, 1, 2, 3, 4, 5, 7, 8, 13, 16, 96];
175
176    // A naive, obviously-correct Hamming reference: count differing bits one at
177    // a time, independent of `count_ones`.
178    fn hamming_naive(a: &[u64], b: &[u64]) -> u32 {
179        let mut n = 0u32;
180        for (x, y) in a.iter().zip(b.iter()) {
181            let mut d = x ^ y;
182            while d != 0 {
183                n += (d & 1) as u32;
184                d >>= 1;
185            }
186        }
187        n
188    }
189
190    fn close(got: f32, exp: f32) -> bool {
191        (got - exp).abs() <= 1e-3 + 1e-4 * exp.abs()
192    }
193
194    #[test]
195    fn dot_f32_matches_scalar() {
196        let mut rng = Rng::new(0xC0FFEE);
197        for &dim in F32_DIMS {
198            let a: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
199            let b: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
200            let (got, exp) = (dot_f32(&a, &b), scalar::dot_f32(&a, &b));
201            assert!(close(got, exp), "dim {dim}: {got} vs {exp}");
202        }
203    }
204
205    #[test]
206    fn l2_sq_f32_matches_scalar() {
207        let mut rng = Rng::new(0xBEEF);
208        for &dim in F32_DIMS {
209            let a: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
210            let b: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
211            let (got, exp) = (l2_sq_f32(&a, &b), scalar::l2_sq_f32(&a, &b));
212            assert!(close(got, exp), "dim {dim}: {got} vs {exp}");
213        }
214    }
215
216    #[test]
217    fn cosine_f32_matches_scalar() {
218        let mut rng = Rng::new(0xABCD);
219        for &dim in F32_DIMS {
220            let a: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
221            let b: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
222            let (got, exp) = (cosine_f32(&a, &b), scalar::cosine_f32(&a, &b));
223            assert!(close(got, exp), "dim {dim}: {got} vs {exp}");
224        }
225    }
226
227    #[test]
228    fn i8_kernels_match_scalar_exactly() {
229        let mut rng = Rng::new(0x1234_5678);
230        for &dim in I8_DIMS {
231            let a: Vec<i8> = (0..dim).map(|_| rng.i8()).collect();
232            let b: Vec<i8> = (0..dim).map(|_| rng.i8()).collect();
233            assert_eq!(dot_i8(&a, &b), scalar::dot_i8(&a, &b), "dot dim {dim}");
234            assert_eq!(l2_sq_i8(&a, &b), scalar::l2_sq_i8(&a, &b), "l2 dim {dim}");
235        }
236    }
237
238    #[test]
239    fn cosine_zero_vector_is_zero() {
240        let z = vec![0.0f32; 8];
241        let v = vec![1.0f32; 8];
242        assert!(cosine_f32(&z, &v).abs() < 1e-6);
243        assert!(cosine_f32(&z, &z).abs() < 1e-6);
244    }
245
246    #[test]
247    fn empty_vectors() {
248        let e: [f32; 0] = [];
249        assert!(dot_f32(&e, &e).abs() < 1e-6);
250        assert!(l2_sq_f32(&e, &e).abs() < 1e-6);
251        let ei: [i8; 0] = [];
252        assert_eq!(dot_i8(&ei, &ei), 0);
253        let eu: [u64; 0] = [];
254        assert_eq!(hamming_u64(&eu, &eu), 0);
255    }
256
257    #[test]
258    fn hamming_matches_naive_and_scalar() {
259        let mut rng = Rng::new(0x9911_AA55);
260        for &words in U64_WORDS {
261            let a: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
262            let b: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
263            let naive = hamming_naive(&a, &b);
264            assert_eq!(hamming_u64(&a, &b), naive, "dispatch, {words} words");
265            assert_eq!(scalar::hamming_u64(&a, &b), naive, "scalar, {words} words");
266        }
267    }
268
269    #[test]
270    fn hamming_axioms() {
271        let mut rng = Rng::new(0x5151_2727);
272        for &words in U64_WORDS {
273            let a: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
274            let b: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
275            // Identity of indiscernibles, symmetry, and the bit-count bound.
276            assert_eq!(hamming_u64(&a, &a), 0, "{words}: d(a,a)=0");
277            assert_eq!(
278                hamming_u64(&a, &b),
279                hamming_u64(&b, &a),
280                "{words}: symmetry"
281            );
282            assert!(
283                hamming_u64(&a, &b) <= (words * 64) as u32,
284                "{words}: within bound"
285            );
286        }
287        // All-ones vs all-zeros differs in every bit.
288        let ones = vec![u64::MAX; 8];
289        let zeros = vec![0u64; 8];
290        assert_eq!(hamming_u64(&ones, &zeros), 8 * 64);
291    }
292
293    #[cfg(target_arch = "x86_64")]
294    #[test]
295    fn hamming_avx2_matches_scalar_directly() {
296        if !is_x86_feature_detected!("avx2") {
297            return;
298        }
299        let mut rng = Rng::new(0xC1A0_F00D);
300        for &words in U64_WORDS {
301            let a: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
302            let b: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
303            // SAFETY: AVX2 detected above.
304            let got = unsafe { avx2::hamming_u64(&a, &b) };
305            assert_eq!(got, scalar::hamming_u64(&a, &b), "avx2 {words} words");
306        }
307    }
308
309    #[cfg(target_arch = "x86_64")]
310    #[test]
311    fn avx2_paths_match_scalar_directly() {
312        let have_f32 = is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma");
313        let have_i8 = is_x86_feature_detected!("avx2");
314        if !have_f32 && !have_i8 {
315            return;
316        }
317        let mut rng = Rng::new(99);
318        for &dim in &[8usize, 17, 256, 769] {
319            let a: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
320            let b: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
321            if have_f32 {
322                // SAFETY: AVX + FMA detected above.
323                let got = unsafe { avx2::dot_f32(&a, &b) };
324                assert!(close(got, scalar::dot_f32(&a, &b)), "dot dim {dim}");
325                // SAFETY: AVX + FMA detected above.
326                let got = unsafe { avx2::l2_sq_f32(&a, &b) };
327                assert!(close(got, scalar::l2_sq_f32(&a, &b)), "l2 dim {dim}");
328            }
329            if have_i8 {
330                let ai: Vec<i8> = (0..dim).map(|_| rng.i8()).collect();
331                let bi: Vec<i8> = (0..dim).map(|_| rng.i8()).collect();
332                // SAFETY: AVX2 detected above.
333                assert_eq!(unsafe { avx2::dot_i8(&ai, &bi) }, scalar::dot_i8(&ai, &bi));
334            }
335        }
336    }
337}