Skip to main content

next_plaid/
codec.rs

1//! Residual codec for quantization and decompression
2
3use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
4
5use crate::error::{Error, Result};
6
7/// Maximum memory (bytes) to allocate for nearest centroid computation in `compress_into_codes`.
8/// This limits the size of the [batch_size, num_centroids] scores matrix to prevent OOM errors
9/// with large centroid counts (e.g., 2.5M centroids).
10const MAX_NEAREST_CENTROID_MEMORY: usize = 4 * 1024 * 1024 * 1024; // 4GB
11
12/// Storage backend for centroids, supporting both owned arrays and memory-mapped files.
13///
14/// This enum enables `ResidualCodec` to work with centroids stored either:
15/// - In memory as an owned `Array2<f32>` (default, for `Index` and `LoadedIndex`)
16/// - Memory-mapped from disk (for `MmapIndex`, reducing RAM usage)
17pub enum CentroidStore {
18    /// Centroids stored as an owned ndarray (loaded into RAM)
19    Owned(Array2<f32>),
20    /// Centroids stored as a memory-mapped NPY file (OS-managed paging)
21    Mmap(crate::mmap::MmapNpyArray2F32),
22}
23
24impl CentroidStore {
25    /// Get a view of the centroids as ArrayView2.
26    ///
27    /// This is zero-copy for both owned and mmap variants.
28    pub fn view(&self) -> ArrayView2<'_, f32> {
29        match self {
30            CentroidStore::Owned(arr) => arr.view(),
31            CentroidStore::Mmap(mmap) => mmap.view(),
32        }
33    }
34
35    /// Get the number of centroids (rows).
36    pub fn nrows(&self) -> usize {
37        match self {
38            CentroidStore::Owned(arr) => arr.nrows(),
39            CentroidStore::Mmap(mmap) => mmap.nrows(),
40        }
41    }
42
43    /// Get the embedding dimension (columns).
44    pub fn ncols(&self) -> usize {
45        match self {
46            CentroidStore::Owned(arr) => arr.ncols(),
47            CentroidStore::Mmap(mmap) => mmap.ncols(),
48        }
49    }
50
51    /// Get a view of a single centroid row.
52    pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
53        match self {
54            CentroidStore::Owned(arr) => arr.row(idx),
55            CentroidStore::Mmap(mmap) => mmap.row(idx),
56        }
57    }
58
59    /// Get a view of rows [start..end] as ArrayView2.
60    ///
61    /// This is zero-copy for both owned and mmap variants.
62    pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
63        match self {
64            CentroidStore::Owned(arr) => arr.slice(s![start..end, ..]),
65            CentroidStore::Mmap(mmap) => mmap.slice_rows(start, end),
66        }
67    }
68}
69
70impl Clone for CentroidStore {
71    fn clone(&self) -> Self {
72        match self {
73            // For owned, just clone the array
74            CentroidStore::Owned(arr) => CentroidStore::Owned(arr.clone()),
75            // For mmap, we need to convert to owned since Mmap isn't Clone
76            CentroidStore::Mmap(mmap) => CentroidStore::Owned(mmap.to_owned()),
77        }
78    }
79}
80
81/// A codec that manages quantization parameters and lookup tables for the index.
82///
83/// This struct contains all tensors required to compress embeddings during indexing
84/// and decompress vectors during search. It uses pre-computed lookup tables to
85/// accelerate bit unpacking operations.
86#[derive(Clone)]
87pub struct ResidualCodec {
88    /// Number of bits used to represent each residual bucket (e.g., 2 or 4)
89    pub nbits: usize,
90    /// Coarse centroids (codebook) of shape `[num_centroids, dim]`.
91    /// Can be either owned (in-memory) or memory-mapped for reduced RAM usage.
92    pub centroids: CentroidStore,
93    /// Average residual vector, used to reduce reconstruction error
94    pub avg_residual: Array1<f32>,
95    /// Boundaries defining which bucket a residual value falls into
96    pub bucket_cutoffs: Option<Array1<f32>>,
97    /// Values (weights) corresponding to each quantization bucket
98    pub bucket_weights: Option<Array1<f32>>,
99    /// Lookup table (256 entries) for byte-to-bits unpacking
100    pub byte_reversed_bits_map: Vec<u8>,
101    /// Maps byte values directly to bucket indices for fast decompression
102    pub bucket_weight_indices_lookup: Option<Array2<usize>>,
103}
104
105impl ResidualCodec {
106    /// Creates a new ResidualCodec and pre-computes lookup tables.
107    ///
108    /// # Arguments
109    ///
110    /// * `nbits` - Number of bits per code (e.g., 2 bits = 4 buckets)
111    /// * `centroids` - Coarse centroids of shape `[num_centroids, dim]`
112    /// * `avg_residual` - Global average residual
113    /// * `bucket_cutoffs` - Quantization boundaries (optional, for indexing)
114    /// * `bucket_weights` - Reconstruction values (optional, for search)
115    pub fn new(
116        nbits: usize,
117        centroids: Array2<f32>,
118        avg_residual: Array1<f32>,
119        bucket_cutoffs: Option<Array1<f32>>,
120        bucket_weights: Option<Array1<f32>>,
121    ) -> Result<Self> {
122        Self::new_with_store(
123            nbits,
124            CentroidStore::Owned(centroids),
125            avg_residual,
126            bucket_cutoffs,
127            bucket_weights,
128        )
129    }
130
131    /// Creates a new ResidualCodec with a specified centroid storage backend.
132    ///
133    /// This is the internal constructor that supports both owned and mmap centroids.
134    pub fn new_with_store(
135        nbits: usize,
136        centroids: CentroidStore,
137        avg_residual: Array1<f32>,
138        bucket_cutoffs: Option<Array1<f32>>,
139        bucket_weights: Option<Array1<f32>>,
140    ) -> Result<Self> {
141        if nbits == 0 || 8 % nbits != 0 {
142            return Err(Error::Codec(format!(
143                "nbits must be a divisor of 8, got {}",
144                nbits
145            )));
146        }
147
148        // Build bit reversal map for unpacking
149        let nbits_mask = (1u32 << nbits) - 1;
150        let mut byte_reversed_bits_map = vec![0u8; 256];
151
152        for (i, byte_slot) in byte_reversed_bits_map.iter_mut().enumerate() {
153            let val = i as u32;
154            let mut out = 0u32;
155            let mut pos = 8i32;
156
157            while pos >= nbits as i32 {
158                let segment = (val >> (pos as u32 - nbits as u32)) & nbits_mask;
159
160                let mut rev_segment = 0u32;
161                for k in 0..nbits {
162                    if (segment & (1 << k)) != 0 {
163                        rev_segment |= 1 << (nbits - 1 - k);
164                    }
165                }
166
167                out |= rev_segment;
168
169                if pos > nbits as i32 {
170                    out <<= nbits;
171                }
172
173                pos -= nbits as i32;
174            }
175            *byte_slot = out as u8;
176        }
177
178        // Build lookup table for bucket weight indices
179        let keys_per_byte = 8 / nbits;
180        let bucket_weight_indices_lookup = if bucket_weights.is_some() {
181            let mask = (1usize << nbits) - 1;
182            let mut table = Array2::<usize>::zeros((256, keys_per_byte));
183
184            for byte_val in 0..256usize {
185                for k in (0..keys_per_byte).rev() {
186                    let shift = k * nbits;
187                    let index = (byte_val >> shift) & mask;
188                    table[[byte_val, keys_per_byte - 1 - k]] = index;
189                }
190            }
191            Some(table)
192        } else {
193            None
194        };
195
196        Ok(Self {
197            nbits,
198            centroids,
199            avg_residual,
200            bucket_cutoffs,
201            bucket_weights,
202            byte_reversed_bits_map,
203            bucket_weight_indices_lookup,
204        })
205    }
206
207    /// Returns the embedding dimension
208    pub fn embedding_dim(&self) -> usize {
209        self.centroids.ncols()
210    }
211
212    /// Returns the number of centroids
213    pub fn num_centroids(&self) -> usize {
214        self.centroids.nrows()
215    }
216
217    /// Returns a view of the centroids.
218    ///
219    /// This is zero-copy for both owned and mmap centroids.
220    pub fn centroids_view(&self) -> ArrayView2<'_, f32> {
221        self.centroids.view()
222    }
223
224    /// Compress embeddings into centroid codes using nearest neighbor search.
225    ///
226    /// Uses batch matrix multiplication for efficiency:
227    /// `scores = embeddings @ centroids.T  -> [N, K]`
228    /// `codes = argmax(scores, axis=1)     -> [N]`
229    ///
230    /// When the `cuda` feature is enabled and a GPU is available, this function
231    /// automatically uses CUDA acceleration. No code changes required.
232    ///
233    /// # Arguments
234    ///
235    /// * `embeddings` - Embeddings of shape `[N, dim]`
236    ///
237    /// # Returns
238    ///
239    /// Centroid indices of shape `[N]`
240    pub fn compress_into_codes(&self, embeddings: &Array2<f32>) -> Array1<usize> {
241        // Try CUDA acceleration if available
242        #[cfg(feature = "cuda")]
243        {
244            if let Some(ctx) = crate::cuda::get_global_context() {
245                let centroids = self.centroids_view();
246                match crate::cuda::compress_into_codes_cuda_batched(
247                    ctx,
248                    &embeddings.view(),
249                    &centroids,
250                    None,
251                ) {
252                    Ok(codes) => return codes,
253                    Err(e) => {
254                        eprintln!(
255                            "[next-plaid] CUDA compress_into_codes failed: {}, falling back to CPU",
256                            e
257                        );
258                    }
259                }
260            }
261        }
262
263        self.compress_into_codes_cpu(embeddings)
264    }
265
266    /// CPU implementation of compress_into_codes.
267    /// This is useful when you want to explicitly avoid CUDA overhead for small batches.
268    pub fn compress_into_codes_cpu(&self, embeddings: &Array2<f32>) -> Array1<usize> {
269        use rayon::prelude::*;
270
271        let n = embeddings.nrows();
272        if n == 0 {
273            return Array1::zeros(0);
274        }
275
276        // Get centroids view once (zero-copy for both owned and mmap)
277        let centroids = self.centroids_view();
278        let num_centroids = centroids.nrows();
279
280        // Dynamic batch size to stay within memory budget.
281        // The scores matrix has shape [batch_size, num_centroids] with f32 elements.
282        // With 2.5M centroids and 4GB budget: batch_size = 4GB / (2.5M * 4) = 400
283        let max_batch_by_memory =
284            MAX_NEAREST_CENTROID_MEMORY / (num_centroids * std::mem::size_of::<f32>());
285        let batch_size = max_batch_by_memory.clamp(1, 2048);
286
287        let mut all_codes = Vec::with_capacity(n);
288
289        for start in (0..n).step_by(batch_size) {
290            let end = (start + batch_size).min(n);
291            let batch = embeddings.slice(ndarray::s![start..end, ..]);
292
293            // Batch matrix multiplication: [batch, dim] @ [dim, K] -> [batch, K]
294            let scores = batch.dot(&centroids.t());
295
296            // Parallel argmax over each row
297            let batch_codes: Vec<usize> = scores
298                .axis_iter(Axis(0))
299                .into_par_iter()
300                .map(|row| {
301                    row.iter()
302                        .enumerate()
303                        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
304                        .map(|(idx, _)| idx)
305                        .unwrap_or(0)
306                })
307                .collect();
308
309            all_codes.extend(batch_codes);
310        }
311
312        Array1::from_vec(all_codes)
313    }
314
315    /// Quantize residuals into packed bytes.
316    ///
317    /// Uses vectorized bucket search and parallel processing for efficiency.
318    ///
319    /// # Arguments
320    ///
321    /// * `residuals` - Residual vectors of shape `[N, dim]`
322    ///
323    /// # Returns
324    ///
325    /// Packed residuals of shape `[N, dim * nbits / 8]` as bytes
326    pub fn quantize_residuals(&self, residuals: &Array2<f32>) -> Result<Array2<u8>> {
327        use rayon::prelude::*;
328
329        let cutoffs = self
330            .bucket_cutoffs
331            .as_ref()
332            .ok_or_else(|| Error::Codec("bucket_cutoffs required for quantization".into()))?;
333
334        let n = residuals.nrows();
335        let dim = residuals.ncols();
336        let packed_dim = dim * self.nbits / 8;
337        let nbits = self.nbits;
338
339        if n == 0 {
340            return Ok(Array2::zeros((0, packed_dim)));
341        }
342
343        // Convert cutoffs to a slice for faster access
344        let cutoffs_slice = cutoffs.as_slice().unwrap();
345
346        // Process rows in parallel
347        let packed_rows: Vec<Vec<u8>> = residuals
348            .axis_iter(Axis(0))
349            .into_par_iter()
350            .map(|row| {
351                let mut packed_row = vec![0u8; packed_dim];
352                let mut bit_idx = 0;
353
354                for &val in row.iter() {
355                    // Binary search for bucket (searchsorted equivalent)
356                    let bucket = cutoffs_slice.iter().filter(|&&c| val > c).count();
357
358                    // Pack bits directly into bytes
359                    for b in 0..nbits {
360                        let bit = ((bucket >> b) & 1) as u8;
361                        let byte_idx = bit_idx / 8;
362                        let bit_pos = 7 - (bit_idx % 8);
363                        packed_row[byte_idx] |= bit << bit_pos;
364                        bit_idx += 1;
365                    }
366                }
367
368                packed_row
369            })
370            .collect();
371
372        // Assemble into array
373        let mut packed = Array2::<u8>::zeros((n, packed_dim));
374        for (i, row) in packed_rows.into_iter().enumerate() {
375            for (j, val) in row.into_iter().enumerate() {
376                packed[[i, j]] = val;
377            }
378        }
379
380        Ok(packed)
381    }
382
383    /// Decompress residuals from packed bytes using lookup tables.
384    ///
385    /// # Arguments
386    ///
387    /// * `packed_residuals` - Packed residuals of shape `[N, packed_dim]`
388    /// * `codes` - Centroid codes of shape `[N]`
389    ///
390    /// # Returns
391    ///
392    /// Reconstructed embeddings of shape `[N, dim]`
393    pub fn decompress(
394        &self,
395        packed_residuals: &Array2<u8>,
396        codes: &ArrayView1<usize>,
397    ) -> Result<Array2<f32>> {
398        let bucket_weights = self
399            .bucket_weights
400            .as_ref()
401            .ok_or_else(|| Error::Codec("bucket_weights required for decompression".into()))?;
402
403        let lookup = self
404            .bucket_weight_indices_lookup
405            .as_ref()
406            .ok_or_else(|| Error::Codec("bucket_weight_indices_lookup required".into()))?;
407
408        let n = packed_residuals.nrows();
409        let dim = self.embedding_dim();
410
411        let mut output = Array2::<f32>::zeros((n, dim));
412
413        for i in 0..n {
414            // Get centroid for this embedding (zero-copy via CentroidStore)
415            let centroid = self.centroids.row(codes[i]);
416
417            // Unpack residuals
418            let mut residual_idx = 0;
419            for &byte_val in packed_residuals.row(i).iter() {
420                let reversed = self.byte_reversed_bits_map[byte_val as usize];
421                let indices = lookup.row(reversed as usize);
422
423                for &bucket_idx in indices.iter() {
424                    if residual_idx < dim {
425                        output[[i, residual_idx]] =
426                            centroid[residual_idx] + bucket_weights[bucket_idx];
427                        residual_idx += 1;
428                    }
429                }
430            }
431        }
432
433        // Normalize
434        for mut row in output.axis_iter_mut(Axis(0)) {
435            let norm = row.dot(&row).sqrt().max(1e-12);
436            row /= norm;
437        }
438
439        Ok(output)
440    }
441
442    /// Load codec from index directory
443    pub fn load_from_dir(index_path: &std::path::Path) -> Result<Self> {
444        use ndarray_npy::ReadNpyExt;
445        use std::fs::File;
446
447        let centroids_path = index_path.join("centroids.npy");
448        let centroids: Array2<f32> = Array2::read_npy(
449            File::open(&centroids_path)
450                .map_err(|e| Error::IndexLoad(format!("Failed to open centroids.npy: {}", e)))?,
451        )
452        .map_err(|e| Error::IndexLoad(format!("Failed to read centroids.npy: {}", e)))?;
453
454        let avg_residual_path = index_path.join("avg_residual.npy");
455        let avg_residual: Array1<f32> =
456            Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
457                Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
458            })?)
459            .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
460
461        let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
462        let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
463            Some(
464                Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
465                    Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
466                })?)
467                .map_err(|e| {
468                    Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
469                })?,
470            )
471        } else {
472            None
473        };
474
475        let bucket_weights_path = index_path.join("bucket_weights.npy");
476        let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
477            Some(
478                Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
479                    Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
480                })?)
481                .map_err(|e| {
482                    Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
483                })?,
484            )
485        } else {
486            None
487        };
488
489        // Read nbits from metadata
490        let metadata_path = index_path.join("metadata.json");
491        let metadata: serde_json::Value = serde_json::from_reader(
492            File::open(&metadata_path)
493                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
494        )
495        .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
496
497        let nbits = metadata["nbits"]
498            .as_u64()
499            .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
500            as usize;
501
502        Self::new(
503            nbits,
504            centroids,
505            avg_residual,
506            bucket_cutoffs,
507            bucket_weights,
508        )
509    }
510
511    /// Load codec from index directory with memory-mapped centroids.
512    ///
513    /// This is similar to `load_from_dir` but uses memory-mapped I/O for the
514    /// centroids file, reducing RAM usage. The other small tensors (bucket weights,
515    /// etc.) are still loaded into memory as they are negligible in size.
516    ///
517    /// Use this when loading for `MmapIndex` to minimize memory footprint.
518    pub fn load_mmap_from_dir(index_path: &std::path::Path) -> Result<Self> {
519        use ndarray_npy::ReadNpyExt;
520        use std::fs::File;
521
522        // Memory-map centroids instead of loading into RAM
523        let centroids_path = index_path.join("centroids.npy");
524        let mmap_centroids = crate::mmap::MmapNpyArray2F32::from_npy_file(&centroids_path)?;
525
526        // Load small tensors into memory (negligible size)
527        let avg_residual_path = index_path.join("avg_residual.npy");
528        let avg_residual: Array1<f32> =
529            Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
530                Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
531            })?)
532            .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
533
534        let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
535        let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
536            Some(
537                Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
538                    Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
539                })?)
540                .map_err(|e| {
541                    Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
542                })?,
543            )
544        } else {
545            None
546        };
547
548        let bucket_weights_path = index_path.join("bucket_weights.npy");
549        let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
550            Some(
551                Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
552                    Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
553                })?)
554                .map_err(|e| {
555                    Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
556                })?,
557            )
558        } else {
559            None
560        };
561
562        // Read nbits from metadata
563        let metadata_path = index_path.join("metadata.json");
564        let metadata: serde_json::Value = serde_json::from_reader(
565            File::open(&metadata_path)
566                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
567        )
568        .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
569
570        let nbits = metadata["nbits"]
571            .as_u64()
572            .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
573            as usize;
574
575        Self::new_with_store(
576            nbits,
577            CentroidStore::Mmap(mmap_centroids),
578            avg_residual,
579            bucket_cutoffs,
580            bucket_weights,
581        )
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_codec_creation() {
591        let centroids =
592            Array2::from_shape_vec((4, 8), (0..32).map(|x| x as f32).collect()).unwrap();
593        let avg_residual = Array1::zeros(8);
594        let bucket_cutoffs = Some(Array1::from_vec(vec![-0.5, 0.0, 0.5]));
595        let bucket_weights = Some(Array1::from_vec(vec![-0.75, -0.25, 0.25, 0.75]));
596
597        let codec = ResidualCodec::new(2, centroids, avg_residual, bucket_cutoffs, bucket_weights);
598        assert!(codec.is_ok());
599
600        let codec = codec.unwrap();
601        assert_eq!(codec.nbits, 2);
602        assert_eq!(codec.embedding_dim(), 8);
603        assert_eq!(codec.num_centroids(), 4);
604    }
605
606    #[test]
607    fn test_compress_into_codes() {
608        let centroids = Array2::from_shape_vec(
609            (3, 4),
610            vec![
611                1.0, 0.0, 0.0, 0.0, // centroid 0
612                0.0, 1.0, 0.0, 0.0, // centroid 1
613                0.0, 0.0, 1.0, 0.0, // centroid 2
614            ],
615        )
616        .unwrap();
617
618        let avg_residual = Array1::zeros(4);
619        let codec = ResidualCodec::new(2, centroids, avg_residual, None, None).unwrap();
620
621        let embeddings = Array2::from_shape_vec(
622            (2, 4),
623            vec![
624                0.9, 0.1, 0.0, 0.0, // should match centroid 0
625                0.0, 0.0, 0.95, 0.05, // should match centroid 2
626            ],
627        )
628        .unwrap();
629
630        let codes = codec.compress_into_codes(&embeddings);
631        assert_eq!(codes[0], 0);
632        assert_eq!(codes[1], 2);
633    }
634
635    #[test]
636    fn test_quantize_decompress_roundtrip_4bit() {
637        // Test round-trip with 4-bit quantization
638        let dim = 8;
639        let centroids = Array2::zeros((4, dim));
640        let avg_residual = Array1::zeros(dim);
641
642        // Create bucket cutoffs and weights for 16 buckets
643        // Cutoffs at quantiles 1/16, 2/16, ..., 15/16
644        let bucket_cutoffs: Vec<f32> = (1..16).map(|i| (i as f32 / 16.0 - 0.5) * 2.0).collect();
645        // Weights at quantile midpoints
646        let bucket_weights: Vec<f32> = (0..16)
647            .map(|i| ((i as f32 + 0.5) / 16.0 - 0.5) * 2.0)
648            .collect();
649
650        let codec = ResidualCodec::new(
651            4,
652            centroids,
653            avg_residual,
654            Some(Array1::from_vec(bucket_cutoffs)),
655            Some(Array1::from_vec(bucket_weights)),
656        )
657        .unwrap();
658
659        // Create test residuals that span different bucket ranges
660        let residuals = Array2::from_shape_vec(
661            (2, dim),
662            vec![
663                -0.9, -0.7, -0.5, -0.3, 0.0, 0.3, 0.5, 0.9, // various bucket values
664                -0.8, -0.4, 0.0, 0.4, 0.8, -0.6, 0.2, 0.6,
665            ],
666        )
667        .unwrap();
668
669        // Quantize
670        let packed = codec.quantize_residuals(&residuals).unwrap();
671        assert_eq!(packed.ncols(), dim * 4 / 8); // 4 bytes per row for dim=8, nbits=4
672
673        // Create a temporary centroid assignment (all zeros)
674        let codes = Array1::from_vec(vec![0, 0]);
675
676        // Decompress and verify the reconstruction is reasonable
677        let decompressed = codec.decompress(&packed, &codes.view()).unwrap();
678
679        // The decompressed values should be close to the quantized bucket weights
680        // (plus centroid, which is zero here)
681        for i in 0..residuals.nrows() {
682            for j in 0..residuals.ncols() {
683                let orig = residuals[[i, j]];
684                let recon = decompressed[[i, j]];
685                // After normalization, values should be in similar direction
686                // The reconstruction won't be exact due to quantization, but
687                // the sign should generally match for non-zero values
688                if orig.abs() > 0.2 {
689                    assert!(
690                        (orig > 0.0) == (recon > 0.0) || recon.abs() < 0.1,
691                        "Sign mismatch at [{}, {}]: orig={}, recon={}",
692                        i,
693                        j,
694                        orig,
695                        recon
696                    );
697                }
698            }
699        }
700    }
701}