burn_tensor/tensor/quantization/
data.rs1use alloc::vec::Vec;
2
3pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
5 #[cfg(target_endian = "big")]
9 {
10 values
11 .chunks(4)
12 .map(|x| {
13 x.iter()
14 .enumerate()
15 .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
16 })
17 .collect()
18 }
19
20 #[cfg(target_endian = "little")]
23 {
24 let mut values = values;
25 let remainder = values.len() % 4;
26 if remainder != 0 {
27 values.extend(core::iter::repeat_n(0, 4 - remainder));
29 }
30
31 let len = values.len() / 4;
32 let capacity = values.capacity() / 4;
33
34 let mut values = core::mem::ManuallyDrop::new(values);
36 let ptr = values.as_mut_ptr() as *mut u32;
37
38 unsafe { Vec::from_raw_parts(ptr, len, capacity) }
39 }
40}
41
42pub fn unpack_u32s_to_i8s(values: Vec<u32>, numel: usize) -> Vec<i8> {
44 #[cfg(target_endian = "big")]
45 {
46 values
47 .into_iter()
48 .enumerate()
49 .flat_map(|(i, packed)| {
50 let n = core::cmp::min(4, numel - i * 4);
52 (0..n).map(move |i| (packed >> (i * 8) & 0xFF) as i8)
59 })
60 .collect()
61 }
62
63 #[cfg(target_endian = "little")]
66 {
67 let len = values.len() * 4;
68 let capacity = values.capacity() * 4;
69
70 let mut values = core::mem::ManuallyDrop::new(values);
72 let ptr = values.as_mut_ptr() as *mut i8;
73
74 let mut vec = unsafe { Vec::from_raw_parts(ptr, len, capacity) };
75
76 let padding = len - numel;
77 if padding > 0 {
78 vec.truncate(len - padding);
79 }
80
81 vec
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use alloc::vec;
89
90 #[test]
91 fn should_pack_i8s_to_u32() {
92 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
93
94 assert_eq!(packed, vec![2147287680]);
95 }
96
97 #[test]
98 fn should_pack_i8s_to_u32_padded() {
99 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
100 let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
101
102 assert_eq!(packed, vec![2147287680, 55]);
103 assert_eq!(packed, packed_padded);
104 }
105
106 #[test]
107 fn should_unpack_u32s_to_i8s() {
108 let unpacked = unpack_u32s_to_i8s(vec![2147287680u32], 4);
109
110 assert_eq!(unpacked, vec![-128, 2, -3, 127]);
111 }
112
113 #[test]
114 fn should_unpack_u32s_to_i8s_padded() {
115 let unpacked = unpack_u32s_to_i8s(vec![55u32], 1);
116
117 assert_eq!(unpacked, vec![55]);
118 }
119}