burn_tensor/tensor/quantization/
bytes.rs

1use core::any::TypeId;
2
3use crate::{Bytes, Element};
4use alloc::vec::Vec;
5
6use super::{
7    QParams, QuantizationMode, QuantizationScheme, QuantizationStrategy, QuantizationType,
8    SymmetricQuantization, pack_i8s_to_u32s, unpack_u32s_to_i8s,
9};
10
11/// Quantized data bytes representation.
12///
13/// # Notes
14/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8
15///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,
16///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).
17/// 2) Quantization parameters are appended to the tensor data.
18///    As such, the last bytes always correspond to the scale parameter.
19///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.
20pub struct QuantizedBytes {
21    /// The quantized values and quantization parameters represented as bytes.
22    pub bytes: Bytes,
23    /// The quantization scheme.
24    pub scheme: QuantizationScheme,
25    /// The number of quantized elements.
26    pub num_elements: usize,
27}
28
29impl QuantizedBytes {
30    /// Creates a new quantized bytes representation.
31    pub fn new<E: Element>(value: Vec<E>, strategy: QuantizationStrategy) -> Self {
32        let mut bytes: Bytes;
33        let num_elements = value.len();
34        let scheme = strategy.scheme();
35
36        match strategy {
37            QuantizationStrategy::PerTensorSymmetricInt8(quant) => {
38                if TypeId::of::<E>() == TypeId::of::<i8>() {
39                    // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`
40                    let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value));
41                    bytes = Bytes::from_elems(u32s);
42                } else {
43                    panic!("Invalid quantized type");
44                }
45                let scale_bytes = bytemuck::bytes_of(&quant.scale);
46                bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
47            }
48        }
49
50        Self {
51            bytes,
52            scheme,
53            num_elements,
54        }
55    }
56
57    /// Returns the int8 quantized values with the quantization parameters.
58    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>, Vec<i8>>) {
59        let numel = self.num_elements;
60        let (values, (qparams, num_params)) = self.split_values_off();
61
62        let values = unpack_u32s_to_i8s(values, numel);
63
64        // Quantization parameters are added at the end of the tensor data.
65        // As such, the last bytes always correspond to the scale parameter(s).
66        // If the quantization scheme includes an offset (zero-point) parameter, the value(s)
67        // precede(s) the scale parameter(s) bytes.
68        // For example, per-block quantization can have multiple parameters for a single tensor:
69        // [offset, offset, offset, ..., scale, scale, scale, ...]
70        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32
71        let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
72        let total_bytes = qparams_bytes.len();
73
74        let scales_size = scale_size * num_params;
75
76        let scale = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
77        let offset = None;
78
79        (values, QParams { scale, offset })
80    }
81
82    /// Splits the quantized values of the tensor from the quantization parameters.
83    ///
84    /// Returns the packed values and a newly allocated vector containing the quantization parameters.
85    fn split_values_off(self) -> (Vec<u32>, (Vec<u32>, usize)) {
86        // The bytes can be created either from packed u32 or existing bytes with the same representation.
87        let mut values = match self.bytes.align() {
88            1 => {
89                let bytes = self.bytes.try_into_vec::<u8>().unwrap();
90                #[cfg(target_endian = "little")]
91                {
92                    // SAFETY: quantized bytes representation is created from packed u32 values in little endian
93                    unsafe { reinterpret_vec(bytes) }
94                }
95                #[cfg(target_endian = "big")]
96                {
97                    pack_i8s_to_u32s(bytemuck::allocation::cast_vec(bytes))
98                }
99            }
100            4 => self.bytes.try_into_vec::<u32>().unwrap(),
101            _ => unreachable!(),
102        };
103
104        let num_params = match self.scheme {
105            QuantizationScheme::PerTensor(..) => 1,
106        };
107
108        let scale_size = num_params; // f32 scale is the same number of bytes as u32
109        let values_end = values.len() - scale_size;
110
111        let qparams = values.split_off(values_end);
112
113        (values, (qparams, num_params))
114    }
115
116    /// Dequantizes the data according to its quantization scheme.
117    pub fn dequantize(self) -> (Vec<f32>, QParams<Vec<f32>, Vec<i8>>) {
118        match self.scheme {
119            QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
120                let (values, qparams) = self.into_vec_i8();
121                let strategy = QuantizationStrategy::PerTensorSymmetricInt8(
122                    SymmetricQuantization::init(qparams.scale[0]),
123                );
124                (strategy.dequantize(&values), qparams)
125            }
126        }
127    }
128}
129
130/// Reinterprets a `Vec<T>` as a `Vec<U>` without reallocation.
131///
132/// # Safety
133/// - The alignment of `U` must be compatible with `T`.
134/// - The size of `T` must be a multiple of the size of `U`.
135/// - The input `Vec<T>` must have a length that aligns with the size of `U`.
136unsafe fn reinterpret_vec<T, U>(mut input: Vec<T>) -> Vec<U> {
137    // Ensure alignment and size compatibility
138    assert!(
139        input.as_mut_ptr().align_offset(align_of::<U>()) == 0,
140        "Alignment mismatch"
141    );
142    assert!(
143        size_of::<T>() != 0 && size_of::<U>() != 0,
144        "Zero-sized types not allowed"
145    );
146    assert!(
147        input.len() * size_of::<T>() % size_of::<U>() == 0,
148        "Size mismatch"
149    );
150
151    let len = input.len() * size_of::<T>() / size_of::<U>();
152    let cap = input.capacity() * size_of::<T>() / size_of::<U>();
153    let ptr = input.as_mut_ptr() as *mut U;
154
155    core::mem::forget(input);
156
157    unsafe { Vec::from_raw_parts(ptr, len, cap) }
158}
159
160#[cfg(test)]
161mod tests {
162
163    use super::*;
164    use alloc::vec;
165
166    #[test]
167    fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
168        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
169        let scale = 0.03937008;
170        let values = vec![0i8, 25, 51, 76, 102, 127];
171
172        let q_bytes = QuantizedBytes::new(
173            values.clone(),
174            QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale)),
175        );
176
177        let (q_values, qparams) = q_bytes.into_vec_i8();
178
179        assert_eq!(qparams.scale, vec![scale]);
180        assert_eq!(qparams.offset, None);
181
182        assert_eq!(q_values, values);
183    }
184}