burn_tensor/tensor/quantization/
strategy.rs

1use alloc::vec::Vec;
2use num_traits::{Float, PrimInt};
3use serde::{Deserialize, Serialize};
4
5use super::{BlockSize, QuantValue};
6
7/// Quantization strategy.
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9pub enum QuantizationStrategy {
10    /// Per-tensor symmetric quantization.
11    PerTensorSymmetric(SymmetricQuantization<f32>),
12    /// Per-block symmetric quantization.
13    PerBlockSymmetric(Vec<SymmetricQuantization<f32>>, BlockSize),
14}
15
16impl QuantizationStrategy {
17    /// Quantize the values to a lower precision data type.
18    pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
19        match self {
20            QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.quantize(values),
21            QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {
22                let block_elems = block_size.num_elements();
23                let num_blocks = strategy.len();
24                let numel = values.len();
25                assert_eq!(
26                    numel / block_elems,
27                    num_blocks,
28                    "Invalid per-block quantization with num blocks {num_blocks} and {numel} values"
29                );
30                values
31                    .chunks(block_elems)
32                    .enumerate()
33                    .flat_map(|(block_id, block)| strategy[block_id].quantize(block))
34                    .collect()
35            }
36        }
37    }
38
39    /// Dequantize the values to a higher precision data type.
40    pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
41        match self {
42            QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.dequantize(values),
43            QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {
44                let block_elems = block_size.num_elements();
45                let num_blocks = strategy.len();
46                let numel = values.len();
47                assert_eq!(
48                    numel / block_elems,
49                    num_blocks,
50                    "Invalid per-block quantization with block size {block_elems}, num blocks {num_blocks} and {numel} values"
51                );
52                values
53                    .chunks(block_elems)
54                    .enumerate()
55                    .flat_map(|(block_id, block)| strategy[block_id].dequantize(block))
56                    .collect()
57            }
58        }
59    }
60}
61
62/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
63/// data type `Q` and vice-versa.
64pub trait Quantization<E: Float + Send + Sync> {
65    /// Returns the quantization range `[a, b]`.
66    fn range(&self) -> (E, E);
67    /// Convert the values to a lower precision data type.
68    fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q>;
69    /// Convert a single value to a lower precision data type.
70    fn quantize_one<Q: PrimInt>(&self, value: E) -> Q;
71    /// Convert the values back to a higher precision data type.
72    fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E>;
73    /// Convert a single value back to a higher precision data type.
74    fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E;
75}
76
77fn valid_scale<E: Float>(mut scale: E) -> E {
78    // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
79    // scale to 0.1 to avoid division by zero.
80    if scale.eq(&E::zero()) {
81        scale = E::from(0.1).unwrap();
82    }
83    scale
84}
85
86/// Symmetric quantization scheme.
87#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
88pub struct SymmetricQuantization<E: Float + Send + Sync> {
89    /// The scaling factor.
90    pub scale: E,
91    // The quantization value data type.
92    value: QuantValue,
93}
94
95impl<E: Float + Send + Sync> SymmetricQuantization<E> {
96    /// Initialize a symmetric quantization scheme with the given parameters.
97    pub fn init(scale: E, value: QuantValue) -> Self {
98        Self {
99            scale: valid_scale(scale),
100            value,
101        }
102    }
103
104    #[allow(dead_code)]
105    /// Create a new quantization scheme for an input range `[alpha, beta]`.
106    fn new(alpha: E, beta: E, value: QuantValue) -> Self {
107        let (a, b) = value.range();
108        let a = E::from(a).unwrap();
109        let b = E::from(b).unwrap();
110
111        // Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range
112        let alpha = alpha.abs().max(beta.abs());
113        let scale = valid_scale((alpha + alpha) / (b - a));
114        Self { scale, value }
115    }
116}
117
118impl<E: Float + Send + Sync> Quantization<E> for SymmetricQuantization<E> {
119    fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q> {
120        values.iter().map(|x| self.quantize_one(*x)).collect()
121    }
122
123    fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E> {
124        values.iter().map(|x_q| self.dequantize_one(*x_q)).collect()
125    }
126
127    fn quantize_one<Q: PrimInt>(&self, value: E) -> Q {
128        let (a, b) = self.range();
129
130        // x_q = clamp(round(x / scale), a, b)
131        Q::from(value.div(self.scale).round().clamp(a, b)).unwrap()
132    }
133
134    fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E {
135        // x = scale * x_q
136        self.scale * E::from(value).unwrap()
137    }
138
139    fn range(&self) -> (E, E) {
140        let (a, b) = self.value.range();
141        let a = E::from(a).unwrap();
142        let b = E::from(b).unwrap();
143        (a, b)
144    }
145}
146
147impl<E: Float + Send + Sync> PartialEq for SymmetricQuantization<E> {
148    fn eq(&self, other: &Self) -> bool {
149        self.scale == other.scale
150    }
151}
152
153impl<E: Float + Send + Sync> Eq for SymmetricQuantization<E> {}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use alloc::vec;
159
160    #[test]
161    fn test_int8_symmetric_quantization() {
162        let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
163        let expected_q = vec![-127, -71, 0, 35];
164        let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
165
166        let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);
167
168        let q: Vec<i8> = symmetric.quantize(&x);
169        assert_eq!(q, expected_q);
170
171        let d = symmetric.dequantize(&expected_q);
172
173        assert_eq!(d, expected_d);
174    }
175
176    #[test]
177    fn test_int8_symmetric_quantization_per_block() {
178        let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5];
179        let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35];
180        let expected_d = vec![
181            -1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063,
182        ];
183
184        let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);
185        let strategy = QuantizationStrategy::PerBlockSymmetric(
186            vec![symmetric, symmetric],
187            BlockSize::new([4]),
188        );
189
190        let q: Vec<i8> = strategy.quantize(&x);
191        assert_eq!(q, expected_q);
192
193        let d = symmetric.dequantize(&expected_q);
194
195        assert_eq!(d, expected_d);
196    }
197}