burn_tensor/tensor/quantization/
strategy.rs1use core::{
2 hash::{Hash, Hasher},
3 marker::PhantomData,
4};
5
6use alloc::vec::Vec;
7use burn_common::{iter_slice_par, run_par};
8use num_traits::{Float, PrimInt};
9use serde::{Deserialize, Serialize};
10
11use super::{QuantizationScheme, QuantizationType};
12
13#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
15pub enum QuantizationStrategy {
16 PerTensorAffineInt8(AffineQuantization<f32, i8, i32>),
18 PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
20}
21
22impl QuantizationStrategy {
23 pub fn scheme(&self) -> QuantizationScheme {
25 match self {
26 QuantizationStrategy::PerTensorAffineInt8(_) => {
27 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
28 }
29 QuantizationStrategy::PerTensorSymmetricInt8(_) => {
30 QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8)
31 }
32 }
33 }
34}
35
36pub trait Quantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
39 fn new(alpha: E, beta: E) -> Self;
41 fn quantize(&self, values: &[E]) -> Vec<Q>;
43 fn dequantize(&self, values: &[Q]) -> Vec<E>;
45}
46
47#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
51pub struct AffineQuantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> {
52 pub scale: E,
54 pub offset: Q,
56 _a: PhantomData<A>,
58}
59
60fn valid_scale<E: Float>(mut scale: E) -> E {
61 if scale.eq(&E::zero()) {
64 scale = E::from(0.1).unwrap();
65 }
66 scale
67}
68
69impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> AffineQuantization<E, Q, A> {
70 pub fn init(scale: E, offset: Q) -> Self {
72 Self {
73 scale: valid_scale(scale),
74 offset,
75 _a: PhantomData,
76 }
77 }
78}
79
80impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt + Send + Sync> Quantization<E, Q>
81 for AffineQuantization<E, Q, A>
82{
83 fn new(alpha: E, beta: E) -> Self {
84 let a = E::from(Q::min_value()).unwrap();
86 let b = E::from(Q::max_value()).unwrap();
87
88 let alpha = E::min(alpha, E::zero());
92 let beta = E::max(beta, E::zero());
93
94 let scale = valid_scale((beta - alpha) / (b - a));
96 let z = -(alpha / scale - a);
97 Self {
98 scale,
99 offset: Q::from(z).unwrap(),
100 _a: PhantomData,
101 }
102 }
103
104 fn quantize(&self, values: &[E]) -> Vec<Q> {
105 let a = E::from(Q::min_value()).unwrap();
107 let b = E::from(Q::max_value()).unwrap();
108
109 let z = E::from(self.offset).unwrap();
111 run_par!(|| {
112 iter_slice_par!(values)
113 .map(|x| Q::from(x.div(self.scale).add(z).round().clamp(a, b)).unwrap())
114 .collect()
115 })
116 }
117
118 fn dequantize(&self, values: &[Q]) -> Vec<E> {
119 run_par!(|| {
121 iter_slice_par!(values)
122 .map(|x_q| {
123 self.scale
124 * (E::from(
125 A::from(*x_q)
126 .unwrap()
127 .saturating_sub(A::from(self.offset).unwrap()),
128 )
129 .unwrap())
130 })
131 .collect()
132 })
133 }
134}
135
136#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
138pub struct SymmetricQuantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
139 pub scale: E,
141 _q: PhantomData<Q>,
143}
144
145impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> SymmetricQuantization<E, Q> {
146 pub fn init(scale: E) -> Self {
148 Self {
149 scale: valid_scale(scale),
150 _q: PhantomData,
151 }
152 }
153}
154
155impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Quantization<E, Q>
156 for SymmetricQuantization<E, Q>
157{
158 fn new(alpha: E, beta: E) -> Self {
159 assert!(
160 !Q::min_value().is_zero(),
161 "Symmetric quantization is only valid for signed integers."
162 );
163
164 let b = E::from(Q::max_value()).unwrap();
166 let a = b.neg();
167
168 let alpha = alpha.abs().max(beta.abs());
170 let scale = valid_scale((alpha + alpha) / (b - a));
171 Self {
172 scale,
173 _q: PhantomData,
174 }
175 }
176
177 fn quantize(&self, values: &[E]) -> Vec<Q> {
178 let b = E::from(Q::max_value()).unwrap();
180 let a = b.neg();
181
182 values
184 .iter()
185 .map(|x| Q::from(x.div(self.scale).round().clamp(a, b)).unwrap())
186 .collect()
187 }
188
189 fn dequantize(&self, values: &[Q]) -> Vec<E> {
190 values
192 .iter()
193 .map(|x_q| self.scale * E::from(*x_q).unwrap())
194 .collect()
195 }
196}
197
198const SIGN_MASK: u64 = 0x8000000000000000u64;
200const EXP_MASK: u64 = 0x7ff0000000000000u64;
201const MAN_MASK: u64 = 0x000fffffffffffffu64;
202
203#[inline]
204fn raw_double_bits<F: Float>(f: &F) -> u64 {
207 let (man, exp, sign) = f.integer_decode();
208 let exp_u64 = exp as u16 as u64;
209 let sign_u64 = (sign > 0) as u64;
210 (man & MAN_MASK) | ((exp_u64 << 52) & EXP_MASK) | ((sign_u64 << 63) & SIGN_MASK)
211}
212
213#[inline(always)]
214fn canonicalize_signed_zero<T: Float>(x: T) -> T {
215 x + T::zero()
219}
220
221impl<E: Float + Send + Sync, Q: PrimInt + Hash + Send + Sync, A: PrimInt> Hash
222 for AffineQuantization<E, Q, A>
223{
224 fn hash<H: Hasher>(&self, state: &mut H) {
225 let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
227 bits.hash(state);
228 self.offset.hash(state);
229 }
230}
231
232impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> PartialEq
233 for AffineQuantization<E, Q, A>
234{
235 fn eq(&self, other: &Self) -> bool {
236 self.scale == other.scale && self.offset == other.offset
237 }
238}
239
240impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> Eq
241 for AffineQuantization<E, Q, A>
242{
243}
244
245impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Hash for SymmetricQuantization<E, Q> {
246 fn hash<H: Hasher>(&self, state: &mut H) {
247 let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
249 bits.hash(state);
250 }
251}
252
253impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> PartialEq for SymmetricQuantization<E, Q> {
254 fn eq(&self, other: &Self) -> bool {
255 self.scale == other.scale
256 }
257}
258
259impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Eq for SymmetricQuantization<E, Q> {}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use alloc::vec;
265
266 #[test]
267 fn test_int8_affine_quantization() {
268 let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
269 let expected_q = vec![-128, -40, 71, 126];
270 let expected_d = vec![-1.794902, -1.0011765, 0.0, 0.49607843];
271
272 let affine = AffineQuantization::<f32, i8, i32>::new(-1.8, 0.5);
273
274 let q = affine.quantize(&x);
275 assert_eq!(q, expected_q);
276
277 let d = affine.dequantize(&expected_q);
278
279 assert_eq!(d, expected_d);
280 }
281
282 #[test]
283 fn test_affine_should_ensure_zero_point() {
284 let x: [f32; 6] = [2.0, 1.0, 2.0, 3.0, 4.0, 5.0];
285 let expected_q = vec![-26, -77, -26, 25, 76, 127];
286 let expected_d = x.to_vec();
287
288 let affine = AffineQuantization::<f32, i8, i32>::new(1.0, 5.0);
289
290 assert_eq!(affine.offset, -128);
291 assert_eq!(affine.scale, 0.019607844);
292
293 let q = affine.quantize(&x);
294 assert_eq!(q, expected_q);
295
296 let d = affine.dequantize(&expected_q);
297
298 assert_eq!(d, expected_d);
299 }
300
301 #[test]
302 fn test_int8_symmetric_quantization() {
303 let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
304 let expected_q = vec![-127, -71, 0, 35];
305 let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
306
307 let symmetric = SymmetricQuantization::<f32, i8>::new(-1.8, 0.5);
308
309 let q: Vec<i8> = symmetric.quantize(&x);
310 assert_eq!(q, expected_q);
311
312 let d = symmetric.dequantize(&expected_q);
313
314 assert_eq!(d, expected_d);
315 }
316}