burn_tensor/tensor/quantization/
strategy.rs

1use alloc::vec::Vec;
2use core::marker::PhantomData;
3use num_traits::{Float, PrimInt, Signed};
4use serde::{Deserialize, Serialize};
5
6use super::{QuantizationMode, QuantizationScheme, QuantizationType};
7
8/// Quantization strategy.
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum QuantizationStrategy {
11    /// Per-tensor `int8` symmetric quantization.
12    PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
13}
14
15impl QuantizationStrategy {
16    /// Quantize the values to a lower precision data type.
17    pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
18        match self {
19            QuantizationStrategy::PerTensorSymmetricInt8(strategy) => strategy.quantize(values),
20        }
21    }
22
23    /// Dequantize the values to a higher precision data type.
24    pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
25        match self {
26            QuantizationStrategy::PerTensorSymmetricInt8(strategy) => strategy.dequantize(values),
27        }
28    }
29}
30
31impl QuantizationStrategy {
32    /// Returns the corresponding quantization scheme.
33    pub fn scheme(&self) -> QuantizationScheme {
34        match self {
35            QuantizationStrategy::PerTensorSymmetricInt8(_) => {
36                QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8)
37            }
38        }
39    }
40}
41
42/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
43/// data type `Q` and vice-versa.
44pub trait Quantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
45    /// Returns the quantization range `[a, b]`.
46    fn range() -> (Q, Q);
47    /// Create a new quantization scheme for an input range `[alpha, beta]`.
48    fn new(alpha: E, beta: E) -> Self;
49    /// Convert the values to a lower precision data type.
50    fn quantize(&self, values: &[E]) -> Vec<Q>;
51    /// Convert a single value to a lower precision data type.
52    fn quantize_one(&self, value: E) -> Q;
53    /// Convert the values back to a higher precision data type.
54    fn dequantize(&self, values: &[Q]) -> Vec<E>;
55    /// Convert a single value back to a higher precision data type.
56    fn dequantize_one(&self, value: Q) -> E;
57}
58
59fn valid_scale<E: Float>(mut scale: E) -> E {
60    // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
61    // scale to 0.1 to avoid division by zero.
62    if scale.eq(&E::zero()) {
63        scale = E::from(0.1).unwrap();
64    }
65    scale
66}
67
68/// Symmetric quantization scheme.
69#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
70pub struct SymmetricQuantization<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> {
71    /// The scaling factor.
72    pub scale: E,
73    /// The quantized type.
74    _q: PhantomData<Q>,
75}
76
77impl<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> SymmetricQuantization<E, Q> {
78    /// Initialize a symmetric quantization scheme with the given parameters.
79    pub fn init(scale: E) -> Self {
80        Self {
81            scale: valid_scale(scale),
82            _q: PhantomData,
83        }
84    }
85}
86
87impl<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> Quantization<E, Q>
88    for SymmetricQuantization<E, Q>
89{
90    fn new(alpha: E, beta: E) -> Self {
91        let (a, b) = Self::range();
92        let a = E::from(a).unwrap();
93        let b = E::from(b).unwrap();
94
95        // Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range
96        let alpha = alpha.abs().max(beta.abs());
97        let scale = valid_scale((alpha + alpha) / (b - a));
98        Self {
99            scale,
100            _q: PhantomData,
101        }
102    }
103
104    fn quantize(&self, values: &[E]) -> Vec<Q> {
105        values.iter().map(|x| self.quantize_one(*x)).collect()
106    }
107
108    fn dequantize(&self, values: &[Q]) -> Vec<E> {
109        values.iter().map(|x_q| self.dequantize_one(*x_q)).collect()
110    }
111
112    fn quantize_one(&self, value: E) -> Q {
113        let (a, b) = Self::range();
114        let a = E::from(a).unwrap();
115        let b = E::from(b).unwrap();
116
117        // x_q = clamp(round(x / scale), a, b)
118        Q::from(value.div(self.scale).round().clamp(a, b)).unwrap()
119    }
120
121    fn dequantize_one(&self, value: Q) -> E {
122        // x = scale * x_q
123        self.scale * E::from(value).unwrap()
124    }
125
126    fn range() -> (Q, Q) {
127        // Only implemented for symmetric *signed* at this time
128        let b = Q::max_value();
129        (b.neg(), b)
130    }
131}
132
133impl<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> PartialEq
134    for SymmetricQuantization<E, Q>
135{
136    fn eq(&self, other: &Self) -> bool {
137        self.scale == other.scale
138    }
139}
140
141impl<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> Eq for SymmetricQuantization<E, Q> {}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use alloc::vec;
147
148    #[test]
149    fn test_int8_symmetric_quantization() {
150        let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
151        let expected_q = vec![-127, -71, 0, 35];
152        let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
153
154        let symmetric = SymmetricQuantization::<f32, i8>::new(-1.8, 0.5);
155
156        let q: Vec<i8> = symmetric.quantize(&x);
157        assert_eq!(q, expected_q);
158
159        let d = symmetric.dequantize(&expected_q);
160
161        assert_eq!(d, expected_d);
162    }
163}