burn_tensor/tensor/quantization/
bytes.rs1use core::any::TypeId;
2
3use crate::{Bytes, Element};
4use alloc::vec::Vec;
5
6use super::{
7 pack_i8s_to_u32s, unpack_u32s_to_i8s, AffineQuantization, QParams, Quantization,
8 QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization,
9};
10
11pub struct QuantizedBytes {
21 pub bytes: Bytes,
23 pub scheme: QuantizationScheme,
25 pub num_elements: usize,
27}
28
29impl QuantizedBytes {
30 pub fn new<E: Element>(value: Vec<E>, strategy: QuantizationStrategy) -> Self {
32 let mut bytes: Bytes;
33 let num_elements = value.len();
34
35 match strategy {
36 QuantizationStrategy::PerTensorAffineInt8(q) => {
37 if TypeId::of::<E>() == TypeId::of::<i8>() {
38 let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value));
40 bytes = Bytes::from_elems(u32s);
41 } else {
42 panic!("Invalid quantized type");
43 }
44 let offset = q.offset as i32;
46 let scale_bytes = bytemuck::bytes_of(&q.scale);
47 let offset_bytes = bytemuck::bytes_of(&offset);
48 bytes.extend_from_byte_slice_aligned(offset_bytes, align_of::<i32>());
49 bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
50 }
51 QuantizationStrategy::PerTensorSymmetricInt8(q) => {
52 if TypeId::of::<E>() == TypeId::of::<i8>() {
53 let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value));
55 bytes = Bytes::from_elems(u32s);
56 } else {
57 panic!("Invalid quantized type");
58 }
59 let scale_bytes = bytemuck::bytes_of(&q.scale);
60 bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
61 }
62 }
63
64 Self {
65 bytes,
66 scheme: strategy.scheme(),
67 num_elements,
68 }
69 }
70
71 pub fn into_vec_i8(self) -> (Vec<i8>, QParams<f32, i8>) {
73 let numel = self.num_elements;
74 let scheme = self.scheme;
75 let (values, qparams) = self.split_values_off();
76
77 let values = unpack_u32s_to_i8s(values, numel);
78
79 let scale_size = core::mem::size_of::<f32>(); let qparams_bytes = bytemuck::cast_slice(&qparams);
84 let total_bytes = qparams_bytes.len();
85 let scale = *bytemuck::checked::from_bytes(&qparams_bytes[total_bytes - scale_size..]);
86
87 let offset = match scheme {
88 QuantizationScheme::PerTensorAffine(_) => {
89 let offset_size = core::mem::size_of::<i32>(); Some(*bytemuck::checked::from_bytes::<i32>(
91 &qparams_bytes
92 [total_bytes - scale_size - offset_size..total_bytes - scale_size],
93 ) as i8)
94 }
95 QuantizationScheme::PerTensorSymmetric(_) => None,
96 };
97
98 (values, QParams { scale, offset })
99 }
100
101 fn split_values_off(self) -> (Vec<u32>, Vec<u32>) {
105 let mut values = match self.bytes.align() {
107 1 => {
108 let bytes = self.bytes.try_into_vec::<u8>().unwrap();
109 #[cfg(target_endian = "little")]
110 {
111 unsafe { reinterpret_vec(bytes) }
113 }
114 #[cfg(target_endian = "big")]
115 {
116 pack_i8s_to_u32s(bytemuck::allocation::cast_vec(bytes))
117 }
118 }
119 4 => self.bytes.try_into_vec::<u32>().unwrap(),
120 _ => unreachable!(),
121 };
122
123 let scale_size = 1; let mut values_end = values.len() - scale_size;
125
126 if let QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) = self.scheme {
127 values_end -= 1; }
129
130 let qparams = values.split_off(values_end);
131
132 (values, qparams)
133 }
134
135 pub fn dequantize(self) -> (Vec<f32>, QParams<f32, i8>) {
137 match self.scheme {
138 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
139 let (values, qparams) = self.into_vec_i8();
140 let strategy = AffineQuantization::<f32, i8, i32>::init(
141 qparams.scale,
142 qparams.offset.unwrap(),
143 );
144 (strategy.dequantize(&values), qparams)
145 }
146 QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
147 let (values, qparams) = self.into_vec_i8();
148 let strategy = SymmetricQuantization::<f32, i8>::init(qparams.scale);
149 (strategy.dequantize(&values), qparams)
150 }
151 }
152 }
153}
154
155unsafe fn reinterpret_vec<T, U>(mut input: Vec<T>) -> Vec<U> {
162 assert!(
164 input.as_mut_ptr().align_offset(align_of::<U>()) == 0,
165 "Alignment mismatch"
166 );
167 assert!(
168 size_of::<T>() != 0 && size_of::<U>() != 0,
169 "Zero-sized types not allowed"
170 );
171 assert!(
172 input.len() * size_of::<T>() % size_of::<U>() == 0,
173 "Size mismatch"
174 );
175
176 let len = input.len() * size_of::<T>() / size_of::<U>();
177 let cap = input.capacity() * size_of::<T>() / size_of::<U>();
178 let ptr = input.as_mut_ptr() as *mut U;
179
180 core::mem::forget(input);
181
182 Vec::from_raw_parts(ptr, len, cap)
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use alloc::vec;
189
190 #[test]
191 fn should_pack_unpack_quantization_parameters_symmetric() {
192 let scale = 0.03937008;
194 let values = vec![0i8, 25, 51, 76, 102, 127];
195
196 let q_bytes = QuantizedBytes::new(
197 values.clone(),
198 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale)),
199 );
200
201 let (q_values, qparams) = q_bytes.into_vec_i8();
202
203 assert_eq!(qparams.scale, scale);
204 assert_eq!(qparams.offset, None);
205
206 assert_eq!(q_values, values);
207 }
208
209 #[test]
210 fn should_pack_unpack_quantization_parameters_affine() {
211 let scale = 0.019607844;
212 let offset = -128;
213 let values = vec![-128i8, -77, -26, 25, 76, 127];
215 let q_bytes = QuantizedBytes::new(
216 values.clone(),
217 QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(scale, offset)),
218 );
219
220 let (q_values, qparams) = q_bytes.into_vec_i8();
221
222 assert_eq!(qparams.scale, scale);
223 assert_eq!(qparams.offset, Some(offset));
224
225 assert_eq!(q_values, values);
226 }
227}