Skip to main content

ext_vector/
distance.rs

1//! SIMD-accelerated distance functions for vector similarity search.
2//!
3//! Uses NEON intrinsics on aarch64 (Apple Silicon), with scalar fallback
4//! for other architectures (auto-vectorized by LLVM to AVX2/SSE).
5
6/// Distance metric for vector similarity.
7#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8pub enum DistanceMetric {
9    /// Squared Euclidean distance: sum((a[i] - b[i])^2).
10    L2,
11    /// Cosine distance: 1 - (a . b) / (|a| * |b|).
12    Cosine,
13}
14
15impl DistanceMetric {
16    /// Compute distance between two vectors.
17    #[inline]
18    pub fn distance(self, a: &[f32], b: &[f32]) -> f32 {
19        match self {
20            DistanceMetric::L2 => l2_distance(a, b),
21            DistanceMetric::Cosine => cosine_distance(a, b),
22        }
23    }
24}
25
26/// Squared L2 (Euclidean) distance.
27#[inline]
28pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
29    debug_assert_eq!(a.len(), b.len());
30
31    #[cfg(target_arch = "aarch64")]
32    {
33        l2_neon(a, b)
34    }
35
36    #[cfg(not(target_arch = "aarch64"))]
37    {
38        l2_scalar(a, b)
39    }
40}
41
42/// Cosine distance: 1.0 - cosine_similarity.
43#[inline]
44pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
45    debug_assert_eq!(a.len(), b.len());
46
47    #[cfg(target_arch = "aarch64")]
48    {
49        cosine_neon(a, b)
50    }
51
52    #[cfg(not(target_arch = "aarch64"))]
53    {
54        cosine_scalar(a, b)
55    }
56}
57
58// ---- Scalar implementations (auto-vectorize well with LLVM) ----
59
60#[allow(dead_code)]
61fn l2_scalar(a: &[f32], b: &[f32]) -> f32 {
62    let mut sum = 0.0f32;
63    for i in 0..a.len() {
64        let d = a[i] - b[i];
65        sum += d * d;
66    }
67    sum
68}
69
70#[allow(dead_code)]
71fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
72    let mut dot = 0.0f32;
73    let mut norm_a = 0.0f32;
74    let mut norm_b = 0.0f32;
75    for i in 0..a.len() {
76        dot += a[i] * b[i];
77        norm_a += a[i] * a[i];
78        norm_b += b[i] * b[i];
79    }
80    let denom = (norm_a * norm_b).sqrt();
81    if denom == 0.0 {
82        1.0 // zero vectors → maximum distance
83    } else {
84        1.0 - dot / denom
85    }
86}
87
88// ---- NEON implementations (aarch64) ----
89
90#[cfg(target_arch = "aarch64")]
91fn l2_neon(a: &[f32], b: &[f32]) -> f32 {
92    use core::arch::aarch64::*;
93
94    let n = a.len();
95    let chunks = n / 4;
96    let sum;
97
98    // SAFETY: We read exactly `chunks * 4` aligned f32 values from both slices.
99    unsafe {
100        let mut acc = vdupq_n_f32(0.0);
101        let pa = a.as_ptr();
102        let pb = b.as_ptr();
103
104        for i in 0..chunks {
105            let va = vld1q_f32(pa.add(i * 4));
106            let vb = vld1q_f32(pb.add(i * 4));
107            let diff = vsubq_f32(va, vb);
108            acc = vfmaq_f32(acc, diff, diff);
109        }
110
111        sum = vaddvq_f32(acc);
112    }
113
114    // Scalar tail for remaining elements.
115    let mut tail_sum = sum;
116    for i in (chunks * 4)..n {
117        let d = a[i] - b[i];
118        tail_sum += d * d;
119    }
120    tail_sum
121}
122
123#[cfg(target_arch = "aarch64")]
124fn cosine_neon(a: &[f32], b: &[f32]) -> f32 {
125    use core::arch::aarch64::*;
126
127    let n = a.len();
128    let chunks = n / 4;
129    let (dot, norm_a, norm_b);
130
131    // SAFETY: We read exactly `chunks * 4` aligned f32 values from both slices.
132    unsafe {
133        let mut acc_dot = vdupq_n_f32(0.0);
134        let mut acc_na = vdupq_n_f32(0.0);
135        let mut acc_nb = vdupq_n_f32(0.0);
136        let pa = a.as_ptr();
137        let pb = b.as_ptr();
138
139        for i in 0..chunks {
140            let va = vld1q_f32(pa.add(i * 4));
141            let vb = vld1q_f32(pb.add(i * 4));
142            acc_dot = vfmaq_f32(acc_dot, va, vb);
143            acc_na = vfmaq_f32(acc_na, va, va);
144            acc_nb = vfmaq_f32(acc_nb, vb, vb);
145        }
146
147        dot = vaddvq_f32(acc_dot);
148        norm_a = vaddvq_f32(acc_na);
149        norm_b = vaddvq_f32(acc_nb);
150    }
151
152    // Scalar tail.
153    let mut t_dot = dot;
154    let mut t_na = norm_a;
155    let mut t_nb = norm_b;
156    for i in (chunks * 4)..n {
157        t_dot += a[i] * b[i];
158        t_na += a[i] * a[i];
159        t_nb += b[i] * b[i];
160    }
161
162    let denom = (t_na * t_nb).sqrt();
163    if denom == 0.0 {
164        1.0
165    } else {
166        1.0 - t_dot / denom
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn l2_known_vectors() {
176        let a = [1.0f32, 0.0, 0.0];
177        let b = [0.0f32, 1.0, 0.0];
178        let d = l2_distance(&a, &b);
179        assert!(
180            (d - 2.0).abs() < 1e-6,
181            "L2([1,0,0], [0,1,0]) = {d}, expected 2.0"
182        );
183    }
184
185    #[test]
186    fn l2_identical() {
187        let a = [1.0f32, 2.0, 3.0];
188        assert!((l2_distance(&a, &a) - 0.0).abs() < 1e-6);
189    }
190
191    #[test]
192    fn cosine_orthogonal() {
193        let a = [1.0f32, 0.0, 0.0];
194        let b = [0.0f32, 1.0, 0.0];
195        let d = cosine_distance(&a, &b);
196        assert!(
197            (d - 1.0).abs() < 1e-6,
198            "cosine orthogonal = {d}, expected 1.0"
199        );
200    }
201
202    #[test]
203    fn cosine_identical() {
204        let a = [1.0f32, 2.0, 3.0];
205        let d = cosine_distance(&a, &a);
206        assert!(d.abs() < 1e-5, "cosine identical = {d}, expected ~0.0");
207    }
208
209    #[test]
210    fn cosine_zero_vector() {
211        let a = [0.0f32; 3];
212        let b = [1.0f32, 2.0, 3.0];
213        assert_eq!(cosine_distance(&a, &b), 1.0);
214    }
215
216    #[test]
217    fn high_dim_consistency() {
218        // Test with higher-dimensional vectors to exercise SIMD lanes.
219        let dim = 128;
220        let a: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
221        let b: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1) + 0.5).collect();
222
223        let l2 = l2_distance(&a, &b);
224        // Each element differs by 0.5, so L2 = dim * 0.25 = 32.0
225        assert!((l2 - 32.0).abs() < 0.01, "L2 128-d = {l2}, expected 32.0");
226
227        let cos = cosine_distance(&a, &b);
228        assert!(
229            cos >= 0.0 && cos <= 1.0,
230            "cosine 128-d = {cos}, out of range"
231        );
232    }
233
234    #[test]
235    fn metric_dispatch() {
236        let a = [1.0f32, 0.0];
237        let b = [0.0f32, 1.0];
238
239        let l2 = DistanceMetric::L2.distance(&a, &b);
240        let cos = DistanceMetric::Cosine.distance(&a, &b);
241        assert!((l2 - 2.0).abs() < 1e-6);
242        assert!((cos - 1.0).abs() < 1e-6);
243    }
244}