Skip to main content

rvf_quant/
codec.rs

1//! QUANT_SEG and SKETCH_SEG wire format codec.
2//!
3//! Serializes / deserializes quantizer parameters and Count-Min Sketch
4//! data to the binary layout defined in the RVF wire spec.
5
6use alloc::boxed::Box;
7use alloc::vec;
8use alloc::vec::Vec;
9
10use crate::binary;
11use crate::product::ProductQuantizer;
12use crate::rabitq::RabitqQuantizer;
13use crate::scalar::ScalarQuantizer;
14use crate::sketch::CountMinSketch;
15use crate::traits::Quantizer;
16
17// ---------------------------------------------------------------------------
18// QUANT_SEG codec
19// ---------------------------------------------------------------------------
20
21/// Quantization type tags matching the QUANT_SEG wire spec.
22/// (Tag 3 is reserved for residual PQ in `rvf_types::QuantType`.)
23const QUANT_TYPE_SCALAR: u8 = 0;
24const QUANT_TYPE_PRODUCT: u8 = 1;
25const QUANT_TYPE_BINARY: u8 = 2;
26const QUANT_TYPE_RABITQ: u8 = 4;
27
28/// Current RaBitQ QUANT_SEG layout version. Bump on incompatible changes;
29/// decoders reject unknown versions instead of misreading bytes.
30const RABITQ_VERSION: u8 = 1;
31
32/// Errors that can occur while decoding QUANT_SEG payloads.
33#[derive(Clone, Debug, PartialEq, Eq)]
34pub enum CodecError {
35    /// Input data is shorter than expected.
36    TooShort,
37    /// Unknown quantization type tag.
38    UnknownQuantType(u8),
39    /// Known quantization type, but an unsupported layout version.
40    UnsupportedVersion(u8),
41    /// A header field is internally inconsistent (e.g. bad padded_dim).
42    InvalidField,
43}
44
45impl core::fmt::Display for CodecError {
46    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
47        match self {
48            Self::TooShort => write!(f, "input data too short"),
49            Self::UnknownQuantType(t) => write!(f, "unknown quant_type: {}", t),
50            Self::UnsupportedVersion(v) => write!(f, "unsupported quant_seg version: {}", v),
51            Self::InvalidField => write!(f, "invalid quant_seg header field"),
52        }
53    }
54}
55
56/// Encode a quantizer into the QUANT_SEG binary payload.
57///
58/// Layout:
59/// ```text
60/// [quant_type: u8] [tier: u8] [dim: u16 LE] [padding: 60 bytes to 64B]
61/// [type-specific data ...]
62/// ```
63pub fn encode_quant_seg(quantizer: &dyn Quantizer) -> Vec<u8> {
64    // Downcast (via the `Any` supertrait) to serialize the concrete
65    // quantizer's parameters.
66    let any: &dyn core::any::Any = quantizer;
67    if let Some(sq) = any.downcast_ref::<ScalarQuantizer>() {
68        encode_scalar_quantizer(sq)
69    } else if let Some(pq) = any.downcast_ref::<ProductQuantizer>() {
70        encode_product_quantizer(pq)
71    } else if let Some(rq) = any.downcast_ref::<RabitqQuantizer>() {
72        encode_rabitq_quantizer(rq)
73    } else if quantizer.tier() as u8 == 2 {
74        // Binary quantization is parameter-free beyond the dimension.
75        encode_binary_quant_seg(quantizer.dim() as u16)
76    } else {
77        panic!("unknown quantizer type")
78    }
79}
80
81/// Decode a QUANT_SEG binary payload into a boxed Quantizer.
82pub fn decode_quant_seg(data: &[u8]) -> Result<Box<dyn Quantizer>, CodecError> {
83    if data.len() < 64 {
84        return Err(CodecError::TooShort);
85    }
86
87    let quant_type = data[0];
88    let _tier = data[1];
89    let dim = u16::from_le_bytes([data[2], data[3]]) as usize;
90    let body = &data[64..];
91
92    match quant_type {
93        QUANT_TYPE_SCALAR => Ok(Box::new(decode_scalar(body, dim)?)),
94        QUANT_TYPE_PRODUCT => Ok(Box::new(decode_product(body, dim)?)),
95        QUANT_TYPE_BINARY => Ok(Box::new(BinaryQuantizerWrapper { dim })),
96        QUANT_TYPE_RABITQ => Ok(Box::new(decode_rabitq(data, body, dim)?)),
97        _ => Err(CodecError::UnknownQuantType(quant_type)),
98    }
99}
100
101// ---------------------------------------------------------------------------
102// Scalar
103// ---------------------------------------------------------------------------
104
105/// Encode a ScalarQuantizer directly (preferred over trait-based encoding).
106pub fn encode_scalar_quantizer(sq: &ScalarQuantizer) -> Vec<u8> {
107    let dim = sq.dim as u16;
108    let mut buf = vec![0u8; 64];
109    buf[0] = QUANT_TYPE_SCALAR;
110    buf[1] = 0; // Hot tier
111    buf[2..4].copy_from_slice(&dim.to_le_bytes());
112
113    // min[dim], max[dim]
114    for &v in &sq.min_vals {
115        buf.extend_from_slice(&v.to_le_bytes());
116    }
117    for &v in &sq.max_vals {
118        buf.extend_from_slice(&v.to_le_bytes());
119    }
120    buf
121}
122
123fn decode_scalar(body: &[u8], dim: usize) -> Result<ScalarQuantizer, CodecError> {
124    let float_bytes = dim * 4;
125    if body.len() < float_bytes * 2 {
126        return Err(CodecError::TooShort);
127    }
128
129    let mut min_vals = Vec::with_capacity(dim);
130    let mut max_vals = Vec::with_capacity(dim);
131
132    for d in 0..dim {
133        let offset = d * 4;
134        let v = f32::from_le_bytes([
135            body[offset],
136            body[offset + 1],
137            body[offset + 2],
138            body[offset + 3],
139        ]);
140        min_vals.push(v);
141    }
142    for d in 0..dim {
143        let offset = (dim + d) * 4;
144        let v = f32::from_le_bytes([
145            body[offset],
146            body[offset + 1],
147            body[offset + 2],
148            body[offset + 3],
149        ]);
150        max_vals.push(v);
151    }
152
153    Ok(ScalarQuantizer {
154        min_vals,
155        max_vals,
156        dim,
157    })
158}
159
160// ---------------------------------------------------------------------------
161// Product
162// ---------------------------------------------------------------------------
163
164/// Encode a ProductQuantizer directly.
165pub fn encode_product_quantizer(pq: &ProductQuantizer) -> Vec<u8> {
166    let dim = (pq.m * pq.sub_dim) as u16;
167    let mut buf = vec![0u8; 64];
168    buf[0] = QUANT_TYPE_PRODUCT;
169    buf[1] = 1; // Warm tier
170    buf[2..4].copy_from_slice(&dim.to_le_bytes());
171
172    // PQ header: M, K, sub_dim (each as u16 LE)
173    // Written after the 64-byte aligned header.
174    buf.extend_from_slice(&(pq.m as u16).to_le_bytes());
175    buf.extend_from_slice(&(pq.k as u16).to_le_bytes());
176    buf.extend_from_slice(&(pq.sub_dim as u16).to_le_bytes());
177
178    // Codebook: M * K * sub_dim floats
179    for sub_book in &pq.codebooks {
180        for centroid in sub_book {
181            for &val in centroid {
182                buf.extend_from_slice(&val.to_le_bytes());
183            }
184        }
185    }
186
187    buf
188}
189
190fn decode_product(body: &[u8], _dim: usize) -> Result<ProductQuantizer, CodecError> {
191    if body.len() < 6 {
192        return Err(CodecError::TooShort);
193    }
194
195    let m = u16::from_le_bytes([body[0], body[1]]) as usize;
196    let k = u16::from_le_bytes([body[2], body[3]]) as usize;
197    let sub_dim = u16::from_le_bytes([body[4], body[5]]) as usize;
198
199    // Compute the codebook size in u64 with checked arithmetic: on 32-bit
200    // targets (wasm32) `m * k * sub_dim * 4` can wrap usize, slip past the
201    // length check below, and then index out of bounds in the decode loop.
202    let codebook_bytes = (m as u64)
203        .checked_mul(k as u64)
204        .and_then(|v| v.checked_mul(sub_dim as u64))
205        .and_then(|v| v.checked_mul(4))
206        .ok_or(CodecError::InvalidField)?;
207    let expected = codebook_bytes
208        .checked_add(6)
209        .ok_or(CodecError::InvalidField)?;
210    if (body.len() as u64) < expected {
211        return Err(CodecError::TooShort);
212    }
213
214    let mut codebooks = Vec::with_capacity(m);
215    let mut offset = 6;
216    for _ in 0..m {
217        let mut sub_book = Vec::with_capacity(k);
218        for _ in 0..k {
219            let mut centroid = Vec::with_capacity(sub_dim);
220            for _ in 0..sub_dim {
221                let v = f32::from_le_bytes([
222                    body[offset],
223                    body[offset + 1],
224                    body[offset + 2],
225                    body[offset + 3],
226                ]);
227                centroid.push(v);
228                offset += 4;
229            }
230            sub_book.push(centroid);
231        }
232        codebooks.push(sub_book);
233    }
234
235    Ok(ProductQuantizer {
236        m,
237        k,
238        sub_dim,
239        codebooks,
240    })
241}
242
243// ---------------------------------------------------------------------------
244// Binary
245// ---------------------------------------------------------------------------
246
247fn encode_binary_quant_seg(dim: u16) -> Vec<u8> {
248    let mut buf = vec![0u8; 64];
249    buf[0] = QUANT_TYPE_BINARY;
250    buf[1] = 2; // Cold tier
251    buf[2..4].copy_from_slice(&dim.to_le_bytes());
252    // Binary quantization has no additional parameters (sign-based).
253    buf
254}
255
256/// Wrapper to implement `Quantizer` for binary quantization.
257struct BinaryQuantizerWrapper {
258    dim: usize,
259}
260
261impl Quantizer for BinaryQuantizerWrapper {
262    fn encode(&self, vector: &[f32]) -> Vec<u8> {
263        binary::encode_binary(vector)
264    }
265
266    fn decode(&self, codes: &[u8]) -> Vec<f32> {
267        binary::decode_binary(codes, self.dim)
268    }
269
270    fn tier(&self) -> crate::tier::TemperatureTier {
271        crate::tier::TemperatureTier::Cold
272    }
273
274    fn dim(&self) -> usize {
275        self.dim
276    }
277}
278
279// ---------------------------------------------------------------------------
280// RaBitQ
281// ---------------------------------------------------------------------------
282
283/// Encode a RaBitQ quantizer into a QUANT_SEG payload.
284///
285/// Header layout (within the shared 64-byte aligned header; bytes 4..20
286/// were zero padding in pre-RaBitQ payloads, so old types are unaffected):
287/// ```text
288/// [quant_type=4: u8] [tier: u8] [dim: u16 LE]
289/// [version: u8] [rounds: u8] [reserved: u16]
290/// [seed: u64 LE] [padded_dim: u32 LE] [padding to 64B]
291/// [centroid: dim * f32 LE]
292/// ```
293pub fn encode_rabitq_quantizer(rq: &RabitqQuantizer) -> Vec<u8> {
294    let mut buf = vec![0u8; 64];
295    buf[0] = QUANT_TYPE_RABITQ;
296    buf[1] = 2; // Cold tier
297    buf[2..4].copy_from_slice(&(rq.dim as u16).to_le_bytes());
298    buf[4] = RABITQ_VERSION;
299    buf[5] = rq.rounds;
300    // buf[6..8] reserved (zero)
301    buf[8..16].copy_from_slice(&rq.seed.to_le_bytes());
302    buf[16..20].copy_from_slice(&(rq.padded_dim as u32).to_le_bytes());
303
304    for &v in &rq.centroid {
305        buf.extend_from_slice(&v.to_le_bytes());
306    }
307    buf
308}
309
310/// Decode a RaBitQ QUANT_SEG payload (versioned; bounds-checked).
311///
312/// `data` is the full payload (for header fields beyond the shared
313/// prefix), `body` is the slice after the 64-byte header.
314fn decode_rabitq(data: &[u8], body: &[u8], dim: usize) -> Result<RabitqQuantizer, CodecError> {
315    // Caller guarantees data.len() >= 64.
316    let version = data[4];
317    if version != RABITQ_VERSION {
318        return Err(CodecError::UnsupportedVersion(version));
319    }
320    let rounds = data[5];
321    let seed = u64::from_le_bytes(data[8..16].try_into().expect("len checked"));
322    let padded_dim = u32::from_le_bytes(data[16..20].try_into().expect("len checked")) as usize;
323
324    if dim == 0 || rounds == 0 {
325        return Err(CodecError::InvalidField);
326    }
327    // padded_dim must be the canonical power-of-two padding of dim; this
328    // also bounds it (dim is u16, so padded_dim <= 65536).
329    if padded_dim != dim.max(1).next_power_of_two() {
330        return Err(CodecError::InvalidField);
331    }
332
333    let centroid_bytes = dim.checked_mul(4).ok_or(CodecError::InvalidField)?;
334    if body.len() < centroid_bytes {
335        return Err(CodecError::TooShort);
336    }
337    let mut centroid = Vec::with_capacity(dim);
338    for d in 0..dim {
339        let offset = d * 4;
340        centroid.push(f32::from_le_bytes(
341            body[offset..offset + 4].try_into().expect("len checked"),
342        ));
343    }
344
345    Ok(RabitqQuantizer::with_centroid(dim, centroid, seed, rounds))
346}
347
348// ---------------------------------------------------------------------------
349// SKETCH_SEG codec
350// ---------------------------------------------------------------------------
351
352/// Encode a CountMinSketch into the SKETCH_SEG binary payload.
353///
354/// Layout:
355/// ```text
356/// [width: u32 LE] [depth: u32 LE] [total_accesses: u64 LE] [padding: 48 bytes to 64B]
357/// [counters: depth * width bytes]
358/// ```
359pub fn encode_sketch_seg(sketch: &CountMinSketch) -> Vec<u8> {
360    let mut buf = vec![0u8; 64]; // 64-byte aligned header
361
362    buf[0..4].copy_from_slice(&(sketch.width as u32).to_le_bytes());
363    buf[4..8].copy_from_slice(&(sketch.depth as u32).to_le_bytes());
364    buf[8..16].copy_from_slice(&sketch.total_accesses.to_le_bytes());
365
366    // Counter data: row-major
367    for row in &sketch.counters {
368        buf.extend_from_slice(row);
369    }
370
371    buf
372}
373
374/// Decode a SKETCH_SEG binary payload into a CountMinSketch.
375///
376/// Returns an error (never panics) on malformed input: short headers,
377/// counter data shorter than `width * depth`, a zero `width` paired with a
378/// non-zero `depth` (which would bypass the length check while driving an
379/// unbounded row allocation), or `width * depth` overflow.
380pub fn decode_sketch_seg(data: &[u8]) -> Result<CountMinSketch, CodecError> {
381    if data.len() < 64 {
382        return Err(CodecError::TooShort);
383    }
384
385    let width = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
386    let depth = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
387    let total_accesses = u64::from_le_bytes([
388        data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
389    ]);
390
391    let body = &data[64..];
392
393    // Every row must consume at least one byte; otherwise a crafted
394    // depth (up to u32::MAX) passes the `expected == 0` length check and
395    // OOMs in `Vec::with_capacity` below.
396    if width == 0 && depth != 0 {
397        return Err(CodecError::InvalidField);
398    }
399    // Checked u64 arithmetic: `width * depth` can wrap usize on 32-bit
400    // targets (wasm32) and slip past the length check.
401    let expected = (width as u64)
402        .checked_mul(depth as u64)
403        .ok_or(CodecError::InvalidField)?;
404    if (body.len() as u64) < expected {
405        return Err(CodecError::TooShort);
406    }
407
408    // Safe: width >= 1 here, so depth <= expected <= body.len().
409    let mut counters = Vec::with_capacity(depth);
410    for row in 0..depth {
411        let start = row * width;
412        counters.push(body[start..start + width].to_vec());
413    }
414
415    Ok(CountMinSketch {
416        counters,
417        width,
418        depth,
419        total_accesses,
420    })
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn scalar_quant_seg_round_trip() {
429        let sq = ScalarQuantizer {
430            min_vals: vec![-1.0, -2.0, -0.5, 0.0],
431            max_vals: vec![1.0, 2.0, 0.5, 1.0],
432            dim: 4,
433        };
434
435        let encoded = encode_scalar_quantizer(&sq);
436        let decoded = decode_quant_seg(&encoded).unwrap();
437
438        assert_eq!(decoded.dim(), 4);
439        assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Hot);
440
441        // Verify round-trip: encode a test vector, check similar output
442        let test_vec = vec![0.5, 1.0, 0.0, 0.5];
443        let codes_orig = sq.encode_vec(&test_vec);
444        let codes_decoded = decoded.encode(&test_vec);
445        assert_eq!(codes_orig, codes_decoded);
446    }
447
448    #[test]
449    fn product_quant_seg_round_trip() {
450        // Build a small PQ manually
451        let pq = ProductQuantizer {
452            m: 2,
453            k: 4,
454            sub_dim: 2,
455            codebooks: vec![
456                vec![
457                    vec![0.0, 0.1],
458                    vec![0.2, 0.3],
459                    vec![0.4, 0.5],
460                    vec![0.6, 0.7],
461                ],
462                vec![
463                    vec![0.8, 0.9],
464                    vec![1.0, 1.1],
465                    vec![1.2, 1.3],
466                    vec![1.4, 1.5],
467                ],
468            ],
469        };
470
471        let encoded = encode_product_quantizer(&pq);
472        let decoded = decode_quant_seg(&encoded).unwrap();
473
474        assert_eq!(decoded.dim(), 4);
475        assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Warm);
476
477        let test_vec = vec![0.1, 0.2, 0.9, 1.0];
478        let codes_orig = pq.encode_vec(&test_vec);
479        let codes_decoded = decoded.encode(&test_vec);
480        assert_eq!(codes_orig, codes_decoded);
481    }
482
483    #[test]
484    fn binary_quant_seg_round_trip() {
485        let dim: u16 = 16;
486        let encoded = encode_binary_quant_seg(dim);
487        let decoded = decode_quant_seg(&encoded).unwrap();
488
489        assert_eq!(decoded.dim(), 16);
490        assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Cold);
491
492        let test_vec: Vec<f32> = (0..16)
493            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
494            .collect();
495        let codes = decoded.encode(&test_vec);
496        let recon = decoded.decode(&codes);
497        assert_eq!(recon.len(), 16);
498    }
499
500    #[test]
501    fn encode_quant_seg_scalar_round_trip() {
502        let sq = ScalarQuantizer {
503            min_vals: vec![-1.0, -2.0, -0.5, 0.0],
504            max_vals: vec![1.0, 2.0, 0.5, 1.0],
505            dim: 4,
506        };
507
508        let encoded = encode_quant_seg(&sq);
509        let decoded = decode_quant_seg(&encoded).unwrap();
510
511        let any: &dyn core::any::Any = decoded.as_ref();
512        let dec_sq = any
513            .downcast_ref::<ScalarQuantizer>()
514            .expect("expected ScalarQuantizer");
515        assert_eq!(dec_sq.min_vals, sq.min_vals);
516        assert_eq!(dec_sq.max_vals, sq.max_vals);
517        assert_eq!(dec_sq.dim, sq.dim);
518    }
519
520    #[test]
521    fn encode_quant_seg_product_round_trip() {
522        let pq = ProductQuantizer {
523            m: 2,
524            k: 2,
525            sub_dim: 2,
526            codebooks: vec![
527                vec![vec![0.0, 0.1], vec![0.2, 0.3]],
528                vec![vec![0.8, 0.9], vec![1.0, 1.1]],
529            ],
530        };
531
532        let encoded = encode_quant_seg(&pq);
533        let decoded = decode_quant_seg(&encoded).unwrap();
534
535        let any: &dyn core::any::Any = decoded.as_ref();
536        let dec_pq = any
537            .downcast_ref::<ProductQuantizer>()
538            .expect("expected ProductQuantizer");
539        assert_eq!(dec_pq.m, pq.m);
540        assert_eq!(dec_pq.k, pq.k);
541        assert_eq!(dec_pq.sub_dim, pq.sub_dim);
542        assert_eq!(dec_pq.codebooks, pq.codebooks);
543    }
544
545    #[test]
546    fn encode_quant_seg_binary_round_trip() {
547        let bq = BinaryQuantizerWrapper { dim: 16 };
548        let encoded = encode_quant_seg(&bq);
549        let decoded = decode_quant_seg(&encoded).unwrap();
550
551        assert_eq!(decoded.dim(), 16);
552        assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Cold);
553    }
554
555    #[test]
556    fn decode_quant_seg_malformed_inputs() {
557        // Header too short.
558        assert!(matches!(
559            decode_quant_seg(&[0u8; 8]),
560            Err(CodecError::TooShort)
561        ));
562
563        // Unknown quant_type tag.
564        let mut bad_type = vec![0u8; 64];
565        bad_type[0] = 9;
566        assert!(matches!(
567            decode_quant_seg(&bad_type),
568            Err(CodecError::UnknownQuantType(9))
569        ));
570
571        // Scalar header claims dim 4 but carries no min/max body.
572        let mut truncated = vec![0u8; 64];
573        truncated[0] = 0; // scalar
574        truncated[2..4].copy_from_slice(&4u16.to_le_bytes());
575        assert!(matches!(
576            decode_quant_seg(&truncated),
577            Err(CodecError::TooShort)
578        ));
579
580        // Product header present but codebook data missing.
581        let mut pq_truncated = vec![0u8; 64];
582        pq_truncated[0] = 1; // product
583        pq_truncated[2..4].copy_from_slice(&4u16.to_le_bytes());
584        pq_truncated.extend_from_slice(&2u16.to_le_bytes()); // m
585        pq_truncated.extend_from_slice(&4u16.to_le_bytes()); // k
586        pq_truncated.extend_from_slice(&2u16.to_le_bytes()); // sub_dim
587        assert!(matches!(
588            decode_quant_seg(&pq_truncated),
589            Err(CodecError::TooShort)
590        ));
591    }
592
593    #[test]
594    fn rabitq_quant_seg_round_trip() {
595        let centroid: Vec<f32> = (0..20).map(|i| i as f32 * 0.1 - 1.0).collect();
596        let rq = RabitqQuantizer::with_centroid(20, centroid.clone(), 0x1234_5678_9ABC_DEF0, 3);
597
598        let encoded = encode_rabitq_quantizer(&rq);
599        let decoded = decode_quant_seg(&encoded).unwrap();
600        assert_eq!(decoded.dim(), 20);
601        assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Cold);
602
603        let any: &dyn core::any::Any = decoded.as_ref();
604        let dec = any
605            .downcast_ref::<RabitqQuantizer>()
606            .expect("expected RabitqQuantizer");
607        assert_eq!(dec.dim, rq.dim);
608        assert_eq!(dec.padded_dim, 32);
609        assert_eq!(dec.seed, rq.seed);
610        assert_eq!(dec.rounds, rq.rounds);
611        assert_eq!(dec.centroid, centroid);
612
613        // The decoded quantizer must produce byte-identical codes.
614        let v: Vec<f32> = (0..20).map(|i| (i as f32 * 0.7).sin()).collect();
615        assert_eq!(dec.encode(&v), rq.encode(&v));
616
617        // Trait-based encode dispatches to the RaBitQ layout too.
618        assert_eq!(encode_quant_seg(&rq), encoded);
619    }
620
621    #[test]
622    fn rabitq_quant_seg_rejects_bad_versions_and_fields() {
623        let rq = RabitqQuantizer::with_centroid(8, vec![0.0; 8], 7, 3);
624        let good = encode_rabitq_quantizer(&rq);
625
626        // Future layout version: reject instead of misreading.
627        let mut future = good.clone();
628        future[4] = RABITQ_VERSION + 1;
629        assert!(matches!(
630            decode_quant_seg(&future),
631            Err(CodecError::UnsupportedVersion(v)) if v == RABITQ_VERSION + 1
632        ));
633
634        // Inconsistent padded_dim.
635        let mut bad_pad = good.clone();
636        bad_pad[16..20].copy_from_slice(&7u32.to_le_bytes());
637        assert!(matches!(
638            decode_quant_seg(&bad_pad),
639            Err(CodecError::InvalidField)
640        ));
641
642        // Truncated centroid body.
643        assert!(matches!(
644            decode_quant_seg(&good[..good.len() - 4]),
645            Err(CodecError::TooShort)
646        ));
647
648        // Zero rounds.
649        let mut zero_rounds = good.clone();
650        zero_rounds[5] = 0;
651        assert!(matches!(
652            decode_quant_seg(&zero_rounds),
653            Err(CodecError::InvalidField)
654        ));
655    }
656
657    #[test]
658    fn pre_rabitq_payloads_still_decode() {
659        // A byte-frozen legacy binary-quantizer payload (type 2, header
660        // bytes 4..64 all zero, no body) must keep decoding after the
661        // RaBitQ extension claimed header bytes 4..20 for type 4.
662        let mut legacy = vec![0u8; 64];
663        legacy[0] = 2; // QUANT_TYPE_BINARY
664        legacy[1] = 2; // Cold tier
665        legacy[2..4].copy_from_slice(&24u16.to_le_bytes());
666        let decoded = decode_quant_seg(&legacy).unwrap();
667        assert_eq!(decoded.dim(), 24);
668        assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Cold);
669
670        // Same for a legacy scalar payload.
671        let sq = ScalarQuantizer {
672            min_vals: vec![-1.0, 0.0],
673            max_vals: vec![1.0, 2.0],
674            dim: 2,
675        };
676        let legacy_scalar = encode_scalar_quantizer(&sq);
677        assert!(decode_quant_seg(&legacy_scalar).is_ok());
678    }
679
680    #[test]
681    fn decode_product_rejects_huge_codebook_dimensions() {
682        // m = k = sub_dim = u16::MAX -> codebook of ~1.1e15 bytes. The
683        // u64 checked size computation must reject this against the
684        // actual body length instead of wrapping usize on 32-bit targets
685        // (wasm32) and reading out of bounds.
686        let mut pq = vec![0u8; 64];
687        pq[0] = QUANT_TYPE_PRODUCT;
688        pq[2..4].copy_from_slice(&4u16.to_le_bytes());
689        pq.extend_from_slice(&u16::MAX.to_le_bytes()); // m
690        pq.extend_from_slice(&u16::MAX.to_le_bytes()); // k
691        pq.extend_from_slice(&u16::MAX.to_le_bytes()); // sub_dim
692        assert!(matches!(decode_quant_seg(&pq), Err(CodecError::TooShort)));
693    }
694
695    #[test]
696    fn decode_sketch_seg_rejects_malformed_inputs() {
697        // Header too short: error, not panic.
698        assert!(matches!(decode_sketch_seg(&[]), Err(CodecError::TooShort)));
699        assert!(matches!(
700            decode_sketch_seg(&[0u8; 16]),
701            Err(CodecError::TooShort)
702        ));
703
704        // width = 0 + depth = u32::MAX: expected counter bytes are 0, so
705        // the length check alone passes; the zero-width guard must reject
706        // it before the depth-sized allocation OOMs.
707        let mut zero_width = vec![0u8; 64];
708        zero_width[4..8].copy_from_slice(&u32::MAX.to_le_bytes());
709        assert!(matches!(
710            decode_sketch_seg(&zero_width),
711            Err(CodecError::InvalidField)
712        ));
713
714        // width = depth = u32::MAX: product (~1.8e19) wraps a 32-bit
715        // usize; the checked u64 arithmetic must reject it against the
716        // body length.
717        let mut huge = vec![0u8; 64];
718        huge[0..4].copy_from_slice(&u32::MAX.to_le_bytes());
719        huge[4..8].copy_from_slice(&u32::MAX.to_le_bytes());
720        assert!(matches!(
721            decode_sketch_seg(&huge),
722            Err(CodecError::TooShort)
723        ));
724
725        // Counter data shorter than width * depth.
726        let mut truncated = vec![0u8; 64 + 10];
727        truncated[0..4].copy_from_slice(&8u32.to_le_bytes()); // width
728        truncated[4..8].copy_from_slice(&4u32.to_le_bytes()); // depth -> needs 32
729        assert!(matches!(
730            decode_sketch_seg(&truncated),
731            Err(CodecError::TooShort)
732        ));
733
734        // Degenerate-but-consistent empty sketch (width = depth = 0)
735        // still decodes.
736        let empty = decode_sketch_seg(&[0u8; 64]).expect("empty sketch decodes");
737        assert_eq!(empty.width, 0);
738        assert_eq!(empty.depth, 0);
739        assert!(empty.counters.is_empty());
740    }
741
742    #[test]
743    fn sketch_seg_round_trip() {
744        let mut sketch = CountMinSketch::new(64, 4);
745        for block_id in 0..20u64 {
746            for _ in 0..(block_id + 1) {
747                sketch.increment(block_id);
748            }
749        }
750
751        let encoded = encode_sketch_seg(&sketch);
752        let decoded = decode_sketch_seg(&encoded).expect("well-formed sketch should decode");
753
754        assert_eq!(decoded.width, sketch.width);
755        assert_eq!(decoded.depth, sketch.depth);
756        assert_eq!(decoded.total_accesses, sketch.total_accesses);
757
758        // Verify estimates match
759        for block_id in 0..20u64 {
760            assert_eq!(decoded.estimate(block_id), sketch.estimate(block_id));
761        }
762    }
763}