Skip to main content

turboquant/packed/
mod.rs

1//! Packed data structures for quantized blocks.
2//!
3//! After quantization, indices are bit-packed into compact representations
4//! to minimise memory usage. TQ2 uses 2 bits, TQ3 uses 3 bits per value
5//! (3.5 bpw for block_size=32), TQ4 uses 4 bits per value (4.5 bpw for
6//! block_size=32).
7
8use half::f16;
9
10use crate::error::{require, Result, TurboQuantError};
11
12pub mod indices;
13pub use indices::{
14    pack_indices_2bit, pack_indices_3bit, pack_indices_4bit, unpack_indices_2bit,
15    unpack_indices_3bit, unpack_indices_4bit,
16};
17// Re-export internal helpers for use by the tests submodule.
18#[allow(unused_imports)]
19use indices::{
20    chunk_to_2bit_array, chunk_to_3bit_array, chunk_to_4bit_array, chunk_to_packed_3bit_array,
21    has_2bit_remainder, has_3bit_remainder, has_4bit_remainder, num_2bit_groups, num_3bit_groups,
22    num_4bit_pairs, packed_2bit_capacity, packed_3bit_capacity, packed_4bit_capacity,
23    pad_remainder_2bit, pad_remainder_3bit, trailing_4bit_pair,
24};
25
26// ---------------------------------------------------------------------------
27// Named constants (eliminates magic numbers)
28// ---------------------------------------------------------------------------
29
30/// Bits per value for TQ2 quantization.
31pub(crate) const BITS_TQ2: u8 = 2;
32
33/// Bits per value for TQ3 quantization.
34pub(crate) const BITS_TQ3: u8 = 3;
35
36/// Bits per value for TQ4 quantization.
37pub(crate) const BITS_TQ4: u8 = 4;
38
39/// Number of indices packed into one 2-bit group.
40const PACK_2BIT_GROUP_SIZE: usize = 4;
41
42/// Number of indices packed into one 3-bit group.
43const PACK_3BIT_GROUP_SIZE: usize = 8;
44
45/// Number of bytes produced by packing one 3-bit group.
46const PACK_3BIT_BYTES: usize = 3;
47
48/// Number of indices packed into one 4-bit group.
49const PACK_4BIT_GROUP_SIZE: usize = 2;
50
51/// Bit mask for 3-bit values (0b111).
52const MASK_3BIT: u8 = 0x7;
53
54/// Bit mask for 2-bit values (0b11).
55const MASK_2BIT: u8 = 0x3;
56
57/// Bit mask for 1-bit values (0b1).
58const MASK_1BIT: u8 = 0x1;
59
60/// Bit mask for 4-bit values (0b1111).
61const MASK_4BIT: u8 = 0xF;
62
63/// Shift amount for 3-bit boundaries.
64const SHIFT_3: u32 = 3;
65
66/// Shift amount for 4-bit boundaries.
67const SHIFT_4: u32 = 4;
68
69/// Shift amount for 5-bit boundaries.
70const SHIFT_5: u32 = 5;
71
72/// Shift amount for 6-bit boundaries.
73const SHIFT_6: u32 = 6;
74
75/// Shift amount for 7-bit boundaries.
76const SHIFT_7: u32 = 7;
77
78/// Shift amount for 1-bit boundaries.
79const SHIFT_1: u32 = 1;
80
81/// Shift amount for 2-bit boundaries.
82const SHIFT_2: u32 = 2;
83
84/// Size of the f16 scale field in bytes.
85const SCALE_SIZE_BYTES: usize = 2;
86
87// ---------------------------------------------------------------------------
88// Configuration
89// ---------------------------------------------------------------------------
90
91/// Configuration for TurboQuant quantization.
92#[derive(Clone, Copy)]
93pub struct TurboQuantConfig {
94    /// Bits per value (2, 3, or 4).
95    pub(crate) bits: u8,
96    /// Vector dimension (must be a power of two for WHT).
97    pub(crate) dim: usize,
98    /// Seed for the rotation matrix.
99    pub(crate) rotation_seed: u64,
100}
101
102/// Check whether `bits` is a supported value (2, 3, or 4).
103///
104/// Pure Operation: contains only logic, no calls to other project functions.
105pub(crate) fn is_valid_bits(bits: u8) -> bool {
106    bits == BITS_TQ2 || bits == BITS_TQ3 || bits == BITS_TQ4
107}
108
109/// Check whether `dim` is a non-zero power of two.
110///
111/// Pure Operation: contains only logic, no calls to other project functions.
112pub(crate) fn is_valid_dim(dim: usize) -> bool {
113    dim > 0 && dim.is_power_of_two()
114}
115
116impl TurboQuantConfig {
117    /// Create a new configuration after validating inputs.
118    ///
119    /// Returns an error when `bits` is not 2, 3, or 4, or `dim` is not a power
120    /// of two.
121    ///
122    /// Pure Integration: only calls `require`, `is_valid_bits`, `is_valid_dim`.
123    pub fn new(bits: u8, dim: usize) -> Result<Self> {
124        require(is_valid_bits(bits), TurboQuantError::UnsupportedBits(bits))?;
125        require(is_valid_dim(dim), TurboQuantError::InvalidDimension(dim))?;
126        Ok(Self {
127            bits,
128            dim,
129            rotation_seed: 0,
130        })
131    }
132
133    /// Builder-style setter for the rotation seed.
134    // qual:api — public builder API for downstream consumers
135    pub fn with_seed(mut self, seed: u64) -> Self {
136        self.rotation_seed = seed;
137        self
138    }
139}
140
141// ---------------------------------------------------------------------------
142// Unified PackedBlock
143// ---------------------------------------------------------------------------
144
145/// A packed quantized block that stores a scale factor and bit-packed indices.
146///
147/// Replaces the former `BlockTQ2`, `BlockTQ3`, and `BlockTQ4` structs with a
148/// single type that tracks its own bit width.
149pub struct PackedBlock {
150    /// Bit width used for packing (2, 3, or 4).
151    pub bits: u8,
152    /// Scaling factor (L2-norm of original vector).
153    pub scale: f16,
154    /// Packed indices (layout depends on `bits`).
155    pub packed_indices: Vec<u8>,
156}
157
158impl PackedBlock {
159    /// Create a new packed block from a scale and a slice of unpacked index values.
160    ///
161    /// The indices are bit-packed internally based on the specified `bits` width.
162    ///
163    /// Pure Integration: delegates packing to the bit-width-specific helper
164    /// selected by the `pack` closure (IOSP lenient-mode closure pattern).
165    pub fn new(bits: u8, scale: f16, indices: &[u8]) -> Self {
166        let pack = |indices: &[u8]| -> Vec<u8> {
167            match bits {
168                BITS_TQ2 => pack_indices_2bit(indices),
169                BITS_TQ3 => pack_indices_3bit(indices),
170                BITS_TQ4 => pack_indices_4bit(indices),
171                _ => unreachable!("bits validated to be 2, 3, or 4"),
172            }
173        };
174        Self {
175            bits,
176            scale,
177            packed_indices: pack(indices),
178        }
179    }
180
181    /// Total size of the block in bytes (2 bytes for f16 scale + packed data).
182    pub fn size_bytes(&self) -> usize {
183        SCALE_SIZE_BYTES + self.packed_indices.len()
184    }
185
186    /// Creates a `PackedBlock` from pre-packed data without re-packing.
187    ///
188    /// Use this to reconstruct blocks from GPU-quantized data that is already
189    /// in the correct packed layout.
190    ///
191    /// Pure Operation: field assignment only.
192    // qual:api — used by GPU kernel integration for importing quantized data
193    pub fn from_raw(bits: u8, scale: f16, packed_indices: Vec<u8>) -> Self {
194        Self {
195            bits,
196            scale,
197            packed_indices,
198        }
199    }
200
201    /// Unpacks stored indices into a caller-provided buffer, avoiding allocation.
202    ///
203    /// This is the hot-path variant: reuses the buffer across repeated calls
204    /// (e.g. inside attention score loops) to eliminate per-key allocations.
205    ///
206    /// Pure Integration: delegates unpacking to the bit-width-specific helper
207    /// selected by the `do_unpack` closure (IOSP lenient-mode closure pattern).
208    pub fn unpack_into(&self, count: usize, buf: &mut Vec<u8>) {
209        buf.clear();
210        let do_unpack = |packed: &[u8], out: &mut Vec<u8>| match self.bits {
211            BITS_TQ2 => out.extend_from_slice(&unpack_indices_2bit(packed, count)),
212            BITS_TQ3 => out.extend_from_slice(&unpack_indices_3bit(packed, count)),
213            BITS_TQ4 => out.extend_from_slice(&unpack_indices_4bit(packed, count)),
214            _ => unreachable!("bits validated"),
215        };
216        do_unpack(&self.packed_indices, buf);
217        buf.truncate(count);
218    }
219
220    /// Recover the unpacked index values.
221    ///
222    /// Allocates a fresh buffer. For hot paths, prefer
223    /// [`unpack_into`](Self::unpack_into) with a reusable buffer.
224    pub fn unpack(&self, count: usize) -> Vec<u8> {
225        let do_unpack = |packed: &[u8]| match self.bits {
226            BITS_TQ2 => unpack_indices_2bit(packed, count),
227            BITS_TQ3 => unpack_indices_3bit(packed, count),
228            BITS_TQ4 => unpack_indices_4bit(packed, count),
229            _ => unreachable!("bits validated"),
230        };
231        do_unpack(&self.packed_indices)
232    }
233}
234
235// ---------------------------------------------------------------------------
236// 2-bit packing / unpacking  (pure Operation functions)
237// ---------------------------------------------------------------------------
238
239/// Pack 4 two-bit values into 1 byte.
240///
241/// Only the lowest 2 bits of each input byte are used.
242pub fn pack_2bit(values: &[u8; PACK_2BIT_GROUP_SIZE]) -> u8 {
243    (values[0] & MASK_2BIT)
244        | ((values[1] & MASK_2BIT) << SHIFT_2)
245        | ((values[2] & MASK_2BIT) << SHIFT_4)
246        | ((values[3] & MASK_2BIT) << SHIFT_6)
247}
248
249/// Unpack 1 byte into 4 two-bit values.
250pub fn unpack_2bit(packed: u8) -> [u8; PACK_2BIT_GROUP_SIZE] {
251    [
252        packed & MASK_2BIT,
253        (packed >> SHIFT_2) & MASK_2BIT,
254        (packed >> SHIFT_4) & MASK_2BIT,
255        (packed >> SHIFT_6) & MASK_2BIT,
256    ]
257}
258
259// ---------------------------------------------------------------------------
260// 3-bit packing / unpacking  (pure Operation functions)
261// ---------------------------------------------------------------------------
262
263/// Pack 8 three-bit values into 3 bytes.
264///
265/// Only the lowest 3 bits of each input byte are used.
266pub fn pack_3bit(values: &[u8; PACK_3BIT_GROUP_SIZE]) -> [u8; PACK_3BIT_BYTES] {
267    let mut packed = [0u8; PACK_3BIT_BYTES];
268    packed[0] = (values[0] & MASK_3BIT)
269        | ((values[1] & MASK_3BIT) << SHIFT_3)
270        | ((values[2] & MASK_2BIT) << SHIFT_6);
271    packed[1] = ((values[2] >> SHIFT_2) & MASK_1BIT)
272        | ((values[3] & MASK_3BIT) << SHIFT_1)
273        | ((values[4] & MASK_3BIT) << SHIFT_4)
274        | ((values[5] & MASK_1BIT) << SHIFT_7);
275    packed[2] = ((values[5] >> SHIFT_1) & MASK_2BIT)
276        | ((values[6] & MASK_3BIT) << SHIFT_2)
277        | ((values[7] & MASK_3BIT) << SHIFT_5);
278    packed
279}
280
281/// Unpack 3 bytes into 8 three-bit values.
282pub fn unpack_3bit(packed: &[u8; PACK_3BIT_BYTES]) -> [u8; PACK_3BIT_GROUP_SIZE] {
283    let mut values = [0u8; PACK_3BIT_GROUP_SIZE];
284    values[0] = packed[0] & MASK_3BIT;
285    values[1] = (packed[0] >> SHIFT_3) & MASK_3BIT;
286    values[2] = ((packed[0] >> SHIFT_6) & MASK_2BIT) | ((packed[1] & MASK_1BIT) << SHIFT_2);
287    values[3] = (packed[1] >> SHIFT_1) & MASK_3BIT;
288    values[4] = (packed[1] >> SHIFT_4) & MASK_3BIT;
289    values[5] = ((packed[1] >> SHIFT_7) & MASK_1BIT) | ((packed[2] & MASK_2BIT) << SHIFT_1);
290    values[6] = (packed[2] >> SHIFT_2) & MASK_3BIT;
291    values[7] = (packed[2] >> SHIFT_5) & MASK_3BIT;
292    values
293}
294
295// ---------------------------------------------------------------------------
296// 4-bit packing / unpacking  (pure Operation functions)
297// ---------------------------------------------------------------------------
298
299/// Pack 2 four-bit values into 1 byte.
300///
301/// Only the lowest 4 bits of each input byte are used.
302pub fn pack_4bit(values: &[u8; 2]) -> u8 {
303    (values[0] & MASK_4BIT) | ((values[1] & MASK_4BIT) << SHIFT_4)
304}
305
306/// Unpack 1 byte into 2 four-bit values.
307pub fn unpack_4bit(packed: u8) -> [u8; 2] {
308    [packed & MASK_4BIT, (packed >> SHIFT_4) & MASK_4BIT]
309}
310
311// ---------------------------------------------------------------------------
312// Unit tests
313// ---------------------------------------------------------------------------
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    /// Standard block size (power of two) used in config validation tests.
320    const TEST_BLOCK_SIZE: usize = 32;
321    /// Standard large dimension (power of two) used in config validation tests.
322    const TEST_DIM_128: usize = 128;
323    /// Number of 3-bit groups in capacity tests.
324    const TEST_3BIT_GROUPS: usize = 4;
325    /// Number of 4-bit pairs in capacity tests.
326    const TEST_4BIT_PAIRS: usize = 5;
327    /// Maximum valid 3-bit value (2^3 - 1).
328    const MAX_3BIT_VALUE: u8 = 7;
329    /// Maximum valid 4-bit value (2^4 - 1).
330    const MAX_4BIT_VALUE: u8 = 15;
331    /// Test trailing-pair input value.
332    const TEST_TRAILING_VALUE: u8 = 9;
333    /// Number of 3-bit indices in exact-multiple roundtrip test (2 groups of 8).
334    const TEST_3BIT_EXACT_COUNT: usize = 16;
335    /// Number of 3-bit indices in remainder roundtrip test.
336    const TEST_3BIT_REMAINDER_COUNT: usize = 11;
337    /// Number of 4-bit indices in even-count roundtrip test.
338    const TEST_4BIT_EVEN_COUNT: usize = 10;
339    /// Number of 4-bit indices in odd-count roundtrip test.
340    const TEST_4BIT_ODD_COUNT: usize = 7;
341    /// Number of 4-bit levels (2^4).
342    const TEST_4BIT_LEVELS: u8 = 16;
343
344    /// Number of 3-bit levels (2^3).
345    const TEST_3BIT_LEVELS: usize = 8;
346    /// Test scale value.
347    const TEST_SCALE: f32 = 1.5;
348    /// Test scale value (half).
349    const TEST_SCALE_HALF: f32 = 0.5;
350    /// Maximum valid 2-bit value (2^2 - 1).
351    const MAX_2BIT_VALUE: u8 = 3;
352    /// Number of 2-bit indices in exact-multiple roundtrip test (3 groups of 4).
353    const TEST_2BIT_EXACT_COUNT: usize = 12;
354    /// Number of 2-bit indices in remainder roundtrip test.
355    const TEST_2BIT_REMAINDER_COUNT: usize = 7;
356
357    // -- is_valid_bits -------------------------------------------------------
358
359    #[test]
360    fn is_valid_bits_accepts_2_3_and_4() {
361        assert!(is_valid_bits(BITS_TQ2));
362        assert!(is_valid_bits(BITS_TQ3));
363        assert!(is_valid_bits(BITS_TQ4));
364    }
365
366    #[test]
367    fn is_valid_bits_rejects_others() {
368        assert!(!is_valid_bits(0));
369        assert!(!is_valid_bits(1));
370        assert!(!is_valid_bits(5));
371    }
372
373    // -- is_valid_dim --------------------------------------------------------
374
375    #[test]
376    fn is_valid_dim_accepts_powers_of_two() {
377        assert!(is_valid_dim(TEST_DIM_128 / 2));
378        assert!(is_valid_dim(TEST_DIM_128));
379    }
380
381    #[test]
382    fn is_valid_dim_rejects_invalid() {
383        assert!(!is_valid_dim(0));
384        assert!(!is_valid_dim(3));
385        assert!(!is_valid_dim(100));
386    }
387
388    // -- packed_3bit_capacity ------------------------------------------------
389
390    #[test]
391    fn packed_3bit_capacity_no_remainder() {
392        // 4 groups of 8 -> 4 * 3 = 12 bytes
393        assert_eq!(
394            packed_3bit_capacity(TEST_3BIT_GROUPS, false),
395            TEST_3BIT_GROUPS * PACK_3BIT_BYTES
396        );
397    }
398
399    #[test]
400    fn packed_3bit_capacity_with_remainder() {
401        // 4 groups + remainder -> 4 * 3 + 3 = 15 bytes
402        assert_eq!(
403            packed_3bit_capacity(TEST_3BIT_GROUPS, true),
404            TEST_3BIT_GROUPS * PACK_3BIT_BYTES + PACK_3BIT_BYTES
405        );
406    }
407
408    #[test]
409    fn packed_3bit_capacity_zero_groups() {
410        assert_eq!(packed_3bit_capacity(0, false), 0);
411        assert_eq!(packed_3bit_capacity(0, true), 3);
412    }
413
414    // -- packed_4bit_capacity ------------------------------------------------
415
416    #[test]
417    fn packed_4bit_capacity_no_remainder() {
418        assert_eq!(
419            packed_4bit_capacity(TEST_4BIT_PAIRS, false),
420            TEST_4BIT_PAIRS
421        );
422    }
423
424    #[test]
425    fn packed_4bit_capacity_with_remainder() {
426        assert_eq!(
427            packed_4bit_capacity(TEST_4BIT_PAIRS, true),
428            TEST_4BIT_PAIRS + 1
429        );
430    }
431
432    // -- chunk_to_3bit_array / chunk_to_4bit_array ---------------------------
433
434    #[test]
435    fn chunk_to_3bit_array_preserves_values() {
436        let input: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
437        let arr = chunk_to_3bit_array(&input);
438        assert_eq!(arr, [0, 1, 2, 3, 4, 5, 6, 7]);
439    }
440
441    #[test]
442    fn chunk_to_4bit_array_preserves_values() {
443        let input: Vec<u8> = vec![10, 15];
444        let arr = chunk_to_4bit_array(&input);
445        assert_eq!(arr, [10, 15]);
446    }
447
448    // -- pad_remainder_3bit --------------------------------------------------
449
450    #[test]
451    fn pad_remainder_3bit_pads_correctly() {
452        let tail: Vec<u8> = vec![1, 2, 3];
453        let padded = pad_remainder_3bit(&tail);
454        assert_eq!(padded, [1, 2, 3, 0, 0, 0, 0, 0]);
455    }
456
457    #[test]
458    fn pad_remainder_3bit_single_element() {
459        let tail: Vec<u8> = vec![5];
460        let padded = pad_remainder_3bit(&tail);
461        assert_eq!(padded, [5, 0, 0, 0, 0, 0, 0, 0]);
462    }
463
464    // -- trailing_4bit_pair --------------------------------------------------
465
466    #[test]
467    fn trailing_4bit_pair_handles_single_element() {
468        let pair = trailing_4bit_pair(TEST_TRAILING_VALUE);
469        assert_eq!(pair, [TEST_TRAILING_VALUE, 0]);
470    }
471
472    // -- chunk_to_packed_3bit_array ------------------------------------------
473
474    #[test]
475    fn chunk_to_packed_3bit_array_preserves_values() {
476        let input: Vec<u8> = vec![0xAB, 0xCD, 0xEF];
477        let arr = chunk_to_packed_3bit_array(&input);
478        assert_eq!(arr, [0xAB, 0xCD, 0xEF]);
479    }
480
481    // -- 3-bit pack/unpack ---------------------------------------------------
482
483    #[test]
484    fn pack_unpack_3bit_identity() {
485        let values: [u8; PACK_3BIT_GROUP_SIZE] = [0, 1, 2, 3, 4, 5, 6, MAX_3BIT_VALUE];
486        let packed = pack_3bit(&values);
487        let unpacked = unpack_3bit(&packed);
488        assert_eq!(values, unpacked);
489    }
490
491    #[test]
492    fn pack_unpack_3bit_zeros() {
493        let values = [0u8; PACK_3BIT_GROUP_SIZE];
494        assert_eq!(unpack_3bit(&pack_3bit(&values)), values);
495    }
496
497    #[test]
498    fn pack_unpack_3bit_max() {
499        let values = [MAX_3BIT_VALUE; PACK_3BIT_GROUP_SIZE];
500        assert_eq!(unpack_3bit(&pack_3bit(&values)), values);
501    }
502
503    // -- 4-bit pack/unpack ---------------------------------------------------
504
505    #[test]
506    fn pack_unpack_4bit_identity() {
507        let values: [u8; PACK_4BIT_GROUP_SIZE] = [0, MAX_4BIT_VALUE];
508        let packed = pack_4bit(&values);
509        let unpacked = unpack_4bit(packed);
510        assert_eq!(values, unpacked);
511    }
512
513    #[test]
514    fn pack_unpack_4bit_zeros() {
515        let values = [0u8; PACK_4BIT_GROUP_SIZE];
516        assert_eq!(unpack_4bit(pack_4bit(&values)), values);
517    }
518
519    #[test]
520    fn pack_unpack_4bit_max() {
521        let values = [MAX_4BIT_VALUE; PACK_4BIT_GROUP_SIZE];
522        assert_eq!(unpack_4bit(pack_4bit(&values)), values);
523    }
524
525    // -- roundtrip: pack_indices_3bit / unpack_indices_3bit -------------------
526
527    #[test]
528    fn roundtrip_3bit_exact_multiple() {
529        let indices: Vec<u8> = (0..TEST_3BIT_EXACT_COUNT as u8)
530            .map(|i| i % (MAX_3BIT_VALUE + 1))
531            .collect();
532        let packed = pack_indices_3bit(&indices);
533        let unpacked = unpack_indices_3bit(&packed, indices.len());
534        assert_eq!(indices, unpacked);
535    }
536
537    #[test]
538    fn roundtrip_3bit_with_remainder() {
539        let indices: Vec<u8> = (0..TEST_3BIT_REMAINDER_COUNT as u8)
540            .map(|i| i % (MAX_3BIT_VALUE + 1))
541            .collect();
542        let packed = pack_indices_3bit(&indices);
543        let unpacked = unpack_indices_3bit(&packed, indices.len());
544        assert_eq!(indices, unpacked);
545    }
546
547    // -- roundtrip: pack_indices_4bit / unpack_indices_4bit -------------------
548
549    #[test]
550    fn roundtrip_4bit_even_count() {
551        let indices: Vec<u8> = (0..TEST_4BIT_EVEN_COUNT as u8)
552            .map(|i| i % TEST_4BIT_LEVELS)
553            .collect();
554        let packed = pack_indices_4bit(&indices);
555        let unpacked = unpack_indices_4bit(&packed, indices.len());
556        assert_eq!(indices, unpacked);
557    }
558
559    #[test]
560    fn roundtrip_4bit_odd_count() {
561        let indices: Vec<u8> = (0..TEST_4BIT_ODD_COUNT as u8)
562            .map(|i| i % TEST_4BIT_LEVELS)
563            .collect();
564        let packed = pack_indices_4bit(&indices);
565        let unpacked = unpack_indices_4bit(&packed, indices.len());
566        assert_eq!(indices, unpacked);
567    }
568
569    // -- config validation ---------------------------------------------------
570
571    #[test]
572    fn config_rejects_invalid_bits() {
573        assert!(TurboQuantConfig::new(1, TEST_BLOCK_SIZE).is_err());
574        assert!(TurboQuantConfig::new(5, TEST_BLOCK_SIZE).is_err());
575    }
576
577    #[test]
578    fn config_rejects_non_power_of_two() {
579        assert!(TurboQuantConfig::new(BITS_TQ3, 33).is_err());
580        assert!(TurboQuantConfig::new(BITS_TQ4, 0).is_err());
581    }
582
583    #[test]
584    fn config_accepts_valid() {
585        assert!(TurboQuantConfig::new(BITS_TQ2, TEST_BLOCK_SIZE).is_ok());
586        assert!(TurboQuantConfig::new(BITS_TQ3, TEST_BLOCK_SIZE).is_ok());
587        assert!(TurboQuantConfig::new(BITS_TQ4, TEST_DIM_128).is_ok());
588    }
589
590    // -- size_bytes -----------------------------------------------------------
591
592    /// Expected size for 3-bit packing of TEST_BLOCK_SIZE=32 indices:
593    /// packed = 32 * 3 / 8 = 12 bytes, + 2 (scale) = 14 bytes.
594    const TQ3_D32_EXPECTED_SIZE: usize = SCALE_SIZE_BYTES + 12;
595
596    /// Expected size for 4-bit packing of TEST_BLOCK_SIZE=32 indices:
597    /// packed = 32 / 2 = 16 bytes, + 2 (scale) = 18 bytes.
598    const TQ4_D32_EXPECTED_SIZE: usize = SCALE_SIZE_BYTES + 16;
599
600    #[test]
601    fn packed_block_tq3_size_bytes() {
602        let indices = vec![0u8; TEST_BLOCK_SIZE];
603        let block = PackedBlock::new(BITS_TQ3, f16::from_f32(1.0), &indices);
604        // 32 indices * 3 bits / 8 = 12 packed bytes + 2 scale bytes = 14
605        assert_eq!(block.size_bytes(), TQ3_D32_EXPECTED_SIZE);
606    }
607
608    #[test]
609    fn packed_block_tq4_size_bytes() {
610        let indices = vec![0u8; TEST_BLOCK_SIZE];
611        let block = PackedBlock::new(BITS_TQ4, f16::from_f32(1.0), &indices);
612        // 32 indices / 2 = 16 packed bytes + 2 scale bytes = 18
613        assert_eq!(block.size_bytes(), TQ4_D32_EXPECTED_SIZE);
614    }
615
616    // -- 2-bit pack/unpack ---------------------------------------------------
617
618    #[test]
619    fn pack_unpack_2bit_identity() {
620        let values: [u8; PACK_2BIT_GROUP_SIZE] = [0, 1, 2, MAX_2BIT_VALUE];
621        let packed = pack_2bit(&values);
622        let unpacked = unpack_2bit(packed);
623        assert_eq!(values, unpacked);
624    }
625
626    #[test]
627    fn pack_unpack_2bit_zeros() {
628        let values = [0u8; PACK_2BIT_GROUP_SIZE];
629        assert_eq!(unpack_2bit(pack_2bit(&values)), values);
630    }
631
632    #[test]
633    fn pack_unpack_2bit_max() {
634        let values = [MAX_2BIT_VALUE; PACK_2BIT_GROUP_SIZE];
635        assert_eq!(unpack_2bit(pack_2bit(&values)), values);
636    }
637
638    // -- roundtrip: pack_indices_2bit / unpack_indices_2bit -------------------
639
640    #[test]
641    fn roundtrip_2bit_exact_multiple() {
642        let indices: Vec<u8> = (0..TEST_2BIT_EXACT_COUNT as u8)
643            .map(|i| i % (MAX_2BIT_VALUE + 1))
644            .collect();
645        let packed = pack_indices_2bit(&indices);
646        let unpacked = unpack_indices_2bit(&packed, indices.len());
647        assert_eq!(indices, unpacked);
648    }
649
650    #[test]
651    fn roundtrip_2bit_with_remainder() {
652        let indices: Vec<u8> = (0..TEST_2BIT_REMAINDER_COUNT as u8)
653            .map(|i| i % (MAX_2BIT_VALUE + 1))
654            .collect();
655        let packed = pack_indices_2bit(&indices);
656        let unpacked = unpack_indices_2bit(&packed, indices.len());
657        assert_eq!(indices, unpacked);
658    }
659
660    // -- PackedBlock size_bytes for TQ2 --------------------------------------
661
662    #[test]
663    fn packed_block_tq2_size_bytes() {
664        let indices = vec![0u8; TEST_BLOCK_SIZE];
665        let block = PackedBlock::new(BITS_TQ2, f16::from_f32(1.0), &indices);
666        // 32 indices / 4 per byte = 8 bytes packed + 2 bytes scale = 10
667        assert_eq!(block.size_bytes(), 10);
668    }
669
670    // -- packed_indices accessor ---------------------------------------------
671
672    #[test]
673    fn packed_indices_returns_raw_bytes() {
674        let indices = vec![1u8, 2, 3, 0, 1, 2, 3, 0];
675        let block = PackedBlock::new(BITS_TQ2, f16::from_f32(TEST_SCALE), &indices);
676        let raw = block.packed_indices;
677        // 8 indices at 2-bit = 2 bytes
678        assert_eq!(raw.len(), 2);
679        // Unpack roundtrip: packing then accessing raw should match re-packing
680        let block2 = PackedBlock::new(BITS_TQ2, f16::from_f32(TEST_SCALE), &indices);
681        assert_eq!(raw, block2.packed_indices);
682    }
683
684    #[test]
685    fn packed_indices_3bit_length() {
686        let indices = vec![0u8; TEST_DIM_128];
687        let block = PackedBlock::new(BITS_TQ3, f16::from_f32(1.0), &indices);
688        // 128 indices at 3-bit: 128/8 = 16 groups × 3 bytes = 48 bytes
689        assert_eq!(block.packed_indices.len(), 48);
690    }
691
692    // -- from_raw constructor ------------------------------------------------
693
694    #[test]
695    fn from_raw_roundtrip() {
696        let indices = vec![3u8, 1, 0, 2, 3, 1, 0, 2];
697        let original = PackedBlock::new(BITS_TQ2, f16::from_f32(2.0), &indices);
698        let reconstructed = PackedBlock::from_raw(
699            original.bits,
700            original.scale,
701            original.packed_indices.to_vec(),
702        );
703        assert_eq!(reconstructed.bits, original.bits);
704        assert_eq!(reconstructed.scale, original.scale);
705        assert_eq!(reconstructed.packed_indices, original.packed_indices);
706        // Unpack should recover original indices
707        assert_eq!(reconstructed.unpack(indices.len()), indices);
708    }
709
710    #[test]
711    fn from_raw_3bit_roundtrip() {
712        let indices: Vec<u8> = (0..TEST_DIM_128)
713            .map(|i| (i % TEST_3BIT_LEVELS) as u8)
714            .collect();
715        let original = PackedBlock::new(BITS_TQ3, f16::from_f32(TEST_SCALE_HALF), &indices);
716        let reconstructed =
717            PackedBlock::from_raw(BITS_TQ3, original.scale, original.packed_indices.to_vec());
718        assert_eq!(reconstructed.unpack(TEST_DIM_128), indices);
719    }
720}