1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8pub enum DistanceMetric {
9 L2,
11 Cosine,
13}
14
15impl DistanceMetric {
16 #[inline]
18 pub fn distance(self, a: &[f32], b: &[f32]) -> f32 {
19 match self {
20 DistanceMetric::L2 => l2_distance(a, b),
21 DistanceMetric::Cosine => cosine_distance(a, b),
22 }
23 }
24}
25
26#[inline]
28pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
29 debug_assert_eq!(a.len(), b.len());
30
31 #[cfg(target_arch = "aarch64")]
32 {
33 l2_neon(a, b)
34 }
35
36 #[cfg(not(target_arch = "aarch64"))]
37 {
38 l2_scalar(a, b)
39 }
40}
41
42#[inline]
44pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
45 debug_assert_eq!(a.len(), b.len());
46
47 #[cfg(target_arch = "aarch64")]
48 {
49 cosine_neon(a, b)
50 }
51
52 #[cfg(not(target_arch = "aarch64"))]
53 {
54 cosine_scalar(a, b)
55 }
56}
57
58#[allow(dead_code)]
61fn l2_scalar(a: &[f32], b: &[f32]) -> f32 {
62 let mut sum = 0.0f32;
63 for i in 0..a.len() {
64 let d = a[i] - b[i];
65 sum += d * d;
66 }
67 sum
68}
69
70#[allow(dead_code)]
71fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
72 let mut dot = 0.0f32;
73 let mut norm_a = 0.0f32;
74 let mut norm_b = 0.0f32;
75 for i in 0..a.len() {
76 dot += a[i] * b[i];
77 norm_a += a[i] * a[i];
78 norm_b += b[i] * b[i];
79 }
80 let denom = (norm_a * norm_b).sqrt();
81 if denom == 0.0 {
82 1.0 } else {
84 1.0 - dot / denom
85 }
86}
87
88#[cfg(target_arch = "aarch64")]
91fn l2_neon(a: &[f32], b: &[f32]) -> f32 {
92 use core::arch::aarch64::*;
93
94 let n = a.len();
95 let chunks = n / 4;
96 let sum;
97
98 unsafe {
100 let mut acc = vdupq_n_f32(0.0);
101 let pa = a.as_ptr();
102 let pb = b.as_ptr();
103
104 for i in 0..chunks {
105 let va = vld1q_f32(pa.add(i * 4));
106 let vb = vld1q_f32(pb.add(i * 4));
107 let diff = vsubq_f32(va, vb);
108 acc = vfmaq_f32(acc, diff, diff);
109 }
110
111 sum = vaddvq_f32(acc);
112 }
113
114 let mut tail_sum = sum;
116 for i in (chunks * 4)..n {
117 let d = a[i] - b[i];
118 tail_sum += d * d;
119 }
120 tail_sum
121}
122
123#[cfg(target_arch = "aarch64")]
124fn cosine_neon(a: &[f32], b: &[f32]) -> f32 {
125 use core::arch::aarch64::*;
126
127 let n = a.len();
128 let chunks = n / 4;
129 let (dot, norm_a, norm_b);
130
131 unsafe {
133 let mut acc_dot = vdupq_n_f32(0.0);
134 let mut acc_na = vdupq_n_f32(0.0);
135 let mut acc_nb = vdupq_n_f32(0.0);
136 let pa = a.as_ptr();
137 let pb = b.as_ptr();
138
139 for i in 0..chunks {
140 let va = vld1q_f32(pa.add(i * 4));
141 let vb = vld1q_f32(pb.add(i * 4));
142 acc_dot = vfmaq_f32(acc_dot, va, vb);
143 acc_na = vfmaq_f32(acc_na, va, va);
144 acc_nb = vfmaq_f32(acc_nb, vb, vb);
145 }
146
147 dot = vaddvq_f32(acc_dot);
148 norm_a = vaddvq_f32(acc_na);
149 norm_b = vaddvq_f32(acc_nb);
150 }
151
152 let mut t_dot = dot;
154 let mut t_na = norm_a;
155 let mut t_nb = norm_b;
156 for i in (chunks * 4)..n {
157 t_dot += a[i] * b[i];
158 t_na += a[i] * a[i];
159 t_nb += b[i] * b[i];
160 }
161
162 let denom = (t_na * t_nb).sqrt();
163 if denom == 0.0 {
164 1.0
165 } else {
166 1.0 - t_dot / denom
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn l2_known_vectors() {
176 let a = [1.0f32, 0.0, 0.0];
177 let b = [0.0f32, 1.0, 0.0];
178 let d = l2_distance(&a, &b);
179 assert!(
180 (d - 2.0).abs() < 1e-6,
181 "L2([1,0,0], [0,1,0]) = {d}, expected 2.0"
182 );
183 }
184
185 #[test]
186 fn l2_identical() {
187 let a = [1.0f32, 2.0, 3.0];
188 assert!((l2_distance(&a, &a) - 0.0).abs() < 1e-6);
189 }
190
191 #[test]
192 fn cosine_orthogonal() {
193 let a = [1.0f32, 0.0, 0.0];
194 let b = [0.0f32, 1.0, 0.0];
195 let d = cosine_distance(&a, &b);
196 assert!(
197 (d - 1.0).abs() < 1e-6,
198 "cosine orthogonal = {d}, expected 1.0"
199 );
200 }
201
202 #[test]
203 fn cosine_identical() {
204 let a = [1.0f32, 2.0, 3.0];
205 let d = cosine_distance(&a, &a);
206 assert!(d.abs() < 1e-5, "cosine identical = {d}, expected ~0.0");
207 }
208
209 #[test]
210 fn cosine_zero_vector() {
211 let a = [0.0f32; 3];
212 let b = [1.0f32, 2.0, 3.0];
213 assert_eq!(cosine_distance(&a, &b), 1.0);
214 }
215
216 #[test]
217 fn high_dim_consistency() {
218 let dim = 128;
220 let a: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
221 let b: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1) + 0.5).collect();
222
223 let l2 = l2_distance(&a, &b);
224 assert!((l2 - 32.0).abs() < 0.01, "L2 128-d = {l2}, expected 32.0");
226
227 let cos = cosine_distance(&a, &b);
228 assert!(
229 cos >= 0.0 && cos <= 1.0,
230 "cosine 128-d = {cos}, out of range"
231 );
232 }
233
234 #[test]
235 fn metric_dispatch() {
236 let a = [1.0f32, 0.0];
237 let b = [0.0f32, 1.0];
238
239 let l2 = DistanceMetric::L2.distance(&a, &b);
240 let cos = DistanceMetric::Cosine.distance(&a, &b);
241 assert!((l2 - 2.0).abs() < 1e-6);
242 assert!((cos - 1.0).abs() < 1e-6);
243 }
244}