burn_std/tensor/
quantization.rs

1//! Quantization data representation.
2
3// Re-exported types
4pub use cubecl_common::quant::scheme::{
5    BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8/// Alignment (in bytes) for quantization parameters in serialized tensor data.
9///
10/// NOTE: This is currently f32-based since scales were originally always f32.
11/// With `QuantParam` now supporting different precisions (F16, BF16, etc.),
12/// this alignment may need to be revisited in the future.
13pub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();
14
15use alloc::vec::Vec;
16use core::any::TypeId;
17use num_traits::PrimInt;
18use serde::{Deserialize, Serialize};
19
20use crate::{DType, Shape, bytes::Bytes};
21
22#[derive(
23    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
24)]
25/// The precision of accumulating elements.
26pub enum QuantAcc {
27    /// Full precision.
28    #[default]
29    F32,
30    /// Half precision.
31    F16,
32    /// bfloat16 precision.
33    BF16,
34}
35
36/// Specify if the output of an operation is quantized using the scheme of the input
37/// or returned unquantized.
38#[derive(
39    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
40)]
41pub enum QuantPropagation {
42    /// The output is quantized using the scheme of the input.
43    Propagate,
44    /// The output is not quantized.
45    #[default]
46    Inhibit,
47}
48
49/// The quantization tensor data parameters.
50#[derive(Clone, Debug)]
51pub struct QParams<S> {
52    /// The scaling factor.
53    pub scales: S,
54}
55
56/// A quantization parameter tensor descriptor.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct QParamTensor {
59    /// Start of the tensor in the buffer
60    pub offset_start: usize,
61    /// Offset of tensor end from the end of the buffer
62    pub offset_end: usize,
63    /// Shape of the tensor
64    pub shape: Shape,
65    /// Strides of the tensor
66    pub strides: Vec<usize>,
67    /// Data type of the tensor
68    pub dtype: DType,
69}
70
71/// Calculate the shape of the quantization parameters for a given tensor and level
72pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
73    match level {
74        QuantLevel::Tensor => Shape::new([1]),
75        QuantLevel::Block(block_size) => {
76            let mut params_shape = data_shape.clone();
77            let block_size = block_size.to_dim_vec(data_shape.num_dims());
78
79            for (shape, block_size) in params_shape.dims.iter_mut().zip(block_size) {
80                *shape = (*shape).div_ceil(block_size as usize);
81            }
82
83            params_shape
84        }
85    }
86}
87
88/// Quantized data bytes representation.
89///
90/// # Notes
91/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
92///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
93///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).
94/// 2) Quantization parameters are appended to the tensor data.
95///    As such, the last bytes always correspond to the scale parameter.
96///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
97pub struct QuantizedBytes {
98    /// The quantized values and quantization parameters represented as bytes.
99    pub bytes: Bytes,
100    /// The quantization scheme.
101    pub scheme: QuantScheme,
102    /// The number of quantized elements.
103    pub num_elements: usize,
104}
105
106impl QuantizedBytes {
107    /// Creates a new quantized bytes representation.
108    pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
109        value: Vec<E>,
110        scheme: QuantScheme,
111        scales: &[f32],
112    ) -> Self {
113        let num_elements = value.len();
114        // Only used for 8-bit quantization data comparison in tests
115        if TypeId::of::<E>() != TypeId::of::<i8>() {
116            panic!("Invalid quantized type");
117        }
118
119        // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
120        let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
121        let mut bytes = Bytes::from_elems(i8s);
122
123        match scheme.level {
124            QuantLevel::Tensor => {
125                let scale_bytes = bytemuck::bytes_of(&scales[0]);
126                bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);
127            }
128            QuantLevel::Block(_block_size) => {
129                let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
130                for scale in scales {
131                    scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
132                }
133                bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);
134            }
135        }
136
137        Self {
138            bytes,
139            scheme,
140            num_elements,
141        }
142    }
143
144    /// Returns the int8 quantized values with the quantization parameters.
145    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
146        let (values, (qparams, num_params)) = self.split_values_off();
147
148        // Quantization parameters are added at the end of the tensor data.
149        // As such, the last bytes always correspond to the scale parameter(s).
150        // For example, per-block quantization can have multiple parameters for a single tensor:
151        // [scale, scale, scale, ...]
152        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
153        let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
154        let total_bytes = qparams_bytes.len();
155
156        let scales_size = scale_size * num_params;
157
158        let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
159
160        (values, QParams { scales })
161    }
162
163    fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
164        let mut values = read_bytes_to_i8(self.bytes);
165
166        let scale_size = num_params * size_of::<f32>();
167        let values_end = values.len() - scale_size;
168
169        let qparams = values.split_off(values_end);
170
171        let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
172            let mut qparams = core::mem::ManuallyDrop::new(qparams);
173            unsafe {
174                Vec::<u32>::from_raw_parts(
175                    qparams.as_mut_ptr() as _,
176                    qparams.len() / 4,
177                    qparams.capacity() / 4,
178                )
179            }
180        } else {
181            #[cfg(target_endian = "little")]
182            {
183                // SAFETY: quantized bytes representation is created from packed u32 values in little endian
184                bytemuck::cast_vec(qparams)
185            }
186            #[cfg(target_endian = "big")]
187            {
188                crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
189            }
190        };
191        (values, qparams)
192    }
193
194    /// Splits the quantized values of the tensor from the quantization parameters.
195    ///
196    /// Returns the values in i8 and a newly allocated vector containing the quantization parameters.
197    fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
198        let num_params = match self.scheme.level {
199            QuantLevel::Tensor => 1,
200            QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
201        };
202
203        let (values, qparams) = match self.scheme.store {
204            QuantStore::Native => self.split_i8_values(num_params),
205            QuantStore::U32 => match self.scheme.value {
206                QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
207                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
208                    let mut values = self.bytes.try_into_vec::<u32>().unwrap();
209                    let scale_size = num_params; // size of f32 same as u32
210                    let values_end = values.len() - scale_size;
211
212                    let qparams = values.split_off(values_end);
213                    // Sub-byte values are unpacked as i8s for value equality tests
214                    let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
215                    (values, qparams)
216                }
217                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
218                    unimplemented!("Not yet supported")
219                }
220            },
221        };
222
223        (values, (qparams, num_params))
224    }
225}
226
227fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
228    match bytes.try_into_vec::<i8>() {
229        Ok(val) => val,
230        // Safety,
231        //
232        // `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment.
233        Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
234    }
235}
236
237/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
238pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
239    // Shift and combine groups of four 8-bit values into a u32.
240    // Same as doing this:
241    //     let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
242    #[cfg(target_endian = "big")]
243    {
244        values
245            .chunks(4)
246            .map(|x| {
247                x.iter()
248                    .enumerate()
249                    .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
250            })
251            .collect()
252    }
253
254    // The order of bytes in little endian matches the above description, we just need to
255    // handle padding when the number of values is not a factor of 4
256    #[cfg(target_endian = "little")]
257    {
258        let mut values = values;
259        let remainder = values.len() % 4;
260        if remainder != 0 {
261            // Pad with zeros
262            values.extend(core::iter::repeat_n(0, 4 - remainder));
263        }
264
265        let len = values.len() / 4;
266        let capacity = values.capacity() / 4;
267
268        // Pre-forget the old vec and re-interpret as u32
269        let mut values = core::mem::ManuallyDrop::new(values);
270        let ptr = values.as_mut_ptr() as *mut u32;
271
272        unsafe { Vec::from_raw_parts(ptr, len, capacity) }
273    }
274}
275
276/// Unpack integer values into a sequence of signed 8-bit integers.
277pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
278    values: &[Q],
279    numel: usize,
280    value: &QuantValue,
281) -> Vec<i8> {
282    let size_store = size_of::<Q>() * 8;
283    let size_quant = value.size_bits();
284    let num_quants = size_store / size_quant;
285    let mask = Q::from((1 << size_quant) - 1).unwrap();
286    let sign_shift = 8 - size_quant; // sign extension for sub-byte values
287    values
288        .iter()
289        .enumerate()
290        .flat_map(|(i, &packed)| {
291            // A single u32 could contain less than four 8-bit values...
292            let n = core::cmp::min(num_quants, numel - i * num_quants);
293            // Extract each 8-bit segment from u32 and cast back to i8
294            // Same as doing this (when 4 values are fully packed):
295            //     let a = (packed & 0xFF) as i8;
296            //     let b = ((packed >> 8) & 0xFF) as i8;
297            //     let c = ((packed >> 16) & 0xFF) as i8;
298            //     let d = ((packed >> 24) & 0xFF) as i8;
299            (0..n).map(move |i| {
300                let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
301                ((raw << sign_shift) as i8) >> sign_shift
302            })
303        })
304        .collect()
305}
306
307#[cfg(test)]
308mod tests {
309
310    use super::*;
311    use alloc::vec;
312
313    #[test]
314    fn should_pack_i8s_to_u32() {
315        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
316
317        assert_eq!(packed, vec![2147287680]);
318    }
319
320    #[test]
321    fn should_pack_i8s_to_u32_padded() {
322        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
323        let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
324
325        assert_eq!(packed, vec![2147287680, 55]);
326        assert_eq!(packed, packed_padded);
327    }
328
329    #[test]
330    fn should_unpack_u32s_to_i8s() {
331        let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
332
333        assert_eq!(unpacked, vec![-128, 2, -3, 127]);
334    }
335
336    #[test]
337    fn should_unpack_u32s_to_i8s_padded() {
338        let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
339
340        assert_eq!(unpacked, vec![55]);
341    }
342
343    #[test]
344    fn should_unpack_u32s_to_i8s_arange() {
345        let unpacked = unpack_q_to_i8s(
346            &[
347                0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
348                1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
349                2004318071,
350            ],
351            128,
352            &QuantValue::Q4S,
353        );
354
355        assert_eq!(
356            unpacked,
357            vec![
358                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
359                2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
360                3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
361                5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
362                6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
363            ]
364        );
365    }
366
367    #[test]
368    fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
369        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
370        let scale = 0.03937008;
371        let values = vec![0i8, 25, 51, 76, 102, 127];
372
373        let q_bytes = QuantizedBytes::new(
374            values.clone(),
375            QuantScheme::default()
376                .with_value(QuantValue::Q8S)
377                .with_store(QuantStore::Native),
378            &[scale],
379        );
380
381        let (q_values, qparams) = q_bytes.into_vec_i8();
382
383        assert_eq!(qparams.scales, vec![scale]);
384
385        assert_eq!(q_values, values);
386    }
387}