Skip to main content

diskann_quantization/binary/
quantizer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::{
7    AsFunctor, CompressInto,
8    bits::{Binary, MutBitSlice, Representation},
9    distances::Hamming,
10};
11
12/// A simple, training-free binary quantizer.
13///
14/// The canonical implementation of compression with a binary quantizer maps negative values
15/// to -1 (encoded as a bit 0) and positive values to 1. Distances can then be approximated
16/// using the Hamming distance between the compressed vectors.
17///
18/// As a convenience `diskann_quantization::bits::SquaredL2` and
19/// `diskann_quantization::bits::InnerProduct` may be used, which correctly dispatch to the
20/// proper post-op for similarity scores versus mathematical values.
21///
22/// # Example
23/// ```rust
24/// use diskann_quantization::{
25///     AsFunctor, CompressInto,
26///     distances::Hamming,
27///     binary::BinaryQuantizer,
28///     bits::{BoxedBitSlice, Binary},
29/// };
30///
31/// use diskann_utils::{Reborrow, ReborrowMut};
32/// use diskann_vector::{
33///     PureDistanceFunction, DistanceFunction, MathematicalValue,
34/// };
35///
36/// let x = vec![-1, 1, 1, -1, 1];
37/// let y = vec![1, -1, 1, -1, -1];
38///
39/// // Create a quantizer
40/// let quantizer = BinaryQuantizer;
41///
42/// // Create output vectors for compression.
43/// let mut bx = BoxedBitSlice::<1, Binary>::new_boxed(x.len());
44/// let mut by = BoxedBitSlice::<1, Binary>::new_boxed(y.len());
45///
46/// // Do the compression.
47/// quantizer.compress_into(x.as_slice(), bx.reborrow_mut()).unwrap();
48/// quantizer.compress_into(y.as_slice(), by.reborrow_mut()).unwrap();
49///
50/// // Because our inputs are limited to -1 and 1, the compression is perfect.
51/// assert_eq!(bx.get(0).unwrap(), x[0]);
52/// assert_eq!(bx.get(1).unwrap(), x[1]);
53///
54/// // But the compressed vectors only consume a single byte.
55/// assert_eq!(bx.bytes(), 1);
56///
57/// // Lets compute some distances!
58/// assert_eq!(
59///     Hamming::evaluate(bx.reborrow(), by.reborrow()).unwrap(),
60///     MathematicalValue::<u32>::new(3)
61/// );
62///
63/// // We can also use the `AsFunctor` trait if we want more uniformity.
64/// let f: Hamming = quantizer.as_functor();
65/// assert_eq!(
66///     f.evaluate_similarity(bx.reborrow(), by.reborrow()).unwrap(),
67///     MathematicalValue::<u32>::new(3)
68/// );
69/// ```
70#[derive(Debug, Clone, Copy)]
71pub struct BinaryQuantizer;
72
73/////////////////
74// Compression //
75/////////////////
76
77impl<T> CompressInto<&[T], MutBitSlice<'_, 1, Binary>> for BinaryQuantizer
78where
79    T: PartialOrd + Default,
80{
81    type Error = std::convert::Infallible;
82    type Output = ();
83
84    /// Compress the source vector into a binary representation.
85    ///
86    /// This works by mapping positive numbers (as defined by `v > T::default()`) to 1 and
87    /// negative numbers (as defined by `v <= T::default()`) to -1.
88    ///
89    /// # Panics
90    ///
91    /// Panics if `from.len() != into.len()`.
92    fn compress_into(
93        &self,
94        from: &[T],
95        mut into: MutBitSlice<'_, 1, Binary>,
96    ) -> Result<(), Self::Error> {
97        // Check 1
98        assert_eq!(from.len(), into.len());
99        from.iter().enumerate().for_each(|(i, v)| {
100            // Note: Both 1 and -1 are in the domain of `Binary`.
101            let v: u8 = if v > &T::default() {
102                Binary::encode_unchecked(1)
103            } else {
104                Binary::encode_unchecked(-1)
105            };
106
107            // SAFETY: From check 1, we know that `i < into.len()`.
108            unsafe { into.set_unchecked(i, v) };
109        });
110        Ok(())
111    }
112}
113
114///////////////
115// AsFunctor //
116///////////////
117
118impl AsFunctor<Hamming> for BinaryQuantizer {
119    /// Return a [`crate::distances::Hamming`] functor for performing distance computations
120    /// on bit vectors.
121    fn as_functor(&self) -> Hamming {
122        Hamming
123    }
124}
125
126///////////
127// Tests //
128///////////
129
130#[cfg(test)]
131mod tests {
132    use diskann_utils::{ReborrowMut, views::Matrix};
133    use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
134
135    use super::*;
136    use crate::bits::{Binary, BoxedBitSlice};
137
138    fn test_compression_impl(len: usize, rng: &mut StdRng) {
139        let mut domain = [-10, -1, 0, 1, 10];
140        let mut test_pattern = Matrix::<i32>::new(0, domain.len(), len);
141
142        // Fill the test patterns randomly.
143        for col in 0..len {
144            domain.shuffle(rng);
145            for row in 0..test_pattern.nrows() {
146                test_pattern[(row, col)] = domain[row];
147            }
148        }
149
150        let quantizer = BinaryQuantizer;
151        let mut binary = BoxedBitSlice::<1, Binary>::new_boxed(len);
152        for row in test_pattern.row_iter() {
153            quantizer.compress_into(row, binary.reborrow_mut()).unwrap();
154
155            // Check the compression.
156            for (i, r) in row.iter().enumerate() {
157                if *r > 0 {
158                    assert_eq!(binary.get(i).unwrap(), 1);
159                } else {
160                    assert_eq!(binary.get(i).unwrap(), -1);
161                }
162            }
163        }
164    }
165
166    #[test]
167    fn test_compression() {
168        let mut rng = StdRng::seed_from_u64(0x9673d0890bbb7231);
169        for len in 1..17 {
170            test_compression_impl(len, &mut rng);
171        }
172    }
173}