Skip to main content

iqdb_quantize/
binary.rs

1//! [`BinaryQuantizer`] — binary quantization (BQ, 32× compression).
2//!
3//! Codes are one bit per dimension, packed into [`u64`] words. The threshold
4//! used at encode time is the per-dimension mean learned from the training
5//! sample: bit `i` is `1` when `vector[i] >= mean[i]`, `0` otherwise. When
6//! the dimension is not a multiple of 64 the trailing word has unused high
7//! bits; those are zeroed at encode time so they cannot contribute to
8//! Hamming distance.
9//!
10//! Distance is supported under [`DistanceMetric::Hamming`] only — any other
11//! metric returns [`IqdbError::InvalidMetric`]. BQ discards magnitude
12//! entirely, so a cosine or Euclidean distance over ±1 codes would be a
13//! roundabout Hamming dressed in misleading units. Restricting the contract
14//! prevents that silent misuse and matches the public Faiss `IndexBinary`
15//! convention. The query path inside [`BinaryQuantizer::distance`]
16//! binarizes the query against the **same trained per-dimension means**
17//! used during [`BinaryQuantizer::quantize`], so the query bits live in
18//! the same space as the stored code bits.
19
20use error_forge::ForgeError;
21use iqdb_types::{DistanceMetric, IqdbError, Result};
22
23use crate::code::BqCode;
24use crate::traits::Quantizer;
25use crate::validate::{dim_eq, finite_non_empty, training_set};
26
27const BITS_PER_WORD: usize = u64::BITS as usize;
28
29/// Calibration learned during [`BinaryQuantizer::train`].
30#[derive(Debug, Clone, PartialEq)]
31struct BqCalibration {
32    /// Per-dimension mean from the training sample.
33    means: Vec<f32>,
34}
35
36/// Binary quantizer (BQ): one bit per dimension, 32× compression.
37///
38/// Build one with [`BinaryQuantizer::new`] (or [`Default`]), train it once
39/// with a representative sample, then quantize and compare. BQ supports
40/// [`DistanceMetric::Hamming`] only; other metrics return
41/// [`IqdbError::InvalidMetric`].
42///
43/// # Examples
44///
45/// ```
46/// use iqdb_quantize::{BinaryQuantizer, Quantizer};
47/// use iqdb_types::DistanceMetric;
48///
49/// let mut bq = BinaryQuantizer::new();
50/// bq.train(&[&[0.0_f32, 1.0, 2.0][..], &[2.0_f32, 1.0, 0.0][..]])
51///     .expect("two non-empty, finite vectors of equal dim");
52///
53/// let code = bq.quantize(&[0.5_f32, 1.5, 2.5]).expect("dim matches");
54/// let d = bq
55///     .distance(&[0.5_f32, 1.5, 2.5], &code, DistanceMetric::Hamming)
56///     .expect("dim matches");
57/// // Self-distance is zero.
58/// assert_eq!(d, 0.0);
59/// ```
60#[derive(Debug, Clone, Default, PartialEq)]
61pub struct BinaryQuantizer {
62    calibration: Option<BqCalibration>,
63}
64
65impl BinaryQuantizer {
66    /// Build an untrained binary quantizer.
67    ///
68    /// Every hot method returns [`IqdbError::InvalidConfig`] until
69    /// [`Quantizer::train`] succeeds.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use iqdb_quantize::BinaryQuantizer;
75    /// let _bq = BinaryQuantizer::new();
76    /// ```
77    #[must_use]
78    pub fn new() -> Self {
79        Self { calibration: None }
80    }
81
82    /// The trained dimension, if any.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use iqdb_quantize::{BinaryQuantizer, Quantizer};
88    ///
89    /// let mut bq = BinaryQuantizer::new();
90    /// assert_eq!(bq.dim(), None);
91    /// bq.train(&[&[0.0_f32, 1.0][..]]).expect("ok");
92    /// assert_eq!(bq.dim(), Some(2));
93    /// ```
94    #[must_use]
95    pub fn dim(&self) -> Option<usize> {
96        self.calibration.as_ref().map(|c| c.means.len())
97    }
98
99    fn calibration(&self) -> Result<&BqCalibration> {
100        self.calibration.as_ref().ok_or(IqdbError::InvalidConfig {
101            reason: "BinaryQuantizer has not been trained",
102        })
103    }
104}
105
106impl Quantizer for BinaryQuantizer {
107    type Quantized = BqCode;
108
109    #[tracing::instrument(
110        level = "info",
111        skip_all,
112        fields(quantizer = "bq", training_size = vectors.len()),
113    )]
114    fn train(&mut self, vectors: &[&[f32]]) -> Result<()> {
115        let dim = training_set(vectors).inspect_err(|err: &IqdbError| {
116            tracing::error!(
117                error.kind = err.kind(),
118                error.reason = err.caption(),
119                "binary quantizer training failed",
120            );
121        })?;
122        let mut sums = vec![0.0_f64; dim];
123        for v in vectors {
124            for (i, &x) in v.iter().enumerate() {
125                sums[i] += f64::from(x);
126            }
127        }
128        let n = vectors.len() as f64;
129        let means: Vec<f32> = sums.iter().map(|s| (s / n) as f32).collect();
130        // Defensive: a finite f64 average MUST cast to a finite f32 given
131        // finite inputs (training_set rejected non-finite). Belt-and-braces
132        // guard avoids storing a NaN threshold if a future change weakens
133        // that invariant.
134        if means.iter().any(|m| !m.is_finite()) {
135            let err = IqdbError::InvalidVector;
136            tracing::error!(
137                error.kind = err.kind(),
138                error.reason = err.caption(),
139                "binary quantizer training failed: non-finite mean",
140            );
141            return Err(err);
142        }
143        self.calibration = Some(BqCalibration { means });
144        Ok(())
145    }
146
147    fn quantize(&self, vector: &[f32]) -> Result<Self::Quantized> {
148        let cal = self.calibration()?;
149        finite_non_empty(vector)?;
150        dim_eq(cal.means.len(), vector.len())?;
151        Ok(BqCode {
152            words: pack_bits(vector, &cal.means),
153            dim: vector.len(),
154        })
155    }
156
157    fn dequantize(&self, quantized: &Self::Quantized) -> Result<Vec<f32>> {
158        let cal = self.calibration()?;
159        dim_eq(cal.means.len(), quantized.dim)?;
160        let mut out = Vec::with_capacity(quantized.dim);
161        for i in 0..quantized.dim {
162            let word = quantized.words[i / BITS_PER_WORD];
163            let bit = (word >> (i % BITS_PER_WORD)) & 1;
164            out.push(if bit == 1 { 1.0_f32 } else { -1.0_f32 });
165        }
166        Ok(out)
167    }
168
169    fn distance(
170        &self,
171        query: &[f32],
172        quantized: &Self::Quantized,
173        metric: DistanceMetric,
174    ) -> Result<f32> {
175        let cal = self.calibration()?;
176        finite_non_empty(query)?;
177        dim_eq(cal.means.len(), query.len())?;
178        dim_eq(cal.means.len(), quantized.dim)?;
179        if metric != DistanceMetric::Hamming {
180            return Err(IqdbError::InvalidMetric);
181        }
182        // Binarize the query against the same trained thresholds the stored
183        // code was built from, then Hamming via packed XOR + popcount.
184        let query_words = pack_bits(query, &cal.means);
185        let mut diff: u32 = 0;
186        for (q, c) in query_words.iter().zip(quantized.words.iter()) {
187            diff = diff.saturating_add((q ^ c).count_ones());
188        }
189        Ok(diff as f32)
190    }
191}
192
193/// Pack one bit per component of `vector` into `u64` words, with the bit
194/// set when `vector[i] >= means[i]`. The trailing word's unused high bits
195/// are zero so they cannot contribute to Hamming distance.
196fn pack_bits(vector: &[f32], means: &[f32]) -> Vec<u64> {
197    let dim = vector.len();
198    let words = dim.div_ceil(BITS_PER_WORD);
199    let mut out = vec![0_u64; words];
200    for i in 0..dim {
201        if vector[i] >= means[i] {
202            out[i / BITS_PER_WORD] |= 1_u64 << (i % BITS_PER_WORD);
203        }
204    }
205    out
206}
207
208#[cfg(test)]
209mod tests {
210    #![allow(clippy::unwrap_used)]
211
212    use super::*;
213    use iqdb_types::{DistanceMetric, IqdbError};
214
215    fn trained_unit() -> BinaryQuantizer {
216        let mut bq = BinaryQuantizer::new();
217        bq.train(&[&[0.0_f32, 1.0, 2.0][..], &[2.0_f32, 1.0, 0.0][..]])
218            .unwrap();
219        bq
220    }
221
222    #[test]
223    fn quantize_before_train_returns_invalid_config() {
224        let bq = BinaryQuantizer::new();
225        let err = bq.quantize(&[0.5_f32, 0.5]).unwrap_err();
226        assert!(
227            matches!(err, IqdbError::InvalidConfig { .. }),
228            "expected InvalidConfig, got {err:?}",
229        );
230    }
231
232    #[test]
233    fn distance_before_train_returns_invalid_config() {
234        let bq = BinaryQuantizer::new();
235        let code = BqCode {
236            words: vec![0],
237            dim: 3,
238        };
239        let err = bq
240            .distance(&[0.0_f32, 0.0, 0.0], &code, DistanceMetric::Hamming)
241            .unwrap_err();
242        assert!(
243            matches!(err, IqdbError::InvalidConfig { .. }),
244            "expected InvalidConfig, got {err:?}",
245        );
246    }
247
248    #[test]
249    fn dequantize_before_train_returns_invalid_config() {
250        let bq = BinaryQuantizer::new();
251        let code = BqCode {
252            words: vec![0],
253            dim: 3,
254        };
255        let err = bq.dequantize(&code).unwrap_err();
256        assert!(
257            matches!(err, IqdbError::InvalidConfig { .. }),
258            "expected InvalidConfig, got {err:?}",
259        );
260    }
261
262    #[test]
263    fn train_empty_set_returns_invalid_config() {
264        let mut bq = BinaryQuantizer::new();
265        let empty: [&[f32]; 0] = [];
266        let err = bq.train(&empty).unwrap_err();
267        assert!(
268            matches!(err, IqdbError::InvalidConfig { .. }),
269            "expected InvalidConfig, got {err:?}",
270        );
271    }
272
273    #[test]
274    fn train_inconsistent_dim_returns_dimension_mismatch() {
275        let mut bq = BinaryQuantizer::new();
276        let a = [0.0_f32, 1.0, 2.0];
277        let b = [1.0_f32, 0.0];
278        let err = bq.train(&[&a[..], &b[..]]).unwrap_err();
279        assert_eq!(
280            err,
281            IqdbError::DimensionMismatch {
282                expected: 3,
283                found: 2,
284            },
285        );
286    }
287
288    #[test]
289    fn train_non_finite_returns_invalid_vector() {
290        let mut bq = BinaryQuantizer::new();
291        let v = [1.0_f32, f32::NAN];
292        assert_eq!(bq.train(&[&v[..]]).unwrap_err(), IqdbError::InvalidVector,);
293    }
294
295    #[test]
296    fn quantize_dim_mismatch_returns_dimension_mismatch() {
297        let bq = trained_unit();
298        let err = bq.quantize(&[0.5_f32, 0.5]).unwrap_err();
299        assert_eq!(
300            err,
301            IqdbError::DimensionMismatch {
302                expected: 3,
303                found: 2,
304            },
305        );
306    }
307
308    #[test]
309    fn quantize_non_finite_returns_invalid_vector() {
310        let bq = trained_unit();
311        let err = bq.quantize(&[0.5_f32, f32::NEG_INFINITY, 0.5]).unwrap_err();
312        assert_eq!(err, IqdbError::InvalidVector);
313    }
314
315    #[test]
316    fn distance_rejects_non_hamming_metrics() {
317        let bq = trained_unit();
318        let code = bq.quantize(&[0.5_f32, 0.5, 0.5]).unwrap();
319        let q = [0.5_f32, 0.5, 0.5];
320        for metric in [
321            DistanceMetric::Cosine,
322            DistanceMetric::DotProduct,
323            DistanceMetric::Euclidean,
324            DistanceMetric::Manhattan,
325        ] {
326            assert_eq!(
327                bq.distance(&q, &code, metric).unwrap_err(),
328                IqdbError::InvalidMetric,
329                "metric {metric:?} must be rejected",
330            );
331        }
332    }
333
334    #[test]
335    fn distance_self_consistency_is_zero() {
336        // The query path MUST binarize against the same trained means used
337        // to build the stored code. If it didn't, `distance(v, code(v))`
338        // would generally be > 0.
339        let bq = trained_unit();
340        let v = [0.4_f32, 1.1, 1.9];
341        let code = bq.quantize(&v).unwrap();
342        let d = bq.distance(&v, &code, DistanceMetric::Hamming).unwrap();
343        assert_eq!(d, 0.0);
344    }
345
346    fn naive_hamming(a: &[u64], b: &[u64]) -> u32 {
347        a.iter()
348            .zip(b.iter())
349            .map(|(x, y)| (x ^ y).count_ones())
350            .sum()
351    }
352
353    #[test]
354    fn hamming_matches_naive_popcount_reference() {
355        let mut bq = BinaryQuantizer::new();
356        // 70-dim training -> two words per code, with 6 padding bits in
357        // the trailing word that MUST stay zero.
358        let dim = 70;
359        let a: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
360        let b: Vec<f32> = (0..dim).map(|i| (i as f32).cos()).collect();
361        bq.train(&[&a[..], &b[..]]).unwrap();
362
363        let query: Vec<f32> = (0..dim).map(|i| ((i as f32) * 0.5).sin()).collect();
364        let code = bq.quantize(&b).unwrap();
365        let d = bq.distance(&query, &code, DistanceMetric::Hamming).unwrap();
366
367        let cal = bq.calibration.as_ref().unwrap();
368        let query_words = pack_bits(&query, &cal.means);
369        let expected = naive_hamming(&query_words, &code.words);
370        assert_eq!(d as u32, expected);
371    }
372
373    #[test]
374    fn quantize_zeros_padding_bits_for_dim_not_multiple_of_64() {
375        let dims = [63_usize, 64, 65, 127, 128, 129];
376        for &dim in &dims {
377            // Train with all-zeros and all-ones so the per-dim mean is 0.5;
378            // an all-ones query then sets every meaningful bit, leaving the
379            // padding bits in the trailing word as the only source of 0s
380            // above `dim`.
381            let zeros = vec![0.0_f32; dim];
382            let ones = vec![1.0_f32; dim];
383            let mut bq = BinaryQuantizer::new();
384            bq.train(&[&zeros[..], &ones[..]]).unwrap();
385
386            let code = bq.quantize(&ones).unwrap();
387            assert_eq!(code.dim, dim);
388            assert_eq!(code.words.len(), dim.div_ceil(BITS_PER_WORD));
389
390            let used_in_last = dim % BITS_PER_WORD;
391            if used_in_last != 0 {
392                let last = *code.words.last().unwrap();
393                let padding_mask = !0_u64 << used_in_last;
394                assert_eq!(
395                    last & padding_mask,
396                    0,
397                    "dim={dim}: padding bits must be zero",
398                );
399            }
400        }
401    }
402}