entrenar/quant/fake_quantize/
quantize.rs1use crate::Tensor;
4
5use super::config::FakeQuantConfig;
6
7#[derive(Clone, Debug)]
12pub struct FakeQuantize {
13 pub config: FakeQuantConfig,
15 pub scale: f32,
17 pub zero_point: i32,
19 pub initialized: bool,
21}
22
23impl FakeQuantize {
24 pub fn new(config: FakeQuantConfig) -> Self {
26 Self { config, scale: 1.0, zero_point: 0, initialized: false }
27 }
28
29 pub fn q4() -> Self {
31 Self::new(FakeQuantConfig::q4_symmetric())
32 }
33
34 pub fn q8() -> Self {
36 Self::new(FakeQuantConfig::q8_symmetric())
37 }
38
39 pub fn calibrate(&mut self, data: &[f32]) {
44 if data.is_empty() {
45 return;
46 }
47
48 let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
49 let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
50
51 if self.config.symmetric {
52 let max_abs = min_val.abs().max(max_val.abs());
54 self.scale = max_abs / self.config.qmax as f32;
55 self.zero_point = 0;
56 } else {
57 self.scale = (max_val - min_val) / (self.config.qmax - self.config.qmin) as f32;
59 self.zero_point = (self.config.qmin as f32 - min_val / self.scale).round() as i32;
60 self.zero_point = self.zero_point.clamp(self.config.qmin, self.config.qmax);
61 }
62
63 if self.scale < 1e-10 {
65 self.scale = 1e-10;
66 }
67
68 self.initialized = true;
69 }
70
71 pub fn forward(&self, input: &Tensor) -> Tensor {
76 let data: Vec<f32> = input.data().iter().map(|&x| self.fake_quantize_value(x)).collect();
77
78 Tensor::new(ndarray::arr1(&data), input.requires_grad())
79 }
80
81 pub fn forward_with_calibration(&mut self, input: &Tensor) -> Tensor {
85 if !self.initialized {
86 self.calibrate(input.data().as_slice().unwrap_or(&[]));
87 }
88 self.forward(input)
89 }
90
91 pub fn backward(&self, grad_output: &Tensor) -> Tensor {
99 grad_output.clone()
101 }
102
103 pub fn backward_clamped(&self, grad_output: &Tensor, input: &Tensor) -> Tensor {
108 let qmin_float = self.config.qmin as f32 * self.scale;
109 let qmax_float = self.config.qmax as f32 * self.scale;
110
111 let data: Vec<f32> = grad_output
112 .data()
113 .iter()
114 .zip(input.data().iter())
115 .map(|(&grad, &x)| {
116 if x < qmin_float || x > qmax_float {
118 0.0
119 } else {
120 grad
121 }
122 })
123 .collect();
124
125 Tensor::new(ndarray::arr1(&data), grad_output.requires_grad())
126 }
127
128 fn fake_quantize_value(&self, x: f32) -> f32 {
130 let q = if self.config.symmetric {
132 (x / self.scale).round().clamp(self.config.qmin as f32, self.config.qmax as f32) as i32
133 } else {
134 ((x / self.scale) + self.zero_point as f32)
135 .round()
136 .clamp(self.config.qmin as f32, self.config.qmax as f32) as i32
137 };
138
139 if self.config.symmetric {
141 q as f32 * self.scale
142 } else {
143 (q - self.zero_point) as f32 * self.scale
144 }
145 }
146
147 pub fn scale(&self) -> f32 {
149 self.scale
150 }
151
152 pub fn zero_point(&self) -> i32 {
154 self.zero_point
155 }
156
157 pub fn is_initialized(&self) -> bool {
159 self.initialized
160 }
161
162 pub fn num_levels(&self) -> usize {
164 (self.config.qmax - self.config.qmin + 1) as usize
165 }
166}