1use serde::{Deserialize, Serialize};
24
25#[cfg(target_arch = "aarch64")]
26use std::arch::aarch64::{vaddvq_u8, vcntq_u8, veorq_u8, vld1q_u8};
27#[cfg(target_arch = "x86_64")]
28#[allow(clippy::wildcard_imports)]
29use std::arch::x86_64::*;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct BinaryParams {
34 pub thresholds: Vec<f32>,
36 pub dimensions: usize,
38}
39
40impl BinaryParams {
41 #[must_use]
43 pub fn new(dimensions: usize) -> Self {
44 Self {
45 thresholds: vec![0.0; dimensions],
46 dimensions,
47 }
48 }
49
50 pub fn train(vectors: &[&[f32]]) -> Result<Self, &'static str> {
55 if vectors.is_empty() {
56 return Err("Need at least one vector to train");
57 }
58 let dimensions = vectors[0].len();
59 if !vectors.iter().all(|v| v.len() == dimensions) {
60 return Err("All vectors must have same dimensions");
61 }
62
63 let n = vectors.len();
64 let mut thresholds = Vec::with_capacity(dimensions);
65 let mut dim_values: Vec<f32> = Vec::with_capacity(n);
66
67 for d in 0..dimensions {
68 dim_values.clear();
69 for v in vectors {
70 dim_values.push(v[d]);
71 }
72 dim_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
73
74 let median = if n.is_multiple_of(2) {
76 let mid = n / 2;
77 f32::midpoint(dim_values[mid - 1], dim_values[mid])
78 } else {
79 dim_values[n / 2]
80 };
81
82 thresholds.push(median);
83 }
84
85 Ok(Self {
86 thresholds,
87 dimensions,
88 })
89 }
90
91 #[must_use]
93 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
94 debug_assert_eq!(vector.len(), self.dimensions);
95
96 let num_bytes = self.dimensions.div_ceil(8);
97 let mut quantized = vec![0u8; num_bytes];
98
99 for (i, (&value, &threshold)) in vector.iter().zip(self.thresholds.iter()).enumerate() {
100 if value > threshold {
101 let byte_idx = i / 8;
102 let bit_idx = i % 8;
103 quantized[byte_idx] |= 1 << bit_idx;
104 }
105 }
106
107 quantized
108 }
109
110 pub fn quantize_into(&self, vector: &[f32], output: &mut [u8]) {
112 debug_assert_eq!(vector.len(), self.dimensions);
113 let num_bytes = self.dimensions.div_ceil(8);
114 debug_assert!(output.len() >= num_bytes);
115
116 for byte in output.iter_mut().take(num_bytes) {
118 *byte = 0;
119 }
120
121 for (i, (&value, &threshold)) in vector.iter().zip(self.thresholds.iter()).enumerate() {
122 if value > threshold {
123 let byte_idx = i / 8;
124 let bit_idx = i % 8;
125 output[byte_idx] |= 1 << bit_idx;
126 }
127 }
128 }
129}
130
131#[inline]
140#[must_use]
141#[allow(clippy::needless_return)] pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
143 debug_assert_eq!(a.len(), b.len());
144
145 #[cfg(target_arch = "x86_64")]
146 {
147 if is_x86_feature_detected!("avx2") {
148 return unsafe { hamming_distance_avx2(a, b) };
149 }
150 if is_x86_feature_detected!("popcnt") {
151 return unsafe { hamming_distance_popcnt(a, b) };
152 }
153 return hamming_distance_scalar(a, b);
154 }
155
156 #[cfg(target_arch = "aarch64")]
157 {
158 unsafe { hamming_distance_neon(a, b) }
159 }
160
161 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
162 hamming_distance_scalar(a, b)
163}
164
165#[allow(dead_code)]
167fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
168 a.iter()
169 .zip(b.iter())
170 .map(|(&x, &y)| (x ^ y).count_ones())
171 .sum()
172}
173
174#[cfg(target_arch = "x86_64")]
179#[target_feature(enable = "avx2")]
180#[allow(clippy::cast_ptr_alignment)] unsafe fn hamming_distance_avx2(a: &[u8], b: &[u8]) -> u32 {
182 let lookup = _mm256_setr_epi8(
184 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, );
187 let low_mask = _mm256_set1_epi8(0x0f); let mut total = _mm256_setzero_si256();
190 let mut i = 0;
191
192 while i + 32 <= a.len() {
194 let va = _mm256_loadu_si256(a.as_ptr().add(i).cast::<__m256i>());
195 let vb = _mm256_loadu_si256(b.as_ptr().add(i).cast::<__m256i>());
196 let xor = _mm256_xor_si256(va, vb);
197
198 let lo = _mm256_and_si256(xor, low_mask);
200 let hi = _mm256_and_si256(_mm256_srli_epi16(xor, 4), low_mask);
201
202 let cnt_lo = _mm256_shuffle_epi8(lookup, lo);
203 let cnt_hi = _mm256_shuffle_epi8(lookup, hi);
204
205 let cnt = _mm256_add_epi8(cnt_lo, cnt_hi);
207
208 total = _mm256_add_epi64(total, _mm256_sad_epu8(cnt, _mm256_setzero_si256()));
210
211 i += 32;
212 }
213
214 let lo = _mm256_castsi256_si128(total);
216 let hi = _mm256_extracti128_si256(total, 1);
217 let sum128 = _mm_add_epi64(lo, hi);
218 let count = (_mm_extract_epi64(sum128, 0) + _mm_extract_epi64(sum128, 1)) as u32;
219
220 let mut remainder = 0u32;
222 for j in i..a.len() {
223 remainder += (a[j] ^ b[j]).count_ones();
224 }
225
226 count + remainder
227}
228
229#[cfg(target_arch = "x86_64")]
231#[target_feature(enable = "popcnt")]
232#[allow(clippy::cast_possible_wrap)] unsafe fn hamming_distance_popcnt(a: &[u8], b: &[u8]) -> u32 {
234 let mut count = 0u64;
235 let mut i = 0;
236
237 while i + 8 <= a.len() {
239 let a_u64 = std::ptr::read_unaligned(a.as_ptr().add(i).cast::<u64>());
240 let b_u64 = std::ptr::read_unaligned(b.as_ptr().add(i).cast::<u64>());
241 count += _popcnt64((a_u64 ^ b_u64) as i64) as u64;
242 i += 8;
243 }
244
245 for j in i..a.len() {
247 count += (a[j] ^ b[j]).count_ones() as u64;
248 }
249
250 count as u32
251}
252
253#[cfg(target_arch = "aarch64")]
255#[inline]
256unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
257 let mut sum: u32 = 0;
258 let mut i = 0;
259
260 while i + 16 <= a.len() {
262 let va = vld1q_u8(a.as_ptr().add(i));
264 let vb = vld1q_u8(b.as_ptr().add(i));
265
266 let xor = veorq_u8(va, vb);
268
269 let cnt = vcntq_u8(xor);
271 sum += vaddvq_u8(cnt) as u32;
272
273 i += 16;
274 }
275
276 for j in i..a.len() {
278 sum += (a[j] ^ b[j]).count_ones();
279 }
280
281 sum
282}
283
284#[inline]
291#[must_use]
292pub fn corrected_distance(hamming: u32, query_norm: f32, vec_norm: f32, dimensions: usize) -> f32 {
293 let hamming_f = hamming as f32;
294 hamming_f * (query_norm * vec_norm) / (dimensions as f32)
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_binary_quantize() {
305 let params = BinaryParams::new(8);
306 let vector = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.1, -0.5, 0.2];
307
308 let quantized = params.quantize(&vector);
309
310 assert_eq!(quantized.len(), 1);
314 assert_eq!(quantized[0], 0b1010_0101);
315 }
316
317 #[test]
318 fn test_binary_train() {
319 let v1 = vec![1.0, 5.0, 0.0, 2.0];
320 let v2 = vec![2.0, 6.0, 1.0, 3.0];
321 let v3 = vec![3.0, 7.0, 2.0, 4.0];
322 let vectors: Vec<&[f32]> = vec![v1.as_slice(), v2.as_slice(), v3.as_slice()];
323
324 let params = BinaryParams::train(&vectors).unwrap();
325
326 assert_eq!(params.thresholds, vec![2.0, 6.0, 1.0, 3.0]);
328 }
329
330 #[test]
331 fn test_hamming_distance_identical() {
332 let a = vec![0b1010_1010, 0b1111_0000, 0b0000_1111];
333 let b = vec![0b1010_1010, 0b1111_0000, 0b0000_1111];
334
335 let dist = hamming_distance(&a, &b);
336 assert_eq!(dist, 0);
337 }
338
339 #[test]
340 fn test_hamming_distance_all_different() {
341 let a = vec![0b0000_0000];
342 let b = vec![0b1111_1111];
343
344 let dist = hamming_distance(&a, &b);
345 assert_eq!(dist, 8); }
347
348 #[test]
349 fn test_hamming_distance_partial() {
350 let a = vec![0b1010_1010];
351 let b = vec![0b0101_0101];
352
353 let dist = hamming_distance(&a, &b);
354 assert_eq!(dist, 8); }
356
357 #[test]
358 fn test_hamming_distance_large() {
359 let a: Vec<u8> = vec![0b1010_1010; 96];
361 let b: Vec<u8> = vec![0b0101_0101; 96];
362
363 let dist = hamming_distance(&a, &b);
364 assert_eq!(dist, 96 * 8); }
366
367 #[test]
368 fn test_compression_ratio() {
369 let dims: usize = 768;
370 let original_size = dims * 4; let quantized_size = dims.div_ceil(8); let ratio = original_size as f32 / quantized_size as f32;
374 assert!(
375 (ratio - 32.0).abs() < 0.1,
376 "Expected 32x compression, got {ratio}"
377 );
378 }
379
380 #[test]
381 fn test_corrected_distance() {
382 let hamming = 100;
383 let query_norm = 2.0;
384 let vec_norm = 1.5;
385 let dimensions = 768;
386
387 let dist = corrected_distance(hamming, query_norm, vec_norm, dimensions);
388
389 assert!((dist - 0.39).abs() < 0.01);
391 }
392}