Skip to main content

burn_std/tensor/
quantization.rs

1//! Quantization data representation.
2
3// Re-exported types
4pub use cubecl_common::quant::scheme::{
5    BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8/// Alignment (in bytes) for quantization parameters in serialized tensor data.
9///
10/// NOTE: This is currently f32-based since scales were originally always f32.
11/// With `QuantParam` now supporting different precisions (F16, BF16, etc.),
12/// this alignment may need to be revisited in the future.
13pub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();
14
15use alloc::vec::Vec;
16use core::any::TypeId;
17use num_traits::PrimInt;
18use serde::{Deserialize, Serialize};
19
20use crate::{DType, Metadata, Shape, bytes::Bytes};
21
22#[derive(
23    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
24)]
25/// The precision of accumulating elements.
26pub enum QuantAcc {
27    /// Full precision.
28    #[default]
29    F32,
30    /// Half precision.
31    F16,
32    /// bfloat16 precision.
33    BF16,
34}
35
36/// Specify if the output of an operation is quantized using the scheme of the input
37/// or returned unquantized.
38#[derive(
39    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
40)]
41pub enum QuantPropagation {
42    /// The output is quantized using the scheme of the input.
43    Propagate,
44    /// The output is not quantized.
45    #[default]
46    Inhibit,
47}
48
49/// The quantization tensor data parameters.
50#[derive(Clone, Debug)]
51pub struct QParams<S> {
52    /// The scaling factor.
53    pub scales: S,
54}
55
56/// A quantization parameter tensor descriptor.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct QParamTensor {
59    /// Start of the tensor in the buffer
60    pub offset_start: usize,
61    /// Offset of tensor end from the end of the buffer
62    pub offset_end: usize,
63    /// Metadata of the tensor
64    pub metadata: Metadata,
65    /// Data type of the tensor
66    pub dtype: DType,
67}
68
69/// Calculate the shape of the quantization parameters for a given tensor and level
70pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
71    match level {
72        QuantLevel::Tensor => Shape::new([1]),
73        QuantLevel::Block(block_size) => {
74            let mut params_shape = data_shape.clone();
75            let block_size = block_size.to_dim_vec(data_shape.num_dims());
76
77            for (shape, block_size) in params_shape.iter_mut().zip(block_size) {
78                *shape = (*shape).div_ceil(block_size as usize);
79            }
80
81            params_shape
82        }
83    }
84}
85
86/// Quantized data bytes representation.
87///
88/// # Notes
89/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
90///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
91///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).
92/// 2) Quantization parameters are appended to the tensor data.
93///    As such, the last bytes always correspond to the scale parameter.
94///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
95pub struct QuantizedBytes {
96    /// The quantized values and quantization parameters represented as bytes.
97    pub bytes: Bytes,
98    /// The quantization scheme.
99    pub scheme: QuantScheme,
100    /// The number of quantized elements.
101    pub num_elements: usize,
102}
103
104impl QuantizedBytes {
105    /// Creates a new quantized bytes representation.
106    pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
107        value: Vec<E>,
108        scheme: QuantScheme,
109        scales: &[f32],
110    ) -> Self {
111        let num_elements = value.len();
112        // Only used for 8-bit quantization data comparison in tests
113        if TypeId::of::<E>() != TypeId::of::<i8>() {
114            panic!("Invalid quantized type");
115        }
116
117        // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
118        let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
119        let mut bytes = Bytes::from_elems(i8s);
120
121        match scheme.level {
122            QuantLevel::Tensor => {
123                let scale_bytes = bytemuck::bytes_of(&scales[0]);
124                bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);
125            }
126            QuantLevel::Block(_block_size) => {
127                let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
128                for scale in scales {
129                    scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
130                }
131                bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);
132            }
133        }
134
135        Self {
136            bytes,
137            scheme,
138            num_elements,
139        }
140    }
141
142    /// Returns the int8 quantized values with the quantization parameters.
143    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
144        let (values, (qparams, num_params)) = self.split_values_off();
145
146        // Quantization parameters are added at the end of the tensor data.
147        // As such, the last bytes always correspond to the scale parameter(s).
148        // For example, per-block quantization can have multiple parameters for a single tensor:
149        // [scale, scale, scale, ...]
150        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
151        let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
152        let total_bytes = qparams_bytes.len();
153
154        let scales_size = scale_size * num_params;
155
156        let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
157
158        (values, QParams { scales })
159    }
160
161    fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
162        let mut values = read_bytes_to_i8(self.bytes);
163
164        let scale_size = num_params * size_of::<f32>();
165        let values_end = values.len() - scale_size;
166
167        let qparams = values.split_off(values_end);
168
169        let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
170            let mut qparams = core::mem::ManuallyDrop::new(qparams);
171            unsafe {
172                Vec::<u32>::from_raw_parts(
173                    qparams.as_mut_ptr() as _,
174                    qparams.len() / 4,
175                    qparams.capacity() / 4,
176                )
177            }
178        } else {
179            #[cfg(target_endian = "little")]
180            {
181                // SAFETY: quantized bytes representation is created from packed u32 values in little endian
182                bytemuck::cast_vec(qparams)
183            }
184            #[cfg(target_endian = "big")]
185            {
186                crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
187            }
188        };
189        (values, qparams)
190    }
191
192    /// Splits the quantized values of the tensor from the quantization parameters.
193    ///
194    /// Returns the values in i8 and a newly allocated vector containing the quantization parameters.
195    fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
196        let num_params = match self.scheme.level {
197            QuantLevel::Tensor => 1,
198            QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
199        };
200
201        if let QuantStore::PackedU32(packed_dim) = self.scheme.store {
202            assert_eq!(
203                packed_dim, 0,
204                "Packing must be on innermost dimension for splitting off values"
205            );
206        }
207
208        let (values, qparams) = match self.scheme.store {
209            QuantStore::Native => self.split_i8_values(num_params),
210            QuantStore::PackedU32(_) => match self.scheme.value {
211                QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
212                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
213                    let mut values = self.bytes.try_into_vec::<u32>().unwrap();
214                    let scale_size = num_params; // size of f32 same as u32
215                    let values_end = values.len() - scale_size;
216
217                    let qparams = values.split_off(values_end);
218                    // Sub-byte values are unpacked as i8s for value equality tests
219                    let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
220                    (values, qparams)
221                }
222                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
223                    unimplemented!("Not yet supported")
224                }
225            },
226            QuantStore::PackedNative(_) => unimplemented!("Not yet supported"),
227        };
228
229        (values, (qparams, num_params))
230    }
231}
232
233fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
234    match bytes.try_into_vec::<i8>() {
235        Ok(val) => val,
236        // Safety,
237        //
238        // `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment.
239        Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
240    }
241}
242
243/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
244pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
245    // Shift and combine groups of four 8-bit values into a u32.
246    // Same as doing this:
247    //     let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
248    #[cfg(target_endian = "big")]
249    {
250        values
251            .chunks(4)
252            .map(|x| {
253                x.iter()
254                    .enumerate()
255                    .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
256            })
257            .collect()
258    }
259
260    // The order of bytes in little endian matches the above description, we just need to
261    // handle padding when the number of values is not a factor of 4
262    #[cfg(target_endian = "little")]
263    {
264        let mut values = values;
265        let remainder = values.len() % 4;
266        if remainder != 0 {
267            // Pad with zeros
268            values.extend(core::iter::repeat_n(0, 4 - remainder));
269        }
270
271        let len = values.len() / 4;
272        let capacity = values.capacity() / 4;
273
274        // Pre-forget the old vec and re-interpret as u32
275        let mut values = core::mem::ManuallyDrop::new(values);
276        let ptr = values.as_mut_ptr() as *mut u32;
277
278        unsafe { Vec::from_raw_parts(ptr, len, capacity) }
279    }
280}
281
282/// Unpack integer values into a sequence of signed 8-bit integers.
283pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
284    values: &[Q],
285    numel: usize,
286    value: &QuantValue,
287) -> Vec<i8> {
288    let size_store = size_of::<Q>() * 8;
289    let size_quant = value.size_bits();
290    let num_quants = size_store / size_quant;
291    let mask = Q::from((1 << size_quant) - 1).unwrap();
292    let sign_shift = 8 - size_quant; // sign extension for sub-byte values
293    values
294        .iter()
295        .enumerate()
296        .flat_map(|(i, &packed)| {
297            // A single u32 could contain less than four 8-bit values...
298            let n = core::cmp::min(num_quants, numel - i * num_quants);
299            // Extract each 8-bit segment from u32 and cast back to i8
300            // Same as doing this (when 4 values are fully packed):
301            //     let a = (packed & 0xFF) as i8;
302            //     let b = ((packed >> 8) & 0xFF) as i8;
303            //     let c = ((packed >> 16) & 0xFF) as i8;
304            //     let d = ((packed >> 24) & 0xFF) as i8;
305            (0..n).map(move |i| {
306                let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
307                ((raw << sign_shift) as i8) >> sign_shift
308            })
309        })
310        .collect()
311}
312
313#[cfg(test)]
314mod tests {
315
316    use super::*;
317    use alloc::vec;
318
319    #[test]
320    fn should_pack_i8s_to_u32() {
321        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
322
323        assert_eq!(packed, vec![2147287680]);
324    }
325
326    #[test]
327    fn should_pack_i8s_to_u32_padded() {
328        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
329        let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
330
331        assert_eq!(packed, vec![2147287680, 55]);
332        assert_eq!(packed, packed_padded);
333    }
334
335    #[test]
336    fn should_unpack_u32s_to_i8s() {
337        let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
338
339        assert_eq!(unpacked, vec![-128, 2, -3, 127]);
340    }
341
342    #[test]
343    fn should_unpack_u32s_to_i8s_padded() {
344        let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
345
346        assert_eq!(unpacked, vec![55]);
347    }
348
349    #[test]
350    fn should_unpack_u32s_to_i8s_arange() {
351        let unpacked = unpack_q_to_i8s(
352            &[
353                0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
354                1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
355                2004318071,
356            ],
357            128,
358            &QuantValue::Q4S,
359        );
360
361        assert_eq!(
362            unpacked,
363            vec![
364                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
365                2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
366                3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
367                5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
368                6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
369            ]
370        );
371    }
372
373    #[test]
374    fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
375        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
376        let scale = 0.03937008;
377        let values = vec![0i8, 25, 51, 76, 102, 127];
378
379        let q_bytes = QuantizedBytes::new(
380            values.clone(),
381            QuantScheme::default()
382                .with_value(QuantValue::Q8S)
383                .with_store(QuantStore::Native),
384            &[scale],
385        );
386
387        let (q_values, qparams) = q_bytes.into_vec_i8();
388
389        assert_eq!(qparams.scales, vec![scale]);
390
391        assert_eq!(q_values, values);
392    }
393}