burn_tensor/tensor/quantization/
strategy.rs

1use core::{
2    hash::{Hash, Hasher},
3    marker::PhantomData,
4};
5
6use alloc::vec::Vec;
7use burn_common::{iter_slice_par, run_par};
8use num_traits::{Float, PrimInt};
9use serde::{Deserialize, Serialize};
10
11use super::{QuantizationScheme, QuantizationType};
12
13/// Quantization strategy.
14#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
15pub enum QuantizationStrategy {
16    /// Per-tensor `int8` affine/asymmetric quantization.
17    PerTensorAffineInt8(AffineQuantization<f32, i8, i32>),
18    /// Per-tensor `int8` symmetric quantization.
19    PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
20}
21
22impl QuantizationStrategy {
23    /// Returns the corresponding quantization scheme.
24    pub fn scheme(&self) -> QuantizationScheme {
25        match self {
26            QuantizationStrategy::PerTensorAffineInt8(_) => {
27                QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
28            }
29            QuantizationStrategy::PerTensorSymmetricInt8(_) => {
30                QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8)
31            }
32        }
33    }
34}
35
36/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
37/// data type `Q` and vice-versa.
38pub trait Quantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
39    /// Create a new quantization scheme for an input range `[alpha, beta]`.
40    fn new(alpha: E, beta: E) -> Self;
41    /// Convert the values to a lower precision data type.
42    fn quantize(&self, values: &[E]) -> Vec<Q>;
43    /// Convert the values back to a higher precision data type.
44    fn dequantize(&self, values: &[Q]) -> Vec<E>;
45}
46
47/// Affine quantization scheme.
48///
49/// Note that the accumulation type `A` should have a bigger range than quantized type `Q`.
50#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
51pub struct AffineQuantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> {
52    /// The scaling factor.
53    pub scale: E,
54    /// The zero-point offset.
55    pub offset: Q,
56    /// Accumulation type.
57    _a: PhantomData<A>,
58}
59
60fn valid_scale<E: Float>(mut scale: E) -> E {
61    // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
62    // scale to 0.1 to avoid division by zero.
63    if scale.eq(&E::zero()) {
64        scale = E::from(0.1).unwrap();
65    }
66    scale
67}
68
69impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> AffineQuantization<E, Q, A> {
70    /// Initialize an affine quantization scheme with the given parameters.
71    pub fn init(scale: E, offset: Q) -> Self {
72        Self {
73            scale: valid_scale(scale),
74            offset,
75            _a: PhantomData,
76        }
77    }
78}
79
80impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt + Send + Sync> Quantization<E, Q>
81    for AffineQuantization<E, Q, A>
82{
83    fn new(alpha: E, beta: E) -> Self {
84        // Q range `[a, b]`
85        let a = E::from(Q::min_value()).unwrap();
86        let b = E::from(Q::max_value()).unwrap();
87
88        // We extend the `[alpha, beta]` interval to ensure that it contains 0.
89        // Otherwise, we would not meet the requirement that 0 be an exactly
90        // representable value (zero-point).
91        let alpha = E::min(alpha, E::zero());
92        let beta = E::max(beta, E::zero());
93
94        // Compute scale and offset to convert a floating point value in range `[alpha, beta]` to the quantized range
95        let scale = valid_scale((beta - alpha) / (b - a));
96        let z = -(alpha / scale - a);
97        Self {
98            scale,
99            offset: Q::from(z).unwrap(),
100            _a: PhantomData,
101        }
102    }
103
104    fn quantize(&self, values: &[E]) -> Vec<Q> {
105        // Quantized range `[a, b]`
106        let a = E::from(Q::min_value()).unwrap();
107        let b = E::from(Q::max_value()).unwrap();
108
109        // x_q = clamp(round(x / scale + offset), a, b)
110        let z = E::from(self.offset).unwrap();
111        run_par!(|| {
112            iter_slice_par!(values)
113                .map(|x| Q::from(x.div(self.scale).add(z).round().clamp(a, b)).unwrap())
114                .collect()
115        })
116    }
117
118    fn dequantize(&self, values: &[Q]) -> Vec<E> {
119        // x = scale * (x_q - offset)
120        run_par!(|| {
121            iter_slice_par!(values)
122                .map(|x_q| {
123                    self.scale
124                        * (E::from(
125                            A::from(*x_q)
126                                .unwrap()
127                                .saturating_sub(A::from(self.offset).unwrap()),
128                        )
129                        .unwrap())
130                })
131                .collect()
132        })
133    }
134}
135
136/// Symmetric quantization scheme.
137#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
138pub struct SymmetricQuantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
139    /// The scaling factor.
140    pub scale: E,
141    /// The quantized type.
142    _q: PhantomData<Q>,
143}
144
145impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> SymmetricQuantization<E, Q> {
146    /// Initialize a symmetric quantization scheme with the given parameters.
147    pub fn init(scale: E) -> Self {
148        Self {
149            scale: valid_scale(scale),
150            _q: PhantomData,
151        }
152    }
153}
154
155impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Quantization<E, Q>
156    for SymmetricQuantization<E, Q>
157{
158    fn new(alpha: E, beta: E) -> Self {
159        assert!(
160            !Q::min_value().is_zero(),
161            "Symmetric quantization is only valid for signed integers."
162        );
163
164        // Quantized range `[a, b]`
165        let b = E::from(Q::max_value()).unwrap();
166        let a = b.neg();
167
168        // Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range
169        let alpha = alpha.abs().max(beta.abs());
170        let scale = valid_scale((alpha + alpha) / (b - a));
171        Self {
172            scale,
173            _q: PhantomData,
174        }
175    }
176
177    fn quantize(&self, values: &[E]) -> Vec<Q> {
178        // Quantized range [a, b]
179        let b = E::from(Q::max_value()).unwrap();
180        let a = b.neg();
181
182        // x_q = clamp(round(x / scale), a, b)
183        values
184            .iter()
185            .map(|x| Q::from(x.div(self.scale).round().clamp(a, b)).unwrap())
186            .collect()
187    }
188
189    fn dequantize(&self, values: &[Q]) -> Vec<E> {
190        // x = scale * x_q
191        values
192            .iter()
193            .map(|x_q| self.scale * E::from(*x_q).unwrap())
194            .collect()
195    }
196}
197
198// Masks for the parts of the IEEE 754 float
199const SIGN_MASK: u64 = 0x8000000000000000u64;
200const EXP_MASK: u64 = 0x7ff0000000000000u64;
201const MAN_MASK: u64 = 0x000fffffffffffffu64;
202
203#[inline]
204/// Used for hashing. Input must not be zero or NaN.
205/// Adapted from: https://github.com/reem/rust-ordered-float/blob/master/src/lib.rs
206fn raw_double_bits<F: Float>(f: &F) -> u64 {
207    let (man, exp, sign) = f.integer_decode();
208    let exp_u64 = exp as u16 as u64;
209    let sign_u64 = (sign > 0) as u64;
210    (man & MAN_MASK) | ((exp_u64 << 52) & EXP_MASK) | ((sign_u64 << 63) & SIGN_MASK)
211}
212
213#[inline(always)]
214fn canonicalize_signed_zero<T: Float>(x: T) -> T {
215    // -0.0 + 0.0 == +0.0 under IEEE754 roundTiesToEven rounding mode,
216    // which Rust guarantees. Thus by adding a positive zero we
217    // canonicalize signed zero without any branches in one instruction.
218    x + T::zero()
219}
220
221impl<E: Float + Send + Sync, Q: PrimInt + Hash + Send + Sync, A: PrimInt> Hash
222    for AffineQuantization<E, Q, A>
223{
224    fn hash<H: Hasher>(&self, state: &mut H) {
225        // Hash raw bits.
226        let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
227        bits.hash(state);
228        self.offset.hash(state);
229    }
230}
231
232impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> PartialEq
233    for AffineQuantization<E, Q, A>
234{
235    fn eq(&self, other: &Self) -> bool {
236        self.scale == other.scale && self.offset == other.offset
237    }
238}
239
240impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> Eq
241    for AffineQuantization<E, Q, A>
242{
243}
244
245impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Hash for SymmetricQuantization<E, Q> {
246    fn hash<H: Hasher>(&self, state: &mut H) {
247        // Hash raw bits.
248        let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
249        bits.hash(state);
250    }
251}
252
253impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> PartialEq for SymmetricQuantization<E, Q> {
254    fn eq(&self, other: &Self) -> bool {
255        self.scale == other.scale
256    }
257}
258
259impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Eq for SymmetricQuantization<E, Q> {}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use alloc::vec;
265
266    #[test]
267    fn test_int8_affine_quantization() {
268        let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
269        let expected_q = vec![-128, -40, 71, 126];
270        let expected_d = vec![-1.794902, -1.0011765, 0.0, 0.49607843];
271
272        let affine = AffineQuantization::<f32, i8, i32>::new(-1.8, 0.5);
273
274        let q = affine.quantize(&x);
275        assert_eq!(q, expected_q);
276
277        let d = affine.dequantize(&expected_q);
278
279        assert_eq!(d, expected_d);
280    }
281
282    #[test]
283    fn test_affine_should_ensure_zero_point() {
284        let x: [f32; 6] = [2.0, 1.0, 2.0, 3.0, 4.0, 5.0];
285        let expected_q = vec![-26, -77, -26, 25, 76, 127];
286        let expected_d = x.to_vec();
287
288        let affine = AffineQuantization::<f32, i8, i32>::new(1.0, 5.0);
289
290        assert_eq!(affine.offset, -128);
291        assert_eq!(affine.scale, 0.019607844);
292
293        let q = affine.quantize(&x);
294        assert_eq!(q, expected_q);
295
296        let d = affine.dequantize(&expected_q);
297
298        assert_eq!(d, expected_d);
299    }
300
301    #[test]
302    fn test_int8_symmetric_quantization() {
303        let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
304        let expected_q = vec![-127, -71, 0, 35];
305        let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
306
307        let symmetric = SymmetricQuantization::<f32, i8>::new(-1.8, 0.5);
308
309        let q: Vec<i8> = symmetric.quantize(&x);
310        assert_eq!(q, expected_q);
311
312        let d = symmetric.dequantize(&expected_q);
313
314        assert_eq!(d, expected_d);
315    }
316}