burn_tensor/tensor/quantization/
strategy.rs1use alloc::vec::Vec;
2use core::marker::PhantomData;
3use num_traits::{Float, PrimInt, Signed};
4use serde::{Deserialize, Serialize};
5
6use super::{QuantizationMode, QuantizationScheme, QuantizationType};
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum QuantizationStrategy {
11 PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
13}
14
15impl QuantizationStrategy {
16 pub fn quantize(&self, values: &[f32]) -> Vec<i8> {
18 match self {
19 QuantizationStrategy::PerTensorSymmetricInt8(strategy) => strategy.quantize(values),
20 }
21 }
22
23 pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {
25 match self {
26 QuantizationStrategy::PerTensorSymmetricInt8(strategy) => strategy.dequantize(values),
27 }
28 }
29}
30
31impl QuantizationStrategy {
32 pub fn scheme(&self) -> QuantizationScheme {
34 match self {
35 QuantizationStrategy::PerTensorSymmetricInt8(_) => {
36 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8)
37 }
38 }
39 }
40}
41
42pub trait Quantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
45 fn range() -> (Q, Q);
47 fn new(alpha: E, beta: E) -> Self;
49 fn quantize(&self, values: &[E]) -> Vec<Q>;
51 fn quantize_one(&self, value: E) -> Q;
53 fn dequantize(&self, values: &[Q]) -> Vec<E>;
55 fn dequantize_one(&self, value: Q) -> E;
57}
58
59fn valid_scale<E: Float>(mut scale: E) -> E {
60 if scale.eq(&E::zero()) {
63 scale = E::from(0.1).unwrap();
64 }
65 scale
66}
67
68#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
70pub struct SymmetricQuantization<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> {
71 pub scale: E,
73 _q: PhantomData<Q>,
75}
76
77impl<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> SymmetricQuantization<E, Q> {
78 pub fn init(scale: E) -> Self {
80 Self {
81 scale: valid_scale(scale),
82 _q: PhantomData,
83 }
84 }
85}
86
87impl<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> Quantization<E, Q>
88 for SymmetricQuantization<E, Q>
89{
90 fn new(alpha: E, beta: E) -> Self {
91 let (a, b) = Self::range();
92 let a = E::from(a).unwrap();
93 let b = E::from(b).unwrap();
94
95 let alpha = alpha.abs().max(beta.abs());
97 let scale = valid_scale((alpha + alpha) / (b - a));
98 Self {
99 scale,
100 _q: PhantomData,
101 }
102 }
103
104 fn quantize(&self, values: &[E]) -> Vec<Q> {
105 values.iter().map(|x| self.quantize_one(*x)).collect()
106 }
107
108 fn dequantize(&self, values: &[Q]) -> Vec<E> {
109 values.iter().map(|x_q| self.dequantize_one(*x_q)).collect()
110 }
111
112 fn quantize_one(&self, value: E) -> Q {
113 let (a, b) = Self::range();
114 let a = E::from(a).unwrap();
115 let b = E::from(b).unwrap();
116
117 Q::from(value.div(self.scale).round().clamp(a, b)).unwrap()
119 }
120
121 fn dequantize_one(&self, value: Q) -> E {
122 self.scale * E::from(value).unwrap()
124 }
125
126 fn range() -> (Q, Q) {
127 let b = Q::max_value();
129 (b.neg(), b)
130 }
131}
132
133impl<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> PartialEq
134 for SymmetricQuantization<E, Q>
135{
136 fn eq(&self, other: &Self) -> bool {
137 self.scale == other.scale
138 }
139}
140
141impl<E: Float + Send + Sync, Q: PrimInt + Signed + Send + Sync> Eq for SymmetricQuantization<E, Q> {}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use alloc::vec;
147
148 #[test]
149 fn test_int8_symmetric_quantization() {
150 let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
151 let expected_q = vec![-127, -71, 0, 35];
152 let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
153
154 let symmetric = SymmetricQuantization::<f32, i8>::new(-1.8, 0.5);
155
156 let q: Vec<i8> = symmetric.quantize(&x);
157 assert_eq!(q, expected_q);
158
159 let d = symmetric.dequantize(&expected_q);
160
161 assert_eq!(d, expected_d);
162 }
163}