Skip to main content

aprender_rag/multivector/
codec.rs

1//! Residual quantization codec for WARP
2//!
3//! This module implements the residual quantization scheme used to compress
4//! token embeddings in the WARP algorithm. Each vector is decomposed into:
5//! - A centroid (learned via k-means)
6//! - A residual (difference from centroid), quantized to 2-4 bits per dimension
7//!
8//! The codec enables efficient scoring without full decompression by using
9//! precomputed centroid scores and bucket weights.
10
11use crate::multivector::types::WarpIndexConfig;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14
15/// Residual quantization codec for compressing token embeddings.
16///
17/// The codec learns centroids via k-means clustering, then quantizes the
18/// residuals (v - centroid) to a small number of bits per dimension.
19///
20/// # Compression Process
21///
22/// 1. Find nearest centroid for input vector
23/// 2. Compute residual = vector - centroid
24/// 3. Quantize each dimension to `nbits` using learned bucket boundaries
25/// 4. Pack quantized values into bytes
26///
27/// # Scoring
28///
29/// Score computation avoids full decompression:
30/// ```text
31/// q · v ≈ q · c + Σ_d q[d] × bucket_weight[d, code[d]]
32/// ```
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ResidualCodec {
35    /// Centroid vectors: [num_centroids × dim], flattened
36    centroids: Vec<f32>,
37    /// Number of centroids
38    num_centroids: usize,
39    /// Token dimension
40    dim: usize,
41    /// Quantization bucket boundaries per dimension: [dim × (num_buckets - 1)]
42    bucket_cutoffs: Vec<f32>,
43    /// Reconstruction weights per bucket: [dim × num_buckets]
44    bucket_weights: Vec<f32>,
45    /// Bits per dimension (2 or 4)
46    nbits: u8,
47}
48
49impl ResidualCodec {
50    /// Train a codec from sample embeddings.
51    ///
52    /// # Arguments
53    ///
54    /// * `embeddings` - Flattened sample embeddings [n × dim]
55    /// * `dim` - Embedding dimension
56    /// * `num_centroids` - Number of k-means centroids
57    /// * `nbits` - Bits per dimension (2 or 4)
58    /// * `iterations` - K-means iterations
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if training data is insufficient or parameters invalid.
63    pub fn train(
64        embeddings: &[f32],
65        dim: usize,
66        num_centroids: usize,
67        nbits: u8,
68        iterations: usize,
69    ) -> Result<Self> {
70        if nbits != 2 && nbits != 4 {
71            return Err(crate::Error::InvalidInput("nbits must be 2 or 4".to_string()));
72        }
73
74        if dim == 0 {
75            return Err(crate::Error::InvalidInput("dim must be > 0".to_string()));
76        }
77        let n = embeddings.len() / dim;
78        if n < num_centroids {
79            return Err(crate::Error::InvalidInput(format!(
80                "Insufficient training data: {n} samples for {num_centroids} centroids"
81            )));
82        }
83
84        // Contract: embedding-algebra-v1.yaml precondition (pv codegen)
85        contract_pre_embedding_lookup!(embeddings);
86
87        // Step 1: K-means clustering to find centroids
88        let centroids = Self::kmeans_clustering(embeddings, dim, num_centroids, iterations);
89
90        // Step 2: Compute residuals for all training points
91        let residuals = Self::compute_all_residuals(embeddings, dim, &centroids, num_centroids);
92
93        // Step 3: Learn quantization boundaries from residual distribution
94        let (bucket_cutoffs, bucket_weights) =
95            Self::learn_quantization_params(&residuals, dim, nbits);
96
97        Ok(Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits })
98    }
99
100    /// Create a codec with pre-trained parameters.
101    ///
102    /// # Panics
103    ///
104    /// Panics if `dim == 0` (poka-yoke: division-by-zero guard).
105    #[must_use]
106    pub fn with_params(
107        centroids: Vec<f32>,
108        num_centroids: usize,
109        dim: usize,
110        bucket_cutoffs: Vec<f32>,
111        bucket_weights: Vec<f32>,
112        nbits: u8,
113    ) -> Self {
114        assert!(dim > 0, "dim must be > 0: division by zero in centroid/residual arithmetic");
115        Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits }
116    }
117
118    /// Get the number of centroids.
119    #[must_use]
120    pub fn num_centroids(&self) -> usize {
121        self.num_centroids
122    }
123
124    /// Get the embedding dimension.
125    #[must_use]
126    pub fn dim(&self) -> usize {
127        self.dim
128    }
129
130    /// Get bits per dimension.
131    #[must_use]
132    pub fn nbits(&self) -> u8 {
133        self.nbits
134    }
135
136    /// Get the packed residual size in bytes.
137    #[must_use]
138    pub fn packed_size(&self) -> usize {
139        (self.dim * self.nbits as usize + 7) / 8
140    }
141
142    /// Get centroid slice by ID.
143    #[must_use]
144    pub fn centroid(&self, id: usize) -> &[f32] {
145        let start = id * self.dim;
146        &self.centroids[start..start + self.dim]
147    }
148
149    /// Get all centroids as a flat slice.
150    #[must_use]
151    pub fn centroids(&self) -> &[f32] {
152        &self.centroids
153    }
154
155    /// Find the nearest centroid for a vector.
156    #[must_use]
157    pub fn find_nearest_centroid(&self, embedding: &[f32]) -> usize {
158        // Contract: configuration-v1.yaml precondition (pv codegen)
159        contract_pre_configuration!(embedding);
160        let mut best_id = 0;
161        let mut best_dist = f32::MAX;
162
163        for c in 0..self.num_centroids {
164            let centroid = self.centroid(c);
165            let dist = Self::squared_distance(embedding, centroid);
166            if dist < best_dist {
167                best_dist = dist;
168                best_id = c;
169            }
170        }
171
172        best_id
173    }
174
175    /// Compress an embedding to (centroid_id, packed_residual).
176    #[must_use]
177    pub fn compress(&self, embedding: &[f32]) -> (usize, Vec<u8>) {
178        // Contract: embedding-algebra-v1.yaml precondition (pv codegen)
179        contract_pre_embedding_lookup!(embedding);
180        // Find nearest centroid
181        let centroid_id = self.find_nearest_centroid(embedding);
182        let centroid = self.centroid(centroid_id);
183
184        // Compute residual
185        let residual: Vec<f32> =
186            embedding.iter().zip(centroid.iter()).map(|(e, c)| e - c).collect();
187
188        // Quantize residual
189        let codes = self.quantize_residual(&residual);
190
191        // Pack codes into bytes
192        let packed = self.pack_codes(&codes);
193
194        (centroid_id, packed)
195    }
196
197    /// Compute score between query token and compressed document token.
198    ///
199    /// score ≈ q · d = q · c + q · r
200    ///
201    /// # Arguments
202    ///
203    /// * `query_token` - Query embedding
204    /// * `centroid_id` - Assigned centroid
205    /// * `centroid_score` - Precomputed q · c
206    /// * `packed_residual` - Packed quantized residual
207    #[must_use]
208    pub fn decompress_score(
209        &self,
210        query_token: &[f32],
211        centroid_id: usize,
212        centroid_score: f32,
213        packed_residual: &[u8],
214    ) -> f32 {
215        let _ = centroid_id; // Centroid info already in centroid_score
216
217        // Unpack residual codes
218        let codes = self.unpack_codes(packed_residual);
219
220        // Compute q · r using bucket weights
221        let num_buckets = 1usize << self.nbits;
222        let residual_score: f32 = codes
223            .iter()
224            .enumerate()
225            .map(|(d, &code)| {
226                let weight_idx = d * num_buckets + code as usize;
227                query_token[d] * self.bucket_weights[weight_idx]
228            })
229            .sum();
230
231        centroid_score + residual_score
232    }
233
234    /// Compute dot product between query and centroid.
235    #[must_use]
236    pub fn centroid_score(&self, query_token: &[f32], centroid_id: usize) -> f32 {
237        let centroid = self.centroid(centroid_id);
238        Self::dot_product(query_token, centroid)
239    }
240
241    /// Quantize a residual vector to codes.
242    fn quantize_residual(&self, residual: &[f32]) -> Vec<u8> {
243        let num_buckets = 1usize << self.nbits;
244
245        residual
246            .iter()
247            .enumerate()
248            .map(|(d, &value)| {
249                // Binary search for bucket
250                let cutoff_start = d * (num_buckets - 1);
251                let cutoffs = &self.bucket_cutoffs[cutoff_start..cutoff_start + num_buckets - 1];
252
253                // Find first cutoff >= value
254                cutoffs.iter().position(|&c| value < c).unwrap_or(num_buckets - 1) as u8
255            })
256            .collect()
257    }
258
259    /// Pack quantization codes into bytes.
260    fn pack_codes(&self, codes: &[u8]) -> Vec<u8> {
261        match self.nbits {
262            2 => {
263                // Pack 4 codes per byte
264                codes
265                    .chunks(4)
266                    .map(|chunk| {
267                        let mut byte = 0u8;
268                        for (i, &code) in chunk.iter().enumerate() {
269                            byte |= (code & 0x03) << (i * 2);
270                        }
271                        byte
272                    })
273                    .collect()
274            }
275            4 => {
276                // Pack 2 codes per byte
277                codes
278                    .chunks(2)
279                    .map(|chunk| {
280                        let low = chunk.first().copied().unwrap_or(0) & 0x0F;
281                        let high = chunk.get(1).copied().unwrap_or(0) & 0x0F;
282                        low | (high << 4)
283                    })
284                    .collect()
285            }
286            _ => panic!("Unsupported nbits: {}", self.nbits),
287        }
288    }
289
290    /// Unpack codes from packed bytes.
291    fn unpack_codes(&self, packed: &[u8]) -> Vec<u8> {
292        match self.nbits {
293            2 => packed
294                .iter()
295                .flat_map(|&byte| (0..4).map(move |i| (byte >> (i * 2)) & 0x03))
296                .take(self.dim)
297                .collect(),
298            4 => packed
299                .iter()
300                .flat_map(|&byte| vec![byte & 0x0F, (byte >> 4) & 0x0F])
301                .take(self.dim)
302                .collect(),
303            _ => panic!("Unsupported nbits: {}", self.nbits),
304        }
305    }
306
307    // ============ K-means Implementation ============
308
309    /// K-means clustering with k-means++ initialization.
310    fn kmeans_clustering(embeddings: &[f32], dim: usize, k: usize, iterations: usize) -> Vec<f32> {
311        let n = embeddings.len() / dim;
312
313        // K-means++ initialization
314        let mut centroids = Self::kmeans_plus_plus_init(embeddings, dim, k);
315        let mut assignments = vec![0usize; n];
316
317        for _ in 0..iterations {
318            // Assign points to nearest centroid
319            for i in 0..n {
320                let point = &embeddings[i * dim..(i + 1) * dim];
321                let mut best_dist = f32::MAX;
322                let mut best_c = 0;
323
324                for c in 0..k {
325                    let centroid = &centroids[c * dim..(c + 1) * dim];
326                    let dist = Self::squared_distance(point, centroid);
327                    if dist < best_dist {
328                        best_dist = dist;
329                        best_c = c;
330                    }
331                }
332                assignments[i] = best_c;
333            }
334
335            // Update centroids as mean of assigned points
336            let mut new_centroids = vec![0.0f32; k * dim];
337            let mut counts = vec![0usize; k];
338
339            for i in 0..n {
340                let c = assignments[i];
341                counts[c] += 1;
342                let point = &embeddings[i * dim..(i + 1) * dim];
343                for d in 0..dim {
344                    new_centroids[c * dim + d] += point[d];
345                }
346            }
347
348            for c in 0..k {
349                if counts[c] > 0 {
350                    for d in 0..dim {
351                        new_centroids[c * dim + d] /= counts[c] as f32;
352                    }
353                } else {
354                    // Keep old centroid if no points assigned
355                    for d in 0..dim {
356                        new_centroids[c * dim + d] = centroids[c * dim + d];
357                    }
358                }
359            }
360
361            centroids = new_centroids;
362        }
363
364        centroids
365    }
366
367    /// K-means++ initialization.
368    fn kmeans_plus_plus_init(embeddings: &[f32], dim: usize, k: usize) -> Vec<f32> {
369        let n = embeddings.len() / dim;
370        let mut centroids = Vec::with_capacity(k * dim);
371        let mut rng_state = 42u64; // Simple deterministic RNG
372
373        // Choose first centroid uniformly at random
374        let first_idx = Self::simple_random(&mut rng_state, n);
375        centroids.extend_from_slice(&embeddings[first_idx * dim..(first_idx + 1) * dim]);
376
377        let mut distances = vec![f32::MAX; n];
378
379        for _ in 1..k {
380            let num_centroids = centroids.len() / dim;
381
382            // Update distances to nearest centroid
383            for i in 0..n {
384                let point = &embeddings[i * dim..(i + 1) * dim];
385                let centroid = &centroids[(num_centroids - 1) * dim..num_centroids * dim];
386                let dist = Self::squared_distance(point, centroid);
387                distances[i] = distances[i].min(dist);
388            }
389
390            // Choose next centroid with probability proportional to D²
391            let total: f32 = distances.iter().sum();
392            if total <= 0.0 {
393                // All points are centroids already, pick random
394                let idx = Self::simple_random(&mut rng_state, n);
395                centroids.extend_from_slice(&embeddings[idx * dim..(idx + 1) * dim]);
396                continue;
397            }
398
399            let threshold = Self::simple_random_f32(&mut rng_state) * total;
400            let mut cumsum = 0.0f32;
401            let mut chosen = 0;
402
403            for (i, &d) in distances.iter().enumerate() {
404                cumsum += d;
405                if cumsum >= threshold {
406                    chosen = i;
407                    break;
408                }
409            }
410
411            centroids.extend_from_slice(&embeddings[chosen * dim..(chosen + 1) * dim]);
412        }
413
414        centroids
415    }
416
417    /// Compute residuals for all embeddings.
418    fn compute_all_residuals(
419        embeddings: &[f32],
420        dim: usize,
421        centroids: &[f32],
422        num_centroids: usize,
423    ) -> Vec<f32> {
424        let n = embeddings.len() / dim;
425        let mut residuals = Vec::with_capacity(n * dim);
426
427        for i in 0..n {
428            let point = &embeddings[i * dim..(i + 1) * dim];
429
430            // Find nearest centroid
431            let mut best_c = 0;
432            let mut best_dist = f32::MAX;
433            for c in 0..num_centroids {
434                let centroid = &centroids[c * dim..(c + 1) * dim];
435                let dist = Self::squared_distance(point, centroid);
436                if dist < best_dist {
437                    best_dist = dist;
438                    best_c = c;
439                }
440            }
441
442            // Compute residual
443            let centroid = &centroids[best_c * dim..(best_c + 1) * dim];
444            for d in 0..dim {
445                residuals.push(point[d] - centroid[d]);
446            }
447        }
448
449        residuals
450    }
451
452    /// Learn quantization bucket boundaries and weights from residuals.
453    fn learn_quantization_params(residuals: &[f32], dim: usize, nbits: u8) -> (Vec<f32>, Vec<f32>) {
454        let num_buckets = 1usize << nbits;
455        let n = residuals.len() / dim;
456
457        let mut cutoffs = Vec::with_capacity(dim * (num_buckets - 1));
458        let mut weights = Vec::with_capacity(dim * num_buckets);
459
460        for d in 0..dim {
461            // Collect residual values for dimension d
462            let mut values: Vec<f32> = (0..n).map(|i| residuals[i * dim + d]).collect();
463            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
464
465            // Quantile-based boundaries for equal-frequency buckets
466            for b in 1..num_buckets {
467                let quantile_idx = (b * n) / num_buckets;
468                cutoffs.push(values[quantile_idx.min(n - 1)]);
469            }
470
471            // Bucket weights = mean value in each bucket
472            for b in 0..num_buckets {
473                let start = (b * n) / num_buckets;
474                let end = ((b + 1) * n) / num_buckets;
475                let end = end.max(start + 1).min(n);
476
477                let sum: f32 = values[start..end].iter().sum();
478                let mean = sum / (end - start) as f32;
479                weights.push(mean);
480            }
481        }
482
483        (cutoffs, weights)
484    }
485
486    // ============ Math Utilities ============
487
488    fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
489        a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
490    }
491
492    fn dot_product(a: &[f32], b: &[f32]) -> f32 {
493        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
494    }
495
496    fn simple_random(state: &mut u64, max: usize) -> usize {
497        *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
498        ((*state >> 33) as usize) % max
499    }
500
501    fn simple_random_f32(state: &mut u64) -> f32 {
502        *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
503        ((*state >> 33) as f32) / (u32::MAX as f32)
504    }
505}
506
507/// Builder for creating a `ResidualCodec` from a `WarpIndexConfig`.
508pub struct ResidualCodecBuilder {
509    config: WarpIndexConfig,
510}
511
512impl ResidualCodecBuilder {
513    /// Create a new builder from config.
514    #[must_use]
515    pub fn new(config: WarpIndexConfig) -> Self {
516        Self { config }
517    }
518
519    /// Train the codec from sample embeddings.
520    pub fn train(&self, embeddings: &[f32]) -> Result<ResidualCodec> {
521        // Contract: embedding-algebra-v1.yaml precondition (pv codegen)
522        contract_pre_embedding_lookup!(embeddings);
523        ResidualCodec::train(
524            embeddings,
525            self.config.token_dim,
526            self.config.num_centroids,
527            self.config.nbits,
528            self.config.kmeans_iterations,
529        )
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    fn generate_test_embeddings(n: usize, dim: usize) -> Vec<f32> {
538        let mut embeddings = Vec::with_capacity(n * dim);
539        let mut rng_state = 12345u64;
540
541        for _ in 0..(n * dim) {
542            rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
543            let val = ((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
544            embeddings.push(val);
545        }
546
547        embeddings
548    }
549
550    // ============ Basic Codec Tests ============
551
552    #[test]
553    fn test_codec_train_2bit() {
554        let embeddings = generate_test_embeddings(1000, 32);
555        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
556
557        assert_eq!(codec.num_centroids(), 16);
558        assert_eq!(codec.dim(), 32);
559        assert_eq!(codec.nbits(), 2);
560    }
561
562    #[test]
563    fn test_codec_train_4bit() {
564        let embeddings = generate_test_embeddings(1000, 32);
565        let codec = ResidualCodec::train(&embeddings, 32, 16, 4, 5).unwrap();
566
567        assert_eq!(codec.nbits(), 4);
568    }
569
570    #[test]
571    fn test_codec_train_insufficient_data() {
572        let embeddings = generate_test_embeddings(5, 32);
573        let result = ResidualCodec::train(&embeddings, 32, 16, 2, 5);
574
575        assert!(result.is_err());
576    }
577
578    #[test]
579    fn test_codec_train_invalid_nbits() {
580        let embeddings = generate_test_embeddings(100, 32);
581        let result = ResidualCodec::train(&embeddings, 32, 16, 3, 5);
582
583        assert!(result.is_err());
584    }
585
586    /// Regression test for paiml/trueno-rag#15: train() rejects dim=0.
587    #[test]
588    fn test_codec_train_dim_zero() {
589        let result = ResidualCodec::train(&[], 0, 4, 2, 3);
590        assert!(result.is_err());
591    }
592
593    /// Regression test for paiml/trueno-rag#15: with_params() rejects dim=0.
594    #[test]
595    #[should_panic(expected = "dim must be > 0")]
596    fn test_codec_with_params_dim_zero() {
597        let _ = ResidualCodec::with_params(vec![], 0, 0, vec![], vec![], 2);
598    }
599
600    // ============ Compression Tests ============
601
602    #[test]
603    fn test_codec_compress() {
604        let embeddings = generate_test_embeddings(500, 32);
605        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
606
607        let test_vec = &embeddings[0..32];
608        let (centroid_id, packed) = codec.compress(test_vec);
609
610        assert!(centroid_id < 16);
611        assert_eq!(packed.len(), codec.packed_size());
612    }
613
614    #[test]
615    fn test_codec_packed_size_2bit() {
616        let embeddings = generate_test_embeddings(500, 128);
617        let codec = ResidualCodec::train(&embeddings, 128, 16, 2, 5).unwrap();
618
619        // 128 dims × 2 bits = 256 bits = 32 bytes
620        assert_eq!(codec.packed_size(), 32);
621    }
622
623    #[test]
624    fn test_codec_packed_size_4bit() {
625        let embeddings = generate_test_embeddings(500, 128);
626        let codec = ResidualCodec::train(&embeddings, 128, 16, 4, 5).unwrap();
627
628        // 128 dims × 4 bits = 512 bits = 64 bytes
629        assert_eq!(codec.packed_size(), 64);
630    }
631
632    // ============ Pack/Unpack Tests ============
633
634    #[test]
635    fn test_pack_unpack_2bit() {
636        let embeddings = generate_test_embeddings(500, 8);
637        let codec = ResidualCodec::train(&embeddings, 8, 16, 2, 5).unwrap();
638
639        let codes: Vec<u8> = vec![0, 1, 2, 3, 0, 1, 2, 3];
640        let packed = codec.pack_codes(&codes);
641        let unpacked = codec.unpack_codes(&packed);
642
643        assert_eq!(codes, unpacked);
644    }
645
646    #[test]
647    fn test_pack_unpack_4bit() {
648        let embeddings = generate_test_embeddings(500, 8);
649        let codec = ResidualCodec::train(&embeddings, 8, 16, 4, 5).unwrap();
650
651        let codes: Vec<u8> = vec![0, 5, 10, 15, 1, 6, 11, 14];
652        let packed = codec.pack_codes(&codes);
653        let unpacked = codec.unpack_codes(&packed);
654
655        assert_eq!(codes, unpacked);
656    }
657
658    // ============ Scoring Tests ============
659
660    #[test]
661    fn test_decompress_score() {
662        let embeddings = generate_test_embeddings(500, 32);
663        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
664
665        let query = &embeddings[0..32];
666        let doc = &embeddings[32..64];
667
668        // Compress document
669        let (centroid_id, packed) = codec.compress(doc);
670
671        // Compute centroid score
672        let centroid_score = codec.centroid_score(query, centroid_id);
673
674        // Compute approximate score
675        let approx_score = codec.decompress_score(query, centroid_id, centroid_score, &packed);
676
677        // Compute exact score
678        let exact_score: f32 = query.iter().zip(doc.iter()).map(|(q, d)| q * d).sum();
679
680        // Approximate score should be close to exact (within reasonable tolerance)
681        let error = (approx_score - exact_score).abs();
682        assert!(
683            error < exact_score.abs() * 0.5 + 1.0,
684            "Error too large: approx={}, exact={}, error={}",
685            approx_score,
686            exact_score,
687            error
688        );
689    }
690
691    #[test]
692    fn test_centroid_score() {
693        let embeddings = generate_test_embeddings(500, 32);
694        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
695
696        let query = &embeddings[0..32];
697        let centroid = codec.centroid(0);
698
699        let expected: f32 = query.iter().zip(centroid.iter()).map(|(q, c)| q * c).sum();
700        let actual = codec.centroid_score(query, 0);
701
702        assert!((expected - actual).abs() < 1e-6);
703    }
704
705    // ============ K-means Tests ============
706
707    #[test]
708    fn test_find_nearest_centroid() {
709        let embeddings = generate_test_embeddings(500, 32);
710        let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
711
712        // A centroid should be nearest to itself
713        let centroid_0 = codec.centroid(0).to_vec();
714        let nearest = codec.find_nearest_centroid(&centroid_0);
715        assert_eq!(nearest, 0);
716    }
717
718    // ============ Builder Tests ============
719
720    #[test]
721    fn test_codec_builder() {
722        let config = WarpIndexConfig::new(2, 16, 32).with_kmeans_iterations(5);
723        let builder = ResidualCodecBuilder::new(config);
724
725        let embeddings = generate_test_embeddings(500, 32);
726        let codec = builder.train(&embeddings).unwrap();
727
728        assert_eq!(codec.num_centroids(), 16);
729        assert_eq!(codec.dim(), 32);
730    }
731
732    // ============ Serialization Tests ============
733
734    #[test]
735    fn test_codec_serialization() {
736        let embeddings = generate_test_embeddings(500, 16);
737        let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 5).unwrap();
738
739        let json = serde_json::to_string(&codec).unwrap();
740        let deserialized: ResidualCodec = serde_json::from_str(&json).unwrap();
741
742        assert_eq!(codec.num_centroids(), deserialized.num_centroids());
743        assert_eq!(codec.dim(), deserialized.dim());
744        assert_eq!(codec.nbits(), deserialized.nbits());
745    }
746
747    // ============ Property-Based Tests ============
748
749    use proptest::prelude::*;
750
751    proptest! {
752        #[test]
753        fn prop_compress_produces_valid_centroid_id(
754            seed in 0u64..1000
755        ) {
756            let mut embeddings = Vec::with_capacity(200 * 16);
757            let mut rng_state = seed;
758            for _ in 0..(200 * 16) {
759                rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
760                embeddings.push(((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0);
761            }
762
763            let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 3).unwrap();
764            let test_vec = &embeddings[0..16];
765            let (centroid_id, _) = codec.compress(test_vec);
766
767            prop_assert!(centroid_id < 8);
768        }
769
770        #[test]
771        fn prop_packed_size_matches_config(
772            nbits in prop::sample::select(vec![2u8, 4]),
773            dim in 8usize..64
774        ) {
775            let num_samples = 100 * dim;
776            let embeddings = generate_test_embeddings(num_samples / dim, dim);
777
778            if let Ok(codec) = ResidualCodec::train(&embeddings, dim, 8, nbits, 3) {
779                let expected_size = (dim * nbits as usize + 7) / 8;
780                prop_assert_eq!(codec.packed_size(), expected_size);
781            }
782        }
783    }
784}