burn_tensor/tensor/quantization/
data.rs

1use alloc::vec::Vec;
2
3/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
4pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
5    // Shift and combine groups of four 8-bit values into a u32.
6    // Same as doing this:
7    //     let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
8    #[cfg(target_endian = "big")]
9    {
10        values
11            .chunks(4)
12            .map(|x| {
13                x.iter()
14                    .enumerate()
15                    .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
16            })
17            .collect()
18    }
19
20    // The order of bytes in little endian matches the above description, we just need to
21    // handle padding when the number of values is not a factor of 4
22    #[cfg(target_endian = "little")]
23    {
24        let mut values = values;
25        let remainder = values.len() % 4;
26        if remainder != 0 {
27            // Pad with zeros
28            values.extend(core::iter::repeat_n(0, 4 - remainder));
29        }
30
31        let len = values.len() / 4;
32        let capacity = values.capacity() / 4;
33
34        // Pre-forget the old vec and re-interpret as u32
35        let mut values = core::mem::ManuallyDrop::new(values);
36        let ptr = values.as_mut_ptr() as *mut u32;
37
38        unsafe { Vec::from_raw_parts(ptr, len, capacity) }
39    }
40}
41
42/// Unpack 32-bit unsigned integer values into a sequence of signed 8-bit integers.
43pub fn unpack_u32s_to_i8s(values: Vec<u32>, numel: usize) -> Vec<i8> {
44    #[cfg(target_endian = "big")]
45    {
46        values
47            .into_iter()
48            .enumerate()
49            .flat_map(|(i, packed)| {
50                // A single u32 could contain less than four 8-bit values...
51                let n = core::cmp::min(4, numel - i * 4);
52                // Extract each 8-bit segment from u32 and cast back to i8
53                // Same as doing this (when 4 values are fully packed):
54                //     let a = (packed & 0xFF) as i8;
55                //     let b = ((packed >> 8) & 0xFF) as i8;
56                //     let c = ((packed >> 16) & 0xFF) as i8;
57                //     let d = ((packed >> 24) & 0xFF) as i8;
58                (0..n).map(move |i| (packed >> (i * 8) & 0xFF) as i8)
59            })
60            .collect()
61    }
62
63    // The order of bytes in little endian matches the above description, we just need to
64    // handle padding when the number of elements is not a factor of 4
65    #[cfg(target_endian = "little")]
66    {
67        let len = values.len() * 4;
68        let capacity = values.capacity() * 4;
69
70        // Pre-forget the old vec and re-interpret as u32
71        let mut values = core::mem::ManuallyDrop::new(values);
72        let ptr = values.as_mut_ptr() as *mut i8;
73
74        let mut vec = unsafe { Vec::from_raw_parts(ptr, len, capacity) };
75
76        let padding = len - numel;
77        if padding > 0 {
78            vec.truncate(len - padding);
79        }
80
81        vec
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use alloc::vec;
89
90    #[test]
91    fn should_pack_i8s_to_u32() {
92        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
93
94        assert_eq!(packed, vec![2147287680]);
95    }
96
97    #[test]
98    fn should_pack_i8s_to_u32_padded() {
99        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
100        let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
101
102        assert_eq!(packed, vec![2147287680, 55]);
103        assert_eq!(packed, packed_padded);
104    }
105
106    #[test]
107    fn should_unpack_u32s_to_i8s() {
108        let unpacked = unpack_u32s_to_i8s(vec![2147287680u32], 4);
109
110        assert_eq!(unpacked, vec![-128, 2, -3, 127]);
111    }
112
113    #[test]
114    fn should_unpack_u32s_to_i8s_padded() {
115        let unpacked = unpack_u32s_to_i8s(vec![55u32], 1);
116
117        assert_eq!(unpacked, vec![55]);
118    }
119}