Skip to main content

next_plaid/
codec.rs

1//! Residual codec for quantization and decompression
2
3use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
4
5use crate::error::{Error, Result};
6
7/// Default maximum memory (bytes) to allocate for nearest centroid computation in
8/// `compress_into_codes`. This limits the size of the `[batch_size, num_centroids]`
9/// scores matrix. Keeping this lower reduces page-fault and zero-fill overhead
10/// from giant temporary score buffers.
11const DEFAULT_MAX_NEAREST_CENTROID_MEMORY: usize = 1024 * 1024 * 1024; // 1GB
12
13fn max_nearest_centroid_memory() -> usize {
14    std::env::var("NEXT_PLAID_MAX_NEAREST_CENTROID_MEMORY_MB")
15        .ok()
16        .and_then(|v| v.parse::<usize>().ok())
17        .filter(|&mb| mb > 0)
18        .map(|mb| mb.saturating_mul(1024 * 1024))
19        .unwrap_or(DEFAULT_MAX_NEAREST_CENTROID_MEMORY)
20}
21
22/// Storage backend for centroids, supporting both owned arrays and memory-mapped files.
23///
24/// This enum enables `ResidualCodec` to work with centroids stored either:
25/// - In memory as an owned `Array2<f32>` (default, for `Index` and `LoadedIndex`)
26/// - Memory-mapped from disk (for `MmapIndex`, reducing RAM usage)
27pub enum CentroidStore {
28    /// Centroids stored as an owned ndarray (loaded into RAM)
29    Owned(Array2<f32>),
30    /// Centroids stored as a memory-mapped NPY file (OS-managed paging)
31    Mmap(crate::mmap::MmapNpyArray2F32),
32}
33
34impl CentroidStore {
35    /// Get a view of the centroids as ArrayView2.
36    ///
37    /// This is zero-copy for both owned and mmap variants.
38    pub fn view(&self) -> ArrayView2<'_, f32> {
39        match self {
40            CentroidStore::Owned(arr) => arr.view(),
41            CentroidStore::Mmap(mmap) => mmap.view(),
42        }
43    }
44
45    /// Get the number of centroids (rows).
46    pub fn nrows(&self) -> usize {
47        match self {
48            CentroidStore::Owned(arr) => arr.nrows(),
49            CentroidStore::Mmap(mmap) => mmap.nrows(),
50        }
51    }
52
53    /// Get the embedding dimension (columns).
54    pub fn ncols(&self) -> usize {
55        match self {
56            CentroidStore::Owned(arr) => arr.ncols(),
57            CentroidStore::Mmap(mmap) => mmap.ncols(),
58        }
59    }
60
61    /// Get a view of a single centroid row.
62    pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
63        match self {
64            CentroidStore::Owned(arr) => arr.row(idx),
65            CentroidStore::Mmap(mmap) => mmap.row(idx),
66        }
67    }
68
69    /// Get a view of rows [start..end] as ArrayView2.
70    ///
71    /// This is zero-copy for both owned and mmap variants.
72    pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
73        match self {
74            CentroidStore::Owned(arr) => arr.slice(s![start..end, ..]),
75            CentroidStore::Mmap(mmap) => mmap.slice_rows(start, end),
76        }
77    }
78}
79
80impl Clone for CentroidStore {
81    fn clone(&self) -> Self {
82        match self {
83            // For owned, just clone the array
84            CentroidStore::Owned(arr) => CentroidStore::Owned(arr.clone()),
85            // For mmap, we need to convert to owned since Mmap isn't Clone
86            CentroidStore::Mmap(mmap) => CentroidStore::Owned(mmap.to_owned()),
87        }
88    }
89}
90
91/// A codec that manages quantization parameters and lookup tables for the index.
92///
93/// This struct contains all tensors required to compress embeddings during indexing
94/// and decompress vectors during search. It uses pre-computed lookup tables to
95/// accelerate bit unpacking operations.
96#[derive(Clone)]
97pub struct ResidualCodec {
98    /// Number of bits used to represent each residual bucket (e.g., 2 or 4)
99    pub nbits: usize,
100    /// Coarse centroids (codebook) of shape `[num_centroids, dim]`.
101    /// Can be either owned (in-memory) or memory-mapped for reduced RAM usage.
102    pub centroids: CentroidStore,
103    /// Average residual vector, used to reduce reconstruction error
104    pub avg_residual: Array1<f32>,
105    /// Boundaries defining which bucket a residual value falls into
106    pub bucket_cutoffs: Option<Array1<f32>>,
107    /// Values (weights) corresponding to each quantization bucket
108    pub bucket_weights: Option<Array1<f32>>,
109    /// Lookup table (256 entries) for byte-to-bits unpacking
110    pub byte_reversed_bits_map: Vec<u8>,
111    /// Maps byte values directly to bucket indices for fast decompression
112    pub bucket_weight_indices_lookup: Option<Array2<usize>>,
113}
114
115impl ResidualCodec {
116    /// Creates a new ResidualCodec and pre-computes lookup tables.
117    ///
118    /// # Arguments
119    ///
120    /// * `nbits` - Number of bits per code (e.g., 2 bits = 4 buckets)
121    /// * `centroids` - Coarse centroids of shape `[num_centroids, dim]`
122    /// * `avg_residual` - Global average residual
123    /// * `bucket_cutoffs` - Quantization boundaries (optional, for indexing)
124    /// * `bucket_weights` - Reconstruction values (optional, for search)
125    pub fn new(
126        nbits: usize,
127        centroids: Array2<f32>,
128        avg_residual: Array1<f32>,
129        bucket_cutoffs: Option<Array1<f32>>,
130        bucket_weights: Option<Array1<f32>>,
131    ) -> Result<Self> {
132        Self::new_with_store(
133            nbits,
134            CentroidStore::Owned(centroids),
135            avg_residual,
136            bucket_cutoffs,
137            bucket_weights,
138        )
139    }
140
141    /// Creates a new ResidualCodec with a specified centroid storage backend.
142    ///
143    /// This is the internal constructor that supports both owned and mmap centroids.
144    pub fn new_with_store(
145        nbits: usize,
146        centroids: CentroidStore,
147        avg_residual: Array1<f32>,
148        bucket_cutoffs: Option<Array1<f32>>,
149        bucket_weights: Option<Array1<f32>>,
150    ) -> Result<Self> {
151        if nbits == 0 || 8 % nbits != 0 {
152            return Err(Error::Codec(format!(
153                "nbits must be a divisor of 8, got {}",
154                nbits
155            )));
156        }
157
158        // Build bit reversal map for unpacking
159        let nbits_mask = (1u32 << nbits) - 1;
160        let mut byte_reversed_bits_map = vec![0u8; 256];
161
162        for (i, byte_slot) in byte_reversed_bits_map.iter_mut().enumerate() {
163            let val = i as u32;
164            let mut out = 0u32;
165            let mut pos = 8i32;
166
167            while pos >= nbits as i32 {
168                let segment = (val >> (pos as u32 - nbits as u32)) & nbits_mask;
169
170                let mut rev_segment = 0u32;
171                for k in 0..nbits {
172                    if (segment & (1 << k)) != 0 {
173                        rev_segment |= 1 << (nbits - 1 - k);
174                    }
175                }
176
177                out |= rev_segment;
178
179                if pos > nbits as i32 {
180                    out <<= nbits;
181                }
182
183                pos -= nbits as i32;
184            }
185            *byte_slot = out as u8;
186        }
187
188        // Build lookup table for bucket weight indices
189        let keys_per_byte = 8 / nbits;
190        let bucket_weight_indices_lookup = if bucket_weights.is_some() {
191            let mask = (1usize << nbits) - 1;
192            let mut table = Array2::<usize>::zeros((256, keys_per_byte));
193
194            for byte_val in 0..256usize {
195                for k in (0..keys_per_byte).rev() {
196                    let shift = k * nbits;
197                    let index = (byte_val >> shift) & mask;
198                    table[[byte_val, keys_per_byte - 1 - k]] = index;
199                }
200            }
201            Some(table)
202        } else {
203            None
204        };
205
206        Ok(Self {
207            nbits,
208            centroids,
209            avg_residual,
210            bucket_cutoffs,
211            bucket_weights,
212            byte_reversed_bits_map,
213            bucket_weight_indices_lookup,
214        })
215    }
216
217    /// Returns the embedding dimension
218    pub fn embedding_dim(&self) -> usize {
219        self.centroids.ncols()
220    }
221
222    /// Returns the number of centroids
223    pub fn num_centroids(&self) -> usize {
224        self.centroids.nrows()
225    }
226
227    /// Returns a view of the centroids.
228    ///
229    /// This is zero-copy for both owned and mmap centroids.
230    pub fn centroids_view(&self) -> ArrayView2<'_, f32> {
231        self.centroids.view()
232    }
233
234    /// Compress embeddings into centroid codes using nearest neighbor search.
235    ///
236    /// Uses batch matrix multiplication for efficiency:
237    /// `scores = embeddings @ centroids.T  -> [N, K]`
238    /// `codes = argmax(scores, axis=1)     -> [N]`
239    ///
240    /// When the `cuda` feature is enabled and a GPU is available, this function
241    /// automatically uses CUDA acceleration. No code changes required.
242    ///
243    /// # Arguments
244    ///
245    /// * `embeddings` - Embeddings of shape `[N, dim]`
246    ///
247    /// # Returns
248    ///
249    /// Centroid indices of shape `[N]`
250    pub fn compress_into_codes(&self, embeddings: &Array2<f32>) -> Array1<usize> {
251        // Try CUDA acceleration if available
252        #[cfg(feature = "cuda")]
253        {
254            let force_gpu = crate::is_force_gpu();
255            if let Some(ctx) = crate::cuda::get_global_context() {
256                let centroids = self.centroids_view();
257                match crate::cuda::compress_into_codes_cuda_batched(
258                    &ctx,
259                    &embeddings.view(),
260                    &centroids,
261                    None,
262                ) {
263                    Ok(codes) => return codes,
264                    Err(e) => {
265                        if force_gpu {
266                            panic!(
267                                "FORCE_GPU is set but CUDA compress_into_codes failed: {}",
268                                e
269                            );
270                        }
271                        eprintln!(
272                            "[next-plaid] CUDA compression error: {}. Falling back to CPU.",
273                            e
274                        );
275                    }
276                }
277            } else if force_gpu {
278                panic!("FORCE_GPU is set but CUDA context is unavailable");
279            }
280        }
281
282        self.compress_into_codes_cpu(embeddings)
283    }
284
285    /// CPU implementation of compress_into_codes.
286    /// This is useful when you want to explicitly avoid CUDA overhead for small batches.
287    pub fn compress_into_codes_cpu(&self, embeddings: &Array2<f32>) -> Array1<usize> {
288        use rayon::prelude::*;
289
290        let n = embeddings.nrows();
291        if n == 0 {
292            return Array1::zeros(0);
293        }
294
295        // Get centroids view once (zero-copy for both owned and mmap)
296        let centroids = self.centroids_view();
297        let num_centroids = centroids.nrows();
298
299        // Dynamic batch size to stay within memory budget.
300        // The scores matrix has shape [batch_size, num_centroids] with f32 elements.
301        // With 2.5M centroids and 4GB budget: batch_size = 4GB / (2.5M * 4) = 400
302        let max_batch_by_memory =
303            max_nearest_centroid_memory() / (num_centroids * std::mem::size_of::<f32>());
304        let batch_size = max_batch_by_memory.clamp(1, 1024);
305        let batch_ranges: Vec<(usize, usize)> = (0..n)
306            .step_by(batch_size)
307            .map(|start| (start, (start + batch_size).min(n)))
308            .collect();
309
310        let chunked_codes: Vec<Vec<usize>> = batch_ranges
311            .into_par_iter()
312            .map(|(start, end)| {
313                let batch = embeddings.slice(ndarray::s![start..end, ..]);
314
315                // Batch matrix multiplication: [batch, dim] @ [dim, K] -> [batch, K]
316                let scores = batch.dot(&centroids.t());
317
318                // Keep the per-row scan local to avoid nested parallelism.
319                scores
320                    .axis_iter(Axis(0))
321                    .map(|row| {
322                        row.iter()
323                            .enumerate()
324                            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
325                            .map(|(idx, _)| idx)
326                            .unwrap_or(0)
327                    })
328                    .collect()
329            })
330            .collect();
331
332        Array1::from_vec(chunked_codes.into_iter().flatten().collect())
333    }
334
335    /// Quantize residuals into packed bytes.
336    ///
337    /// Uses vectorized bucket search and parallel processing for efficiency.
338    ///
339    /// # Arguments
340    ///
341    /// * `residuals` - Residual vectors of shape `[N, dim]`
342    ///
343    /// # Returns
344    ///
345    /// Packed residuals of shape `[N, dim * nbits / 8]` as bytes
346    pub fn quantize_residuals(&self, residuals: &Array2<f32>) -> Result<Array2<u8>> {
347        use rayon::prelude::*;
348
349        let cutoffs = self
350            .bucket_cutoffs
351            .as_ref()
352            .ok_or_else(|| Error::Codec("bucket_cutoffs required for quantization".into()))?;
353
354        let n = residuals.nrows();
355        let dim = residuals.ncols();
356        let packed_dim = dim * self.nbits / 8;
357        let nbits = self.nbits;
358
359        if n == 0 {
360            return Ok(Array2::zeros((0, packed_dim)));
361        }
362
363        // Convert cutoffs to a slice for faster access
364        let cutoffs_slice = cutoffs.as_slice().unwrap();
365
366        // Process rows in parallel
367        let packed_rows: Vec<Vec<u8>> = residuals
368            .axis_iter(Axis(0))
369            .into_par_iter()
370            .map(|row| {
371                let mut packed_row = vec![0u8; packed_dim];
372                let mut bit_idx = 0;
373
374                for &val in row.iter() {
375                    // Binary search for bucket (searchsorted equivalent)
376                    let bucket = cutoffs_slice.iter().filter(|&&c| val > c).count();
377
378                    // Pack bits directly into bytes
379                    for b in 0..nbits {
380                        let bit = ((bucket >> b) & 1) as u8;
381                        let byte_idx = bit_idx / 8;
382                        let bit_pos = 7 - (bit_idx % 8);
383                        packed_row[byte_idx] |= bit << bit_pos;
384                        bit_idx += 1;
385                    }
386                }
387
388                packed_row
389            })
390            .collect();
391
392        // Assemble into array
393        let mut packed = Array2::<u8>::zeros((n, packed_dim));
394        for (i, row) in packed_rows.into_iter().enumerate() {
395            for (j, val) in row.into_iter().enumerate() {
396                packed[[i, j]] = val;
397            }
398        }
399
400        Ok(packed)
401    }
402
403    /// Decompress residuals from packed bytes using lookup tables.
404    ///
405    /// # Arguments
406    ///
407    /// * `packed_residuals` - Packed residuals of shape `[N, packed_dim]`
408    /// * `codes` - Centroid codes of shape `[N]`
409    ///
410    /// # Returns
411    ///
412    /// Reconstructed embeddings of shape `[N, dim]`
413    pub fn decompress(
414        &self,
415        packed_residuals: &Array2<u8>,
416        codes: &ArrayView1<usize>,
417    ) -> Result<Array2<f32>> {
418        let bucket_weights = self
419            .bucket_weights
420            .as_ref()
421            .ok_or_else(|| Error::Codec("bucket_weights required for decompression".into()))?;
422
423        let lookup = self
424            .bucket_weight_indices_lookup
425            .as_ref()
426            .ok_or_else(|| Error::Codec("bucket_weight_indices_lookup required".into()))?;
427
428        let n = packed_residuals.nrows();
429        let dim = self.embedding_dim();
430
431        let mut output = Array2::<f32>::zeros((n, dim));
432
433        for i in 0..n {
434            // Get centroid for this embedding (zero-copy via CentroidStore)
435            let centroid = self.centroids.row(codes[i]);
436
437            // Unpack residuals
438            let mut residual_idx = 0;
439            for &byte_val in packed_residuals.row(i).iter() {
440                let reversed = self.byte_reversed_bits_map[byte_val as usize];
441                let indices = lookup.row(reversed as usize);
442
443                for &bucket_idx in indices.iter() {
444                    if residual_idx < dim {
445                        output[[i, residual_idx]] =
446                            centroid[residual_idx] + bucket_weights[bucket_idx];
447                        residual_idx += 1;
448                    }
449                }
450            }
451        }
452
453        // Normalize
454        for mut row in output.axis_iter_mut(Axis(0)) {
455            let norm = row.dot(&row).sqrt().max(1e-12);
456            row /= norm;
457        }
458
459        Ok(output)
460    }
461
462    /// Load codec from index directory
463    pub fn load_from_dir(index_path: &std::path::Path) -> Result<Self> {
464        use ndarray_npy::ReadNpyExt;
465        use std::fs::File;
466
467        let centroids_path = index_path.join("centroids.npy");
468        let centroids: Array2<f32> = Array2::read_npy(
469            File::open(&centroids_path)
470                .map_err(|e| Error::IndexLoad(format!("Failed to open centroids.npy: {}", e)))?,
471        )
472        .map_err(|e| Error::IndexLoad(format!("Failed to read centroids.npy: {}", e)))?;
473
474        let avg_residual_path = index_path.join("avg_residual.npy");
475        let avg_residual: Array1<f32> =
476            Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
477                Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
478            })?)
479            .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
480
481        let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
482        let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
483            Some(
484                Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
485                    Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
486                })?)
487                .map_err(|e| {
488                    Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
489                })?,
490            )
491        } else {
492            None
493        };
494
495        let bucket_weights_path = index_path.join("bucket_weights.npy");
496        let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
497            Some(
498                Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
499                    Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
500                })?)
501                .map_err(|e| {
502                    Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
503                })?,
504            )
505        } else {
506            None
507        };
508
509        // Read nbits from metadata
510        let metadata_path = index_path.join("metadata.json");
511        let metadata: serde_json::Value = serde_json::from_reader(
512            File::open(&metadata_path)
513                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
514        )
515        .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
516
517        let nbits = metadata["nbits"]
518            .as_u64()
519            .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
520            as usize;
521
522        Self::new(
523            nbits,
524            centroids,
525            avg_residual,
526            bucket_cutoffs,
527            bucket_weights,
528        )
529    }
530
531    /// Load codec from index directory with memory-mapped centroids.
532    ///
533    /// This is similar to `load_from_dir` but uses memory-mapped I/O for the
534    /// centroids file, reducing RAM usage. The other small tensors (bucket weights,
535    /// etc.) are still loaded into memory as they are negligible in size.
536    ///
537    /// Use this when loading for `MmapIndex` to minimize memory footprint.
538    pub fn load_mmap_from_dir(index_path: &std::path::Path) -> Result<Self> {
539        use ndarray_npy::ReadNpyExt;
540        use std::fs::File;
541
542        // Memory-map centroids instead of loading into RAM
543        let centroids_path = index_path.join("centroids.npy");
544        let mmap_centroids = crate::mmap::MmapNpyArray2F32::from_npy_file(&centroids_path)?;
545
546        // Load small tensors into memory (negligible size)
547        let avg_residual_path = index_path.join("avg_residual.npy");
548        let avg_residual: Array1<f32> =
549            Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
550                Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
551            })?)
552            .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
553
554        let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
555        let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
556            Some(
557                Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
558                    Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
559                })?)
560                .map_err(|e| {
561                    Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
562                })?,
563            )
564        } else {
565            None
566        };
567
568        let bucket_weights_path = index_path.join("bucket_weights.npy");
569        let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
570            Some(
571                Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
572                    Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
573                })?)
574                .map_err(|e| {
575                    Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
576                })?,
577            )
578        } else {
579            None
580        };
581
582        // Read nbits from metadata
583        let metadata_path = index_path.join("metadata.json");
584        let metadata: serde_json::Value = serde_json::from_reader(
585            File::open(&metadata_path)
586                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
587        )
588        .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
589
590        let nbits = metadata["nbits"]
591            .as_u64()
592            .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
593            as usize;
594
595        Self::new_with_store(
596            nbits,
597            CentroidStore::Mmap(mmap_centroids),
598            avg_residual,
599            bucket_cutoffs,
600            bucket_weights,
601        )
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    #[test]
610    fn test_codec_creation() {
611        let centroids =
612            Array2::from_shape_vec((4, 8), (0..32).map(|x| x as f32).collect()).unwrap();
613        let avg_residual = Array1::zeros(8);
614        let bucket_cutoffs = Some(Array1::from_vec(vec![-0.5, 0.0, 0.5]));
615        let bucket_weights = Some(Array1::from_vec(vec![-0.75, -0.25, 0.25, 0.75]));
616
617        let codec = ResidualCodec::new(2, centroids, avg_residual, bucket_cutoffs, bucket_weights);
618        assert!(codec.is_ok());
619
620        let codec = codec.unwrap();
621        assert_eq!(codec.nbits, 2);
622        assert_eq!(codec.embedding_dim(), 8);
623        assert_eq!(codec.num_centroids(), 4);
624    }
625
626    #[test]
627    fn test_compress_into_codes() {
628        let centroids = Array2::from_shape_vec(
629            (3, 4),
630            vec![
631                1.0, 0.0, 0.0, 0.0, // centroid 0
632                0.0, 1.0, 0.0, 0.0, // centroid 1
633                0.0, 0.0, 1.0, 0.0, // centroid 2
634            ],
635        )
636        .unwrap();
637
638        let avg_residual = Array1::zeros(4);
639        let codec = ResidualCodec::new(2, centroids, avg_residual, None, None).unwrap();
640
641        let embeddings = Array2::from_shape_vec(
642            (2, 4),
643            vec![
644                0.9, 0.1, 0.0, 0.0, // should match centroid 0
645                0.0, 0.0, 0.95, 0.05, // should match centroid 2
646            ],
647        )
648        .unwrap();
649
650        let codes = codec.compress_into_codes(&embeddings);
651        assert_eq!(codes[0], 0);
652        assert_eq!(codes[1], 2);
653    }
654
655    #[test]
656    fn test_quantize_decompress_roundtrip_4bit() {
657        // Test round-trip with 4-bit quantization
658        let dim = 8;
659        let centroids = Array2::zeros((4, dim));
660        let avg_residual = Array1::zeros(dim);
661
662        // Create bucket cutoffs and weights for 16 buckets
663        // Cutoffs at quantiles 1/16, 2/16, ..., 15/16
664        let bucket_cutoffs: Vec<f32> = (1..16).map(|i| (i as f32 / 16.0 - 0.5) * 2.0).collect();
665        // Weights at quantile midpoints
666        let bucket_weights: Vec<f32> = (0..16)
667            .map(|i| ((i as f32 + 0.5) / 16.0 - 0.5) * 2.0)
668            .collect();
669
670        let codec = ResidualCodec::new(
671            4,
672            centroids,
673            avg_residual,
674            Some(Array1::from_vec(bucket_cutoffs)),
675            Some(Array1::from_vec(bucket_weights)),
676        )
677        .unwrap();
678
679        // Create test residuals that span different bucket ranges
680        let residuals = Array2::from_shape_vec(
681            (2, dim),
682            vec![
683                -0.9, -0.7, -0.5, -0.3, 0.0, 0.3, 0.5, 0.9, // various bucket values
684                -0.8, -0.4, 0.0, 0.4, 0.8, -0.6, 0.2, 0.6,
685            ],
686        )
687        .unwrap();
688
689        // Quantize
690        let packed = codec.quantize_residuals(&residuals).unwrap();
691        assert_eq!(packed.ncols(), dim * 4 / 8); // 4 bytes per row for dim=8, nbits=4
692
693        // Create a temporary centroid assignment (all zeros)
694        let codes = Array1::from_vec(vec![0, 0]);
695
696        // Decompress and verify the reconstruction is reasonable
697        let decompressed = codec.decompress(&packed, &codes.view()).unwrap();
698
699        // The decompressed values should be close to the quantized bucket weights
700        // (plus centroid, which is zero here)
701        for i in 0..residuals.nrows() {
702            for j in 0..residuals.ncols() {
703                let orig = residuals[[i, j]];
704                let recon = decompressed[[i, j]];
705                // After normalization, values should be in similar direction
706                // The reconstruction won't be exact due to quantization, but
707                // the sign should generally match for non-zero values
708                if orig.abs() > 0.2 {
709                    assert!(
710                        (orig > 0.0) == (recon > 0.0) || recon.abs() < 0.1,
711                        "Sign mismatch at [{}, {}]: orig={}, recon={}",
712                        i,
713                        j,
714                        orig,
715                        recon
716                    );
717                }
718            }
719        }
720    }
721}