burn_tensor/tensor/quantization/
bytes.rs

1use core::any::TypeId;
2
3use crate::{Bytes, Element, quantization::unpack_q_to_i8s};
4use alloc::vec::Vec;
5
6use super::{
7    QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue, QuantizationStrategy,
8    SymmetricQuantization,
9};
10
11/// Quantized data bytes representation.
12///
13/// # Notes
14/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
15///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
16///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).
17/// 2) Quantization parameters are appended to the tensor data.
18///    As such, the last bytes always correspond to the scale parameter.
19///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
20pub struct QuantizedBytes {
21    /// The quantized values and quantization parameters represented as bytes.
22    pub bytes: Bytes,
23    /// The quantization scheme.
24    pub scheme: QuantScheme,
25    /// The number of quantized elements.
26    pub num_elements: usize,
27}
28
29impl QuantizedBytes {
30    /// Creates a new quantized bytes representation.
31    pub fn new<E: Element>(
32        value: Vec<E>,
33        strategy: QuantizationStrategy,
34        scheme: QuantScheme,
35    ) -> Self {
36        let mut bytes: Bytes;
37        let num_elements = value.len();
38
39        match strategy {
40            QuantizationStrategy::PerTensorSymmetric(quant) => {
41                if TypeId::of::<E>() == TypeId::of::<i8>() {
42                    // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
43                    let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
44                    bytes = Bytes::from_elems(i8s);
45                } else {
46                    panic!("Invalid quantized type");
47                }
48                let scale_bytes = bytemuck::bytes_of(&quant.scale);
49                bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
50            }
51            QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
52                if TypeId::of::<E>() == TypeId::of::<i8>() {
53                    // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
54                    let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
55                    bytes = Bytes::from_elems(i8s);
56                } else {
57                    panic!("Invalid quantized type");
58                }
59
60                let mut scale_bytes = Vec::with_capacity(quant.len() * size_of::<f32>());
61                for q in quant {
62                    scale_bytes.extend_from_slice(bytemuck::bytes_of(&q.scale));
63                }
64                bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::<f32>());
65            }
66        }
67
68        Self {
69            bytes,
70            scheme,
71            num_elements,
72        }
73    }
74
75    /// Returns the int8 quantized values with the quantization parameters.
76    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
77        let (values, (qparams, num_params)) = self.split_values_off();
78
79        // Quantization parameters are added at the end of the tensor data.
80        // As such, the last bytes always correspond to the scale parameter(s).
81        // For example, per-block quantization can have multiple parameters for a single tensor:
82        // [scale, scale, scale, ...]
83        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
84        let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
85        let total_bytes = qparams_bytes.len();
86
87        let scales_size = scale_size * num_params;
88
89        let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
90
91        (values, QParams { scales })
92    }
93
94    fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
95        let mut values = self.bytes.try_into_vec::<i8>().unwrap();
96
97        let scale_size = num_params * size_of::<f32>();
98        let values_end = values.len() - scale_size;
99
100        let qparams = values.split_off(values_end);
101
102        let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
103            let mut qparams = core::mem::ManuallyDrop::new(qparams);
104            unsafe {
105                Vec::<u32>::from_raw_parts(
106                    qparams.as_mut_ptr() as _,
107                    qparams.len() / 4,
108                    qparams.capacity() / 4,
109                )
110            }
111        } else {
112            #[cfg(target_endian = "little")]
113            {
114                // SAFETY: quantized bytes representation is created from packed u32 values in little endian
115                bytemuck::cast_vec(qparams)
116            }
117            #[cfg(target_endian = "big")]
118            {
119                crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
120            }
121        };
122        (values, qparams)
123    }
124
125    /// Splits the quantized values of the tensor from the quantization parameters.
126    ///
127    /// Returns the values in i8 and a newly allocated vector containing the quantization parameters.
128    fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
129        let num_params = match self.scheme.level {
130            QuantLevel::Tensor => 1,
131            QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
132        };
133
134        let (values, qparams) = match self.scheme.store {
135            QuantStore::Native => self.split_i8_values(num_params),
136            QuantStore::U32 => match self.scheme.value {
137                QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
138                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
139                    let mut values = self.bytes.try_into_vec::<u32>().unwrap();
140                    let scale_size = num_params; // size of f32 same as u32
141                    let values_end = values.len() - scale_size;
142
143                    let qparams = values.split_off(values_end);
144                    // Sub-byte values are unpacked as i8s for value equality tests
145                    let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
146                    (values, qparams)
147                }
148                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
149                    unimplemented!("Not yet supported")
150                }
151            },
152        };
153
154        (values, (qparams, num_params))
155    }
156
157    /// Dequantizes the data according to its quantization scheme.
158    pub fn dequantize(self) -> (Vec<f32>, QParams<Vec<f32>>) {
159        match self.scheme {
160            QuantScheme {
161                level: QuantLevel::Tensor,
162                mode: QuantMode::Symmetric,
163                value:
164                    QuantValue::Q8S
165                    | QuantValue::Q8F
166                    | QuantValue::Q4S
167                    | QuantValue::Q4F
168                    | QuantValue::Q2S
169                    | QuantValue::Q2F,
170                ..
171            } => {
172                let value = self.scheme.value;
173                let (values, qparams) = self.into_vec_i8();
174                let strategy = QuantizationStrategy::PerTensorSymmetric(
175                    SymmetricQuantization::init(qparams.scales[0], value),
176                );
177                (strategy.dequantize(&values), qparams)
178            }
179            QuantScheme {
180                level: QuantLevel::Block(block_size),
181                mode: QuantMode::Symmetric,
182                value:
183                    QuantValue::Q8S
184                    | QuantValue::Q8F
185                    | QuantValue::Q4S
186                    | QuantValue::Q4F
187                    | QuantValue::Q2S
188                    | QuantValue::Q2F,
189                ..
190            } => {
191                let value = self.scheme.value;
192                let (values, qparams) = self.into_vec_i8();
193                assert_eq!(
194                    values.len() / qparams.scales.len(),
195                    block_size.num_elements()
196                );
197                let strategy = QuantizationStrategy::PerBlockSymmetric(
198                    qparams
199                        .scales
200                        .iter()
201                        .map(|&s| SymmetricQuantization::init(s, value))
202                        .collect(),
203                    block_size,
204                );
205                (strategy.dequantize(&values), qparams)
206            }
207            QuantScheme {
208                value: QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
209                ..
210            } => unimplemented!("Not yet supported"),
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217
218    use super::*;
219    use alloc::vec;
220
221    #[test]
222    fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
223        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
224        let scale = 0.03937008;
225        let values = vec![0i8, 25, 51, 76, 102, 127];
226
227        let q_bytes = QuantizedBytes::new(
228            values.clone(),
229            QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
230                scale,
231                QuantValue::Q8S,
232            )),
233            QuantScheme::default(),
234        );
235
236        let (q_values, qparams) = q_bytes.into_vec_i8();
237
238        assert_eq!(qparams.scales, vec![scale]);
239
240        assert_eq!(q_values, values);
241    }
242}