Skip to main content

lattice_embed/simd/
binary.rs

1//! Binary quantization for ultra-fast pre-filtering.
2//!
3//! Sign-bit quantization: 1 bit per dimension (32x compression vs f32).
4//! Distance via Hamming distance using hardware popcount.
5//!
6//! ## Format
7//!
8//! `b[i] = 1 if v[i] >= 0 else 0`. Pack 8 dimensions into one byte
9//! (bit 7 = dimension 0, bit 6 = dimension 1, etc.).
10//! For D dimensions, storage is `ceil(D / 8)` bytes.
11//!
12//! ## Distance
13//!
14//! Hamming distance counts differing bits between two binary vectors.
15//! Approximate cosine distance: `1.0 - (1.0 - 2.0 * hamming / dims)`
16//! i.e. `2.0 * hamming / dims`.
17
18#[cfg(target_arch = "aarch64")]
19use std::arch::aarch64::*;
20
21use super::simd_config;
22
23/// **Unstable**: binary quantization format and struct layout are under active design.
24///
25/// Binary quantized vector with packed bit storage.
26#[derive(Debug, Clone)]
27pub struct BinaryVector {
28    /// **Unstable**: packed bit data; bit layout may change with format revision.
29    pub data: Vec<u8>,
30    /// **Unstable**: number of original dimensions.
31    pub dims: usize,
32    /// **Unstable**: L2 norm of the original float vector; field may be removed.
33    pub norm: f32,
34}
35
36impl BinaryVector {
37    /// **Unstable**: quantization API; threshold default may change.
38    ///
39    /// Values >= threshold map to 1, values < threshold map to 0.
40    /// Default threshold is 0.0 (sign bit).
41    pub fn from_f32(vector: &[f32]) -> Self {
42        Self::from_f32_with_threshold(vector, 0.0)
43    }
44
45    /// **Unstable**: custom-threshold variant; may be merged into a config struct.
46    pub fn from_f32_with_threshold(vector: &[f32], threshold: f32) -> Self {
47        let dims = vector.len();
48
49        // Compute norm
50        let mut norm_sq = 0.0f32;
51        for &v in vector {
52            if v.is_finite() {
53                norm_sq += v * v;
54            }
55        }
56        let norm = norm_sq.sqrt();
57
58        let packed_len = dims.div_ceil(8);
59        let mut data = vec![0u8; packed_len];
60
61        for (i, &v) in vector.iter().enumerate() {
62            let val = if v.is_finite() { v } else { 0.0 };
63            if val >= threshold {
64                let byte_idx = i / 8;
65                let bit_idx = 7 - (i % 8); // bit 7 = first dimension in byte
66                data[byte_idx] |= 1 << bit_idx;
67            }
68        }
69
70        Self { data, dims, norm }
71    }
72
73    /// **Unstable**: dequantize to float32; output semantics may change.
74    ///
75    /// Binary quantization is lossy: 1 -> +1.0, 0 -> -1.0.
76    pub fn to_f32(&self) -> Vec<f32> {
77        let mut result = Vec::with_capacity(self.dims);
78        for i in 0..self.dims {
79            let byte_idx = i / 8;
80            let bit_idx = 7 - (i % 8);
81            let bit = (self.data[byte_idx] >> bit_idx) & 1;
82            result.push(if bit == 1 { 1.0 } else { -1.0 });
83        }
84        result
85    }
86
87    /// **Unstable**: Hamming dispatch; delegates to NEON or scalar based on runtime detection.
88    ///
89    /// Returns the number of differing bits (dimensions with different signs).
90    #[inline]
91    pub fn hamming_distance(&self, other: &BinaryVector) -> u32 {
92        hamming_distance_binary(self, other)
93    }
94
95    /// **Unstable**: approximation formula may be revised; do not use in latency-sensitive production paths.
96    ///
97    /// The relationship between Hamming distance and angular distance:
98    /// `cos_approx = 1.0 - 2.0 * hamming / dims`
99    /// `cosine_distance_approx = 2.0 * hamming / dims`
100    #[inline]
101    pub fn cosine_distance_approx(&self, other: &BinaryVector) -> f32 {
102        if self.dims == 0 {
103            return 0.0;
104        }
105        let hamming = self.hamming_distance(other) as f32;
106        2.0 * hamming / self.dims as f32
107    }
108
109    /// **Unstable**: approximation formula may be revised; complement of `cosine_distance_approx`.
110    #[inline]
111    pub fn cosine_similarity_approx(&self, other: &BinaryVector) -> f32 {
112        1.0 - self.cosine_distance_approx(other)
113    }
114}
115
116/// **Unstable**: free function dispatches to NEON or scalar; may become private in a future cleanup.
117///
118/// Uses popcount on XOR of the packed bytes.
119#[inline]
120pub fn hamming_distance_binary(a: &BinaryVector, b: &BinaryVector) -> u32 {
121    if a.dims != b.dims {
122        return u32::MAX;
123    }
124
125    let config = simd_config();
126
127    #[cfg(target_arch = "aarch64")]
128    {
129        if config.neon_enabled {
130            debug_assert_eq!(a.data.len(), b.data.len());
131            // SAFETY: NEON is available on aarch64. Matching dimensions require
132            // matching packed-byte lengths for vectors built by the constructor;
133            // the debug guard catches public-field invariant violations. The callee
134            // uses unaligned loads and chunk/remainder bounds within those slices.
135            return unsafe { hamming_distance_neon(&a.data, &b.data) };
136        }
137    }
138
139    #[cfg(not(target_arch = "aarch64"))]
140    {
141        let _ = config;
142    }
143
144    hamming_distance_scalar(&a.data, &b.data)
145}
146
147/// Scalar Hamming distance using u64 chunks and count_ones().
148fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
149    let mut total: u32 = 0;
150
151    // Process 8 bytes at a time as u64
152    let chunks = a.len() / 8;
153    for c in 0..chunks {
154        let offset = c * 8;
155        let a_u64 = u64::from_ne_bytes([
156            a[offset],
157            a[offset + 1],
158            a[offset + 2],
159            a[offset + 3],
160            a[offset + 4],
161            a[offset + 5],
162            a[offset + 6],
163            a[offset + 7],
164        ]);
165        let b_u64 = u64::from_ne_bytes([
166            b[offset],
167            b[offset + 1],
168            b[offset + 2],
169            b[offset + 3],
170            b[offset + 4],
171            b[offset + 5],
172            b[offset + 6],
173            b[offset + 7],
174        ]);
175        total += (a_u64 ^ b_u64).count_ones();
176    }
177
178    // Handle remaining bytes
179    let remainder_start = chunks * 8;
180    for i in remainder_start..a.len() {
181        total += (a[i] ^ b[i]).count_ones();
182    }
183
184    total
185}
186
187/// NEON Hamming distance using `vcnt` (per-byte population count).
188///
189/// `vcnt` counts set bits in each byte of a NEON register.
190/// We XOR the two vectors, apply vcnt, then horizontally sum.
191///
192/// # Safety
193///
194/// Caller must ensure running on aarch64 (NEON is mandatory).
195/// `a` and `b` must have equal length.
196#[cfg(target_arch = "aarch64")]
197#[inline]
198unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
199    // SA-163/164: verify equal-length backing slices before the SIMD loop.
200    debug_assert_eq!(
201        a.len(),
202        b.len(),
203        "hamming_distance_neon: slice lengths differ ({} vs {})",
204        a.len(),
205        b.len()
206    );
207    let len = a.len();
208    const SIMD_WIDTH: usize = 16;
209    let chunks = len / SIMD_WIDTH;
210
211    // Accumulate popcount bytes into u16 to avoid overflow
212    // (max 8 bits per byte, 16 bytes per register = 128 per chunk, fits u8 for ~1 chunk)
213    // Use vpaddlq to widen: u8 -> u16 -> u32 -> u64
214    let mut sum_u64 = vdupq_n_u64(0);
215
216    for c in 0..chunks {
217        let base = c * SIMD_WIDTH;
218        let va = vld1q_u8(a.as_ptr().add(base));
219        let vb = vld1q_u8(b.as_ptr().add(base));
220
221        // XOR to find differing bits
222        let xor = veorq_u8(va, vb);
223
224        // Population count per byte
225        let popcnt = vcntq_u8(xor);
226
227        // Widen and accumulate: u8 -> u16 -> u32 -> u64
228        let sum_u16 = vpaddlq_u8(popcnt);
229        let sum_u32 = vpaddlq_u16(sum_u16);
230        sum_u64 = vaddq_u64(sum_u64, vpaddlq_u32(sum_u32));
231    }
232
233    // Extract final sum
234    let total = vgetq_lane_u64(sum_u64, 0) + vgetq_lane_u64(sum_u64, 1);
235    let mut result = total as u32;
236
237    // Handle remainder
238    let remainder_start = chunks * SIMD_WIDTH;
239    for i in remainder_start..len {
240        result += (a[i] ^ b[i]).count_ones();
241    }
242
243    result
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
251        let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
252        (0..dim)
253            .map(|i| {
254                state = state
255                    .wrapping_mul(6364136223846793005)
256                    .wrapping_add(1442695040888963407)
257                    .wrapping_add(i as u64);
258                let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
259                unit * 2.0 - 1.0
260            })
261            .collect()
262    }
263
264    #[test]
265    fn test_binary_quantize_basic() {
266        let v = vec![0.5, -0.3, 0.0, -1.0, 1.0, 0.1, -0.1, 0.9];
267        let bv = BinaryVector::from_f32(&v);
268        assert_eq!(bv.data.len(), 1); // 8 dims -> 1 byte
269        assert_eq!(bv.dims, 8);
270
271        // Expected bits (MSB first): 1, 0, 1, 0, 1, 1, 0, 1 = 0b10101101 = 0xAD
272        assert_eq!(bv.data[0], 0xAD, "packed bits: {:08b}", bv.data[0]);
273    }
274
275    #[test]
276    fn test_binary_roundtrip() {
277        let v = vec![0.5, -0.3, 0.0, -1.0, 1.0, 0.1, -0.1, 0.9];
278        let bv = BinaryVector::from_f32(&v);
279        let deq = bv.to_f32();
280
281        // Binary: positive -> +1.0, negative -> -1.0
282        assert_eq!(deq, vec![1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0]);
283    }
284
285    #[test]
286    fn test_binary_hamming_distance() {
287        // Same vector should have 0 Hamming distance
288        let v = generate_vector(384, 42);
289        let bv = BinaryVector::from_f32(&v);
290        assert_eq!(bv.hamming_distance(&bv), 0);
291
292        // Opposite sign vector should have max Hamming distance
293        let neg_v: Vec<f32> = v.iter().map(|x| -x).collect();
294        let neg_bv = BinaryVector::from_f32(&neg_v);
295        // Some values might be exactly 0.0, which maps to +1 in both cases
296        // So Hamming may not be exactly 384
297        let hamming = bv.hamming_distance(&neg_bv);
298        // But it should be close to 384 for random vectors
299        assert!(hamming > 350, "hamming={hamming}, expected close to 384");
300    }
301
302    #[test]
303    fn test_binary_cosine_approx_identical() {
304        let v = generate_vector(384, 55);
305        let bv = BinaryVector::from_f32(&v);
306        let cos_dist = bv.cosine_distance_approx(&bv);
307        assert!(
308            cos_dist.abs() < 1e-5,
309            "Identical binary vectors should have 0 cosine distance, got {cos_dist}"
310        );
311    }
312
313    #[test]
314    fn test_binary_cosine_approx_quality() {
315        let a = generate_vector(384, 101);
316        let b = generate_vector(384, 202);
317
318        // f32 reference cosine
319        let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
320        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
321        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
322        let f32_cos = dot / (norm_a * norm_b);
323
324        let ba = BinaryVector::from_f32(&a);
325        let bb = BinaryVector::from_f32(&b);
326        let bin_cos = ba.cosine_similarity_approx(&bb);
327
328        // Binary is a rough approximation -- within 0.3 is acceptable for pre-filtering
329        assert!(
330            (f32_cos - bin_cos).abs() < 0.35,
331            "Binary cosine too far from f32: f32={f32_cos}, binary={bin_cos}"
332        );
333    }
334
335    #[test]
336    fn test_binary_memory_savings() {
337        let v = generate_vector(384, 999);
338        let bv = BinaryVector::from_f32(&v);
339
340        // f32: 384 * 4 = 1536 bytes
341        // Binary: ceil(384/8) = 48 bytes = 32x compression
342        assert_eq!(bv.data.len(), 48);
343    }
344
345    #[test]
346    fn test_binary_non_multiple_of_8_dims() {
347        // 385 dims -> ceil(385/8) = 49 bytes
348        let v = generate_vector(385, 77);
349        let bv = BinaryVector::from_f32(&v);
350        assert_eq!(bv.data.len(), 49);
351        assert_eq!(bv.dims, 385);
352
353        // Roundtrip should preserve all 385 values
354        let deq = bv.to_f32();
355        assert_eq!(deq.len(), 385);
356    }
357
358    #[test]
359    fn test_binary_with_threshold() {
360        let v = vec![0.5, 0.3, 0.1, -0.1, -0.3, -0.5, 0.7, 0.2];
361        // With threshold 0.25, only values >= 0.25 map to 1
362        let bv = BinaryVector::from_f32_with_threshold(&v, 0.25);
363        let deq = bv.to_f32();
364        // Expected: 1, 1, -1, -1, -1, -1, 1, -1
365        assert_eq!(deq, vec![1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0]);
366    }
367
368    #[test]
369    fn test_binary_nan_inf_handling() {
370        let v = vec![
371            f32::NAN,
372            f32::INFINITY,
373            f32::NEG_INFINITY,
374            1.0,
375            -1.0,
376            0.0,
377            0.5,
378            -0.5,
379        ];
380        let bv = BinaryVector::from_f32(&v);
381        let deq = bv.to_f32();
382        assert_eq!(deq.len(), 8);
383        for &val in &deq {
384            assert!(val == 1.0 || val == -1.0, "Binary should produce +/-1.0");
385        }
386    }
387
388    #[test]
389    fn test_hamming_scalar_vs_neon_parity() {
390        // Generate two distinct binary vectors and verify both paths give same result
391        let a = generate_vector(384, 111);
392        let b = generate_vector(384, 222);
393        let ba = BinaryVector::from_f32(&a);
394        let bb = BinaryVector::from_f32(&b);
395
396        let scalar_result = hamming_distance_scalar(&ba.data, &bb.data);
397        let dispatch_result = ba.hamming_distance(&bb);
398
399        assert_eq!(
400            scalar_result, dispatch_result,
401            "Scalar and dispatched Hamming should match"
402        );
403    }
404}