burn_tensor/tensor/quantization/
bytes.rs1use core::any::TypeId;
2
3use crate::{Bytes, Element, quantization::unpack_q_to_i8s};
4use alloc::vec::Vec;
5
6use super::{
7 QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue, QuantizationStrategy,
8 SymmetricQuantization,
9};
10
11pub struct QuantizedBytes {
21 pub bytes: Bytes,
23 pub scheme: QuantScheme,
25 pub num_elements: usize,
27}
28
29impl QuantizedBytes {
30 pub fn new<E: Element>(
32 value: Vec<E>,
33 strategy: QuantizationStrategy,
34 scheme: QuantScheme,
35 ) -> Self {
36 let mut bytes: Bytes;
37 let num_elements = value.len();
38
39 match strategy {
40 QuantizationStrategy::PerTensorSymmetric(quant) => {
41 if TypeId::of::<E>() == TypeId::of::<i8>() {
42 let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
44 bytes = Bytes::from_elems(i8s);
45 } else {
46 panic!("Invalid quantized type");
47 }
48 let scale_bytes = bytemuck::bytes_of(&quant.scale);
49 bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
50 }
51 QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
52 if TypeId::of::<E>() == TypeId::of::<i8>() {
53 let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
55 bytes = Bytes::from_elems(i8s);
56 } else {
57 panic!("Invalid quantized type");
58 }
59
60 let mut scale_bytes = Vec::with_capacity(quant.len() * size_of::<f32>());
61 for q in quant {
62 scale_bytes.extend_from_slice(bytemuck::bytes_of(&q.scale));
63 }
64 bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::<f32>());
65 }
66 }
67
68 Self {
69 bytes,
70 scheme,
71 num_elements,
72 }
73 }
74
75 pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
77 let (values, (qparams, num_params)) = self.split_values_off();
78
79 let scale_size = core::mem::size_of::<f32>(); let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
85 let total_bytes = qparams_bytes.len();
86
87 let scales_size = scale_size * num_params;
88
89 let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
90
91 (values, QParams { scales })
92 }
93
94 fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
95 let mut values = self.bytes.try_into_vec::<i8>().unwrap();
96
97 let scale_size = num_params * size_of::<f32>();
98 let values_end = values.len() - scale_size;
99
100 let qparams = values.split_off(values_end);
101
102 let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
103 let mut qparams = core::mem::ManuallyDrop::new(qparams);
104 unsafe {
105 Vec::<u32>::from_raw_parts(
106 qparams.as_mut_ptr() as _,
107 qparams.len() / 4,
108 qparams.capacity() / 4,
109 )
110 }
111 } else {
112 #[cfg(target_endian = "little")]
113 {
114 bytemuck::cast_vec(qparams)
116 }
117 #[cfg(target_endian = "big")]
118 {
119 crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
120 }
121 };
122 (values, qparams)
123 }
124
125 fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
129 let num_params = match self.scheme.level {
130 QuantLevel::Tensor => 1,
131 QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
132 };
133
134 let (values, qparams) = match self.scheme.store {
135 QuantStore::Native => self.split_i8_values(num_params),
136 QuantStore::U32 => match self.scheme.value {
137 QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
138 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
139 let mut values = self.bytes.try_into_vec::<u32>().unwrap();
140 let scale_size = num_params; let values_end = values.len() - scale_size;
142
143 let qparams = values.split_off(values_end);
144 let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
146 (values, qparams)
147 }
148 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
149 unimplemented!("Not yet supported")
150 }
151 },
152 };
153
154 (values, (qparams, num_params))
155 }
156
157 pub fn dequantize(self) -> (Vec<f32>, QParams<Vec<f32>>) {
159 match self.scheme {
160 QuantScheme {
161 level: QuantLevel::Tensor,
162 mode: QuantMode::Symmetric,
163 value:
164 QuantValue::Q8S
165 | QuantValue::Q8F
166 | QuantValue::Q4S
167 | QuantValue::Q4F
168 | QuantValue::Q2S
169 | QuantValue::Q2F,
170 ..
171 } => {
172 let value = self.scheme.value;
173 let (values, qparams) = self.into_vec_i8();
174 let strategy = QuantizationStrategy::PerTensorSymmetric(
175 SymmetricQuantization::init(qparams.scales[0], value),
176 );
177 (strategy.dequantize(&values), qparams)
178 }
179 QuantScheme {
180 level: QuantLevel::Block(block_size),
181 mode: QuantMode::Symmetric,
182 value:
183 QuantValue::Q8S
184 | QuantValue::Q8F
185 | QuantValue::Q4S
186 | QuantValue::Q4F
187 | QuantValue::Q2S
188 | QuantValue::Q2F,
189 ..
190 } => {
191 let value = self.scheme.value;
192 let (values, qparams) = self.into_vec_i8();
193 assert_eq!(
194 values.len() / qparams.scales.len(),
195 block_size.num_elements()
196 );
197 let strategy = QuantizationStrategy::PerBlockSymmetric(
198 qparams
199 .scales
200 .iter()
201 .map(|&s| SymmetricQuantization::init(s, value))
202 .collect(),
203 block_size,
204 );
205 (strategy.dequantize(&values), qparams)
206 }
207 QuantScheme {
208 value: QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
209 ..
210 } => unimplemented!("Not yet supported"),
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217
218 use super::*;
219 use alloc::vec;
220
221 #[test]
222 fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
223 let scale = 0.03937008;
225 let values = vec![0i8, 25, 51, 76, 102, 127];
226
227 let q_bytes = QuantizedBytes::new(
228 values.clone(),
229 QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
230 scale,
231 QuantValue::Q8S,
232 )),
233 QuantScheme::default(),
234 );
235
236 let (q_values, qparams) = q_bytes.into_vec_i8();
237
238 assert_eq!(qparams.scales, vec![scale]);
239
240 assert_eq!(q_values, values);
241 }
242}