next_plaid/
codec.rs

1//! Residual codec for quantization and decompression
2
3use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
4
5use crate::error::{Error, Result};
6
7/// Maximum memory (bytes) to allocate for nearest centroid computation in `compress_into_codes`.
8/// This limits the size of the [batch_size, num_centroids] scores matrix to prevent OOM errors
9/// with large centroid counts (e.g., 2.5M centroids).
10const MAX_NEAREST_CENTROID_MEMORY: usize = 4 * 1024 * 1024 * 1024; // 4GB
11
12/// Storage backend for centroids, supporting both owned arrays and memory-mapped files.
13///
14/// This enum enables `ResidualCodec` to work with centroids stored either:
15/// - In memory as an owned `Array2<f32>` (default, for `Index` and `LoadedIndex`)
16/// - Memory-mapped from disk (for `MmapIndex`, reducing RAM usage)
17pub enum CentroidStore {
18    /// Centroids stored as an owned ndarray (loaded into RAM)
19    Owned(Array2<f32>),
20    /// Centroids stored as a memory-mapped NPY file (OS-managed paging)
21    Mmap(crate::mmap::MmapNpyArray2F32),
22}
23
24impl CentroidStore {
25    /// Get a view of the centroids as ArrayView2.
26    ///
27    /// This is zero-copy for both owned and mmap variants.
28    pub fn view(&self) -> ArrayView2<'_, f32> {
29        match self {
30            CentroidStore::Owned(arr) => arr.view(),
31            CentroidStore::Mmap(mmap) => mmap.view(),
32        }
33    }
34
35    /// Get the number of centroids (rows).
36    pub fn nrows(&self) -> usize {
37        match self {
38            CentroidStore::Owned(arr) => arr.nrows(),
39            CentroidStore::Mmap(mmap) => mmap.nrows(),
40        }
41    }
42
43    /// Get the embedding dimension (columns).
44    pub fn ncols(&self) -> usize {
45        match self {
46            CentroidStore::Owned(arr) => arr.ncols(),
47            CentroidStore::Mmap(mmap) => mmap.ncols(),
48        }
49    }
50
51    /// Get a view of a single centroid row.
52    pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
53        match self {
54            CentroidStore::Owned(arr) => arr.row(idx),
55            CentroidStore::Mmap(mmap) => mmap.row(idx),
56        }
57    }
58
59    /// Get a view of rows [start..end] as ArrayView2.
60    ///
61    /// This is zero-copy for both owned and mmap variants.
62    pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
63        match self {
64            CentroidStore::Owned(arr) => arr.slice(s![start..end, ..]),
65            CentroidStore::Mmap(mmap) => mmap.slice_rows(start, end),
66        }
67    }
68}
69
70impl Clone for CentroidStore {
71    fn clone(&self) -> Self {
72        match self {
73            // For owned, just clone the array
74            CentroidStore::Owned(arr) => CentroidStore::Owned(arr.clone()),
75            // For mmap, we need to convert to owned since Mmap isn't Clone
76            CentroidStore::Mmap(mmap) => CentroidStore::Owned(mmap.to_owned()),
77        }
78    }
79}
80
81/// A codec that manages quantization parameters and lookup tables for the index.
82///
83/// This struct contains all tensors required to compress embeddings during indexing
84/// and decompress vectors during search. It uses pre-computed lookup tables to
85/// accelerate bit unpacking operations.
86#[derive(Clone)]
87pub struct ResidualCodec {
88    /// Number of bits used to represent each residual bucket (e.g., 2 or 4)
89    pub nbits: usize,
90    /// Coarse centroids (codebook) of shape `[num_centroids, dim]`.
91    /// Can be either owned (in-memory) or memory-mapped for reduced RAM usage.
92    pub centroids: CentroidStore,
93    /// Average residual vector, used to reduce reconstruction error
94    pub avg_residual: Array1<f32>,
95    /// Boundaries defining which bucket a residual value falls into
96    pub bucket_cutoffs: Option<Array1<f32>>,
97    /// Values (weights) corresponding to each quantization bucket
98    pub bucket_weights: Option<Array1<f32>>,
99    /// Lookup table (256 entries) for byte-to-bits unpacking
100    pub byte_reversed_bits_map: Vec<u8>,
101    /// Maps byte values directly to bucket indices for fast decompression
102    pub bucket_weight_indices_lookup: Option<Array2<usize>>,
103}
104
105impl ResidualCodec {
106    /// Creates a new ResidualCodec and pre-computes lookup tables.
107    ///
108    /// # Arguments
109    ///
110    /// * `nbits` - Number of bits per code (e.g., 2 bits = 4 buckets)
111    /// * `centroids` - Coarse centroids of shape `[num_centroids, dim]`
112    /// * `avg_residual` - Global average residual
113    /// * `bucket_cutoffs` - Quantization boundaries (optional, for indexing)
114    /// * `bucket_weights` - Reconstruction values (optional, for search)
115    pub fn new(
116        nbits: usize,
117        centroids: Array2<f32>,
118        avg_residual: Array1<f32>,
119        bucket_cutoffs: Option<Array1<f32>>,
120        bucket_weights: Option<Array1<f32>>,
121    ) -> Result<Self> {
122        Self::new_with_store(
123            nbits,
124            CentroidStore::Owned(centroids),
125            avg_residual,
126            bucket_cutoffs,
127            bucket_weights,
128        )
129    }
130
131    /// Creates a new ResidualCodec with a specified centroid storage backend.
132    ///
133    /// This is the internal constructor that supports both owned and mmap centroids.
134    pub fn new_with_store(
135        nbits: usize,
136        centroids: CentroidStore,
137        avg_residual: Array1<f32>,
138        bucket_cutoffs: Option<Array1<f32>>,
139        bucket_weights: Option<Array1<f32>>,
140    ) -> Result<Self> {
141        if nbits == 0 || 8 % nbits != 0 {
142            return Err(Error::Codec(format!(
143                "nbits must be a divisor of 8, got {}",
144                nbits
145            )));
146        }
147
148        // Build bit reversal map for unpacking
149        let nbits_mask = (1u32 << nbits) - 1;
150        let mut byte_reversed_bits_map = vec![0u8; 256];
151
152        for (i, byte_slot) in byte_reversed_bits_map.iter_mut().enumerate() {
153            let val = i as u32;
154            let mut out = 0u32;
155            let mut pos = 8i32;
156
157            while pos >= nbits as i32 {
158                let segment = (val >> (pos as u32 - nbits as u32)) & nbits_mask;
159
160                let mut rev_segment = 0u32;
161                for k in 0..nbits {
162                    if (segment & (1 << k)) != 0 {
163                        rev_segment |= 1 << (nbits - 1 - k);
164                    }
165                }
166
167                out |= rev_segment;
168
169                if pos > nbits as i32 {
170                    out <<= nbits;
171                }
172
173                pos -= nbits as i32;
174            }
175            *byte_slot = out as u8;
176        }
177
178        // Build lookup table for bucket weight indices
179        let keys_per_byte = 8 / nbits;
180        let bucket_weight_indices_lookup = if bucket_weights.is_some() {
181            let mask = (1usize << nbits) - 1;
182            let mut table = Array2::<usize>::zeros((256, keys_per_byte));
183
184            for byte_val in 0..256usize {
185                for k in (0..keys_per_byte).rev() {
186                    let shift = k * nbits;
187                    let index = (byte_val >> shift) & mask;
188                    table[[byte_val, keys_per_byte - 1 - k]] = index;
189                }
190            }
191            Some(table)
192        } else {
193            None
194        };
195
196        Ok(Self {
197            nbits,
198            centroids,
199            avg_residual,
200            bucket_cutoffs,
201            bucket_weights,
202            byte_reversed_bits_map,
203            bucket_weight_indices_lookup,
204        })
205    }
206
207    /// Returns the embedding dimension
208    pub fn embedding_dim(&self) -> usize {
209        self.centroids.ncols()
210    }
211
212    /// Returns the number of centroids
213    pub fn num_centroids(&self) -> usize {
214        self.centroids.nrows()
215    }
216
217    /// Returns a view of the centroids.
218    ///
219    /// This is zero-copy for both owned and mmap centroids.
220    pub fn centroids_view(&self) -> ArrayView2<'_, f32> {
221        self.centroids.view()
222    }
223
224    /// Compress embeddings into centroid codes using nearest neighbor search.
225    ///
226    /// Uses batch matrix multiplication for efficiency:
227    /// `scores = embeddings @ centroids.T  -> [N, K]`
228    /// `codes = argmax(scores, axis=1)     -> [N]`
229    ///
230    /// # Arguments
231    ///
232    /// * `embeddings` - Embeddings of shape `[N, dim]`
233    ///
234    /// # Returns
235    ///
236    /// Centroid indices of shape `[N]`
237    pub fn compress_into_codes(&self, embeddings: &Array2<f32>) -> Array1<usize> {
238        use rayon::prelude::*;
239
240        let n = embeddings.nrows();
241        if n == 0 {
242            return Array1::zeros(0);
243        }
244
245        // Get centroids view once (zero-copy for both owned and mmap)
246        let centroids = self.centroids_view();
247        let num_centroids = centroids.nrows();
248
249        // Dynamic batch size to stay within memory budget.
250        // The scores matrix has shape [batch_size, num_centroids] with f32 elements.
251        // With 2.5M centroids and 4GB budget: batch_size = 4GB / (2.5M * 4) = 400
252        let max_batch_by_memory =
253            MAX_NEAREST_CENTROID_MEMORY / (num_centroids * std::mem::size_of::<f32>());
254        let batch_size = max_batch_by_memory.clamp(1, 2048);
255
256        let mut all_codes = Vec::with_capacity(n);
257
258        for start in (0..n).step_by(batch_size) {
259            let end = (start + batch_size).min(n);
260            let batch = embeddings.slice(ndarray::s![start..end, ..]);
261
262            // Batch matrix multiplication: [batch, dim] @ [dim, K] -> [batch, K]
263            let scores = batch.dot(&centroids.t());
264
265            // Parallel argmax over each row
266            let batch_codes: Vec<usize> = scores
267                .axis_iter(Axis(0))
268                .into_par_iter()
269                .map(|row| {
270                    row.iter()
271                        .enumerate()
272                        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
273                        .map(|(idx, _)| idx)
274                        .unwrap_or(0)
275                })
276                .collect();
277
278            all_codes.extend(batch_codes);
279        }
280
281        Array1::from_vec(all_codes)
282    }
283
284    /// Quantize residuals into packed bytes.
285    ///
286    /// Uses vectorized bucket search and parallel processing for efficiency.
287    ///
288    /// # Arguments
289    ///
290    /// * `residuals` - Residual vectors of shape `[N, dim]`
291    ///
292    /// # Returns
293    ///
294    /// Packed residuals of shape `[N, dim * nbits / 8]` as bytes
295    pub fn quantize_residuals(&self, residuals: &Array2<f32>) -> Result<Array2<u8>> {
296        use rayon::prelude::*;
297
298        let cutoffs = self
299            .bucket_cutoffs
300            .as_ref()
301            .ok_or_else(|| Error::Codec("bucket_cutoffs required for quantization".into()))?;
302
303        let n = residuals.nrows();
304        let dim = residuals.ncols();
305        let packed_dim = dim * self.nbits / 8;
306        let nbits = self.nbits;
307
308        if n == 0 {
309            return Ok(Array2::zeros((0, packed_dim)));
310        }
311
312        // Convert cutoffs to a slice for faster access
313        let cutoffs_slice = cutoffs.as_slice().unwrap();
314
315        // Process rows in parallel
316        let packed_rows: Vec<Vec<u8>> = residuals
317            .axis_iter(Axis(0))
318            .into_par_iter()
319            .map(|row| {
320                let mut packed_row = vec![0u8; packed_dim];
321                let mut bit_idx = 0;
322
323                for &val in row.iter() {
324                    // Binary search for bucket (searchsorted equivalent)
325                    let bucket = cutoffs_slice.iter().filter(|&&c| val > c).count();
326
327                    // Pack bits directly into bytes
328                    for b in 0..nbits {
329                        let bit = ((bucket >> b) & 1) as u8;
330                        let byte_idx = bit_idx / 8;
331                        let bit_pos = 7 - (bit_idx % 8);
332                        packed_row[byte_idx] |= bit << bit_pos;
333                        bit_idx += 1;
334                    }
335                }
336
337                packed_row
338            })
339            .collect();
340
341        // Assemble into array
342        let mut packed = Array2::<u8>::zeros((n, packed_dim));
343        for (i, row) in packed_rows.into_iter().enumerate() {
344            for (j, val) in row.into_iter().enumerate() {
345                packed[[i, j]] = val;
346            }
347        }
348
349        Ok(packed)
350    }
351
352    /// Decompress residuals from packed bytes using lookup tables.
353    ///
354    /// # Arguments
355    ///
356    /// * `packed_residuals` - Packed residuals of shape `[N, packed_dim]`
357    /// * `codes` - Centroid codes of shape `[N]`
358    ///
359    /// # Returns
360    ///
361    /// Reconstructed embeddings of shape `[N, dim]`
362    pub fn decompress(
363        &self,
364        packed_residuals: &Array2<u8>,
365        codes: &ArrayView1<usize>,
366    ) -> Result<Array2<f32>> {
367        let bucket_weights = self
368            .bucket_weights
369            .as_ref()
370            .ok_or_else(|| Error::Codec("bucket_weights required for decompression".into()))?;
371
372        let lookup = self
373            .bucket_weight_indices_lookup
374            .as_ref()
375            .ok_or_else(|| Error::Codec("bucket_weight_indices_lookup required".into()))?;
376
377        let n = packed_residuals.nrows();
378        let dim = self.embedding_dim();
379
380        let mut output = Array2::<f32>::zeros((n, dim));
381
382        for i in 0..n {
383            // Get centroid for this embedding (zero-copy via CentroidStore)
384            let centroid = self.centroids.row(codes[i]);
385
386            // Unpack residuals
387            let mut residual_idx = 0;
388            for &byte_val in packed_residuals.row(i).iter() {
389                let reversed = self.byte_reversed_bits_map[byte_val as usize];
390                let indices = lookup.row(reversed as usize);
391
392                for &bucket_idx in indices.iter() {
393                    if residual_idx < dim {
394                        output[[i, residual_idx]] =
395                            centroid[residual_idx] + bucket_weights[bucket_idx];
396                        residual_idx += 1;
397                    }
398                }
399            }
400        }
401
402        // Normalize
403        for mut row in output.axis_iter_mut(Axis(0)) {
404            let norm = row.dot(&row).sqrt().max(1e-12);
405            row /= norm;
406        }
407
408        Ok(output)
409    }
410
411    /// Load codec from index directory
412    pub fn load_from_dir(index_path: &std::path::Path) -> Result<Self> {
413        use ndarray_npy::ReadNpyExt;
414        use std::fs::File;
415
416        let centroids_path = index_path.join("centroids.npy");
417        let centroids: Array2<f32> = Array2::read_npy(
418            File::open(&centroids_path)
419                .map_err(|e| Error::IndexLoad(format!("Failed to open centroids.npy: {}", e)))?,
420        )
421        .map_err(|e| Error::IndexLoad(format!("Failed to read centroids.npy: {}", e)))?;
422
423        let avg_residual_path = index_path.join("avg_residual.npy");
424        let avg_residual: Array1<f32> =
425            Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
426                Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
427            })?)
428            .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
429
430        let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
431        let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
432            Some(
433                Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
434                    Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
435                })?)
436                .map_err(|e| {
437                    Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
438                })?,
439            )
440        } else {
441            None
442        };
443
444        let bucket_weights_path = index_path.join("bucket_weights.npy");
445        let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
446            Some(
447                Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
448                    Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
449                })?)
450                .map_err(|e| {
451                    Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
452                })?,
453            )
454        } else {
455            None
456        };
457
458        // Read nbits from metadata
459        let metadata_path = index_path.join("metadata.json");
460        let metadata: serde_json::Value = serde_json::from_reader(
461            File::open(&metadata_path)
462                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
463        )
464        .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
465
466        let nbits = metadata["nbits"]
467            .as_u64()
468            .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
469            as usize;
470
471        Self::new(
472            nbits,
473            centroids,
474            avg_residual,
475            bucket_cutoffs,
476            bucket_weights,
477        )
478    }
479
480    /// Load codec from index directory with memory-mapped centroids.
481    ///
482    /// This is similar to `load_from_dir` but uses memory-mapped I/O for the
483    /// centroids file, reducing RAM usage. The other small tensors (bucket weights,
484    /// etc.) are still loaded into memory as they are negligible in size.
485    ///
486    /// Use this when loading for `MmapIndex` to minimize memory footprint.
487    pub fn load_mmap_from_dir(index_path: &std::path::Path) -> Result<Self> {
488        use ndarray_npy::ReadNpyExt;
489        use std::fs::File;
490
491        // Memory-map centroids instead of loading into RAM
492        let centroids_path = index_path.join("centroids.npy");
493        let mmap_centroids = crate::mmap::MmapNpyArray2F32::from_npy_file(&centroids_path)?;
494
495        // Load small tensors into memory (negligible size)
496        let avg_residual_path = index_path.join("avg_residual.npy");
497        let avg_residual: Array1<f32> =
498            Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
499                Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
500            })?)
501            .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
502
503        let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
504        let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
505            Some(
506                Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
507                    Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
508                })?)
509                .map_err(|e| {
510                    Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
511                })?,
512            )
513        } else {
514            None
515        };
516
517        let bucket_weights_path = index_path.join("bucket_weights.npy");
518        let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
519            Some(
520                Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
521                    Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
522                })?)
523                .map_err(|e| {
524                    Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
525                })?,
526            )
527        } else {
528            None
529        };
530
531        // Read nbits from metadata
532        let metadata_path = index_path.join("metadata.json");
533        let metadata: serde_json::Value = serde_json::from_reader(
534            File::open(&metadata_path)
535                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
536        )
537        .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
538
539        let nbits = metadata["nbits"]
540            .as_u64()
541            .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
542            as usize;
543
544        Self::new_with_store(
545            nbits,
546            CentroidStore::Mmap(mmap_centroids),
547            avg_residual,
548            bucket_cutoffs,
549            bucket_weights,
550        )
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_codec_creation() {
560        let centroids =
561            Array2::from_shape_vec((4, 8), (0..32).map(|x| x as f32).collect()).unwrap();
562        let avg_residual = Array1::zeros(8);
563        let bucket_cutoffs = Some(Array1::from_vec(vec![-0.5, 0.0, 0.5]));
564        let bucket_weights = Some(Array1::from_vec(vec![-0.75, -0.25, 0.25, 0.75]));
565
566        let codec = ResidualCodec::new(2, centroids, avg_residual, bucket_cutoffs, bucket_weights);
567        assert!(codec.is_ok());
568
569        let codec = codec.unwrap();
570        assert_eq!(codec.nbits, 2);
571        assert_eq!(codec.embedding_dim(), 8);
572        assert_eq!(codec.num_centroids(), 4);
573    }
574
575    #[test]
576    fn test_compress_into_codes() {
577        let centroids = Array2::from_shape_vec(
578            (3, 4),
579            vec![
580                1.0, 0.0, 0.0, 0.0, // centroid 0
581                0.0, 1.0, 0.0, 0.0, // centroid 1
582                0.0, 0.0, 1.0, 0.0, // centroid 2
583            ],
584        )
585        .unwrap();
586
587        let avg_residual = Array1::zeros(4);
588        let codec = ResidualCodec::new(2, centroids, avg_residual, None, None).unwrap();
589
590        let embeddings = Array2::from_shape_vec(
591            (2, 4),
592            vec![
593                0.9, 0.1, 0.0, 0.0, // should match centroid 0
594                0.0, 0.0, 0.95, 0.05, // should match centroid 2
595            ],
596        )
597        .unwrap();
598
599        let codes = codec.compress_into_codes(&embeddings);
600        assert_eq!(codes[0], 0);
601        assert_eq!(codes[1], 2);
602    }
603
604    #[test]
605    fn test_quantize_decompress_roundtrip_4bit() {
606        // Test round-trip with 4-bit quantization
607        let dim = 8;
608        let centroids = Array2::zeros((4, dim));
609        let avg_residual = Array1::zeros(dim);
610
611        // Create bucket cutoffs and weights for 16 buckets
612        // Cutoffs at quantiles 1/16, 2/16, ..., 15/16
613        let bucket_cutoffs: Vec<f32> = (1..16).map(|i| (i as f32 / 16.0 - 0.5) * 2.0).collect();
614        // Weights at quantile midpoints
615        let bucket_weights: Vec<f32> = (0..16)
616            .map(|i| ((i as f32 + 0.5) / 16.0 - 0.5) * 2.0)
617            .collect();
618
619        let codec = ResidualCodec::new(
620            4,
621            centroids,
622            avg_residual,
623            Some(Array1::from_vec(bucket_cutoffs)),
624            Some(Array1::from_vec(bucket_weights)),
625        )
626        .unwrap();
627
628        // Create test residuals that span different bucket ranges
629        let residuals = Array2::from_shape_vec(
630            (2, dim),
631            vec![
632                -0.9, -0.7, -0.5, -0.3, 0.0, 0.3, 0.5, 0.9, // various bucket values
633                -0.8, -0.4, 0.0, 0.4, 0.8, -0.6, 0.2, 0.6,
634            ],
635        )
636        .unwrap();
637
638        // Quantize
639        let packed = codec.quantize_residuals(&residuals).unwrap();
640        assert_eq!(packed.ncols(), dim * 4 / 8); // 4 bytes per row for dim=8, nbits=4
641
642        // Create a temporary centroid assignment (all zeros)
643        let codes = Array1::from_vec(vec![0, 0]);
644
645        // Decompress and verify the reconstruction is reasonable
646        let decompressed = codec.decompress(&packed, &codes.view()).unwrap();
647
648        // The decompressed values should be close to the quantized bucket weights
649        // (plus centroid, which is zero here)
650        for i in 0..residuals.nrows() {
651            for j in 0..residuals.ncols() {
652                let orig = residuals[[i, j]];
653                let recon = decompressed[[i, j]];
654                // After normalization, values should be in similar direction
655                // The reconstruction won't be exact due to quantization, but
656                // the sign should generally match for non-zero values
657                if orig.abs() > 0.2 {
658                    assert!(
659                        (orig > 0.0) == (recon > 0.0) || recon.abs() < 0.1,
660                        "Sign mismatch at [{}, {}]: orig={}, recon={}",
661                        i,
662                        j,
663                        orig,
664                        recon
665                    );
666                }
667            }
668        }
669    }
670}