diskann_quantization/binary/quantizer.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::{
7 AsFunctor, CompressInto,
8 bits::{Binary, MutBitSlice, Representation},
9 distances::Hamming,
10};
11
12/// A simple, training-free binary quantizer.
13///
14/// The canonical implementation of compression with a binary quantizer maps negative values
15/// to -1 (encoded as a bit 0) and positive values to 1. Distances can then be approximated
16/// using the Hamming distance between the compressed vectors.
17///
18/// As a convenience `diskann_quantization::bits::SquaredL2` and
19/// `diskann_quantization::bits::InnerProduct` may be used, which correctly dispatch to the
20/// proper post-op for similarity scores versus mathematical values.
21///
22/// # Example
23/// ```rust
24/// use diskann_quantization::{
25/// AsFunctor, CompressInto,
26/// distances::Hamming,
27/// binary::BinaryQuantizer,
28/// bits::{BoxedBitSlice, Binary},
29/// };
30///
31/// use diskann_utils::{Reborrow, ReborrowMut};
32/// use diskann_vector::{
33/// PureDistanceFunction, DistanceFunction, MathematicalValue,
34/// };
35///
36/// let x = vec![-1, 1, 1, -1, 1];
37/// let y = vec![1, -1, 1, -1, -1];
38///
39/// // Create a quantizer
40/// let quantizer = BinaryQuantizer;
41///
42/// // Create output vectors for compression.
43/// let mut bx = BoxedBitSlice::<1, Binary>::new_boxed(x.len());
44/// let mut by = BoxedBitSlice::<1, Binary>::new_boxed(y.len());
45///
46/// // Do the compression.
47/// quantizer.compress_into(x.as_slice(), bx.reborrow_mut()).unwrap();
48/// quantizer.compress_into(y.as_slice(), by.reborrow_mut()).unwrap();
49///
50/// // Because our inputs are limited to -1 and 1, the compression is perfect.
51/// assert_eq!(bx.get(0).unwrap(), x[0]);
52/// assert_eq!(bx.get(1).unwrap(), x[1]);
53///
54/// // But the compressed vectors only consume a single byte.
55/// assert_eq!(bx.bytes(), 1);
56///
57/// // Lets compute some distances!
58/// assert_eq!(
59/// Hamming::evaluate(bx.reborrow(), by.reborrow()).unwrap(),
60/// MathematicalValue::<u32>::new(3)
61/// );
62///
63/// // We can also use the `AsFunctor` trait if we want more uniformity.
64/// let f: Hamming = quantizer.as_functor();
65/// assert_eq!(
66/// f.evaluate_similarity(bx.reborrow(), by.reborrow()).unwrap(),
67/// MathematicalValue::<u32>::new(3)
68/// );
69/// ```
70#[derive(Debug, Clone, Copy)]
71pub struct BinaryQuantizer;
72
73/////////////////
74// Compression //
75/////////////////
76
77impl<T> CompressInto<&[T], MutBitSlice<'_, 1, Binary>> for BinaryQuantizer
78where
79 T: PartialOrd + Default,
80{
81 type Error = std::convert::Infallible;
82 type Output = ();
83
84 /// Compress the source vector into a binary representation.
85 ///
86 /// This works by mapping positive numbers (as defined by `v > T::default()`) to 1 and
87 /// negative numbers (as defined by `v <= T::default()`) to -1.
88 ///
89 /// # Panics
90 ///
91 /// Panics if `from.len() != into.len()`.
92 fn compress_into(
93 &self,
94 from: &[T],
95 mut into: MutBitSlice<'_, 1, Binary>,
96 ) -> Result<(), Self::Error> {
97 // Check 1
98 assert_eq!(from.len(), into.len());
99 from.iter().enumerate().for_each(|(i, v)| {
100 // Note: Both 1 and -1 are in the domain of `Binary`.
101 let v: u8 = if v > &T::default() {
102 Binary::encode_unchecked(1)
103 } else {
104 Binary::encode_unchecked(-1)
105 };
106
107 // SAFETY: From check 1, we know that `i < into.len()`.
108 unsafe { into.set_unchecked(i, v) };
109 });
110 Ok(())
111 }
112}
113
114///////////////
115// AsFunctor //
116///////////////
117
118impl AsFunctor<Hamming> for BinaryQuantizer {
119 /// Return a [`crate::distances::Hamming`] functor for performing distance computations
120 /// on bit vectors.
121 fn as_functor(&self) -> Hamming {
122 Hamming
123 }
124}
125
126///////////
127// Tests //
128///////////
129
130#[cfg(test)]
131mod tests {
132 use diskann_utils::{ReborrowMut, views::Matrix};
133 use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
134
135 use super::*;
136 use crate::bits::{Binary, BoxedBitSlice};
137
138 fn test_compression_impl(len: usize, rng: &mut StdRng) {
139 let mut domain = [-10, -1, 0, 1, 10];
140 let mut test_pattern = Matrix::<i32>::new(0, domain.len(), len);
141
142 // Fill the test patterns randomly.
143 for col in 0..len {
144 domain.shuffle(rng);
145 for row in 0..test_pattern.nrows() {
146 test_pattern[(row, col)] = domain[row];
147 }
148 }
149
150 let quantizer = BinaryQuantizer;
151 let mut binary = BoxedBitSlice::<1, Binary>::new_boxed(len);
152 for row in test_pattern.row_iter() {
153 quantizer.compress_into(row, binary.reborrow_mut()).unwrap();
154
155 // Check the compression.
156 for (i, r) in row.iter().enumerate() {
157 if *r > 0 {
158 assert_eq!(binary.get(i).unwrap(), 1);
159 } else {
160 assert_eq!(binary.get(i).unwrap(), -1);
161 }
162 }
163 }
164 }
165
166 #[test]
167 fn test_compression() {
168 let mut rng = StdRng::seed_from_u64(0x9673d0890bbb7231);
169 for len in 1..17 {
170 test_compression_impl(len, &mut rng);
171 }
172 }
173}