Skip to main content

irithyll_core/
turbo_quant.rs

1//! TurboQuant multi-mode weight quantization with randomized Hadamard rotation.
2//!
3//! Compresses weight vectors using one of three quantization modes:
4//!
5//! | Mode | Levels | Packing | Compression vs f64 | Typical error |
6//! |------|--------|---------|--------------------|---------------|
7//! | 8-bit | 256 | 4 per u32 | ~8x | ~0.4% |
8//! | 3.5-bit | 11 | 7 per u32 (base-11) | ~14x | ~10% |
9//! | 2.5-bit | 5 | 13 per u32 (base-5) | ~21x | ~20% |
10//!
11//! # Design
12//!
13//! - **Data-oblivious**: No calibration set required -- randomized rotation + min/max scaling
14//! - **Online-compatible**: Quantize once after training, inference is pure integer
15//! - **Embedded-friendly**: [`TurboQuantizedView`] is zero-copy from `&[u8]`
16//! - **Zero-alloc predict**: [`predict_with_scratch`](TurboQuantized::predict_with_scratch)
17//!   avoids allocation when given a caller-provided scratch buffer
18//! - **Multi-type input**: [`quantize_f32`] and [`quantize_i16`] accept non-f64 weights
19//! - **Hadamard rotation**: Applies `H * D * w` before quantization where `D` is a
20//!   random sign-flip diagonal and `H` is the normalized Walsh-Hadamard matrix.
21//!   This decorrelates weight dimensions so quantization error distributes uniformly.
22//!
23//! # Packing
24//!
25//! ```text
26//! 8-bit:   4 values x 256 levels = byte packing, 4 per u32 (shift encoding)
27//! 3.5-bit: 7 values x  11 levels = 11^7 = 19,487,171 states <= 2^25 (base-11)
28//! 2.5-bit: 13 values x  5 levels = 5^13 = 1,220,703,125 states <= 2^31 (base-5)
29//! ```
30//!
31//! # References
32//!
33//! Inspired by data-oblivious quantization (Google/NYU, ICLR 2026).
34
35use alloc::vec;
36use alloc::vec::Vec;
37
38// ---------------------------------------------------------------------------
39// QuantMode
40// ---------------------------------------------------------------------------
41
42/// Quantization bit depth. Controls the quality/compression tradeoff.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44#[non_exhaustive]
45pub enum QuantMode {
46    /// 8-bit: 256 levels, 4 values per u32. Near-lossless (~0.4% max error).
47    /// ~8x compression vs f64. Simple byte packing.
48    Bits8,
49    /// 3.5-bit: 11 levels, 7 values per u32. Aggressive (~10% max error).
50    /// ~14x compression vs f64. Base-11 mixed-radix packing.
51    Bits3_5,
52    /// 2.5-bit: 5 levels, 13 values per u32. Ultra-aggressive (~20% max error).
53    /// ~21x compression vs f64. Base-5 mixed-radix packing.
54    Bits2_5,
55}
56
57impl QuantMode {
58    /// Number of quantization levels for this mode.
59    #[inline]
60    fn n_levels(self) -> u32 {
61        match self {
62            QuantMode::Bits8 => N_LEVELS_8,
63            QuantMode::Bits3_5 => N_LEVELS_3_5,
64            QuantMode::Bits2_5 => N_LEVELS_2_5,
65        }
66    }
67
68    /// Number of values packed per u32 word for this mode.
69    #[inline]
70    fn values_per_word(self) -> usize {
71        match self {
72            QuantMode::Bits8 => VALUES_PER_WORD_8,
73            QuantMode::Bits3_5 => VALUES_PER_WORD_3_5,
74            QuantMode::Bits2_5 => VALUES_PER_WORD_2_5,
75        }
76    }
77
78    /// Encode mode as u32 for serialization.
79    #[inline]
80    fn to_u32(self) -> u32 {
81        match self {
82            QuantMode::Bits8 => 0,
83            QuantMode::Bits3_5 => 1,
84            QuantMode::Bits2_5 => 2,
85        }
86    }
87
88    /// Decode mode from u32. Returns `None` for unknown values.
89    #[inline]
90    fn from_u32(v: u32) -> Option<Self> {
91        match v {
92            0 => Some(QuantMode::Bits8),
93            1 => Some(QuantMode::Bits3_5),
94            2 => Some(QuantMode::Bits2_5),
95            _ => None,
96        }
97    }
98}
99
100// ---------------------------------------------------------------------------
101// Constants
102// ---------------------------------------------------------------------------
103
104/// 8-bit: 256 levels (0..=255).
105const N_LEVELS_8: u32 = 256;
106/// 8-bit: 4 values per u32.
107const VALUES_PER_WORD_8: usize = 4;
108
109/// 3.5-bit: 11 levels (0..=10).
110const N_LEVELS_3_5: u32 = 11;
111/// 3.5-bit: 7 values per u32.
112const VALUES_PER_WORD_3_5: usize = 7;
113
114/// 2.5-bit: 5 levels (0..=4).
115const N_LEVELS_2_5: u32 = 5;
116/// 2.5-bit: 13 values per u32.
117const VALUES_PER_WORD_2_5: usize = 13;
118
119// ---------------------------------------------------------------------------
120// Packing: 8-bit (4 values per u32, byte shift encoding)
121// ---------------------------------------------------------------------------
122
123/// Pack up to 4 quantized u8 values into a single `u32` using byte shift encoding.
124#[inline]
125fn pack4_bytes(values: &[u8]) -> u32 {
126    let mut packed: u32 = 0;
127    for (i, &v) in values.iter().enumerate().take(4) {
128        packed |= (v as u32) << (i * 8);
129    }
130    packed
131}
132
133/// Unpack a `u32` into up to 4 quantized u8 values.
134#[inline]
135fn unpack4_bytes(packed: u32, count: usize) -> [u8; 4] {
136    let mut values = [0u8; 4];
137    for (i, v) in values.iter_mut().enumerate().take(count) {
138        *v = ((packed >> (i * 8)) & 0xFF) as u8;
139    }
140    values
141}
142
143// ---------------------------------------------------------------------------
144// Packing: 3.5-bit (7 values per u32, base-11 mixed-radix)
145// ---------------------------------------------------------------------------
146
147/// Pack up to 7 quantized values (each in 0..=10) into a single `u32`.
148///
149/// Uses base-11 mixed-radix encoding: `v0 + 11*v1 + 11^2*v2 + ... + 11^6*v6`.
150/// The maximum packed value is `11^7 - 1 = 19,487,170`, which fits in 25 bits.
151///
152/// `values` must have length <= 7, and each element must be in `0..=10`.
153#[inline]
154fn pack7(values: &[u8]) -> u32 {
155    debug_assert!(values.len() <= 7);
156    let mut packed: u32 = 0;
157    for &v in values.iter().rev() {
158        debug_assert!(v < N_LEVELS_3_5 as u8);
159        packed = packed * N_LEVELS_3_5 + v as u32;
160    }
161    packed
162}
163
164/// Unpack a `u32` into up to 7 quantized values.
165///
166/// Extracts `count` values from the base-11 mixed-radix encoding.
167/// Remaining slots in the returned array are zero-filled.
168#[inline]
169fn unpack7(packed: u32, count: usize) -> [u8; 7] {
170    let mut values = [0u8; 7];
171    let mut p = packed;
172    for v in values.iter_mut().take(count) {
173        *v = (p % N_LEVELS_3_5) as u8;
174        p /= N_LEVELS_3_5;
175    }
176    values
177}
178
179// ---------------------------------------------------------------------------
180// Packing: 2.5-bit (13 values per u32, base-5 mixed-radix)
181// ---------------------------------------------------------------------------
182
183/// Pack up to 13 quantized values (each in 0..=4) into a single `u32`.
184///
185/// Uses base-5 mixed-radix encoding. The maximum packed value is
186/// `5^13 - 1 = 1,220,703,124`, which fits in 31 bits.
187#[inline]
188fn pack13(values: &[u8]) -> u32 {
189    debug_assert!(values.len() <= 13);
190    let mut packed: u32 = 0;
191    for &v in values.iter().rev() {
192        debug_assert!(v < N_LEVELS_2_5 as u8);
193        packed = packed * N_LEVELS_2_5 + v as u32;
194    }
195    packed
196}
197
198/// Unpack a `u32` into up to 13 quantized values.
199///
200/// Extracts `count` values from the base-5 mixed-radix encoding.
201/// Remaining slots in the returned array are zero-filled.
202#[inline]
203fn unpack13(packed: u32, count: usize) -> [u8; 13] {
204    let mut values = [0u8; 13];
205    let mut p = packed;
206    for v in values.iter_mut().take(count) {
207        *v = (p % N_LEVELS_2_5) as u8;
208        p /= N_LEVELS_2_5;
209    }
210    values
211}
212
213// ---------------------------------------------------------------------------
214// Generic mode-dispatched packing
215// ---------------------------------------------------------------------------
216
217/// Pack a slice of quantized values into a u32 word, dispatching on mode.
218#[inline]
219fn pack_word(values: &[u8], mode: QuantMode) -> u32 {
220    match mode {
221        QuantMode::Bits8 => pack4_bytes(values),
222        QuantMode::Bits3_5 => pack7(values),
223        QuantMode::Bits2_5 => pack13(values),
224    }
225}
226
227// ---------------------------------------------------------------------------
228// TurboQuantized (owned)
229// ---------------------------------------------------------------------------
230
231/// Quantized weight vector (owned).
232///
233/// Created by [`quantize`], [`quantize_weights`], [`quantize_f32`], or
234/// [`quantize_i16`]. Supports inference via [`predict`](Self::predict)
235/// and serialization via [`to_bytes`](Self::to_bytes).
236///
237/// Weights are stored in Hadamard-rotated space (`H * D * w`). During prediction,
238/// the same rotation is applied to features so the dot product is preserved.
239pub struct TurboQuantized {
240    /// Packed u32 words. Format depends on `mode`:
241    /// - Bits8: 4 byte-shift-encoded u8 values per u32
242    /// - Bits3_5: 7 base-11 mixed-radix values per u32
243    /// - Bits2_5: 13 base-5 mixed-radix values per u32
244    packed: Vec<u32>,
245    /// Number of original weights (last word may be partially filled).
246    n_weights: usize,
247    /// Scale factor for dequantization. Zero means all weights are identical.
248    scale: f64,
249    /// Offset: minimum weight value (in rotated space).
250    offset: f64,
251    /// Seed for the random sign-flip diagonal (needed to reproduce rotation).
252    seed: u64,
253    /// Power-of-2 padded length used for FWHT.
254    padded_len: usize,
255    /// Quantization mode.
256    mode: QuantMode,
257}
258
259impl TurboQuantized {
260    /// Dot product of quantized weights with a feature vector.
261    ///
262    /// Applies the same Hadamard rotation to `features`, then computes
263    /// the dot product with the quantized rotated weights over the full
264    /// padded length. Since `HD` is orthogonal, `w . x == (HD*w) . (HD*x)`.
265    pub fn predict(&self, features: &[f64]) -> f64 {
266        if self.n_weights == 0 {
267            return 0.0;
268        }
269        // Rotate features with the same transform (pad to padded_len)
270        let mut rotated_features = Vec::with_capacity(self.padded_len);
271        let use_len = self.n_weights.min(features.len());
272        rotated_features.extend_from_slice(&features[..use_len]);
273        rotated_features.resize(self.padded_len, 0.0);
274        apply_rotation(&mut rotated_features, self.seed);
275
276        self.dot_with_rotated(&rotated_features)
277    }
278
279    /// Predict using a caller-provided scratch buffer for the Hadamard rotation.
280    ///
281    /// `scratch` must have length >= `padded_len`. This avoids allocation,
282    /// making it suitable for embedded inference loops.
283    pub fn predict_with_scratch(&self, features: &[f64], scratch: &mut [f64]) -> f64 {
284        if self.n_weights == 0 {
285            return 0.0;
286        }
287        assert!(
288            scratch.len() >= self.padded_len,
289            "scratch buffer too small: {} < {}",
290            scratch.len(),
291            self.padded_len
292        );
293
294        // Zero and fill scratch
295        for v in scratch[..self.padded_len].iter_mut() {
296            *v = 0.0;
297        }
298        let use_len = self.n_weights.min(features.len());
299        scratch[..use_len].copy_from_slice(&features[..use_len]);
300
301        // Rotate in-place
302        apply_rotation(&mut scratch[..self.padded_len], self.seed);
303
304        self.dot_with_rotated(&scratch[..self.padded_len])
305    }
306
307    /// Compute dot product of packed weights with already-rotated features.
308    fn dot_with_rotated(&self, rotated_features: &[f64]) -> f64 {
309        let mut sum = 0.0;
310        let mut feat_idx = 0;
311        let vpw = self.mode.values_per_word();
312
313        for &word in self.packed.iter() {
314            let remaining = self.padded_len - feat_idx;
315            let count = remaining.min(vpw);
316            // Inline unpack + dot to avoid temporary array overhead
317            match self.mode {
318                QuantMode::Bits8 => {
319                    let values = unpack4_bytes(word, count);
320                    for &q in values.iter().take(count) {
321                        let w = q as f64 * self.scale + self.offset;
322                        sum += w * rotated_features[feat_idx];
323                        feat_idx += 1;
324                    }
325                }
326                QuantMode::Bits3_5 => {
327                    let values = unpack7(word, count);
328                    for &q in values.iter().take(count) {
329                        let w = q as f64 * self.scale + self.offset;
330                        sum += w * rotated_features[feat_idx];
331                        feat_idx += 1;
332                    }
333                }
334                QuantMode::Bits2_5 => {
335                    let values = unpack13(word, count);
336                    for &q in values.iter().take(count) {
337                        let w = q as f64 * self.scale + self.offset;
338                        sum += w * rotated_features[feat_idx];
339                        feat_idx += 1;
340                    }
341                }
342            }
343            if feat_idx >= self.padded_len {
344                break;
345            }
346        }
347        sum
348    }
349
350    /// Dequantize all weights back to `f64` (approximate original space).
351    ///
352    /// Unpacks all `padded_len` rotated values, then applies the inverse
353    /// Hadamard rotation to recover approximate original weights.
354    pub fn dequantize(&self) -> Vec<f64> {
355        let mut rotated = Vec::with_capacity(self.padded_len);
356        let mut count_total = 0;
357        let vpw = self.mode.values_per_word();
358
359        for &word in self.packed.iter() {
360            let remaining = self.padded_len - count_total;
361            let count = remaining.min(vpw);
362            match self.mode {
363                QuantMode::Bits8 => {
364                    let values = unpack4_bytes(word, count);
365                    for &q in values.iter().take(count) {
366                        rotated.push(q as f64 * self.scale + self.offset);
367                        count_total += 1;
368                    }
369                }
370                QuantMode::Bits3_5 => {
371                    let values = unpack7(word, count);
372                    for &q in values.iter().take(count) {
373                        rotated.push(q as f64 * self.scale + self.offset);
374                        count_total += 1;
375                    }
376                }
377                QuantMode::Bits2_5 => {
378                    let values = unpack13(word, count);
379                    for &q in values.iter().take(count) {
380                        rotated.push(q as f64 * self.scale + self.offset);
381                        count_total += 1;
382                    }
383                }
384            }
385            if count_total >= self.padded_len {
386                break;
387            }
388        }
389        // Apply inverse rotation to recover original space
390        apply_inverse_rotation(&mut rotated, self.seed);
391        rotated.truncate(self.n_weights);
392        rotated
393    }
394
395    /// Number of quantized weights.
396    pub fn n_weights(&self) -> usize {
397        self.n_weights
398    }
399
400    /// Power-of-2 padded length used for FWHT (needed for scratch allocation).
401    pub fn padded_len(&self) -> usize {
402        self.padded_len
403    }
404
405    /// Quantization mode used.
406    pub fn mode(&self) -> QuantMode {
407        self.mode
408    }
409
410    /// Compression ratio vs `f64` (original bytes / packed bytes).
411    pub fn compression_ratio(&self) -> f64 {
412        let original_bytes = self.n_weights * 8; // f64
413        let packed_bytes = self.packed.len() * 4 + HEADER_SIZE;
414        original_bytes as f64 / packed_bytes as f64
415    }
416
417    /// Serialize to bytes for embedded deployment.
418    ///
419    /// Format (36-byte header):
420    /// ```text
421    /// [n_weights: u32 LE]
422    /// [mode: u32 LE]        // 0=Bits8, 1=Bits3_5, 2=Bits2_5
423    /// [seed: u64 LE]
424    /// [padded_len: u32 LE]
425    /// [scale: f64 LE]
426    /// [offset: f64 LE]
427    /// [packed_words: u32 LE...]
428    /// ```
429    pub fn to_bytes(&self) -> Vec<u8> {
430        let mut buf = Vec::with_capacity(HEADER_SIZE + self.packed.len() * 4);
431        buf.extend_from_slice(&(self.n_weights as u32).to_le_bytes());
432        buf.extend_from_slice(&self.mode.to_u32().to_le_bytes());
433        buf.extend_from_slice(&self.seed.to_le_bytes());
434        buf.extend_from_slice(&(self.padded_len as u32).to_le_bytes());
435        buf.extend_from_slice(&self.scale.to_le_bytes());
436        buf.extend_from_slice(&self.offset.to_le_bytes());
437        for &word in &self.packed {
438            buf.extend_from_slice(&word.to_le_bytes());
439        }
440        buf
441    }
442}
443
444// ---------------------------------------------------------------------------
445// TurboQuantizedView (zero-copy)
446// ---------------------------------------------------------------------------
447
448/// Zero-copy view over a TurboQuant packed binary.
449///
450/// Constructed from `&[u8]` with no allocation -- suitable for embedded
451/// targets where the binary is in flash/ROM. Note: `predict` does allocate
452/// for the Hadamard rotation of the feature vector; use
453/// [`predict_with_scratch`](Self::predict_with_scratch) for zero-alloc inference.
454pub struct TurboQuantizedView<'a> {
455    /// Raw bytes of packed u32 words.
456    packed: &'a [u8],
457    /// Number of original weights.
458    n_weights: usize,
459    /// Seed for the random sign-flip diagonal.
460    seed: u64,
461    /// Power-of-2 padded length used for FWHT.
462    padded_len: usize,
463    /// Scale factor for dequantization.
464    scale: f64,
465    /// Offset (minimum weight value in rotated space) for dequantization.
466    offset: f64,
467    /// Quantization mode.
468    mode: QuantMode,
469}
470
471/// Header size in bytes: n_weights(4) + mode(4) + seed(8) + padded_len(4) + scale(8) + offset(8) = 36.
472const HEADER_SIZE: usize = 36;
473
474impl<'a> TurboQuantizedView<'a> {
475    /// Parse a TurboQuant binary from raw bytes.
476    ///
477    /// Returns [`FormatError::Truncated`](crate::error::FormatError::Truncated)
478    /// if the buffer is too short, has inconsistent length, or contains an
479    /// unknown quantization mode.
480    pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, crate::error::FormatError> {
481        if bytes.len() < HEADER_SIZE {
482            return Err(crate::error::FormatError::Truncated);
483        }
484        let n_weights = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
485        let mode_raw = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
486        let mode = QuantMode::from_u32(mode_raw).ok_or(crate::error::FormatError::Truncated)?;
487        let seed = u64::from_le_bytes([
488            bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
489        ]);
490        let padded_len = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]) as usize;
491        let scale = f64::from_le_bytes([
492            bytes[20], bytes[21], bytes[22], bytes[23], bytes[24], bytes[25], bytes[26], bytes[27],
493        ]);
494        let offset = f64::from_le_bytes([
495            bytes[28], bytes[29], bytes[30], bytes[31], bytes[32], bytes[33], bytes[34], bytes[35],
496        ]);
497
498        // Packed data holds padded_len values
499        let vpw = mode.values_per_word();
500        let n_words = padded_len.div_ceil(vpw);
501        let expected_len = HEADER_SIZE + n_words * 4;
502        if bytes.len() < expected_len {
503            return Err(crate::error::FormatError::Truncated);
504        }
505
506        Ok(Self {
507            packed: &bytes[HEADER_SIZE..HEADER_SIZE + n_words * 4],
508            n_weights,
509            seed,
510            padded_len,
511            scale,
512            offset,
513            mode,
514        })
515    }
516
517    /// Dot product of quantized weights with a feature vector.
518    ///
519    /// Applies the same Hadamard rotation to `features`, then computes
520    /// the dot product with the quantized rotated weights over the full
521    /// padded length.
522    pub fn predict(&self, features: &[f64]) -> f64 {
523        if self.n_weights == 0 {
524            return 0.0;
525        }
526        // Rotate features with the same transform (pad to padded_len)
527        let mut rotated_features = Vec::with_capacity(self.padded_len);
528        let use_len = self.n_weights.min(features.len());
529        rotated_features.extend_from_slice(&features[..use_len]);
530        rotated_features.resize(self.padded_len, 0.0);
531        apply_rotation(&mut rotated_features, self.seed);
532
533        self.dot_with_rotated(&rotated_features)
534    }
535
536    /// Predict using a caller-provided scratch buffer for the Hadamard rotation.
537    ///
538    /// `scratch` must have length >= `padded_len`. This avoids allocation,
539    /// making it suitable for embedded inference loops.
540    pub fn predict_with_scratch(&self, features: &[f64], scratch: &mut [f64]) -> f64 {
541        if self.n_weights == 0 {
542            return 0.0;
543        }
544        assert!(
545            scratch.len() >= self.padded_len,
546            "scratch buffer too small: {} < {}",
547            scratch.len(),
548            self.padded_len
549        );
550
551        for v in scratch[..self.padded_len].iter_mut() {
552            *v = 0.0;
553        }
554        let use_len = self.n_weights.min(features.len());
555        scratch[..use_len].copy_from_slice(&features[..use_len]);
556        apply_rotation(&mut scratch[..self.padded_len], self.seed);
557
558        self.dot_with_rotated(&scratch[..self.padded_len])
559    }
560
561    /// Compute dot product of packed weights with already-rotated features.
562    fn dot_with_rotated(&self, rotated_features: &[f64]) -> f64 {
563        let mut sum = 0.0;
564        let mut feat_idx = 0;
565        let vpw = self.mode.values_per_word();
566        let n_words = self.packed.len() / 4;
567
568        for word_idx in 0..n_words {
569            let off = word_idx * 4;
570            let word = u32::from_le_bytes([
571                self.packed[off],
572                self.packed[off + 1],
573                self.packed[off + 2],
574                self.packed[off + 3],
575            ]);
576            let remaining = self.padded_len - feat_idx;
577            let count = remaining.min(vpw);
578            match self.mode {
579                QuantMode::Bits8 => {
580                    let values = unpack4_bytes(word, count);
581                    for &q in values.iter().take(count) {
582                        let w = q as f64 * self.scale + self.offset;
583                        sum += w * rotated_features[feat_idx];
584                        feat_idx += 1;
585                    }
586                }
587                QuantMode::Bits3_5 => {
588                    let values = unpack7(word, count);
589                    for &q in values.iter().take(count) {
590                        let w = q as f64 * self.scale + self.offset;
591                        sum += w * rotated_features[feat_idx];
592                        feat_idx += 1;
593                    }
594                }
595                QuantMode::Bits2_5 => {
596                    let values = unpack13(word, count);
597                    for &q in values.iter().take(count) {
598                        let w = q as f64 * self.scale + self.offset;
599                        sum += w * rotated_features[feat_idx];
600                        feat_idx += 1;
601                    }
602                }
603            }
604            if feat_idx >= self.padded_len {
605                break;
606            }
607        }
608        sum
609    }
610
611    /// Number of weights in this view.
612    pub fn n_weights(&self) -> usize {
613        self.n_weights
614    }
615
616    /// Power-of-2 padded length used for FWHT.
617    pub fn padded_len(&self) -> usize {
618        self.padded_len
619    }
620
621    /// Quantization mode.
622    pub fn mode(&self) -> QuantMode {
623        self.mode
624    }
625}
626
627// ---------------------------------------------------------------------------
628// Hadamard rotation internals
629// ---------------------------------------------------------------------------
630
631/// Default deterministic seed for Hadamard rotation.
632const DEFAULT_SEED: u64 = 0xDEAD_BEEF;
633
634/// Smallest power of 2 >= `n`. Returns 1 for `n == 0`.
635#[inline]
636fn next_power_of_two(n: usize) -> usize {
637    if n <= 1 {
638        return 1;
639    }
640    // Bit trick: round up to next power of 2
641    let mut v = n - 1;
642    v |= v >> 1;
643    v |= v >> 2;
644    v |= v >> 4;
645    v |= v >> 8;
646    v |= v >> 16;
647    #[cfg(target_pointer_width = "64")]
648    {
649        v |= v >> 32;
650    }
651    v + 1
652}
653
654/// In-place Fast Walsh-Hadamard Transform (normalized).
655///
656/// `x` must have power-of-2 length. After transform, `H` is orthogonal:
657/// applying FWHT twice recovers the original vector (self-inverse).
658fn fwht_inplace(x: &mut [f64]) {
659    let n = x.len();
660    debug_assert!(
661        n > 0 && (n & (n - 1)) == 0,
662        "FWHT requires power-of-2 length"
663    );
664    let mut h = 1;
665    while h < n {
666        for i in (0..n).step_by(h * 2) {
667            for j in i..i + h {
668                let a = x[j];
669                let b = x[j + h];
670                x[j] = a + b;
671                x[j + h] = a - b;
672            }
673        }
674        h *= 2;
675    }
676    let scale = 1.0 / crate::math::sqrt(n as f64);
677    for v in x.iter_mut() {
678        *v *= scale;
679    }
680}
681
682/// Apply random sign flips (diagonal D matrix) to `x`.
683///
684/// `D` is self-inverse: applying the same sign flips twice recovers the original.
685fn apply_sign_flip(x: &mut [f64], seed: u64) {
686    let mut state = seed;
687    for v in x.iter_mut() {
688        let r = crate::rng::xorshift64(&mut state);
689        if r & 1 == 0 {
690            *v = -*v;
691        }
692    }
693}
694
695/// Apply the full Hadamard rotation `H * D * x`.
696fn apply_rotation(buf: &mut [f64], seed: u64) {
697    apply_sign_flip(buf, seed);
698    fwht_inplace(buf);
699}
700
701/// Apply the inverse Hadamard rotation `D * H * x` to recover original space.
702fn apply_inverse_rotation(buf: &mut [f64], seed: u64) {
703    fwht_inplace(buf);
704    apply_sign_flip(buf, seed);
705}
706
707// ---------------------------------------------------------------------------
708// Public quantization API
709// ---------------------------------------------------------------------------
710
711/// Quantize a weight vector to 3.5-bit TurboQuant format.
712///
713/// Applies a randomized Hadamard rotation before quantization to decorrelate
714/// weight dimensions, then compresses using an 11-level linear grid with
715/// min/max scaling. Uses a deterministic default seed.
716///
717/// This is a backwards-compatible wrapper around [`quantize`] with
718/// [`QuantMode::Bits3_5`] and the default seed.
719///
720/// # Example
721///
722/// ```
723/// use irithyll_core::turbo_quant::quantize_weights;
724///
725/// let weights = vec![0.1, -0.5, 0.3, 0.0, -0.2, 0.4, 0.1, -0.1, 0.2];
726/// let quantized = quantize_weights(&weights);
727/// assert_eq!(quantized.n_weights(), 9);
728///
729/// // Predict with on-the-fly dequantization
730/// let features = vec![1.0; 9];
731/// let pred = quantized.predict(&features);
732/// assert!(pred.is_finite());
733///
734/// // Roundtrip check
735/// let original_dot: f64 = weights.iter().zip(features.iter()).map(|(w, f)| w * f).sum();
736/// assert!((pred - original_dot).abs() < 0.5, "quantization error should be small");
737/// ```
738pub fn quantize_weights(weights: &[f64]) -> TurboQuantized {
739    quantize(weights, QuantMode::Bits3_5, DEFAULT_SEED)
740}
741
742/// Quantize with an explicit seed for the Hadamard rotation (3.5-bit mode).
743///
744/// Backwards-compatible wrapper around [`quantize`].
745pub fn quantize_weights_with_seed(weights: &[f64], seed: u64) -> TurboQuantized {
746    quantize(weights, QuantMode::Bits3_5, seed)
747}
748
749/// Quantize a weight vector with explicit mode and seed.
750///
751/// Applies a randomized Hadamard rotation before quantization to decorrelate
752/// weight dimensions, then compresses using a linear grid with min/max
753/// scaling at the specified bit depth.
754pub fn quantize(weights: &[f64], mode: QuantMode, seed: u64) -> TurboQuantized {
755    if weights.is_empty() {
756        return TurboQuantized {
757            packed: vec![],
758            n_weights: 0,
759            scale: 0.0,
760            offset: 0.0,
761            seed,
762            padded_len: 1,
763            mode,
764        };
765    }
766
767    // Apply Hadamard rotation: pad to power of 2, sign flip, FWHT
768    let padded_len = next_power_of_two(weights.len());
769    let mut rotated = Vec::with_capacity(padded_len);
770    rotated.extend_from_slice(weights);
771    rotated.resize(padded_len, 0.0);
772    apply_rotation(&mut rotated, seed);
773
774    // Quantize ALL padded_len rotated values (rotation spreads information
775    // across the full padded vector, so truncating would lose data).
776    let min_val = rotated.iter().copied().fold(f64::INFINITY, f64::min);
777    let max_val = rotated.iter().copied().fold(f64::NEG_INFINITY, f64::max);
778    let range = max_val - min_val;
779    let n_levels = mode.n_levels();
780    let max_level = n_levels - 1;
781    let scale = if range < 1e-15 {
782        0.0
783    } else {
784        range / max_level as f64
785    };
786
787    // Quantize each rotated weight to [0, max_level]
788    let quantized: Vec<u8> = rotated
789        .iter()
790        .map(|&w| {
791            if scale < 1e-15 {
792                (max_level / 2) as u8 // constant weights -> mid-grid
793            } else {
794                let q = crate::math::round((w - min_val) / scale);
795                (q as u8).min(max_level as u8)
796            }
797        })
798        .collect();
799
800    // Pack into u32 words (all padded_len values)
801    let vpw = mode.values_per_word();
802    let n_words = padded_len.div_ceil(vpw);
803    let mut packed = Vec::with_capacity(n_words);
804    for chunk in quantized.chunks(vpw) {
805        packed.push(pack_word(chunk, mode));
806    }
807
808    TurboQuantized {
809        packed,
810        n_weights: weights.len(),
811        scale,
812        offset: min_val,
813        seed,
814        padded_len,
815        mode,
816    }
817}
818
819/// Quantize f32 weights with explicit mode. Uses the default seed.
820pub fn quantize_f32(weights: &[f32], mode: QuantMode) -> TurboQuantized {
821    let f64_weights: Vec<f64> = weights.iter().map(|&w| w as f64).collect();
822    quantize(&f64_weights, mode, DEFAULT_SEED)
823}
824
825/// Quantize i16 weights with a dequantization scale and explicit mode.
826///
827/// Each i16 value is converted to `f64` via `value as f64 * scale` before
828/// quantization. Uses the default seed.
829pub fn quantize_i16(weights: &[i16], scale: f64, mode: QuantMode) -> TurboQuantized {
830    let f64_weights: Vec<f64> = weights.iter().map(|&w| w as f64 * scale).collect();
831    quantize(&f64_weights, mode, DEFAULT_SEED)
832}
833
834// ---------------------------------------------------------------------------
835// Tests
836// ---------------------------------------------------------------------------
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841
842    // ---- Original tests (updated for new internals) ----
843
844    #[test]
845    fn pack_unpack_roundtrip() {
846        let values = [0u8, 5, 10, 3, 7, 1, 9];
847        let packed = pack7(&values);
848        let unpacked = unpack7(packed, 7);
849        assert_eq!(&unpacked, &values, "pack/unpack roundtrip failed");
850    }
851
852    #[test]
853    fn pack_unpack_partial() {
854        let values = [2u8, 8, 4];
855        let packed = pack7(&values);
856        let unpacked = unpack7(packed, 3);
857        assert_eq!(&unpacked[..3], &values, "partial pack/unpack failed");
858    }
859
860    #[test]
861    fn quantize_empty() {
862        let q = quantize_weights(&[]);
863        assert_eq!(q.n_weights(), 0);
864        assert_eq!(q.predict(&[]), 0.0);
865    }
866
867    #[test]
868    fn quantize_single_weight() {
869        let q = quantize_weights(&[3.125]);
870        assert_eq!(q.n_weights(), 1);
871        let pred = q.predict(&[1.0]);
872        assert!(
873            (pred - 3.125).abs() < 0.5,
874            "single weight should roundtrip reasonably, got {pred}"
875        );
876    }
877
878    #[test]
879    fn quantize_constant_weights() {
880        let q = quantize_weights(&[2.5, 2.5, 2.5, 2.5]);
881        let dq = q.dequantize();
882        for (i, &w) in dq.iter().enumerate() {
883            assert!(
884                (w - 2.5).abs() < 0.05,
885                "constant weights should dequantize closely, got {w} at [{i}]"
886            );
887        }
888    }
889
890    #[test]
891    fn quantize_predict_accuracy() {
892        let weights = vec![0.1, -0.5, 0.3, 0.0, -0.2, 0.4, 0.1, -0.1, 0.2];
893        let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
894        let exact: f64 = weights
895            .iter()
896            .zip(features.iter())
897            .map(|(w, f)| w * f)
898            .sum();
899        let q = quantize_weights(&weights);
900        let pred = q.predict(&features);
901        let rel_err = if exact.abs() > 1e-10 {
902            (pred - exact).abs() / exact.abs()
903        } else {
904            (pred - exact).abs()
905        };
906        assert!(
907            rel_err < 0.25,
908            "relative error should be < 25%, got {rel_err:.4} (exact={exact:.4}, pred={pred:.4})"
909        );
910    }
911
912    #[test]
913    fn quantize_dequantize_bounded_error() {
914        let weights: Vec<f64> = (0..100).map(|i| (i as f64 - 50.0) / 50.0).collect();
915        let q = quantize_weights(&weights);
916        let dq = q.dequantize();
917        let max_err = weights
918            .iter()
919            .zip(dq.iter())
920            .map(|(w, d)| (w - d).abs())
921            .fold(0.0f64, f64::max);
922        assert!(
923            max_err < 0.25,
924            "max dequantize error should be < 0.25, got {max_err}"
925        );
926    }
927
928    #[test]
929    fn to_bytes_from_bytes_roundtrip() {
930        let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.2, 0.7, -0.9, 0.4, 0.6];
931        let q = quantize_weights(&weights);
932        let bytes = q.to_bytes();
933        let view = TurboQuantizedView::from_bytes(&bytes).expect("valid bytes");
934        assert_eq!(view.n_weights(), q.n_weights());
935        let features = vec![1.0; 10];
936        let pred_owned = q.predict(&features);
937        let pred_view = view.predict(&features);
938        assert!(
939            (pred_owned - pred_view).abs() < 1e-15,
940            "owned vs view predict mismatch: {pred_owned} vs {pred_view}"
941        );
942    }
943
944    #[test]
945    fn from_bytes_rejects_short() {
946        assert!(TurboQuantizedView::from_bytes(&[0u8; 10]).is_err());
947        assert!(TurboQuantizedView::from_bytes(&[0u8; 35]).is_err());
948    }
949
950    #[test]
951    fn compression_ratio_reasonable() {
952        let weights: Vec<f64> = (0..100).map(|i| i as f64 * 0.01).collect();
953        let q = quantize_weights(&weights);
954        let ratio = q.compression_ratio();
955        assert!(
956            ratio > 3.0,
957            "compression ratio should be > 3x for 100 weights, got {ratio:.2}"
958        );
959    }
960
961    #[test]
962    fn predict_large_vector() {
963        let n = 1000;
964        let weights: Vec<f64> = (0..n).map(|i| ((i as f64) * 0.1).sin()).collect();
965        let features: Vec<f64> = (0..n).map(|i| ((i as f64) * 0.05).cos()).collect();
966        let exact: f64 = weights
967            .iter()
968            .zip(features.iter())
969            .map(|(w, f)| w * f)
970            .sum();
971        let q = quantize_weights(&weights);
972        let pred = q.predict(&features);
973        assert!(pred.is_finite(), "prediction should be finite");
974        let abs_err = (pred - exact).abs();
975        assert!(
976            abs_err < exact.abs() * 0.5 + 5.0,
977            "absolute error too large: {abs_err} for exact {exact}"
978        );
979    }
980
981    #[test]
982    fn next_power_of_two_correctness() {
983        assert_eq!(next_power_of_two(0), 1);
984        assert_eq!(next_power_of_two(1), 1);
985        assert_eq!(next_power_of_two(2), 2);
986        assert_eq!(next_power_of_two(3), 4);
987        assert_eq!(next_power_of_two(4), 4);
988        assert_eq!(next_power_of_two(5), 8);
989        assert_eq!(next_power_of_two(7), 8);
990        assert_eq!(next_power_of_two(8), 8);
991        assert_eq!(next_power_of_two(9), 16);
992        assert_eq!(next_power_of_two(100), 128);
993        assert_eq!(next_power_of_two(1024), 1024);
994        assert_eq!(next_power_of_two(1025), 2048);
995    }
996
997    #[test]
998    fn fwht_roundtrip() {
999        let mut data = vec![1.0, 2.0, 3.0, 4.0];
1000        let original = data.clone();
1001        fwht_inplace(&mut data);
1002        fwht_inplace(&mut data);
1003        for (i, (&a, &b)) in data.iter().zip(original.iter()).enumerate() {
1004            assert!(
1005                (a - b).abs() < 1e-10,
1006                "FWHT roundtrip failed at [{i}]: {a} vs {b}"
1007            );
1008        }
1009    }
1010
1011    #[test]
1012    fn fwht_roundtrip_large() {
1013        let n = 64;
1014        let mut data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1 - 3.0).collect();
1015        let original = data.clone();
1016        fwht_inplace(&mut data);
1017        fwht_inplace(&mut data);
1018        for (i, (&a, &b)) in data.iter().zip(original.iter()).enumerate() {
1019            assert!(
1020                (a - b).abs() < 1e-10,
1021                "FWHT large roundtrip failed at [{i}]: {a} vs {b}"
1022            );
1023        }
1024    }
1025
1026    #[test]
1027    fn sign_flip_is_self_inverse() {
1028        let seed = 42u64;
1029        let mut data = vec![1.0, -2.5, 3.7, 0.0, -1.1, 5.5, 2.2, -0.8];
1030        let original = data.clone();
1031        apply_sign_flip(&mut data, seed);
1032        apply_sign_flip(&mut data, seed);
1033        for (i, (&a, &b)) in data.iter().zip(original.iter()).enumerate() {
1034            assert!(
1035                (a - b).abs() < 1e-15,
1036                "sign flip self-inverse failed at [{i}]: {a} vs {b}"
1037            );
1038        }
1039    }
1040
1041    #[test]
1042    fn full_rotation_roundtrip() {
1043        let seed = 0xCAFE_u64;
1044        let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1045        let mut buf = original.clone();
1046        apply_rotation(&mut buf, seed);
1047        apply_inverse_rotation(&mut buf, seed);
1048        for (i, (&a, &b)) in buf.iter().zip(original.iter()).enumerate() {
1049            assert!(
1050                (a - b).abs() < 1e-10,
1051                "rotation roundtrip failed at [{i}]: {a} vs {b}"
1052            );
1053        }
1054    }
1055
1056    #[test]
1057    fn rotation_preserves_norm() {
1058        let seed = 0xBEEF_u64;
1059        let data = vec![1.0, 2.0, 3.0, 4.0];
1060        let norm_before: f64 = data.iter().map(|x| x * x).sum();
1061        let mut rotated = data;
1062        apply_rotation(&mut rotated, seed);
1063        let norm_after: f64 = rotated.iter().map(|x| x * x).sum();
1064        assert!(
1065            (norm_before - norm_after).abs() < 1e-10,
1066            "rotation should preserve norm: {norm_before} vs {norm_after}"
1067        );
1068    }
1069
1070    #[test]
1071    fn rotation_improves_correlated_weights() {
1072        let weights = vec![1.0, 1.01, 0.99, 1.02, 0.98, 1.01, 0.99, 1.0];
1073        let q = quantize_weights(&weights);
1074        let dq = q.dequantize();
1075        let max_err: f64 = weights
1076            .iter()
1077            .zip(dq.iter())
1078            .map(|(w, d)| (w - d).abs())
1079            .fold(0.0f64, f64::max);
1080        assert!(
1081            max_err < 0.05,
1082            "rotation should improve correlated weight quantization, max_err={max_err}"
1083        );
1084    }
1085
1086    #[test]
1087    fn quantize_with_seed_deterministic() {
1088        let weights = vec![0.1, -0.5, 0.3, 0.0, -0.2, 0.4, 0.1, -0.1];
1089        let features = vec![1.0; 8];
1090        let q1 = quantize_weights_with_seed(&weights, 123);
1091        let q2 = quantize_weights_with_seed(&weights, 123);
1092        let p1 = q1.predict(&features);
1093        let p2 = q2.predict(&features);
1094        assert!(
1095            (p1 - p2).abs() < 1e-15,
1096            "same seed should give identical results: {p1} vs {p2}"
1097        );
1098    }
1099
1100    #[test]
1101    fn different_seeds_produce_different_quantizations() {
1102        let weights = vec![0.1, -0.5, 0.3, 0.0, -0.2, 0.4, 0.1, -0.1];
1103        let q1 = quantize_weights_with_seed(&weights, 111);
1104        let q2 = quantize_weights_with_seed(&weights, 222);
1105        assert_ne!(
1106            q1.packed, q2.packed,
1107            "different seeds should produce different packed data"
1108        );
1109    }
1110
1111    #[test]
1112    fn to_bytes_from_bytes_preserves_seed_and_padded_len() {
1113        let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0];
1114        let q = quantize_weights_with_seed(&weights, 0xABCD);
1115        let bytes = q.to_bytes();
1116        let view = TurboQuantizedView::from_bytes(&bytes).expect("valid bytes");
1117        assert_eq!(view.seed, 0xABCD);
1118        assert_eq!(view.padded_len, q.padded_len);
1119        assert_eq!(view.n_weights(), q.n_weights());
1120    }
1121
1122    // ---- New tests: 8-bit mode ----
1123
1124    #[test]
1125    fn bits8_pack_unpack_roundtrip() {
1126        let values = [0u8, 127, 255, 42];
1127        let packed = pack4_bytes(&values);
1128        let unpacked = unpack4_bytes(packed, 4);
1129        assert_eq!(&unpacked, &values, "8-bit pack/unpack roundtrip failed");
1130    }
1131
1132    #[test]
1133    fn bits8_near_lossless() {
1134        let weights: Vec<f64> = (0..64).map(|i| (i as f64 - 32.0) / 32.0).collect();
1135        let q = quantize(&weights, QuantMode::Bits8, DEFAULT_SEED);
1136        let dq = q.dequantize();
1137        let max_err = weights
1138            .iter()
1139            .zip(dq.iter())
1140            .map(|(w, d)| (w - d).abs())
1141            .fold(0.0f64, f64::max);
1142        assert!(
1143            max_err < 0.02,
1144            "8-bit should be near-lossless, max_err={max_err}"
1145        );
1146    }
1147
1148    #[test]
1149    fn bits8_predict_accuracy() {
1150        let weights: Vec<f64> = (0..32).map(|i| (i as f64).sin() * 0.5).collect();
1151        let features: Vec<f64> = (0..32).map(|i| (i as f64).cos() * 0.3).collect();
1152        let exact: f64 = weights
1153            .iter()
1154            .zip(features.iter())
1155            .map(|(w, f)| w * f)
1156            .sum();
1157        let q = quantize(&weights, QuantMode::Bits8, DEFAULT_SEED);
1158        let pred = q.predict(&features);
1159        let rel_err = (pred - exact).abs() / exact.abs().max(1e-10);
1160        assert!(
1161            rel_err < 0.10,
1162            "8-bit predict should have <10% relative error, got {rel_err:.4}"
1163        );
1164    }
1165
1166    // ---- New tests: 2.5-bit mode ----
1167
1168    #[test]
1169    fn bits2_5_packing_roundtrip() {
1170        let values = [0u8, 4, 2, 1, 3, 0, 4, 2, 1, 3, 0, 4, 2];
1171        let packed = pack13(&values);
1172        let unpacked = unpack13(packed, 13);
1173        assert_eq!(&unpacked, &values, "2.5-bit pack/unpack roundtrip failed");
1174    }
1175
1176    #[test]
1177    fn bits2_5_quantize_and_predict() {
1178        let weights: Vec<f64> = (0..16).map(|i| (i as f64 - 8.0) / 8.0).collect();
1179        let features = vec![1.0; 16];
1180        let q = quantize(&weights, QuantMode::Bits2_5, DEFAULT_SEED);
1181        let pred = q.predict(&features);
1182        assert!(pred.is_finite(), "2.5-bit predict should be finite");
1183    }
1184
1185    // ---- New tests: cross-mode serialization ----
1186
1187    #[test]
1188    fn all_modes_serialize_roundtrip() {
1189        let weights = vec![0.1, -0.3, 0.5, 0.0, -0.2, 0.4, 0.3, -0.1];
1190        for mode in [QuantMode::Bits8, QuantMode::Bits3_5, QuantMode::Bits2_5] {
1191            let q = quantize(&weights, mode, 42);
1192            let bytes = q.to_bytes();
1193            let view = TurboQuantizedView::from_bytes(&bytes).expect("valid bytes");
1194            assert_eq!(view.n_weights(), q.n_weights());
1195            assert_eq!(view.mode(), mode);
1196            let features = vec![1.0; 8];
1197            let p1 = q.predict(&features);
1198            let p2 = view.predict(&features);
1199            assert!(
1200                (p1 - p2).abs() < 1e-15,
1201                "mode {mode:?}: owned={p1} vs view={p2}"
1202            );
1203        }
1204    }
1205
1206    // ---- New tests: zero-alloc predict ----
1207
1208    #[test]
1209    fn predict_with_scratch_matches_predict() {
1210        let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.2];
1211        let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1212        let q = quantize(&weights, QuantMode::Bits3_5, DEFAULT_SEED);
1213        let pred = q.predict(&features);
1214        let mut scratch = vec![0.0; q.padded_len()];
1215        let pred_scratch = q.predict_with_scratch(&features, &mut scratch);
1216        assert!(
1217            (pred - pred_scratch).abs() < 1e-15,
1218            "scratch predict should match: {pred} vs {pred_scratch}"
1219        );
1220    }
1221
1222    #[test]
1223    fn predict_with_scratch_view_matches_predict() {
1224        let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.2];
1225        let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1226        let q = quantize(&weights, QuantMode::Bits8, DEFAULT_SEED);
1227        let bytes = q.to_bytes();
1228        let view = TurboQuantizedView::from_bytes(&bytes).expect("valid bytes");
1229        let pred = view.predict(&features);
1230        let mut scratch = vec![0.0; view.padded_len()];
1231        let pred_scratch = view.predict_with_scratch(&features, &mut scratch);
1232        assert!(
1233            (pred - pred_scratch).abs() < 1e-15,
1234            "view scratch predict should match: {pred} vs {pred_scratch}"
1235        );
1236    }
1237
1238    // ---- New tests: multi-type input ----
1239
1240    #[test]
1241    fn quantize_f32_works() {
1242        let weights = vec![0.5f32, -0.3, 0.8, -0.1];
1243        let q = quantize_f32(&weights, QuantMode::Bits8);
1244        assert_eq!(q.n_weights(), 4);
1245        let pred = q.predict(&[1.0, 1.0, 1.0, 1.0]);
1246        assert!(pred.is_finite());
1247    }
1248
1249    #[test]
1250    fn quantize_i16_works() {
1251        let weights = vec![1000i16, -500, 2000, -1000];
1252        let scale = 1.0 / 32767.0;
1253        let q = quantize_i16(&weights, scale, QuantMode::Bits3_5);
1254        assert_eq!(q.n_weights(), 4);
1255    }
1256
1257    // ---- New tests: mode-specific compression ratios ----
1258
1259    #[test]
1260    fn bits8_compression_ratio() {
1261        let weights: Vec<f64> = (0..256).map(|i| i as f64 * 0.01).collect();
1262        let q = quantize(&weights, QuantMode::Bits8, DEFAULT_SEED);
1263        let ratio = q.compression_ratio();
1264        // 256 * 8 bytes = 2048 bytes original. 8-bit: 256/4 = 64 words + 36 header = 292 bytes. ~7x
1265        assert!(
1266            ratio > 5.0,
1267            "8-bit compression ratio should be > 5x, got {ratio:.2}"
1268        );
1269    }
1270
1271    #[test]
1272    fn bits2_5_compression_ratio() {
1273        let weights: Vec<f64> = (0..256).map(|i| i as f64 * 0.01).collect();
1274        let q = quantize(&weights, QuantMode::Bits2_5, DEFAULT_SEED);
1275        let ratio = q.compression_ratio();
1276        // 2.5-bit: 256/13 = ~20 words + 36 header = ~116 bytes. ~17x
1277        assert!(
1278            ratio > 10.0,
1279            "2.5-bit compression ratio should be > 10x, got {ratio:.2}"
1280        );
1281    }
1282
1283    // ---- New tests: edge cases ----
1284
1285    #[test]
1286    fn quantize_empty_all_modes() {
1287        for mode in [QuantMode::Bits8, QuantMode::Bits3_5, QuantMode::Bits2_5] {
1288            let q = quantize(&[], mode, DEFAULT_SEED);
1289            assert_eq!(q.n_weights(), 0);
1290            assert_eq!(q.predict(&[]), 0.0);
1291        }
1292    }
1293
1294    #[test]
1295    fn predict_with_scratch_empty() {
1296        let q = quantize(&[], QuantMode::Bits3_5, DEFAULT_SEED);
1297        let mut scratch = vec![0.0; 1];
1298        assert_eq!(q.predict_with_scratch(&[], &mut scratch), 0.0);
1299    }
1300}