burn_tensor/tensor/quantization/
strategy.rs1use alloc::vec::Vec;
2use num_traits::{Float, PrimInt};
3use serde::{Deserialize, Serialize};
4
5use super::{BlockSize, QuantValue};
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9pub enum QuantizationStrategy {
10 PerTensorSymmetric(SymmetricQuantization<f32>),
12 PerBlockSymmetric(Vec<SymmetricQuantization<f32>>, BlockSize),
14}
15
16impl QuantizationStrategy {
17 pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
19 match self {
20 QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.quantize(values),
21 QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {
22 let block_elems = block_size.num_elements();
23 let num_blocks = strategy.len();
24 let numel = values.len();
25 assert_eq!(
26 numel / block_elems,
27 num_blocks,
28 "Invalid per-block quantization with num blocks {num_blocks} and {numel} values"
29 );
30 values
31 .chunks(block_elems)
32 .enumerate()
33 .flat_map(|(block_id, block)| strategy[block_id].quantize(block))
34 .collect()
35 }
36 }
37 }
38
39 pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
41 match self {
42 QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.dequantize(values),
43 QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {
44 let block_elems = block_size.num_elements();
45 let num_blocks = strategy.len();
46 let numel = values.len();
47 assert_eq!(
48 numel / block_elems,
49 num_blocks,
50 "Invalid per-block quantization with block size {block_elems}, num blocks {num_blocks} and {numel} values"
51 );
52 values
53 .chunks(block_elems)
54 .enumerate()
55 .flat_map(|(block_id, block)| strategy[block_id].dequantize(block))
56 .collect()
57 }
58 }
59 }
60}
61
62pub trait Quantization<E: Float + Send + Sync> {
65 fn range(&self) -> (E, E);
67 fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q>;
69 fn quantize_one<Q: PrimInt>(&self, value: E) -> Q;
71 fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E>;
73 fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E;
75}
76
77fn valid_scale<E: Float>(mut scale: E) -> E {
78 if scale.eq(&E::zero()) {
81 scale = E::from(0.1).unwrap();
82 }
83 scale
84}
85
86#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
88pub struct SymmetricQuantization<E: Float + Send + Sync> {
89 pub scale: E,
91 value: QuantValue,
93}
94
95impl<E: Float + Send + Sync> SymmetricQuantization<E> {
96 pub fn init(scale: E, value: QuantValue) -> Self {
98 Self {
99 scale: valid_scale(scale),
100 value,
101 }
102 }
103
104 #[allow(dead_code)]
105 fn new(alpha: E, beta: E, value: QuantValue) -> Self {
107 let (a, b) = value.range();
108 let a = E::from(a).unwrap();
109 let b = E::from(b).unwrap();
110
111 let alpha = alpha.abs().max(beta.abs());
113 let scale = valid_scale((alpha + alpha) / (b - a));
114 Self { scale, value }
115 }
116}
117
118impl<E: Float + Send + Sync> Quantization<E> for SymmetricQuantization<E> {
119 fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q> {
120 values.iter().map(|x| self.quantize_one(*x)).collect()
121 }
122
123 fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E> {
124 values.iter().map(|x_q| self.dequantize_one(*x_q)).collect()
125 }
126
127 fn quantize_one<Q: PrimInt>(&self, value: E) -> Q {
128 let (a, b) = self.range();
129
130 Q::from(value.div(self.scale).round().clamp(a, b)).unwrap()
132 }
133
134 fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E {
135 self.scale * E::from(value).unwrap()
137 }
138
139 fn range(&self) -> (E, E) {
140 let (a, b) = self.value.range();
141 let a = E::from(a).unwrap();
142 let b = E::from(b).unwrap();
143 (a, b)
144 }
145}
146
147impl<E: Float + Send + Sync> PartialEq for SymmetricQuantization<E> {
148 fn eq(&self, other: &Self) -> bool {
149 self.scale == other.scale
150 }
151}
152
153impl<E: Float + Send + Sync> Eq for SymmetricQuantization<E> {}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use alloc::vec;
159
160 #[test]
161 fn test_int8_symmetric_quantization() {
162 let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
163 let expected_q = vec![-127, -71, 0, 35];
164 let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
165
166 let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);
167
168 let q: Vec<i8> = symmetric.quantize(&x);
169 assert_eq!(q, expected_q);
170
171 let d = symmetric.dequantize(&expected_q);
172
173 assert_eq!(d, expected_d);
174 }
175
176 #[test]
177 fn test_int8_symmetric_quantization_per_block() {
178 let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5];
179 let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35];
180 let expected_d = vec![
181 -1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063,
182 ];
183
184 let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);
185 let strategy = QuantizationStrategy::PerBlockSymmetric(
186 vec![symmetric, symmetric],
187 BlockSize::new([4]),
188 );
189
190 let q: Vec<i8> = strategy.quantize(&x);
191 assert_eq!(q, expected_q);
192
193 let d = symmetric.dequantize(&expected_q);
194
195 assert_eq!(d, expected_d);
196 }
197}