Skip to main content

diskann_quantization/minmax/
recompress.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::vectors::{DataMutRef, DataRef, MinMaxCompensation};
7use crate::CompressInto;
8use crate::bits::{Representation, Unsigned};
9use crate::scalar::bit_scale;
10use thiserror::Error;
11
12/// Recompression utilities for MinMax quantized vectors.
13///
14/// This struct provides functionality to further compress MinMax quantized
15/// vectors from a source bitrate `N` to a target bitrate `M` for `N` > `M`.
16///
17/// # Notes
18/// - Currently this API only supports the following conversions: 8 -> 4, 8 -> 2, 4 -> 2
19///
20/// # Example
21///
22/// ```rust
23/// use std::num::NonZeroUsize;
24/// use diskann_quantization::algorithms::{Transform, transforms::NullTransform};
25/// use diskann_quantization::minmax::{Data, MinMaxQuantizer, Recompressor};
26/// use diskann_quantization::num::Positive;
27/// use diskann_quantization::CompressInto;
28/// use diskann_utils::{Reborrow, ReborrowMut};
29///
30/// // Create a quantizer and compress an f32 vector to 8-bit
31/// let vector = vec![0.1, -0.5, 0.8, -0.2];
32/// let quantizer = MinMaxQuantizer::new(
33///     Transform::Null(NullTransform::new(NonZeroUsize::new(4).unwrap())),
34///     Positive::new(1.0).unwrap(),
35/// );
36///
37/// let mut encoded_8 = Data::<8>::new_boxed(4);
38/// quantizer.compress_into(vector.as_slice(), encoded_8.reborrow_mut()).unwrap();
39///
40/// // Recompress from 8-bit to 4-bit
41/// let mut encoded_4 = Data::<4>::new_boxed(4);
42/// Recompressor.compress_into(encoded_8.reborrow(), encoded_4.reborrow_mut()).unwrap();
43/// ```
44#[derive(Debug, Clone, Copy)]
45pub struct Recompressor;
46
47/// Error type for recompression operations.
48#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
49pub enum RecompressError {
50    /// Source and destination vectors have different dimensions.
51    #[error("dimension mismatch: source has {src} dimensions, destination has {dst}")]
52    DimensionMismatch {
53        /// Dimension of the source vector.
54        src: usize,
55        /// Dimension of the destination vector.
56        dst: usize,
57    },
58}
59
60/// Macro to implement `CompressInto<DataRef<'_, N>, DataMutRef<'_, M>>` for M > 1.
61macro_rules! impl_recompress {
62    ($n:literal -> $m:literal) => {
63        impl<'a, 'b> CompressInto<DataRef<'a, $n>, DataMutRef<'b, $m>> for Recompressor
64        where
65            Unsigned: Representation<$n> + Representation<$m>,
66        {
67            type Error = RecompressError;
68            type Output = ();
69
70            fn compress_into(
71                &self,
72                from: DataRef<'a, $n>,
73                to: DataMutRef<'b, $m>,
74            ) -> Result<(), Self::Error> {
75                recompress_kernel::<$n, $m>(from, to)
76            }
77        }
78    };
79}
80
81impl_recompress!(8 -> 4);
82impl_recompress!(8 -> 2);
83impl_recompress!(4 -> 2);
84
85////////////////////////////////////
86// Recompression Kernel for M > 1 //
87////////////////////////////////////
88
89/// Recompress N-bit codes to M-bit codes where M > 1.
90///
91/// Recall from the algorithm for minmax described in [`crate::minmax::MinMaxQuantizer`],
92/// the encoding of a vector `X` into `N`-bits per dimension using minmax is given by:
93///
94/// ```text
95/// X' = round((X - b) * a).clamp(0, 2^n - 1))
96/// ```
97///
98/// where `b = min_i X_i` and `a = max_i X_i - b / (2^N - 1)`.
99///
100/// This routine to recompress to `M`-bits is a simple recomputation
101/// of the codes, assuming the range of values `[b, b + a * (2^N - 1)]`
102/// remains the same.
103///
104/// # Algorithm
105///
106/// ```text
107/// Transformation:
108///   scale_M = (2^M - 1)
109///   scale_N = (2^N - 1)
110///   
111///   old_code = round((X - b) * scale_N)
112///   reconstructed_value = X' = (old_code / scale_N) + b
113///   new_code = round((X' - b) * scale_M)
114///            = round(old_code * scale_M / scale_N)
115/// ```
116#[inline(always)]
117fn recompress_kernel<const N: usize, const M: usize>(
118    from: DataRef<'_, N>,
119    mut to: DataMutRef<'_, M>,
120) -> Result<(), RecompressError>
121where
122    Unsigned: Representation<N> + Representation<M>,
123{
124    const { assert!(N > M, "source bit width must exceed target bits") };
125    const { assert!(M > 1, "target bit width must exceed 1") };
126
127    // Validate dimensions
128    let dim = from.len();
129    if dim != to.vector().len() {
130        return Err(RecompressError::DimensionMismatch {
131            src: dim,
132            dst: to.vector().len(),
133        });
134    }
135
136    let src_meta = from.meta();
137    let src_a = src_meta.a;
138    let src_b = src_meta.b;
139
140    let scale_n = bit_scale::<N>();
141    let scale_m = bit_scale::<M>();
142    let code_scale = scale_m / scale_n;
143
144    let new_a = src_a / code_scale;
145    let new_b = src_b;
146
147    // Single pass: encode and compute statistics
148    let from_vec = from.vector();
149    let mut to_vec = to.vector_mut();
150
151    let mut code_sum: f32 = 0.0;
152    let mut norm_squared: f32 = 0.0;
153
154    for i in 0..dim {
155        // Read source code
156        // SAFETY: we checked that `dim == from.len() == src.len()`
157        let old_code = unsafe { from_vec.get_unchecked(i) };
158        let old_code_f = old_code as f32;
159
160        // new code
161        let new_code_pre = (old_code_f * code_scale).round_ties_even();
162        let new_code = new_code_pre as u8;
163
164        // Write destination code
165        // SAFETY: we checked that `dim == from.len() == src.len()`
166        unsafe { to_vec.set_unchecked(i, new_code) };
167
168        // Accumulate statistics using the actual truncated integer code
169        let new_code_f = new_code as f32;
170        code_sum += new_code_f;
171
172        // Reconstruct value for norm computation
173        let v_m = new_code_f * new_a + new_b;
174
175        norm_squared += v_m * v_m;
176    }
177
178    // Construct metadata
179    to.set_meta(MinMaxCompensation {
180        dim: dim as u32,
181        b: new_b,
182        a: new_a,
183        n: new_a * code_sum,
184        norm_squared,
185    });
186
187    Ok(())
188}
189
190#[cfg(test)]
191mod recompress_tests {
192    use std::num::NonZeroUsize;
193
194    use diskann_utils::{Reborrow, ReborrowMut};
195    use rand::{
196        SeedableRng,
197        distr::{Distribution, Uniform},
198        rngs::StdRng,
199    };
200
201    use super::*;
202    use crate::{
203        algorithms::{Transform, transforms::NullTransform},
204        minmax::quantizer::MinMaxQuantizer,
205        minmax::vectors::Data,
206        num::Positive,
207    };
208
209    /// Reconstruct a MinMax quantized vector to f32 values.
210    fn reconstruct<const NBITS: usize>(v: DataRef<'_, NBITS>) -> Vec<f32>
211    where
212        Unsigned: Representation<NBITS>,
213    {
214        let meta = v.meta();
215        (0..v.len())
216            .map(|i| v.vector().get(i).unwrap() as f32 * meta.a + meta.b)
217            .collect()
218    }
219
220    /// Test recompression from N bits to M bits with random vectors.
221    fn test_recompress_random<const N: usize, const M: usize>(dim: usize, rng: &mut StdRng)
222    where
223        Unsigned: Representation<N> + Representation<M>,
224        MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, N>>
225            + for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, M>>,
226        Recompressor: for<'a, 'b> CompressInto<DataRef<'a, N>, DataMutRef<'b, M>, Output = ()>,
227    {
228        let distribution = Uniform::new_inclusive::<f32, f32>(-1.0, 1.0).unwrap();
229        let quantizer = MinMaxQuantizer::new(
230            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
231            Positive::new(1.0).unwrap(),
232        );
233        let recompressor = Recompressor;
234
235        // Generate random vector and compress to N bits
236        let vector: Vec<f32> = distribution.sample_iter(rng).take(dim).collect();
237        let mut encoded_n = Data::<N>::new_boxed(dim);
238        quantizer
239            .compress_into(&*vector, encoded_n.reborrow_mut())
240            .unwrap();
241
242        // Recompress to M bits
243        let mut encoded_m = Data::<M>::new_boxed(dim);
244        recompressor
245            .compress_into(encoded_n.reborrow(), encoded_m.reborrow_mut())
246            .unwrap();
247
248        // Verify metadata
249        let meta_m = encoded_m.meta();
250
251        assert_eq!(meta_m.dim as usize, dim, "Dimension should be preserved");
252
253        // With reconstruction-based algorithm, b and a are recomputed optimally
254        // for the M-bit quantization grid, so we don't check for preservation
255
256        // Verify code_sum (n = a * code_sum)
257        let expected_code_sum: f32 = (0..dim)
258            .map(|i| encoded_m.vector().get(i).unwrap() as f32)
259            .sum();
260        let computed_code_sum = meta_m.n / meta_m.a;
261        assert!(
262            (computed_code_sum - expected_code_sum).abs() < 1e-4,
263            "Code sum mismatch: expected {}, got {}",
264            expected_code_sum,
265            computed_code_sum
266        );
267
268        // Verify norm_squared
269        let reconstructed_m = reconstruct(encoded_m.reborrow());
270        let expected_norm_sq: f32 = reconstructed_m.iter().map(|x| x * x).sum();
271        assert!(
272            (meta_m.norm_squared - expected_norm_sq).abs() < 1e-4,
273            "norm_squared mismatch: expected {}, got {}",
274            expected_norm_sq,
275            meta_m.norm_squared
276        );
277
278        //Verify precision wrt to direct encoding is close
279        let mut direct_m = Data::<M>::new_boxed(dim);
280        quantizer
281            .compress_into(&*vector, direct_m.reborrow_mut())
282            .unwrap();
283
284        let reconstructed_direct_m = reconstruct(direct_m.reborrow());
285        reconstructed_direct_m
286            .iter()
287            .zip(reconstructed_m.iter())
288            .for_each(|(x, y)| {
289                assert!(
290                    (*x - *y).abs() < 1e-4,
291                    "Direct compression and recompress vectors are not close"
292                )
293            });
294    }
295
296    cfg_if::cfg_if! {
297        if #[cfg(miri)] {
298            const TRIALS: usize = 2;
299            const MAX_DIM: usize = 20;
300        } else {
301            const TRIALS: usize = 10;
302            const MAX_DIM: usize = 100;
303        }
304    }
305
306    macro_rules! test_recompress_pair {
307        ($name:ident, $n:literal -> $m:literal, $seed:literal) => {
308            #[test]
309            fn $name() {
310                let mut rng = StdRng::seed_from_u64($seed);
311                for dim in 10..=MAX_DIM {
312                    for _ in 0..TRIALS {
313                        test_recompress_random::<$n, $m>(dim, &mut rng);
314                    }
315                }
316            }
317        };
318    }
319
320    test_recompress_pair!(recompress_8_to_4, 8 -> 4, 0xabc123def456);
321    test_recompress_pair!(recompress_8_to_2, 8 -> 2, 0xdef456abc123);
322    test_recompress_pair!(recompress_4_to_2, 4 -> 2, 0x456def123abc);
323
324    #[test]
325    fn test_dimension_mismatch_error() {
326        let recompressor = Recompressor;
327
328        let mut src = Data::<8>::new_boxed(10);
329        src.set_meta(MinMaxCompensation {
330            dim: 10,
331            b: 0.0,
332            a: 1.0,
333            n: 0.0,
334            norm_squared: 0.0,
335        });
336
337        let mut dst = Data::<4>::new_boxed(15); // Different dimension
338
339        let result: Result<(), RecompressError> =
340            recompressor.compress_into(src.reborrow(), dst.reborrow_mut());
341
342        assert_eq!(
343            result.unwrap_err(),
344            RecompressError::DimensionMismatch { src: 10, dst: 15 }
345        );
346    }
347
348    #[test]
349    fn test_constant_value_vector() {
350        let dim = 30;
351        let quantizer = MinMaxQuantizer::new(
352            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
353            Positive::new(1.0).unwrap(),
354        );
355        let recompressor = Recompressor;
356
357        let constant_value = 42.5f32;
358        let vector = vec![constant_value; dim];
359
360        // Compress to 8 bits
361        let mut encoded_8 = Data::<8>::new_boxed(dim);
362        quantizer
363            .compress_into(&*vector, encoded_8.reborrow_mut())
364            .unwrap();
365
366        // Recompress to 4 bits
367        let mut encoded_4 = Data::<4>::new_boxed(dim);
368        recompressor
369            .compress_into(encoded_8.reborrow(), encoded_4.reborrow_mut())
370            .unwrap();
371
372        // For constant value, all codes should be the same
373        let first_code = encoded_4.vector().get(0).unwrap();
374        for i in 1..dim {
375            assert_eq!(
376                encoded_4.vector().get(i).unwrap(),
377                first_code,
378                "All codes should be identical for constant-value vector"
379            );
380        }
381
382        // Reconstruction should be close to original
383        let reconstructed = reconstruct(encoded_4.reborrow());
384        for &val in &reconstructed {
385            assert!(
386                (val - constant_value).abs() < 1.0,
387                "Reconstructed value {} should be close to original {}",
388                val,
389                constant_value
390            );
391        }
392    }
393}