1mod scalar;
12
13#[cfg(target_arch = "x86_64")]
14mod avx2;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19pub enum Metric {
20 Dot,
22 Cosine,
24 L2,
26}
27
28#[inline]
33#[must_use]
34pub fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
35 assert_eq!(a.len(), b.len(), "vectors must have equal length");
36 #[cfg(target_arch = "x86_64")]
37 {
38 if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
39 return unsafe { avx2::dot_f32(a, b) };
41 }
42 }
43 scalar::dot_f32(a, b)
44}
45
46#[inline]
51#[must_use]
52pub fn l2_sq_f32(a: &[f32], b: &[f32]) -> f32 {
53 assert_eq!(a.len(), b.len(), "vectors must have equal length");
54 #[cfg(target_arch = "x86_64")]
55 {
56 if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
57 return unsafe { avx2::l2_sq_f32(a, b) };
59 }
60 }
61 scalar::l2_sq_f32(a, b)
62}
63
64#[inline]
71#[must_use]
72pub fn cosine_f32(a: &[f32], b: &[f32]) -> f32 {
73 assert_eq!(a.len(), b.len(), "vectors must have equal length");
74 #[cfg(target_arch = "x86_64")]
75 {
76 if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
77 return unsafe { avx2::cosine_f32(a, b) };
79 }
80 }
81 scalar::cosine_f32(a, b)
82}
83
84#[inline]
89#[must_use]
90pub fn dot_i8(a: &[i8], b: &[i8]) -> i32 {
91 assert_eq!(a.len(), b.len(), "vectors must have equal length");
92 #[cfg(target_arch = "x86_64")]
93 {
94 if is_x86_feature_detected!("avx2") {
95 return unsafe { avx2::dot_i8(a, b) };
97 }
98 }
99 scalar::dot_i8(a, b)
100}
101
102#[inline]
107#[must_use]
108pub fn l2_sq_i8(a: &[i8], b: &[i8]) -> i32 {
109 assert_eq!(a.len(), b.len(), "vectors must have equal length");
110 #[cfg(target_arch = "x86_64")]
111 {
112 if is_x86_feature_detected!("avx2") {
113 return unsafe { avx2::l2_sq_i8(a, b) };
115 }
116 }
117 scalar::l2_sq_i8(a, b)
118}
119
120#[inline]
130#[must_use]
131pub fn hamming_u64(a: &[u64], b: &[u64]) -> u32 {
132 assert_eq!(a.len(), b.len(), "vectors must have equal length");
133 #[cfg(target_arch = "x86_64")]
134 {
135 if is_x86_feature_detected!("avx2") {
136 return unsafe { avx2::hamming_u64(a, b) };
138 }
139 }
140 scalar::hamming_u64(a, b)
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 struct Rng(u64);
149 impl Rng {
150 fn new(seed: u64) -> Self {
151 Self(seed | 1)
152 }
153 fn next_u64(&mut self) -> u64 {
154 let mut x = self.0;
155 x ^= x << 13;
156 x ^= x >> 7;
157 x ^= x << 17;
158 self.0 = x;
159 x
160 }
161 fn f32(&mut self) -> f32 {
163 let bits = (self.next_u64() >> 40) as u32;
164 (bits as f32 / 16_777_216.0) * 2.0 - 1.0
165 }
166 fn i8(&mut self) -> i8 {
167 (self.next_u64() >> 56) as i8
168 }
169 }
170
171 const F32_DIMS: &[usize] = &[0, 1, 7, 8, 9, 16, 31, 128, 769];
172 const I8_DIMS: &[usize] = &[0, 1, 15, 16, 17, 31, 128, 769];
173 const U64_WORDS: &[usize] = &[0, 1, 2, 3, 4, 5, 7, 8, 13, 16, 96];
175
176 fn hamming_naive(a: &[u64], b: &[u64]) -> u32 {
179 let mut n = 0u32;
180 for (x, y) in a.iter().zip(b.iter()) {
181 let mut d = x ^ y;
182 while d != 0 {
183 n += (d & 1) as u32;
184 d >>= 1;
185 }
186 }
187 n
188 }
189
190 fn close(got: f32, exp: f32) -> bool {
191 (got - exp).abs() <= 1e-3 + 1e-4 * exp.abs()
192 }
193
194 #[test]
195 fn dot_f32_matches_scalar() {
196 let mut rng = Rng::new(0xC0FFEE);
197 for &dim in F32_DIMS {
198 let a: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
199 let b: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
200 let (got, exp) = (dot_f32(&a, &b), scalar::dot_f32(&a, &b));
201 assert!(close(got, exp), "dim {dim}: {got} vs {exp}");
202 }
203 }
204
205 #[test]
206 fn l2_sq_f32_matches_scalar() {
207 let mut rng = Rng::new(0xBEEF);
208 for &dim in F32_DIMS {
209 let a: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
210 let b: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
211 let (got, exp) = (l2_sq_f32(&a, &b), scalar::l2_sq_f32(&a, &b));
212 assert!(close(got, exp), "dim {dim}: {got} vs {exp}");
213 }
214 }
215
216 #[test]
217 fn cosine_f32_matches_scalar() {
218 let mut rng = Rng::new(0xABCD);
219 for &dim in F32_DIMS {
220 let a: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
221 let b: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
222 let (got, exp) = (cosine_f32(&a, &b), scalar::cosine_f32(&a, &b));
223 assert!(close(got, exp), "dim {dim}: {got} vs {exp}");
224 }
225 }
226
227 #[test]
228 fn i8_kernels_match_scalar_exactly() {
229 let mut rng = Rng::new(0x1234_5678);
230 for &dim in I8_DIMS {
231 let a: Vec<i8> = (0..dim).map(|_| rng.i8()).collect();
232 let b: Vec<i8> = (0..dim).map(|_| rng.i8()).collect();
233 assert_eq!(dot_i8(&a, &b), scalar::dot_i8(&a, &b), "dot dim {dim}");
234 assert_eq!(l2_sq_i8(&a, &b), scalar::l2_sq_i8(&a, &b), "l2 dim {dim}");
235 }
236 }
237
238 #[test]
239 fn cosine_zero_vector_is_zero() {
240 let z = vec![0.0f32; 8];
241 let v = vec![1.0f32; 8];
242 assert!(cosine_f32(&z, &v).abs() < 1e-6);
243 assert!(cosine_f32(&z, &z).abs() < 1e-6);
244 }
245
246 #[test]
247 fn empty_vectors() {
248 let e: [f32; 0] = [];
249 assert!(dot_f32(&e, &e).abs() < 1e-6);
250 assert!(l2_sq_f32(&e, &e).abs() < 1e-6);
251 let ei: [i8; 0] = [];
252 assert_eq!(dot_i8(&ei, &ei), 0);
253 let eu: [u64; 0] = [];
254 assert_eq!(hamming_u64(&eu, &eu), 0);
255 }
256
257 #[test]
258 fn hamming_matches_naive_and_scalar() {
259 let mut rng = Rng::new(0x9911_AA55);
260 for &words in U64_WORDS {
261 let a: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
262 let b: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
263 let naive = hamming_naive(&a, &b);
264 assert_eq!(hamming_u64(&a, &b), naive, "dispatch, {words} words");
265 assert_eq!(scalar::hamming_u64(&a, &b), naive, "scalar, {words} words");
266 }
267 }
268
269 #[test]
270 fn hamming_axioms() {
271 let mut rng = Rng::new(0x5151_2727);
272 for &words in U64_WORDS {
273 let a: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
274 let b: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
275 assert_eq!(hamming_u64(&a, &a), 0, "{words}: d(a,a)=0");
277 assert_eq!(
278 hamming_u64(&a, &b),
279 hamming_u64(&b, &a),
280 "{words}: symmetry"
281 );
282 assert!(
283 hamming_u64(&a, &b) <= (words * 64) as u32,
284 "{words}: within bound"
285 );
286 }
287 let ones = vec![u64::MAX; 8];
289 let zeros = vec![0u64; 8];
290 assert_eq!(hamming_u64(&ones, &zeros), 8 * 64);
291 }
292
293 #[cfg(target_arch = "x86_64")]
294 #[test]
295 fn hamming_avx2_matches_scalar_directly() {
296 if !is_x86_feature_detected!("avx2") {
297 return;
298 }
299 let mut rng = Rng::new(0xC1A0_F00D);
300 for &words in U64_WORDS {
301 let a: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
302 let b: Vec<u64> = (0..words).map(|_| rng.next_u64()).collect();
303 let got = unsafe { avx2::hamming_u64(&a, &b) };
305 assert_eq!(got, scalar::hamming_u64(&a, &b), "avx2 {words} words");
306 }
307 }
308
309 #[cfg(target_arch = "x86_64")]
310 #[test]
311 fn avx2_paths_match_scalar_directly() {
312 let have_f32 = is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma");
313 let have_i8 = is_x86_feature_detected!("avx2");
314 if !have_f32 && !have_i8 {
315 return;
316 }
317 let mut rng = Rng::new(99);
318 for &dim in &[8usize, 17, 256, 769] {
319 let a: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
320 let b: Vec<f32> = (0..dim).map(|_| rng.f32()).collect();
321 if have_f32 {
322 let got = unsafe { avx2::dot_f32(&a, &b) };
324 assert!(close(got, scalar::dot_f32(&a, &b)), "dot dim {dim}");
325 let got = unsafe { avx2::l2_sq_f32(&a, &b) };
327 assert!(close(got, scalar::l2_sq_f32(&a, &b)), "l2 dim {dim}");
328 }
329 if have_i8 {
330 let ai: Vec<i8> = (0..dim).map(|_| rng.i8()).collect();
331 let bi: Vec<i8> = (0..dim).map(|_| rng.i8()).collect();
332 assert_eq!(unsafe { avx2::dot_i8(&ai, &bi) }, scalar::dot_i8(&ai, &bi));
334 }
335 }
336 }
337}