burn_tensor/tensor/quantization/
data.rs

1use crate::quantization::QuantValue;
2use alloc::vec::Vec;
3use num_traits::PrimInt;
4
5/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
6pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
7    // Shift and combine groups of four 8-bit values into a u32.
8    // Same as doing this:
9    //     let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
10    #[cfg(target_endian = "big")]
11    {
12        values
13            .chunks(4)
14            .map(|x| {
15                x.iter()
16                    .enumerate()
17                    .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
18            })
19            .collect()
20    }
21
22    // The order of bytes in little endian matches the above description, we just need to
23    // handle padding when the number of values is not a factor of 4
24    #[cfg(target_endian = "little")]
25    {
26        let mut values = values;
27        let remainder = values.len() % 4;
28        if remainder != 0 {
29            // Pad with zeros
30            values.extend(core::iter::repeat_n(0, 4 - remainder));
31        }
32
33        let len = values.len() / 4;
34        let capacity = values.capacity() / 4;
35
36        // Pre-forget the old vec and re-interpret as u32
37        let mut values = core::mem::ManuallyDrop::new(values);
38        let ptr = values.as_mut_ptr() as *mut u32;
39
40        unsafe { Vec::from_raw_parts(ptr, len, capacity) }
41    }
42}
43
44/// Unpack integer values into a sequence of signed 8-bit integers.
45pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
46    values: &[Q],
47    numel: usize,
48    value: &QuantValue,
49) -> Vec<i8> {
50    let size_store = size_of::<Q>() * 8;
51    let size_quant = value.size_bits();
52    let num_quants = size_store / size_quant;
53    let mask = Q::from((1 << size_quant) - 1).unwrap();
54    let sign_shift = 8 - size_quant; // sign extension for sub-byte values
55    values
56        .iter()
57        .enumerate()
58        .flat_map(|(i, &packed)| {
59            // A single u32 could contain less than four 8-bit values...
60            let n = core::cmp::min(num_quants, numel - i * num_quants);
61            // Extract each 8-bit segment from u32 and cast back to i8
62            // Same as doing this (when 4 values are fully packed):
63            //     let a = (packed & 0xFF) as i8;
64            //     let b = ((packed >> 8) & 0xFF) as i8;
65            //     let c = ((packed >> 16) & 0xFF) as i8;
66            //     let d = ((packed >> 24) & 0xFF) as i8;
67            (0..n).map(move |i| {
68                let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
69                ((raw << sign_shift) as i8) >> sign_shift
70            })
71        })
72        .collect()
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use alloc::vec;
79
80    #[test]
81    fn should_pack_i8s_to_u32() {
82        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
83
84        assert_eq!(packed, vec![2147287680]);
85    }
86
87    #[test]
88    fn should_pack_i8s_to_u32_padded() {
89        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
90        let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
91
92        assert_eq!(packed, vec![2147287680, 55]);
93        assert_eq!(packed, packed_padded);
94    }
95
96    #[test]
97    fn should_unpack_u32s_to_i8s() {
98        let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
99
100        assert_eq!(unpacked, vec![-128, 2, -3, 127]);
101    }
102
103    #[test]
104    fn should_unpack_u32s_to_i8s_padded() {
105        let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
106
107        assert_eq!(unpacked, vec![55]);
108    }
109
110    #[test]
111    fn should_unpack_u32s_to_i8s_arange() {
112        let unpacked = unpack_q_to_i8s(
113            &[
114                0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
115                1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
116                2004318071,
117            ],
118            128,
119            &QuantValue::Q4S,
120        );
121
122        assert_eq!(
123            unpacked,
124            vec![
125                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
126                2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
127                3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
128                5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
129                6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
130            ]
131        );
132    }
133}