burn_tensor/tensor/quantization/
bytes.rs

1use core::any::TypeId;
2
3use crate::{Bytes, Element};
4use alloc::vec::Vec;
5
6use super::{
7    pack_i8s_to_u32s, unpack_u32s_to_i8s, AffineQuantization, QParams, Quantization,
8    QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization,
9};
10
11/// Quantized data bytes representation.
12///
13/// # Notes
14/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
15///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
16///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).
17/// 2) Quantization parameters are appended to the tensor data.
18///    As such, the last bytes always correspond to the scale parameter.
19///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
20pub struct QuantizedBytes {
21    /// The quantized values and quantization parameters represented as bytes.
22    pub bytes: Bytes,
23    /// The quantization scheme.
24    pub scheme: QuantizationScheme,
25    /// The number of quantized elements.
26    pub num_elements: usize,
27}
28
29impl QuantizedBytes {
30    /// Creates a new quantized bytes representation.
31    pub fn new<E: Element>(value: Vec<E>, strategy: QuantizationStrategy) -> Self {
32        let mut bytes: Bytes;
33        let num_elements = value.len();
34
35        match strategy {
36            QuantizationStrategy::PerTensorAffineInt8(q) => {
37                if TypeId::of::<E>() == TypeId::of::<i8>() {
38                    // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
39                    let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value));
40                    bytes = Bytes::from_elems(u32s);
41                } else {
42                    panic!("Invalid quantized type");
43                }
44                // Scale is always stored as f32 and zero-point offset as i32
45                let offset = q.offset as i32;
46                let scale_bytes = bytemuck::bytes_of(&q.scale);
47                let offset_bytes = bytemuck::bytes_of(&offset);
48                bytes.extend_from_byte_slice_aligned(offset_bytes, align_of::<i32>());
49                bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
50            }
51            QuantizationStrategy::PerTensorSymmetricInt8(q) => {
52                if TypeId::of::<E>() == TypeId::of::<i8>() {
53                    // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
54                    let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value));
55                    bytes = Bytes::from_elems(u32s);
56                } else {
57                    panic!("Invalid quantized type");
58                }
59                let scale_bytes = bytemuck::bytes_of(&q.scale);
60                bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
61            }
62        }
63
64        Self {
65            bytes,
66            scheme: strategy.scheme(),
67            num_elements,
68        }
69    }
70
71    /// Returns the int8 quantized values with the quantization parameters.
72    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<f32, i8>) {
73        let numel = self.num_elements;
74        let scheme = self.scheme;
75        let (values, qparams) = self.split_values_off();
76
77        let values = unpack_u32s_to_i8s(values, numel);
78
79        // Quantization parameters are added at the end of the tensor data.
80        // As such, the last bytes always correspond to the scale parameter.
81        // If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
82        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
83        let qparams_bytes = bytemuck::cast_slice(&qparams);
84        let total_bytes = qparams_bytes.len();
85        let scale = *bytemuck::checked::from_bytes(&qparams_bytes[total_bytes - scale_size..]);
86
87        let offset = match scheme {
88            QuantizationScheme::PerTensorAffine(_) => {
89                let offset_size = core::mem::size_of::<i32>(); // zero-point offset is stored as i32
90                Some(*bytemuck::checked::from_bytes::<i32>(
91                    &qparams_bytes
92                        [total_bytes - scale_size - offset_size..total_bytes - scale_size],
93                ) as i8)
94            }
95            QuantizationScheme::PerTensorSymmetric(_) => None,
96        };
97
98        (values, QParams { scale, offset })
99    }
100
101    /// Splits the quantized values of the tensor from the quantization parameters.
102    ///
103    /// Returns the packed values and a newly allocated vector containing the quantization parameters.
104    fn split_values_off(self) -> (Vec<u32>, Vec<u32>) {
105        // The bytes can be created either from packed u32 or existing bytes with the same representation.
106        let mut values = match self.bytes.align() {
107            1 => {
108                let bytes = self.bytes.try_into_vec::<u8>().unwrap();
109                #[cfg(target_endian = "little")]
110                {
111                    // SAFETY: quantized bytes representation is created from packed u32 values in little endian
112                    unsafe { reinterpret_vec(bytes) }
113                }
114                #[cfg(target_endian = "big")]
115                {
116                    pack_i8s_to_u32s(bytemuck::allocation::cast_vec(bytes))
117                }
118            }
119            4 => self.bytes.try_into_vec::<u32>().unwrap(),
120            _ => unreachable!(),
121        };
122
123        let scale_size = 1; // f32 scale is the same number of bytes as u32
124        let mut values_end = values.len() - scale_size;
125
126        if let QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) = self.scheme {
127            values_end -= 1; // zero-point offset is stored as i32 (same number of bytes as u32)
128        }
129
130        let qparams = values.split_off(values_end);
131
132        (values, qparams)
133    }
134
135    /// Dequantizes the data according to its quantization scheme.
136    pub fn dequantize(self) -> (Vec<f32>, QParams<f32, i8>) {
137        match self.scheme {
138            QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
139                let (values, qparams) = self.into_vec_i8();
140                let strategy = AffineQuantization::<f32, i8, i32>::init(
141                    qparams.scale,
142                    qparams.offset.unwrap(),
143                );
144                (strategy.dequantize(&values), qparams)
145            }
146            QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
147                let (values, qparams) = self.into_vec_i8();
148                let strategy = SymmetricQuantization::<f32, i8>::init(qparams.scale);
149                (strategy.dequantize(&values), qparams)
150            }
151        }
152    }
153}
154
155/// Reinterprets a `Vec<T>` as a `Vec<U>` without reallocation.
156///
157/// # Safety
158/// - The alignment of `U` must be compatible with `T`.
159/// - The size of `T` must be a multiple of the size of `U`.
160/// - The input `Vec<T>` must have a length that aligns with the size of `U`.
161unsafe fn reinterpret_vec<T, U>(mut input: Vec<T>) -> Vec<U> {
162    // Ensure alignment and size compatibility
163    assert!(
164        input.as_mut_ptr().align_offset(align_of::<U>()) == 0,
165        "Alignment mismatch"
166    );
167    assert!(
168        size_of::<T>() != 0 && size_of::<U>() != 0,
169        "Zero-sized types not allowed"
170    );
171    assert!(
172        input.len() * size_of::<T>() % size_of::<U>() == 0,
173        "Size mismatch"
174    );
175
176    let len = input.len() * size_of::<T>() / size_of::<U>();
177    let cap = input.capacity() * size_of::<T>() / size_of::<U>();
178    let ptr = input.as_mut_ptr() as *mut U;
179
180    core::mem::forget(input);
181
182    Vec::from_raw_parts(ptr, len, cap)
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use alloc::vec;
189
190    #[test]
191    fn should_pack_unpack_quantization_parameters_symmetric() {
192        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
193        let scale = 0.03937008;
194        let values = vec![0i8, 25, 51, 76, 102, 127];
195
196        let q_bytes = QuantizedBytes::new(
197            values.clone(),
198            QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale)),
199        );
200
201        let (q_values, qparams) = q_bytes.into_vec_i8();
202
203        assert_eq!(qparams.scale, scale);
204        assert_eq!(qparams.offset, None);
205
206        assert_eq!(q_values, values);
207    }
208
209    #[test]
210    fn should_pack_unpack_quantization_parameters_affine() {
211        let scale = 0.019607844;
212        let offset = -128;
213        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
214        let values = vec![-128i8, -77, -26, 25, 76, 127];
215        let q_bytes = QuantizedBytes::new(
216            values.clone(),
217            QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(scale, offset)),
218        );
219
220        let (q_values, qparams) = q_bytes.into_vec_i8();
221
222        assert_eq!(qparams.scale, scale);
223        assert_eq!(qparams.offset, Some(offset));
224
225        assert_eq!(q_values, values);
226    }
227}