burn_std/tensor/
quantization.rs

1//! Quantization data representation.
2
3// Re-exported types
4pub use cubecl_quant::scheme::{
5    BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8use alloc::vec::Vec;
9use core::any::TypeId;
10use num_traits::PrimInt;
11use serde::{Deserialize, Serialize};
12
13use crate::bytes::Bytes;
14
15#[derive(
16    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
17)]
18/// The precision of accumulating elements.
19pub enum QuantAcc {
20    /// Full precision.
21    #[default]
22    F32,
23    /// Half precision.
24    F16,
25    /// bfloat16 precision.
26    BF16,
27}
28
29/// Specify if the output of an operation is quantized using the scheme of the input
30/// or returned unquantized.
31#[derive(
32    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
33)]
34pub enum QuantPropagation {
35    /// The output is quantized using the scheme of the input.
36    Propagate,
37    /// The output is not quantized.
38    #[default]
39    Inhibit,
40}
41
42/// The quantization tensor data parameters.
43#[derive(Clone, Debug)]
44pub struct QParams<S> {
45    /// The scaling factor.
46    pub scales: S,
47}
48
49/// Quantized data bytes representation.
50///
51/// # Notes
52/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
53///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
54///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).
55/// 2) Quantization parameters are appended to the tensor data.
56///    As such, the last bytes always correspond to the scale parameter.
57///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
58pub struct QuantizedBytes {
59    /// The quantized values and quantization parameters represented as bytes.
60    pub bytes: Bytes,
61    /// The quantization scheme.
62    pub scheme: QuantScheme,
63    /// The number of quantized elements.
64    pub num_elements: usize,
65}
66
67impl QuantizedBytes {
68    /// Creates a new quantized bytes representation.
69    pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
70        value: Vec<E>,
71        scheme: QuantScheme,
72        scales: &[f32],
73    ) -> Self {
74        let num_elements = value.len();
75        // Only used for 8-bit quantization data comparison in tests
76        if TypeId::of::<E>() != TypeId::of::<i8>() {
77            panic!("Invalid quantized type");
78        }
79
80        // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
81        let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
82        let mut bytes = Bytes::from_elems(i8s);
83
84        match scheme.level {
85            QuantLevel::Tensor => {
86                let scale_bytes = bytemuck::bytes_of(&scales[0]);
87                bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
88            }
89            QuantLevel::Block(_block_size) => {
90                let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
91                for scale in scales {
92                    scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
93                }
94                bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::<f32>());
95            }
96        }
97
98        Self {
99            bytes,
100            scheme,
101            num_elements,
102        }
103    }
104
105    /// Returns the int8 quantized values with the quantization parameters.
106    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
107        let (values, (qparams, num_params)) = self.split_values_off();
108
109        // Quantization parameters are added at the end of the tensor data.
110        // As such, the last bytes always correspond to the scale parameter(s).
111        // For example, per-block quantization can have multiple parameters for a single tensor:
112        // [scale, scale, scale, ...]
113        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
114        let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
115        let total_bytes = qparams_bytes.len();
116
117        let scales_size = scale_size * num_params;
118
119        let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
120
121        (values, QParams { scales })
122    }
123
124    fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
125        let mut values = read_bytes_to_i8(self.bytes);
126
127        let scale_size = num_params * size_of::<f32>();
128        let values_end = values.len() - scale_size;
129
130        let qparams = values.split_off(values_end);
131
132        let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
133            let mut qparams = core::mem::ManuallyDrop::new(qparams);
134            unsafe {
135                Vec::<u32>::from_raw_parts(
136                    qparams.as_mut_ptr() as _,
137                    qparams.len() / 4,
138                    qparams.capacity() / 4,
139                )
140            }
141        } else {
142            #[cfg(target_endian = "little")]
143            {
144                // SAFETY: quantized bytes representation is created from packed u32 values in little endian
145                bytemuck::cast_vec(qparams)
146            }
147            #[cfg(target_endian = "big")]
148            {
149                crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
150            }
151        };
152        (values, qparams)
153    }
154
155    /// Splits the quantized values of the tensor from the quantization parameters.
156    ///
157    /// Returns the values in i8 and a newly allocated vector containing the quantization parameters.
158    fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
159        let num_params = match self.scheme.level {
160            QuantLevel::Tensor => 1,
161            QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
162        };
163
164        let (values, qparams) = match self.scheme.store {
165            QuantStore::Native => self.split_i8_values(num_params),
166            QuantStore::U32 => match self.scheme.value {
167                QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
168                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
169                    let mut values = self.bytes.try_into_vec::<u32>().unwrap();
170                    let scale_size = num_params; // size of f32 same as u32
171                    let values_end = values.len() - scale_size;
172
173                    let qparams = values.split_off(values_end);
174                    // Sub-byte values are unpacked as i8s for value equality tests
175                    let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
176                    (values, qparams)
177                }
178                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
179                    unimplemented!("Not yet supported")
180                }
181            },
182        };
183
184        (values, (qparams, num_params))
185    }
186}
187
188fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
189    match bytes.try_into_vec::<i8>() {
190        Ok(val) => val,
191        // Safety,
192        //
193        // `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment.
194        Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
195    }
196}
197
198/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.
199pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
200    // Shift and combine groups of four 8-bit values into a u32.
201    // Same as doing this:
202    //     let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);
203    #[cfg(target_endian = "big")]
204    {
205        values
206            .chunks(4)
207            .map(|x| {
208                x.iter()
209                    .enumerate()
210                    .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
211            })
212            .collect()
213    }
214
215    // The order of bytes in little endian matches the above description, we just need to
216    // handle padding when the number of values is not a factor of 4
217    #[cfg(target_endian = "little")]
218    {
219        let mut values = values;
220        let remainder = values.len() % 4;
221        if remainder != 0 {
222            // Pad with zeros
223            values.extend(core::iter::repeat_n(0, 4 - remainder));
224        }
225
226        let len = values.len() / 4;
227        let capacity = values.capacity() / 4;
228
229        // Pre-forget the old vec and re-interpret as u32
230        let mut values = core::mem::ManuallyDrop::new(values);
231        let ptr = values.as_mut_ptr() as *mut u32;
232
233        unsafe { Vec::from_raw_parts(ptr, len, capacity) }
234    }
235}
236
237/// Unpack integer values into a sequence of signed 8-bit integers.
238pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
239    values: &[Q],
240    numel: usize,
241    value: &QuantValue,
242) -> Vec<i8> {
243    let size_store = size_of::<Q>() * 8;
244    let size_quant = value.size_bits();
245    let num_quants = size_store / size_quant;
246    let mask = Q::from((1 << size_quant) - 1).unwrap();
247    let sign_shift = 8 - size_quant; // sign extension for sub-byte values
248    values
249        .iter()
250        .enumerate()
251        .flat_map(|(i, &packed)| {
252            // A single u32 could contain less than four 8-bit values...
253            let n = core::cmp::min(num_quants, numel - i * num_quants);
254            // Extract each 8-bit segment from u32 and cast back to i8
255            // Same as doing this (when 4 values are fully packed):
256            //     let a = (packed & 0xFF) as i8;
257            //     let b = ((packed >> 8) & 0xFF) as i8;
258            //     let c = ((packed >> 16) & 0xFF) as i8;
259            //     let d = ((packed >> 24) & 0xFF) as i8;
260            (0..n).map(move |i| {
261                let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
262                ((raw << sign_shift) as i8) >> sign_shift
263            })
264        })
265        .collect()
266}
267
268#[cfg(test)]
269mod tests {
270
271    use super::*;
272    use alloc::vec;
273
274    #[test]
275    fn should_pack_i8s_to_u32() {
276        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
277
278        assert_eq!(packed, vec![2147287680]);
279    }
280
281    #[test]
282    fn should_pack_i8s_to_u32_padded() {
283        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
284        let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
285
286        assert_eq!(packed, vec![2147287680, 55]);
287        assert_eq!(packed, packed_padded);
288    }
289
290    #[test]
291    fn should_unpack_u32s_to_i8s() {
292        let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
293
294        assert_eq!(unpacked, vec![-128, 2, -3, 127]);
295    }
296
297    #[test]
298    fn should_unpack_u32s_to_i8s_padded() {
299        let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
300
301        assert_eq!(unpacked, vec![55]);
302    }
303
304    #[test]
305    fn should_unpack_u32s_to_i8s_arange() {
306        let unpacked = unpack_q_to_i8s(
307            &[
308                0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
309                1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
310                2004318071,
311            ],
312            128,
313            &QuantValue::Q4S,
314        );
315
316        assert_eq!(
317            unpacked,
318            vec![
319                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
320                2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
321                3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
322                5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
323                6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
324            ]
325        );
326    }
327
328    #[test]
329    fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
330        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
331        let scale = 0.03937008;
332        let values = vec![0i8, 25, 51, 76, 102, 127];
333
334        let q_bytes = QuantizedBytes::new(
335            values.clone(),
336            QuantScheme::default()
337                .with_value(QuantValue::Q8S)
338                .with_store(QuantStore::Native),
339            &[scale],
340        );
341
342        let (q_values, qparams) = q_bytes.into_vec_i8();
343
344        assert_eq!(qparams.scales, vec![scale]);
345
346        assert_eq!(q_values, values);
347    }
348}