burn_tensor/tensor/quantization/
bytes.rs1use core::any::TypeId;
2
3use crate::{Bytes, Element};
4use alloc::vec::Vec;
5
6use super::{
7 QParams, QuantizationMode, QuantizationScheme, QuantizationStrategy, QuantizationType,
8 SymmetricQuantization, pack_i8s_to_u32s, unpack_u32s_to_i8s,
9};
10
11pub struct QuantizedBytes {
21 pub bytes: Bytes,
23 pub scheme: QuantizationScheme,
25 pub num_elements: usize,
27}
28
29impl QuantizedBytes {
30 pub fn new<E: Element>(value: Vec<E>, strategy: QuantizationStrategy) -> Self {
32 let mut bytes: Bytes;
33 let num_elements = value.len();
34 let scheme = strategy.scheme();
35
36 match strategy {
37 QuantizationStrategy::PerTensorSymmetricInt8(quant) => {
38 if TypeId::of::<E>() == TypeId::of::<i8>() {
39 let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value));
41 bytes = Bytes::from_elems(u32s);
42 } else {
43 panic!("Invalid quantized type");
44 }
45 let scale_bytes = bytemuck::bytes_of(&quant.scale);
46 bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
47 }
48 }
49
50 Self {
51 bytes,
52 scheme,
53 num_elements,
54 }
55 }
56
57 pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>, Vec<i8>>) {
59 let numel = self.num_elements;
60 let (values, (qparams, num_params)) = self.split_values_off();
61
62 let values = unpack_u32s_to_i8s(values, numel);
63
64 let scale_size = core::mem::size_of::<f32>(); let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
72 let total_bytes = qparams_bytes.len();
73
74 let scales_size = scale_size * num_params;
75
76 let scale = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
77 let offset = None;
78
79 (values, QParams { scale, offset })
80 }
81
82 fn split_values_off(self) -> (Vec<u32>, (Vec<u32>, usize)) {
86 let mut values = match self.bytes.align() {
88 1 => {
89 let bytes = self.bytes.try_into_vec::<u8>().unwrap();
90 #[cfg(target_endian = "little")]
91 {
92 unsafe { reinterpret_vec(bytes) }
94 }
95 #[cfg(target_endian = "big")]
96 {
97 pack_i8s_to_u32s(bytemuck::allocation::cast_vec(bytes))
98 }
99 }
100 4 => self.bytes.try_into_vec::<u32>().unwrap(),
101 _ => unreachable!(),
102 };
103
104 let num_params = match self.scheme {
105 QuantizationScheme::PerTensor(..) => 1,
106 };
107
108 let scale_size = num_params; let values_end = values.len() - scale_size;
110
111 let qparams = values.split_off(values_end);
112
113 (values, (qparams, num_params))
114 }
115
116 pub fn dequantize(self) -> (Vec<f32>, QParams<Vec<f32>, Vec<i8>>) {
118 match self.scheme {
119 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
120 let (values, qparams) = self.into_vec_i8();
121 let strategy = QuantizationStrategy::PerTensorSymmetricInt8(
122 SymmetricQuantization::init(qparams.scale[0]),
123 );
124 (strategy.dequantize(&values), qparams)
125 }
126 }
127 }
128}
129
130unsafe fn reinterpret_vec<T, U>(mut input: Vec<T>) -> Vec<U> {
137 assert!(
139 input.as_mut_ptr().align_offset(align_of::<U>()) == 0,
140 "Alignment mismatch"
141 );
142 assert!(
143 size_of::<T>() != 0 && size_of::<U>() != 0,
144 "Zero-sized types not allowed"
145 );
146 assert!(
147 input.len() * size_of::<T>() % size_of::<U>() == 0,
148 "Size mismatch"
149 );
150
151 let len = input.len() * size_of::<T>() / size_of::<U>();
152 let cap = input.capacity() * size_of::<T>() / size_of::<U>();
153 let ptr = input.as_mut_ptr() as *mut U;
154
155 core::mem::forget(input);
156
157 unsafe { Vec::from_raw_parts(ptr, len, cap) }
158}
159
160#[cfg(test)]
161mod tests {
162
163 use super::*;
164 use alloc::vec;
165
166 #[test]
167 fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
168 let scale = 0.03937008;
170 let values = vec![0i8, 25, 51, 76, 102, 127];
171
172 let q_bytes = QuantizedBytes::new(
173 values.clone(),
174 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale)),
175 );
176
177 let (q_values, qparams) = q_bytes.into_vec_i8();
178
179 assert_eq!(qparams.scale, vec![scale]);
180 assert_eq!(qparams.offset, None);
181
182 assert_eq!(q_values, values);
183 }
184}