1#[cfg(target_arch = "aarch64")]
19use std::arch::aarch64::*;
20
21use super::simd_config;
22
23#[derive(Debug, Clone)]
27pub struct BinaryVector {
28 pub data: Vec<u8>,
30 pub dims: usize,
32 pub norm: f32,
34}
35
36impl BinaryVector {
37 pub fn from_f32(vector: &[f32]) -> Self {
42 Self::from_f32_with_threshold(vector, 0.0)
43 }
44
45 pub fn from_f32_with_threshold(vector: &[f32], threshold: f32) -> Self {
47 let dims = vector.len();
48
49 let mut norm_sq = 0.0f32;
51 for &v in vector {
52 if v.is_finite() {
53 norm_sq += v * v;
54 }
55 }
56 let norm = norm_sq.sqrt();
57
58 let packed_len = dims.div_ceil(8);
59 let mut data = vec![0u8; packed_len];
60
61 for (i, &v) in vector.iter().enumerate() {
62 let val = if v.is_finite() { v } else { 0.0 };
63 if val >= threshold {
64 let byte_idx = i / 8;
65 let bit_idx = 7 - (i % 8); data[byte_idx] |= 1 << bit_idx;
67 }
68 }
69
70 Self { data, dims, norm }
71 }
72
73 pub fn to_f32(&self) -> Vec<f32> {
77 let mut result = Vec::with_capacity(self.dims);
78 for i in 0..self.dims {
79 let byte_idx = i / 8;
80 let bit_idx = 7 - (i % 8);
81 let bit = (self.data[byte_idx] >> bit_idx) & 1;
82 result.push(if bit == 1 { 1.0 } else { -1.0 });
83 }
84 result
85 }
86
87 #[inline]
91 pub fn hamming_distance(&self, other: &BinaryVector) -> u32 {
92 hamming_distance_binary(self, other)
93 }
94
95 #[inline]
101 pub fn cosine_distance_approx(&self, other: &BinaryVector) -> f32 {
102 if self.dims == 0 {
103 return 0.0;
104 }
105 let hamming = self.hamming_distance(other) as f32;
106 2.0 * hamming / self.dims as f32
107 }
108
109 #[inline]
111 pub fn cosine_similarity_approx(&self, other: &BinaryVector) -> f32 {
112 1.0 - self.cosine_distance_approx(other)
113 }
114}
115
116#[inline]
120pub fn hamming_distance_binary(a: &BinaryVector, b: &BinaryVector) -> u32 {
121 if a.dims != b.dims {
122 return u32::MAX;
123 }
124
125 let config = simd_config();
126
127 #[cfg(target_arch = "aarch64")]
128 {
129 if config.neon_enabled {
130 debug_assert_eq!(a.data.len(), b.data.len());
131 return unsafe { hamming_distance_neon(&a.data, &b.data) };
136 }
137 }
138
139 #[cfg(not(target_arch = "aarch64"))]
140 {
141 let _ = config;
142 }
143
144 hamming_distance_scalar(&a.data, &b.data)
145}
146
147fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
149 let mut total: u32 = 0;
150
151 let chunks = a.len() / 8;
153 for c in 0..chunks {
154 let offset = c * 8;
155 let a_u64 = u64::from_ne_bytes([
156 a[offset],
157 a[offset + 1],
158 a[offset + 2],
159 a[offset + 3],
160 a[offset + 4],
161 a[offset + 5],
162 a[offset + 6],
163 a[offset + 7],
164 ]);
165 let b_u64 = u64::from_ne_bytes([
166 b[offset],
167 b[offset + 1],
168 b[offset + 2],
169 b[offset + 3],
170 b[offset + 4],
171 b[offset + 5],
172 b[offset + 6],
173 b[offset + 7],
174 ]);
175 total += (a_u64 ^ b_u64).count_ones();
176 }
177
178 let remainder_start = chunks * 8;
180 for i in remainder_start..a.len() {
181 total += (a[i] ^ b[i]).count_ones();
182 }
183
184 total
185}
186
187#[cfg(target_arch = "aarch64")]
197#[inline]
198unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
199 debug_assert_eq!(
201 a.len(),
202 b.len(),
203 "hamming_distance_neon: slice lengths differ ({} vs {})",
204 a.len(),
205 b.len()
206 );
207 let len = a.len();
208 const SIMD_WIDTH: usize = 16;
209 let chunks = len / SIMD_WIDTH;
210
211 let mut sum_u64 = vdupq_n_u64(0);
215
216 for c in 0..chunks {
217 let base = c * SIMD_WIDTH;
218 let va = vld1q_u8(a.as_ptr().add(base));
219 let vb = vld1q_u8(b.as_ptr().add(base));
220
221 let xor = veorq_u8(va, vb);
223
224 let popcnt = vcntq_u8(xor);
226
227 let sum_u16 = vpaddlq_u8(popcnt);
229 let sum_u32 = vpaddlq_u16(sum_u16);
230 sum_u64 = vaddq_u64(sum_u64, vpaddlq_u32(sum_u32));
231 }
232
233 let total = vgetq_lane_u64(sum_u64, 0) + vgetq_lane_u64(sum_u64, 1);
235 let mut result = total as u32;
236
237 let remainder_start = chunks * SIMD_WIDTH;
239 for i in remainder_start..len {
240 result += (a[i] ^ b[i]).count_ones();
241 }
242
243 result
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
251 let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
252 (0..dim)
253 .map(|i| {
254 state = state
255 .wrapping_mul(6364136223846793005)
256 .wrapping_add(1442695040888963407)
257 .wrapping_add(i as u64);
258 let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
259 unit * 2.0 - 1.0
260 })
261 .collect()
262 }
263
264 #[test]
265 fn test_binary_quantize_basic() {
266 let v = vec![0.5, -0.3, 0.0, -1.0, 1.0, 0.1, -0.1, 0.9];
267 let bv = BinaryVector::from_f32(&v);
268 assert_eq!(bv.data.len(), 1); assert_eq!(bv.dims, 8);
270
271 assert_eq!(bv.data[0], 0xAD, "packed bits: {:08b}", bv.data[0]);
273 }
274
275 #[test]
276 fn test_binary_roundtrip() {
277 let v = vec![0.5, -0.3, 0.0, -1.0, 1.0, 0.1, -0.1, 0.9];
278 let bv = BinaryVector::from_f32(&v);
279 let deq = bv.to_f32();
280
281 assert_eq!(deq, vec![1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0]);
283 }
284
285 #[test]
286 fn test_binary_hamming_distance() {
287 let v = generate_vector(384, 42);
289 let bv = BinaryVector::from_f32(&v);
290 assert_eq!(bv.hamming_distance(&bv), 0);
291
292 let neg_v: Vec<f32> = v.iter().map(|x| -x).collect();
294 let neg_bv = BinaryVector::from_f32(&neg_v);
295 let hamming = bv.hamming_distance(&neg_bv);
298 assert!(hamming > 350, "hamming={hamming}, expected close to 384");
300 }
301
302 #[test]
303 fn test_binary_cosine_approx_identical() {
304 let v = generate_vector(384, 55);
305 let bv = BinaryVector::from_f32(&v);
306 let cos_dist = bv.cosine_distance_approx(&bv);
307 assert!(
308 cos_dist.abs() < 1e-5,
309 "Identical binary vectors should have 0 cosine distance, got {cos_dist}"
310 );
311 }
312
313 #[test]
314 fn test_binary_cosine_approx_quality() {
315 let a = generate_vector(384, 101);
316 let b = generate_vector(384, 202);
317
318 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
320 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
321 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
322 let f32_cos = dot / (norm_a * norm_b);
323
324 let ba = BinaryVector::from_f32(&a);
325 let bb = BinaryVector::from_f32(&b);
326 let bin_cos = ba.cosine_similarity_approx(&bb);
327
328 assert!(
330 (f32_cos - bin_cos).abs() < 0.35,
331 "Binary cosine too far from f32: f32={f32_cos}, binary={bin_cos}"
332 );
333 }
334
335 #[test]
336 fn test_binary_memory_savings() {
337 let v = generate_vector(384, 999);
338 let bv = BinaryVector::from_f32(&v);
339
340 assert_eq!(bv.data.len(), 48);
343 }
344
345 #[test]
346 fn test_binary_non_multiple_of_8_dims() {
347 let v = generate_vector(385, 77);
349 let bv = BinaryVector::from_f32(&v);
350 assert_eq!(bv.data.len(), 49);
351 assert_eq!(bv.dims, 385);
352
353 let deq = bv.to_f32();
355 assert_eq!(deq.len(), 385);
356 }
357
358 #[test]
359 fn test_binary_with_threshold() {
360 let v = vec![0.5, 0.3, 0.1, -0.1, -0.3, -0.5, 0.7, 0.2];
361 let bv = BinaryVector::from_f32_with_threshold(&v, 0.25);
363 let deq = bv.to_f32();
364 assert_eq!(deq, vec![1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0]);
366 }
367
368 #[test]
369 fn test_binary_nan_inf_handling() {
370 let v = vec![
371 f32::NAN,
372 f32::INFINITY,
373 f32::NEG_INFINITY,
374 1.0,
375 -1.0,
376 0.0,
377 0.5,
378 -0.5,
379 ];
380 let bv = BinaryVector::from_f32(&v);
381 let deq = bv.to_f32();
382 assert_eq!(deq.len(), 8);
383 for &val in &deq {
384 assert!(val == 1.0 || val == -1.0, "Binary should produce +/-1.0");
385 }
386 }
387
388 #[test]
389 fn test_hamming_scalar_vs_neon_parity() {
390 let a = generate_vector(384, 111);
392 let b = generate_vector(384, 222);
393 let ba = BinaryVector::from_f32(&a);
394 let bb = BinaryVector::from_f32(&b);
395
396 let scalar_result = hamming_distance_scalar(&ba.data, &bb.data);
397 let dispatch_result = ba.hamming_distance(&bb);
398
399 assert_eq!(
400 scalar_result, dispatch_result,
401 "Scalar and dispatched Hamming should match"
402 );
403 }
404}