burn_std/tensor/
quantization.rs

1//! Quantization data representation.
2
3// Re-exported types
4pub use cubecl_common::quant::scheme::{
5    BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8/// Alignment (in bytes) for quantization parameters in serialized tensor data.
9///
10/// NOTE: This is currently f32-based since scales were originally always f32.
11/// With `QuantParam` now supporting different precisions (F16, BF16, etc.),
12/// this alignment may need to be revisited in the future.
13pub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();
14
15use alloc::vec::Vec;
16use core::any::TypeId;
17use num_traits::PrimInt;
18use serde::{Deserialize, Serialize};
19
20use crate::{DType, Shape, bytes::Bytes};
21
22#[derive(
23    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
24)]
25/// The precision of accumulating elements.
26pub enum QuantAcc {
27    /// Full precision.
28    #[default]
29    F32,
30    /// Half precision.
31    F16,
32    /// bfloat16 precision.
33    BF16,
34}
35
36/// Specify if the output of an operation is quantized using the scheme of the input
37/// or returned unquantized.
38#[derive(
39    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
40)]
41pub enum QuantPropagation {
42    /// The output is quantized using the scheme of the input.
43    Propagate,
44    /// The output is not quantized.
45    #[default]
46    Inhibit,
47}
48
49/// The quantization tensor data parameters.
50#[derive(Clone, Debug)]
51pub struct QParams<S> {
52    /// The scaling factor.
53    pub scales: S,
54}
55
56/// A quantization parameter tensor descriptor.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct QParamTensor {
59    /// Start of the tensor in the buffer
60    pub offset_start: usize,
61    /// Offset of tensor end from the end of the buffer
62    pub offset_end: usize,
63    /// Shape of the tensor
64    pub shape: Shape,
65    /// Strides of the tensor
66    pub strides: Vec<usize>,
67    /// Data type of the tensor
68    pub dtype: DType,
69}
70
71/// Calculate the shape of the quantization parameters for a given tensor and level
72pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
73    match level {
74        QuantLevel::Tensor => Shape::new([1]),
75        QuantLevel::Block(block_size) => {
76            let mut params_shape = data_shape.clone();
77            let block_size = block_size.to_dim_vec(data_shape.num_dims());
78
79            for (shape, block_size) in params_shape.dims.iter_mut().zip(block_size) {
80                *shape = (*shape).div_ceil(block_size as usize);
81            }
82
83            params_shape
84        }
85    }
86}
87
88/// Quantized data bytes representation.
89///
90/// # Notes
91/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
92///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
93///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).
94/// 2) Quantization parameters are appended to the tensor data.
95///    As such, the last bytes always correspond to the scale parameter.
96///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
97pub struct QuantizedBytes {
98    /// The quantized values and quantization parameters represented as bytes.
99    pub bytes: Bytes,
100    /// The quantization scheme.
101    pub scheme: QuantScheme,
102    /// The number of quantized elements.
103    pub num_elements: usize,
104}
105
106impl QuantizedBytes {
107    /// Creates a new quantized bytes representation.
108    pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
109        value: Vec<E>,
110        scheme: QuantScheme,
111        scales: &[f32],
112    ) -> Self {
113        let num_elements = value.len();
114        // Only used for 8-bit quantization data comparison in tests
115        if TypeId::of::<E>() != TypeId::of::<i8>() {
116            panic!("Invalid quantized type");
117        }
118
119        // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
120        let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
121        let mut bytes = Bytes::from_elems(i8s);
122
123        match scheme.level {
124            QuantLevel::Tensor => {
125                let scale_bytes = bytemuck::bytes_of(&scales[0]);
126                bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);
127            }
128            QuantLevel::Block(_block_size) => {
129                let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
130                for scale in scales {
131                    scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
132                }
133                bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);
134            }
135        }
136
137        Self {
138            bytes,
139            scheme,
140            num_elements,
141        }
142    }
143
144    /// Returns the int8 quantized values with the quantization parameters.
145    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
146        let (values, (qparams, num_params)) = self.split_values_off();
147
148        // Quantization parameters are added at the end of the tensor data.
149        // As such, the last bytes always correspond to the scale parameter(s).
150        // For example, per-block quantization can have multiple parameters for a single tensor:
151        // [scale, scale, scale, ...]
152        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
153        let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
154        let total_bytes = qparams_bytes.len();
155
156        let scales_size = scale_size * num_params;
157
158        let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
159
160        (values, QParams { scales })
161    }
162
163    fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
164        let mut values = read_bytes_to_i8(self.bytes);
165
166        let scale_size = num_params * size_of::<f32>();
167        let values_end = values.len() - scale_size;
168
169        let qparams = values.split_off(values_end);
170
171        let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
172            let mut qparams = core::mem::ManuallyDrop::new(qparams);
173            unsafe {
174                Vec::<u32>::from_raw_parts(
175                    qparams.as_mut_ptr() as _,
176                    qparams.len() / 4,
177                    qparams.capacity() / 4,
178                )
179            }
180        } else {
181            #[cfg(target_endian = "little")]
182            {
183                // SAFETY: quantized bytes representation is created from packed u32 values in little endian
184                bytemuck::cast_vec(qparams)
185            }
186            #[cfg(target_endian = "big")]
187            {
188                crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
189            }
190        };
191        (values, qparams)
192    }
193
194    /// Splits the quantized values of the tensor from the quantization parameters.
195    ///
196    /// Returns the values in i8 and a newly allocated vector containing the quantization parameters.
197    fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
198        let num_params = match self.scheme.level {
199            QuantLevel::Tensor => 1,
200            QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
201        };
202
203        if let QuantStore::PackedU32(packed_dim) = self.scheme.store {
204            assert_eq!(
205                packed_dim, 0,
206                "Packing must be on innermost dimension for splitting off values"
207            );
208        }
209
210        let (values, qparams) = match self.scheme.store {
211            QuantStore::Native => self.split_i8_values(num_params),
212            QuantStore::PackedU32(_) => match self.scheme.value {
213                QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
214                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
215                    let mut values = self.bytes.try_into_vec::<u32>().unwrap();
216                    let scale_size = num_params; // size of f32 same as u32
217                    let values_end = values.len() - scale_size;
218
219                    let qparams = values.split_off(values_end);
220                    // Sub-byte values are unpacked as i8s for value equality tests
221                    let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
222                    (values, qparams)
223                }
224                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
225                    unimplemented!("Not yet supported")
226                }
227            },
228            QuantStore::PackedNative(_) => unimplemented!("Not yet supported"),
229        };
230
231        (values, (qparams, num_params))
232    }
233}
234
235fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
236    match bytes.try_into_vec::<i8>() {
237        Ok(val) => val,
238        // Safety,
239        //
240        // `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment.
241        Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
242    }
243}
244
245/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
246pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
247    // Shift and combine groups of four 8-bit values into a u32.
248    // Same as doing this:
249    //     let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
250    #[cfg(target_endian = "big")]
251    {
252        values
253            .chunks(4)
254            .map(|x| {
255                x.iter()
256                    .enumerate()
257                    .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
258            })
259            .collect()
260    }
261
262    // The order of bytes in little endian matches the above description, we just need to
263    // handle padding when the number of values is not a factor of 4
264    #[cfg(target_endian = "little")]
265    {
266        let mut values = values;
267        let remainder = values.len() % 4;
268        if remainder != 0 {
269            // Pad with zeros
270            values.extend(core::iter::repeat_n(0, 4 - remainder));
271        }
272
273        let len = values.len() / 4;
274        let capacity = values.capacity() / 4;
275
276        // Pre-forget the old vec and re-interpret as u32
277        let mut values = core::mem::ManuallyDrop::new(values);
278        let ptr = values.as_mut_ptr() as *mut u32;
279
280        unsafe { Vec::from_raw_parts(ptr, len, capacity) }
281    }
282}
283
284/// Unpack integer values into a sequence of signed 8-bit integers.
285pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
286    values: &[Q],
287    numel: usize,
288    value: &QuantValue,
289) -> Vec<i8> {
290    let size_store = size_of::<Q>() * 8;
291    let size_quant = value.size_bits();
292    let num_quants = size_store / size_quant;
293    let mask = Q::from((1 << size_quant) - 1).unwrap();
294    let sign_shift = 8 - size_quant; // sign extension for sub-byte values
295    values
296        .iter()
297        .enumerate()
298        .flat_map(|(i, &packed)| {
299            // A single u32 could contain less than four 8-bit values...
300            let n = core::cmp::min(num_quants, numel - i * num_quants);
301            // Extract each 8-bit segment from u32 and cast back to i8
302            // Same as doing this (when 4 values are fully packed):
303            //     let a = (packed & 0xFF) as i8;
304            //     let b = ((packed >> 8) & 0xFF) as i8;
305            //     let c = ((packed >> 16) & 0xFF) as i8;
306            //     let d = ((packed >> 24) & 0xFF) as i8;
307            (0..n).map(move |i| {
308                let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
309                ((raw << sign_shift) as i8) >> sign_shift
310            })
311        })
312        .collect()
313}
314
315#[cfg(test)]
316mod tests {
317
318    use super::*;
319    use alloc::vec;
320
321    #[test]
322    fn should_pack_i8s_to_u32() {
323        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
324
325        assert_eq!(packed, vec![2147287680]);
326    }
327
328    #[test]
329    fn should_pack_i8s_to_u32_padded() {
330        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
331        let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
332
333        assert_eq!(packed, vec![2147287680, 55]);
334        assert_eq!(packed, packed_padded);
335    }
336
337    #[test]
338    fn should_unpack_u32s_to_i8s() {
339        let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
340
341        assert_eq!(unpacked, vec![-128, 2, -3, 127]);
342    }
343
344    #[test]
345    fn should_unpack_u32s_to_i8s_padded() {
346        let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
347
348        assert_eq!(unpacked, vec![55]);
349    }
350
351    #[test]
352    fn should_unpack_u32s_to_i8s_arange() {
353        let unpacked = unpack_q_to_i8s(
354            &[
355                0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
356                1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
357                2004318071,
358            ],
359            128,
360            &QuantValue::Q4S,
361        );
362
363        assert_eq!(
364            unpacked,
365            vec![
366                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
367                2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
368                3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
369                5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
370                6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
371            ]
372        );
373    }
374
375    #[test]
376    fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
377        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
378        let scale = 0.03937008;
379        let values = vec![0i8, 25, 51, 76, 102, 127];
380
381        let q_bytes = QuantizedBytes::new(
382            values.clone(),
383            QuantScheme::default()
384                .with_value(QuantValue::Q8S)
385                .with_store(QuantStore::Native),
386            &[scale],
387        );
388
389        let (q_values, qparams) = q_bytes.into_vec_i8();
390
391        assert_eq!(qparams.scales, vec![scale]);
392
393        assert_eq!(q_values, values);
394    }
395}