omendb_core/compression/
binary.rs

1//! Binary Quantization (BBQ) for `OmenDB`
2//!
3//! 1-bit quantization with SIMD-optimized Hamming distance.
4//!
5//! # Algorithm
6//!
7//! - Quantize: bit[d] = 1 if f32[d] > threshold[d] else 0
8//! - Distance: Hamming distance via XOR + popcnt
9//! - Correction: Apply norm-based correction for accurate ranking
10//!
11//! # Performance
12//!
13//! - 32x compression (f32 → 1 bit)
14//! - 2-4x faster search than SQ8 (SIMD Hamming is extremely fast)
15//! - ~85% raw recall, ~95-98% with rescore
16//!
17//! # When to Use
18//!
19//! - Dimensions >= 384 (below this, SQ8 has better recall)
20//! - Large datasets (>100K vectors) where memory matters
21//! - Cost-sensitive deployments
22
23use serde::{Deserialize, Serialize};
24
25#[cfg(target_arch = "aarch64")]
26use std::arch::aarch64::{vaddvq_u8, vcntq_u8, veorq_u8, vld1q_u8};
27#[cfg(target_arch = "x86_64")]
28#[allow(clippy::wildcard_imports)]
29use std::arch::x86_64::*;
30
31/// Binary quantization parameters
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct BinaryParams {
34    /// Threshold per dimension (typically 0.0 or median)
35    pub thresholds: Vec<f32>,
36    /// Number of dimensions
37    pub dimensions: usize,
38}
39
40impl BinaryParams {
41    /// Create with zero thresholds (value > 0 = 1, value <= 0 = 0)
42    #[must_use]
43    pub fn new(dimensions: usize) -> Self {
44        Self {
45            thresholds: vec![0.0; dimensions],
46            dimensions,
47        }
48    }
49
50    /// Train thresholds from sample vectors using median per dimension.
51    ///
52    /// # Errors
53    /// Returns error if vectors is empty or vectors have inconsistent dimensions.
54    pub fn train(vectors: &[&[f32]]) -> Result<Self, &'static str> {
55        if vectors.is_empty() {
56            return Err("Need at least one vector to train");
57        }
58        let dimensions = vectors[0].len();
59        if !vectors.iter().all(|v| v.len() == dimensions) {
60            return Err("All vectors must have same dimensions");
61        }
62
63        let n = vectors.len();
64        let mut thresholds = Vec::with_capacity(dimensions);
65        let mut dim_values: Vec<f32> = Vec::with_capacity(n);
66
67        for d in 0..dimensions {
68            dim_values.clear();
69            for v in vectors {
70                dim_values.push(v[d]);
71            }
72            dim_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
73
74            // Use median as threshold
75            let median = if n.is_multiple_of(2) {
76                let mid = n / 2;
77                f32::midpoint(dim_values[mid - 1], dim_values[mid])
78            } else {
79                dim_values[n / 2]
80            };
81
82            thresholds.push(median);
83        }
84
85        Ok(Self {
86            thresholds,
87            dimensions,
88        })
89    }
90
91    /// Quantize f32 vector to packed binary
92    #[must_use]
93    pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
94        debug_assert_eq!(vector.len(), self.dimensions);
95
96        let num_bytes = self.dimensions.div_ceil(8);
97        let mut quantized = vec![0u8; num_bytes];
98
99        for (i, (&value, &threshold)) in vector.iter().zip(self.thresholds.iter()).enumerate() {
100            if value > threshold {
101                let byte_idx = i / 8;
102                let bit_idx = i % 8;
103                quantized[byte_idx] |= 1 << bit_idx;
104            }
105        }
106
107        quantized
108    }
109
110    /// Quantize into pre-allocated buffer
111    pub fn quantize_into(&self, vector: &[f32], output: &mut [u8]) {
112        debug_assert_eq!(vector.len(), self.dimensions);
113        let num_bytes = self.dimensions.div_ceil(8);
114        debug_assert!(output.len() >= num_bytes);
115
116        // Clear output first
117        for byte in output.iter_mut().take(num_bytes) {
118            *byte = 0;
119        }
120
121        for (i, (&value, &threshold)) in vector.iter().zip(self.thresholds.iter()).enumerate() {
122            if value > threshold {
123                let byte_idx = i / 8;
124                let bit_idx = i % 8;
125                output[byte_idx] |= 1 << bit_idx;
126            }
127        }
128    }
129}
130
131/// Compute Hamming distance between two binary vectors
132///
133/// SIMD-optimized using:
134/// - AVX2: _mm256_xor_si256 + manual popcnt
135/// - AVX-512: _mm512_popcnt_epi64 (if available)
136/// - NEON: veorq_u8 + vcntq_u8
137///
138/// Falls back to scalar popcnt if no SIMD available.
139#[inline]
140#[must_use]
141#[allow(clippy::needless_return)] // returns needed for cfg-conditional control flow
142pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
143    debug_assert_eq!(a.len(), b.len());
144
145    #[cfg(target_arch = "x86_64")]
146    {
147        if is_x86_feature_detected!("avx2") {
148            return unsafe { hamming_distance_avx2(a, b) };
149        }
150        if is_x86_feature_detected!("popcnt") {
151            return unsafe { hamming_distance_popcnt(a, b) };
152        }
153        return hamming_distance_scalar(a, b);
154    }
155
156    #[cfg(target_arch = "aarch64")]
157    {
158        unsafe { hamming_distance_neon(a, b) }
159    }
160
161    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
162    hamming_distance_scalar(a, b)
163}
164
165/// Scalar Hamming distance (fallback)
166#[allow(dead_code)]
167fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
168    a.iter()
169        .zip(b.iter())
170        .map(|(&x, &y)| (x ^ y).count_ones())
171        .sum()
172}
173
174/// AVX2 Hamming distance with PSHUFB popcount lookup table
175///
176/// Uses nibble-based lookup instead of scalar popcnt for full SIMD throughput.
177/// Technique: split each byte into two 4-bit nibbles, use shuffle as LUT.
178#[cfg(target_arch = "x86_64")]
179#[target_feature(enable = "avx2")]
180#[allow(clippy::cast_ptr_alignment)] // loadu handles unaligned loads
181unsafe fn hamming_distance_avx2(a: &[u8], b: &[u8]) -> u32 {
182    // Popcount lookup table for 4-bit values: [0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4]
183    let lookup = _mm256_setr_epi8(
184        0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, // low 128 bits
185        0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, // high 128 bits
186    );
187    let low_mask = _mm256_set1_epi8(0x0f); // mask for low nibble
188
189    let mut total = _mm256_setzero_si256();
190    let mut i = 0;
191
192    // Process 32 bytes at a time
193    while i + 32 <= a.len() {
194        let va = _mm256_loadu_si256(a.as_ptr().add(i).cast::<__m256i>());
195        let vb = _mm256_loadu_si256(b.as_ptr().add(i).cast::<__m256i>());
196        let xor = _mm256_xor_si256(va, vb);
197
198        // Split into nibbles and lookup popcount
199        let lo = _mm256_and_si256(xor, low_mask);
200        let hi = _mm256_and_si256(_mm256_srli_epi16(xor, 4), low_mask);
201
202        let cnt_lo = _mm256_shuffle_epi8(lookup, lo);
203        let cnt_hi = _mm256_shuffle_epi8(lookup, hi);
204
205        // Add nibble counts (each byte now has popcount of original byte)
206        let cnt = _mm256_add_epi8(cnt_lo, cnt_hi);
207
208        // Accumulate using sad_epu8 against zero for horizontal sum
209        total = _mm256_add_epi64(total, _mm256_sad_epu8(cnt, _mm256_setzero_si256()));
210
211        i += 32;
212    }
213
214    // Horizontal sum of 4 x u64 accumulators
215    let lo = _mm256_castsi256_si128(total);
216    let hi = _mm256_extracti128_si256(total, 1);
217    let sum128 = _mm_add_epi64(lo, hi);
218    let count = (_mm_extract_epi64(sum128, 0) + _mm_extract_epi64(sum128, 1)) as u32;
219
220    // Handle remaining bytes with scalar
221    let mut remainder = 0u32;
222    for j in i..a.len() {
223        remainder += (a[j] ^ b[j]).count_ones();
224    }
225
226    count + remainder
227}
228
229/// x86_64 popcnt-based Hamming distance (8 bytes at a time)
230#[cfg(target_arch = "x86_64")]
231#[target_feature(enable = "popcnt")]
232#[allow(clippy::cast_possible_wrap)] // intentional reinterpret for popcnt
233unsafe fn hamming_distance_popcnt(a: &[u8], b: &[u8]) -> u32 {
234    let mut count = 0u64;
235    let mut i = 0;
236
237    // Process 8 bytes at a time using u64 popcnt
238    while i + 8 <= a.len() {
239        let a_u64 = std::ptr::read_unaligned(a.as_ptr().add(i).cast::<u64>());
240        let b_u64 = std::ptr::read_unaligned(b.as_ptr().add(i).cast::<u64>());
241        count += _popcnt64((a_u64 ^ b_u64) as i64) as u64;
242        i += 8;
243    }
244
245    // Handle remaining bytes
246    for j in i..a.len() {
247        count += (a[j] ^ b[j]).count_ones() as u64;
248    }
249
250    count as u32
251}
252
253/// NEON Hamming distance with native vcntq_u8
254#[cfg(target_arch = "aarch64")]
255#[inline]
256unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
257    let mut sum: u32 = 0;
258    let mut i = 0;
259
260    // Process 16 bytes at a time
261    while i + 16 <= a.len() {
262        // Load 16 bytes each
263        let va = vld1q_u8(a.as_ptr().add(i));
264        let vb = vld1q_u8(b.as_ptr().add(i));
265
266        // XOR
267        let xor = veorq_u8(va, vb);
268
269        // Count bits per byte and sum horizontally
270        let cnt = vcntq_u8(xor);
271        sum += vaddvq_u8(cnt) as u32;
272
273        i += 16;
274    }
275
276    // Handle remaining bytes
277    for j in i..a.len() {
278        sum += (a[j] ^ b[j]).count_ones();
279    }
280
281    sum
282}
283
284/// Compute corrected distance for binary quantization
285///
286/// Raw Hamming distance gives rough ranking. Correction using norms
287/// approximates true L2 distance for better accuracy.
288///
289/// Formula: corrected = hamming * (query_norm * vec_norm) / dimensions
290#[inline]
291#[must_use]
292pub fn corrected_distance(hamming: u32, query_norm: f32, vec_norm: f32, dimensions: usize) -> f32 {
293    let hamming_f = hamming as f32;
294    // Convert Hamming to approximate L2: each different bit contributes to distance
295    // The scaling approximates the magnitude of the difference
296    hamming_f * (query_norm * vec_norm) / (dimensions as f32)
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_binary_quantize() {
305        let params = BinaryParams::new(8);
306        let vector = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.1, -0.5, 0.2];
307
308        let quantized = params.quantize(&vector);
309
310        // Bits: 1 (0.5>0), 0 (-0.3), 1 (0.8>0), 0 (-0.1), 0 (0.0==0), 1 (0.1>0), 0 (-0.5), 1 (0.2>0)
311        // Packed: bit0=1, bit1=0, bit2=1, bit3=0, bit4=0, bit5=1, bit6=0, bit7=1
312        // = 0b10100101 = 165
313        assert_eq!(quantized.len(), 1);
314        assert_eq!(quantized[0], 0b1010_0101);
315    }
316
317    #[test]
318    fn test_binary_train() {
319        let v1 = vec![1.0, 5.0, 0.0, 2.0];
320        let v2 = vec![2.0, 6.0, 1.0, 3.0];
321        let v3 = vec![3.0, 7.0, 2.0, 4.0];
322        let vectors: Vec<&[f32]> = vec![v1.as_slice(), v2.as_slice(), v3.as_slice()];
323
324        let params = BinaryParams::train(&vectors).unwrap();
325
326        // Median: [2.0, 6.0, 1.0, 3.0]
327        assert_eq!(params.thresholds, vec![2.0, 6.0, 1.0, 3.0]);
328    }
329
330    #[test]
331    fn test_hamming_distance_identical() {
332        let a = vec![0b1010_1010, 0b1111_0000, 0b0000_1111];
333        let b = vec![0b1010_1010, 0b1111_0000, 0b0000_1111];
334
335        let dist = hamming_distance(&a, &b);
336        assert_eq!(dist, 0);
337    }
338
339    #[test]
340    fn test_hamming_distance_all_different() {
341        let a = vec![0b0000_0000];
342        let b = vec![0b1111_1111];
343
344        let dist = hamming_distance(&a, &b);
345        assert_eq!(dist, 8); // All 8 bits different
346    }
347
348    #[test]
349    fn test_hamming_distance_partial() {
350        let a = vec![0b1010_1010];
351        let b = vec![0b0101_0101];
352
353        let dist = hamming_distance(&a, &b);
354        assert_eq!(dist, 8); // All bits flipped
355    }
356
357    #[test]
358    fn test_hamming_distance_large() {
359        // 768 dimensions = 96 bytes
360        let a: Vec<u8> = vec![0b1010_1010; 96];
361        let b: Vec<u8> = vec![0b0101_0101; 96];
362
363        let dist = hamming_distance(&a, &b);
364        assert_eq!(dist, 96 * 8); // All 8 bits different in all 96 bytes
365    }
366
367    #[test]
368    fn test_compression_ratio() {
369        let dims: usize = 768;
370        let original_size = dims * 4; // f32 = 4 bytes
371        let quantized_size = dims.div_ceil(8); // 1 bit = 1/8 byte
372
373        let ratio = original_size as f32 / quantized_size as f32;
374        assert!(
375            (ratio - 32.0).abs() < 0.1,
376            "Expected 32x compression, got {ratio}"
377        );
378    }
379
380    #[test]
381    fn test_corrected_distance() {
382        let hamming = 100;
383        let query_norm = 2.0;
384        let vec_norm = 1.5;
385        let dimensions = 768;
386
387        let dist = corrected_distance(hamming, query_norm, vec_norm, dimensions);
388
389        // Should be: 100 * 2.0 * 1.5 / 768 ≈ 0.39
390        assert!((dist - 0.39).abs() < 0.01);
391    }
392}