1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8pub enum DistanceMetric {
9 L2,
11 Cosine,
13}
14
15impl DistanceMetric {
16 #[inline]
18 pub fn distance(self, a: &[f32], b: &[f32]) -> f32 {
19 match self {
20 DistanceMetric::L2 => l2_distance(a, b),
21 DistanceMetric::Cosine => cosine_distance(a, b),
22 }
23 }
24}
25
26#[inline]
28pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
29 debug_assert_eq!(a.len(), b.len());
30
31 #[cfg(target_arch = "aarch64")]
32 {
33 l2_neon(a, b)
34 }
35
36 #[cfg(not(target_arch = "aarch64"))]
37 {
38 l2_scalar(a, b)
39 }
40}
41
42#[inline]
44pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
45 debug_assert_eq!(a.len(), b.len());
46
47 #[cfg(target_arch = "aarch64")]
48 {
49 cosine_neon(a, b)
50 }
51
52 #[cfg(not(target_arch = "aarch64"))]
53 {
54 cosine_scalar(a, b)
55 }
56}
57
58#[allow(dead_code)]
61fn l2_scalar(a: &[f32], b: &[f32]) -> f32 {
62 let mut sum = 0.0f32;
63 for i in 0..a.len() {
64 let d = a[i] - b[i];
65 sum += d * d;
66 }
67 sum
68}
69
70#[allow(dead_code)]
71fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
72 let mut dot = 0.0f32;
73 let mut norm_a = 0.0f32;
74 let mut norm_b = 0.0f32;
75 for i in 0..a.len() {
76 dot += a[i] * b[i];
77 norm_a += a[i] * a[i];
78 norm_b += b[i] * b[i];
79 }
80 let denom = (norm_a * norm_b).sqrt();
81 if denom == 0.0 {
82 1.0 } else {
84 1.0 - dot / denom
85 }
86}
87
88#[cfg(target_arch = "aarch64")]
91fn l2_neon(a: &[f32], b: &[f32]) -> f32 {
92 use core::arch::aarch64::*;
93
94 let n = a.len();
95 let chunks = n / 4;
96 let sum;
97
98 unsafe {
100 let mut acc = vdupq_n_f32(0.0);
101 let pa = a.as_ptr();
102 let pb = b.as_ptr();
103
104 for i in 0..chunks {
105 let va = vld1q_f32(pa.add(i * 4));
106 let vb = vld1q_f32(pb.add(i * 4));
107 let diff = vsubq_f32(va, vb);
108 acc = vfmaq_f32(acc, diff, diff);
109 }
110
111 sum = vaddvq_f32(acc);
112 }
113
114 let mut tail_sum = sum;
116 for i in (chunks * 4)..n {
117 let d = a[i] - b[i];
118 tail_sum += d * d;
119 }
120 tail_sum
121}
122
123#[cfg(target_arch = "aarch64")]
124fn cosine_neon(a: &[f32], b: &[f32]) -> f32 {
125 use core::arch::aarch64::*;
126
127 let n = a.len();
128 let chunks = n / 4;
129 let (dot, norm_a, norm_b);
130
131 unsafe {
133 let mut acc_dot = vdupq_n_f32(0.0);
134 let mut acc_na = vdupq_n_f32(0.0);
135 let mut acc_nb = vdupq_n_f32(0.0);
136 let pa = a.as_ptr();
137 let pb = b.as_ptr();
138
139 for i in 0..chunks {
140 let va = vld1q_f32(pa.add(i * 4));
141 let vb = vld1q_f32(pb.add(i * 4));
142 acc_dot = vfmaq_f32(acc_dot, va, vb);
143 acc_na = vfmaq_f32(acc_na, va, va);
144 acc_nb = vfmaq_f32(acc_nb, vb, vb);
145 }
146
147 dot = vaddvq_f32(acc_dot);
148 norm_a = vaddvq_f32(acc_na);
149 norm_b = vaddvq_f32(acc_nb);
150 }
151
152 let mut t_dot = dot;
154 let mut t_na = norm_a;
155 let mut t_nb = norm_b;
156 for i in (chunks * 4)..n {
157 t_dot += a[i] * b[i];
158 t_na += a[i] * a[i];
159 t_nb += b[i] * b[i];
160 }
161
162 let denom = (t_na * t_nb).sqrt();
163 if denom == 0.0 {
164 1.0
165 } else {
166 1.0 - t_dot / denom
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn l2_known_vectors() {
176 let a = [1.0f32, 0.0, 0.0];
177 let b = [0.0f32, 1.0, 0.0];
178 let d = l2_distance(&a, &b);
179 assert!((d - 2.0).abs() < 1e-6, "L2([1,0,0], [0,1,0]) = {d}, expected 2.0");
180 }
181
182 #[test]
183 fn l2_identical() {
184 let a = [1.0f32, 2.0, 3.0];
185 assert!((l2_distance(&a, &a) - 0.0).abs() < 1e-6);
186 }
187
188 #[test]
189 fn cosine_orthogonal() {
190 let a = [1.0f32, 0.0, 0.0];
191 let b = [0.0f32, 1.0, 0.0];
192 let d = cosine_distance(&a, &b);
193 assert!((d - 1.0).abs() < 1e-6, "cosine orthogonal = {d}, expected 1.0");
194 }
195
196 #[test]
197 fn cosine_identical() {
198 let a = [1.0f32, 2.0, 3.0];
199 let d = cosine_distance(&a, &a);
200 assert!(d.abs() < 1e-5, "cosine identical = {d}, expected ~0.0");
201 }
202
203 #[test]
204 fn cosine_zero_vector() {
205 let a = [0.0f32; 3];
206 let b = [1.0f32, 2.0, 3.0];
207 assert_eq!(cosine_distance(&a, &b), 1.0);
208 }
209
210 #[test]
211 fn high_dim_consistency() {
212 let dim = 128;
214 let a: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
215 let b: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1) + 0.5).collect();
216
217 let l2 = l2_distance(&a, &b);
218 assert!((l2 - 32.0).abs() < 0.01, "L2 128-d = {l2}, expected 32.0");
220
221 let cos = cosine_distance(&a, &b);
222 assert!(cos >= 0.0 && cos <= 1.0, "cosine 128-d = {cos}, out of range");
223 }
224
225 #[test]
226 fn metric_dispatch() {
227 let a = [1.0f32, 0.0];
228 let b = [0.0f32, 1.0];
229
230 let l2 = DistanceMetric::L2.distance(&a, &b);
231 let cos = DistanceMetric::Cosine.distance(&a, &b);
232 assert!((l2 - 2.0).abs() < 1e-6);
233 assert!((cos - 1.0).abs() < 1e-6);
234 }
235}