diskann_quantization/minmax/
recompress.rs1use super::vectors::{DataMutRef, DataRef, MinMaxCompensation};
7use crate::bits::{Representation, Unsigned};
8use crate::scalar::bit_scale;
9use crate::CompressInto;
10use thiserror::Error;
11
12#[derive(Debug, Clone, Copy)]
45pub struct Recompressor;
46
47#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
49pub enum RecompressError {
50 #[error("dimension mismatch: source has {src} dimensions, destination has {dst}")]
52 DimensionMismatch {
53 src: usize,
55 dst: usize,
57 },
58}
59
60macro_rules! impl_recompress {
62 ($n:literal -> $m:literal) => {
63 impl<'a, 'b> CompressInto<DataRef<'a, $n>, DataMutRef<'b, $m>> for Recompressor
64 where
65 Unsigned: Representation<$n> + Representation<$m>,
66 {
67 type Error = RecompressError;
68 type Output = ();
69
70 fn compress_into(
71 &self,
72 from: DataRef<'a, $n>,
73 to: DataMutRef<'b, $m>,
74 ) -> Result<(), Self::Error> {
75 recompress_kernel::<$n, $m>(from, to)
76 }
77 }
78 };
79}
80
81impl_recompress!(8 -> 4);
82impl_recompress!(8 -> 2);
83impl_recompress!(4 -> 2);
84
85#[inline(always)]
117fn recompress_kernel<const N: usize, const M: usize>(
118 from: DataRef<'_, N>,
119 mut to: DataMutRef<'_, M>,
120) -> Result<(), RecompressError>
121where
122 Unsigned: Representation<N> + Representation<M>,
123{
124 const { assert!(N > M, "source bit width must exceed target bits") };
125 const { assert!(M > 1, "target bit width must exceed 1") };
126
127 let dim = from.len();
129 if dim != to.vector().len() {
130 return Err(RecompressError::DimensionMismatch {
131 src: dim,
132 dst: to.vector().len(),
133 });
134 }
135
136 let src_meta = from.meta();
137 let src_a = src_meta.a;
138 let src_b = src_meta.b;
139
140 let scale_n = bit_scale::<N>();
141 let scale_m = bit_scale::<M>();
142 let code_scale = scale_m / scale_n;
143
144 let new_a = src_a / code_scale;
145 let new_b = src_b;
146
147 let from_vec = from.vector();
149 let mut to_vec = to.vector_mut();
150
151 let mut code_sum: f32 = 0.0;
152 let mut norm_squared: f32 = 0.0;
153
154 for i in 0..dim {
155 let old_code = unsafe { from_vec.get_unchecked(i) };
158 let old_code_f = old_code as f32;
159
160 let new_code_pre = (old_code_f * code_scale).round_ties_even();
162 let new_code = new_code_pre as u8;
163
164 unsafe { to_vec.set_unchecked(i, new_code) };
167
168 let new_code_f = new_code as f32;
170 code_sum += new_code_f;
171
172 let v_m = new_code_f * new_a + new_b;
174
175 norm_squared += v_m * v_m;
176 }
177
178 to.set_meta(MinMaxCompensation {
180 dim: dim as u32,
181 b: new_b,
182 a: new_a,
183 n: new_a * code_sum,
184 norm_squared,
185 });
186
187 Ok(())
188}
189
190#[cfg(test)]
191mod recompress_tests {
192 use std::num::NonZeroUsize;
193
194 use diskann_utils::{Reborrow, ReborrowMut};
195 use rand::{
196 distr::{Distribution, Uniform},
197 rngs::StdRng,
198 SeedableRng,
199 };
200
201 use super::*;
202 use crate::{
203 algorithms::{transforms::NullTransform, Transform},
204 minmax::quantizer::MinMaxQuantizer,
205 minmax::vectors::Data,
206 num::Positive,
207 };
208
209 fn reconstruct<const NBITS: usize>(v: DataRef<'_, NBITS>) -> Vec<f32>
211 where
212 Unsigned: Representation<NBITS>,
213 {
214 let meta = v.meta();
215 (0..v.len())
216 .map(|i| v.vector().get(i).unwrap() as f32 * meta.a + meta.b)
217 .collect()
218 }
219
220 fn test_recompress_random<const N: usize, const M: usize>(dim: usize, rng: &mut StdRng)
222 where
223 Unsigned: Representation<N> + Representation<M>,
224 MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, N>>
225 + for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, M>>,
226 Recompressor: for<'a, 'b> CompressInto<DataRef<'a, N>, DataMutRef<'b, M>, Output = ()>,
227 {
228 let distribution = Uniform::new_inclusive::<f32, f32>(-1.0, 1.0).unwrap();
229 let quantizer = MinMaxQuantizer::new(
230 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
231 Positive::new(1.0).unwrap(),
232 );
233 let recompressor = Recompressor;
234
235 let vector: Vec<f32> = distribution.sample_iter(rng).take(dim).collect();
237 let mut encoded_n = Data::<N>::new_boxed(dim);
238 quantizer
239 .compress_into(&*vector, encoded_n.reborrow_mut())
240 .unwrap();
241
242 let mut encoded_m = Data::<M>::new_boxed(dim);
244 recompressor
245 .compress_into(encoded_n.reborrow(), encoded_m.reborrow_mut())
246 .unwrap();
247
248 let meta_m = encoded_m.meta();
250
251 assert_eq!(meta_m.dim as usize, dim, "Dimension should be preserved");
252
253 let expected_code_sum: f32 = (0..dim)
258 .map(|i| encoded_m.vector().get(i).unwrap() as f32)
259 .sum();
260 let computed_code_sum = meta_m.n / meta_m.a;
261 assert!(
262 (computed_code_sum - expected_code_sum).abs() < 1e-4,
263 "Code sum mismatch: expected {}, got {}",
264 expected_code_sum,
265 computed_code_sum
266 );
267
268 let reconstructed_m = reconstruct(encoded_m.reborrow());
270 let expected_norm_sq: f32 = reconstructed_m.iter().map(|x| x * x).sum();
271 assert!(
272 (meta_m.norm_squared - expected_norm_sq).abs() < 1e-4,
273 "norm_squared mismatch: expected {}, got {}",
274 expected_norm_sq,
275 meta_m.norm_squared
276 );
277
278 let mut direct_m = Data::<M>::new_boxed(dim);
280 quantizer
281 .compress_into(&*vector, direct_m.reborrow_mut())
282 .unwrap();
283
284 let reconstructed_direct_m = reconstruct(direct_m.reborrow());
285 reconstructed_direct_m
286 .iter()
287 .zip(reconstructed_m.iter())
288 .for_each(|(x, y)| {
289 assert!(
290 (*x - *y).abs() < 1e-4,
291 "Direct compression and recompress vectors are not close"
292 )
293 });
294 }
295
296 cfg_if::cfg_if! {
297 if #[cfg(miri)] {
298 const TRIALS: usize = 2;
299 const MAX_DIM: usize = 20;
300 } else {
301 const TRIALS: usize = 10;
302 const MAX_DIM: usize = 100;
303 }
304 }
305
306 macro_rules! test_recompress_pair {
307 ($name:ident, $n:literal -> $m:literal, $seed:literal) => {
308 #[test]
309 fn $name() {
310 let mut rng = StdRng::seed_from_u64($seed);
311 for dim in 10..=MAX_DIM {
312 for _ in 0..TRIALS {
313 test_recompress_random::<$n, $m>(dim, &mut rng);
314 }
315 }
316 }
317 };
318 }
319
320 test_recompress_pair!(recompress_8_to_4, 8 -> 4, 0xabc123def456);
321 test_recompress_pair!(recompress_8_to_2, 8 -> 2, 0xdef456abc123);
322 test_recompress_pair!(recompress_4_to_2, 4 -> 2, 0x456def123abc);
323
324 #[test]
325 fn test_dimension_mismatch_error() {
326 let recompressor = Recompressor;
327
328 let mut src = Data::<8>::new_boxed(10);
329 src.set_meta(MinMaxCompensation {
330 dim: 10,
331 b: 0.0,
332 a: 1.0,
333 n: 0.0,
334 norm_squared: 0.0,
335 });
336
337 let mut dst = Data::<4>::new_boxed(15); let result: Result<(), RecompressError> =
340 recompressor.compress_into(src.reborrow(), dst.reborrow_mut());
341
342 assert_eq!(
343 result.unwrap_err(),
344 RecompressError::DimensionMismatch { src: 10, dst: 15 }
345 );
346 }
347
348 #[test]
349 fn test_constant_value_vector() {
350 let dim = 30;
351 let quantizer = MinMaxQuantizer::new(
352 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
353 Positive::new(1.0).unwrap(),
354 );
355 let recompressor = Recompressor;
356
357 let constant_value = 42.5f32;
358 let vector = vec![constant_value; dim];
359
360 let mut encoded_8 = Data::<8>::new_boxed(dim);
362 quantizer
363 .compress_into(&*vector, encoded_8.reborrow_mut())
364 .unwrap();
365
366 let mut encoded_4 = Data::<4>::new_boxed(dim);
368 recompressor
369 .compress_into(encoded_8.reborrow(), encoded_4.reborrow_mut())
370 .unwrap();
371
372 let first_code = encoded_4.vector().get(0).unwrap();
374 for i in 1..dim {
375 assert_eq!(
376 encoded_4.vector().get(i).unwrap(),
377 first_code,
378 "All codes should be identical for constant-value vector"
379 );
380 }
381
382 let reconstructed = reconstruct(encoded_4.reborrow());
384 for &val in &reconstructed {
385 assert!(
386 (val - constant_value).abs() < 1.0,
387 "Reconstructed value {} should be close to original {}",
388 val,
389 constant_value
390 );
391 }
392 }
393}