burn_std/tensor/
quantization.rs

1//! Quantization data representation.
2
3// Re-exported types
4pub use cubecl_quant::scheme::{
5    BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8use alloc::vec::Vec;
9use core::any::TypeId;
10use num_traits::PrimInt;
11
12use crate::bytes::Bytes;
13
14/// The quantization tensor data parameters.
15#[derive(Clone, Debug)]
16pub struct QParams<S> {
17    /// The scaling factor.
18    pub scales: S,
19}
20
21/// Quantized data bytes representation.
22///
23/// # Notes
24/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
25///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
26///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).
27/// 2) Quantization parameters are appended to the tensor data.
28///    As such, the last bytes always correspond to the scale parameter.
29///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
30pub struct QuantizedBytes {
31    /// The quantized values and quantization parameters represented as bytes.
32    pub bytes: Bytes,
33    /// The quantization scheme.
34    pub scheme: QuantScheme,
35    /// The number of quantized elements.
36    pub num_elements: usize,
37}
38
39impl QuantizedBytes {
40    /// Creates a new quantized bytes representation.
41    pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
42        value: Vec<E>,
43        scheme: QuantScheme,
44        scales: &[f32],
45    ) -> Self {
46        let num_elements = value.len();
47        // Only used for 8-bit quantization data comparison in tests
48        if TypeId::of::<E>() != TypeId::of::<i8>() {
49            panic!("Invalid quantized type");
50        }
51
52        // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
53        let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
54        let mut bytes = Bytes::from_elems(i8s);
55
56        match scheme.level {
57            QuantLevel::Tensor => {
58                let scale_bytes = bytemuck::bytes_of(&scales[0]);
59                bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
60            }
61            QuantLevel::Block(_block_size) => {
62                let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
63                for scale in scales {
64                    scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
65                }
66                bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::<f32>());
67            }
68        }
69
70        Self {
71            bytes,
72            scheme,
73            num_elements,
74        }
75    }
76
77    /// Returns the int8 quantized values with the quantization parameters.
78    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
79        let (values, (qparams, num_params)) = self.split_values_off();
80
81        // Quantization parameters are added at the end of the tensor data.
82        // As such, the last bytes always correspond to the scale parameter(s).
83        // For example, per-block quantization can have multiple parameters for a single tensor:
84        // [scale, scale, scale, ...]
85        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
86        let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
87        let total_bytes = qparams_bytes.len();
88
89        let scales_size = scale_size * num_params;
90
91        let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
92
93        (values, QParams { scales })
94    }
95
96    fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
97        let mut values = read_bytes_to_i8(self.bytes);
98
99        let scale_size = num_params * size_of::<f32>();
100        let values_end = values.len() - scale_size;
101
102        let qparams = values.split_off(values_end);
103
104        let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
105            let mut qparams = core::mem::ManuallyDrop::new(qparams);
106            unsafe {
107                Vec::<u32>::from_raw_parts(
108                    qparams.as_mut_ptr() as _,
109                    qparams.len() / 4,
110                    qparams.capacity() / 4,
111                )
112            }
113        } else {
114            #[cfg(target_endian = "little")]
115            {
116                // SAFETY: quantized bytes representation is created from packed u32 values in little endian
117                bytemuck::cast_vec(qparams)
118            }
119            #[cfg(target_endian = "big")]
120            {
121                crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
122            }
123        };
124        (values, qparams)
125    }
126
127    /// Splits the quantized values of the tensor from the quantization parameters.
128    ///
129    /// Returns the values in i8 and a newly allocated vector containing the quantization parameters.
130    fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
131        let num_params = match self.scheme.level {
132            QuantLevel::Tensor => 1,
133            QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
134        };
135
136        let (values, qparams) = match self.scheme.store {
137            QuantStore::Native => self.split_i8_values(num_params),
138            QuantStore::U32 => match self.scheme.value {
139                QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
140                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
141                    let mut values = self.bytes.try_into_vec::<u32>().unwrap();
142                    let scale_size = num_params; // size of f32 same as u32
143                    let values_end = values.len() - scale_size;
144
145                    let qparams = values.split_off(values_end);
146                    // Sub-byte values are unpacked as i8s for value equality tests
147                    let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
148                    (values, qparams)
149                }
150                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
151                    unimplemented!("Not yet supported")
152                }
153            },
154        };
155
156        (values, (qparams, num_params))
157    }
158}
159
160fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
161    match bytes.try_into_vec::<i8>() {
162        Ok(val) => val,
163        // Safety,
164        //
165        // `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment.
166        Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
167    }
168}
169
170/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
171pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
172    // Shift and combine groups of four 8-bit values into a u32.
173    // Same as doing this:
174    //     let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
175    #[cfg(target_endian = "big")]
176    {
177        values
178            .chunks(4)
179            .map(|x| {
180                x.iter()
181                    .enumerate()
182                    .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
183            })
184            .collect()
185    }
186
187    // The order of bytes in little endian matches the above description, we just need to
188    // handle padding when the number of values is not a factor of 4
189    #[cfg(target_endian = "little")]
190    {
191        let mut values = values;
192        let remainder = values.len() % 4;
193        if remainder != 0 {
194            // Pad with zeros
195            values.extend(core::iter::repeat_n(0, 4 - remainder));
196        }
197
198        let len = values.len() / 4;
199        let capacity = values.capacity() / 4;
200
201        // Pre-forget the old vec and re-interpret as u32
202        let mut values = core::mem::ManuallyDrop::new(values);
203        let ptr = values.as_mut_ptr() as *mut u32;
204
205        unsafe { Vec::from_raw_parts(ptr, len, capacity) }
206    }
207}
208
209/// Unpack integer values into a sequence of signed 8-bit integers.
210pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
211    values: &[Q],
212    numel: usize,
213    value: &QuantValue,
214) -> Vec<i8> {
215    let size_store = size_of::<Q>() * 8;
216    let size_quant = value.size_bits();
217    let num_quants = size_store / size_quant;
218    let mask = Q::from((1 << size_quant) - 1).unwrap();
219    let sign_shift = 8 - size_quant; // sign extension for sub-byte values
220    values
221        .iter()
222        .enumerate()
223        .flat_map(|(i, &packed)| {
224            // A single u32 could contain less than four 8-bit values...
225            let n = core::cmp::min(num_quants, numel - i * num_quants);
226            // Extract each 8-bit segment from u32 and cast back to i8
227            // Same as doing this (when 4 values are fully packed):
228            //     let a = (packed & 0xFF) as i8;
229            //     let b = ((packed >> 8) & 0xFF) as i8;
230            //     let c = ((packed >> 16) & 0xFF) as i8;
231            //     let d = ((packed >> 24) & 0xFF) as i8;
232            (0..n).map(move |i| {
233                let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
234                ((raw << sign_shift) as i8) >> sign_shift
235            })
236        })
237        .collect()
238}
239
240#[cfg(test)]
241mod tests {
242
243    use super::*;
244    use alloc::vec;
245
246    #[test]
247    fn should_pack_i8s_to_u32() {
248        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
249
250        assert_eq!(packed, vec![2147287680]);
251    }
252
253    #[test]
254    fn should_pack_i8s_to_u32_padded() {
255        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
256        let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
257
258        assert_eq!(packed, vec![2147287680, 55]);
259        assert_eq!(packed, packed_padded);
260    }
261
262    #[test]
263    fn should_unpack_u32s_to_i8s() {
264        let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
265
266        assert_eq!(unpacked, vec![-128, 2, -3, 127]);
267    }
268
269    #[test]
270    fn should_unpack_u32s_to_i8s_padded() {
271        let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
272
273        assert_eq!(unpacked, vec![55]);
274    }
275
276    #[test]
277    fn should_unpack_u32s_to_i8s_arange() {
278        let unpacked = unpack_q_to_i8s(
279            &[
280                0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
281                1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
282                2004318071,
283            ],
284            128,
285            &QuantValue::Q4S,
286        );
287
288        assert_eq!(
289            unpacked,
290            vec![
291                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
292                2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
293                3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
294                5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
295                6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
296            ]
297        );
298    }
299
300    #[test]
301    fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
302        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
303        let scale = 0.03937008;
304        let values = vec![0i8, 25, 51, 76, 102, 127];
305
306        let q_bytes = QuantizedBytes::new(
307            values.clone(),
308            QuantScheme::default()
309                .with_value(QuantValue::Q8S)
310                .with_store(QuantStore::Native),
311            &[scale],
312        );
313
314        let (q_values, qparams) = q_bytes.into_vec_i8();
315
316        assert_eq!(qparams.scales, vec![scale]);
317
318        assert_eq!(q_values, values);
319    }
320}