1use crate::quantization::QuantValue;
2use alloc::vec::Vec;
3use num_traits::PrimInt;
4
5pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
7 #[cfg(target_endian = "big")]
11 {
12 values
13 .chunks(4)
14 .map(|x| {
15 x.iter()
16 .enumerate()
17 .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
18 })
19 .collect()
20 }
21
22 #[cfg(target_endian = "little")]
25 {
26 let mut values = values;
27 let remainder = values.len() % 4;
28 if remainder != 0 {
29 values.extend(core::iter::repeat_n(0, 4 - remainder));
31 }
32
33 let len = values.len() / 4;
34 let capacity = values.capacity() / 4;
35
36 let mut values = core::mem::ManuallyDrop::new(values);
38 let ptr = values.as_mut_ptr() as *mut u32;
39
40 unsafe { Vec::from_raw_parts(ptr, len, capacity) }
41 }
42}
43
44pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
46 values: &[Q],
47 numel: usize,
48 value: &QuantValue,
49) -> Vec<i8> {
50 let size_store = size_of::<Q>() * 8;
51 let size_quant = value.size_bits();
52 let num_quants = size_store / size_quant;
53 let mask = Q::from((1 << size_quant) - 1).unwrap();
54 let sign_shift = 8 - size_quant; values
56 .iter()
57 .enumerate()
58 .flat_map(|(i, &packed)| {
59 let n = core::cmp::min(num_quants, numel - i * num_quants);
61 (0..n).map(move |i| {
68 let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
69 ((raw << sign_shift) as i8) >> sign_shift
70 })
71 })
72 .collect()
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use alloc::vec;
79
80 #[test]
81 fn should_pack_i8s_to_u32() {
82 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
83
84 assert_eq!(packed, vec![2147287680]);
85 }
86
87 #[test]
88 fn should_pack_i8s_to_u32_padded() {
89 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
90 let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
91
92 assert_eq!(packed, vec![2147287680, 55]);
93 assert_eq!(packed, packed_padded);
94 }
95
96 #[test]
97 fn should_unpack_u32s_to_i8s() {
98 let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
99
100 assert_eq!(unpacked, vec![-128, 2, -3, 127]);
101 }
102
103 #[test]
104 fn should_unpack_u32s_to_i8s_padded() {
105 let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
106
107 assert_eq!(unpacked, vec![55]);
108 }
109
110 #[test]
111 fn should_unpack_u32s_to_i8s_arange() {
112 let unpacked = unpack_q_to_i8s(
113 &[
114 0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
115 1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
116 2004318071,
117 ],
118 128,
119 &QuantValue::Q4S,
120 );
121
122 assert_eq!(
123 unpacked,
124 vec![
125 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
126 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
127 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
128 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
129 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
130 ]
131 );
132 }
133}