1use crate::error::{QuantError, QuantResult};
25
26#[derive(Debug, Clone)]
33pub struct FakeQuantize {
34 pub bits: u32,
36 pub symmetric: bool,
38 pub scale: f32,
40 pub zero_point: i32,
42 pub enabled: bool,
45}
46
47impl FakeQuantize {
48 pub fn new(bits: u32, symmetric: bool, scale: f32, zero_point: i32) -> QuantResult<Self> {
55 if bits == 0 || bits > 16 {
56 return Err(QuantError::InvalidBitWidth { bits });
57 }
58 if !scale.is_finite() || scale <= 0.0 {
59 return Err(QuantError::InvalidScale { scale });
60 }
61 Ok(Self {
62 bits,
63 symmetric,
64 scale,
65 zero_point,
66 enabled: true,
67 })
68 }
69
70 pub fn with_defaults(bits: u32, symmetric: bool) -> QuantResult<Self> {
76 Self::new(bits, symmetric, 1.0, 0)
77 }
78
79 pub fn update_params(&mut self, scale: f32, zero_point: i32) -> QuantResult<()> {
85 if !scale.is_finite() || scale <= 0.0 {
86 return Err(QuantError::InvalidScale { scale });
87 }
88 self.scale = scale;
89 self.zero_point = zero_point;
90 Ok(())
91 }
92
93 #[must_use]
95 pub fn quant_range(&self) -> (i32, i32) {
96 if self.symmetric {
97 let half = 1i32 << (self.bits - 1);
98 (-half, half - 1)
99 } else {
100 (0i32, (1i32 << self.bits) - 1)
101 }
102 }
103
104 #[must_use]
106 pub fn float_range(&self) -> (f32, f32) {
107 let (q_min, q_max) = self.quant_range();
108 let zp = self.zero_point as f32;
109 let lo = (q_min as f32 - zp) * self.scale;
110 let hi = (q_max as f32 - zp) * self.scale;
111 (lo, hi)
112 }
113
114 #[must_use]
118 pub fn forward(&self, x: &[f32]) -> Vec<f32> {
119 if !self.enabled {
120 return x.to_vec();
121 }
122 let (q_min, q_max) = self.quant_range();
123 let zp = self.zero_point as f32;
124 x.iter()
125 .map(|&v| {
126 let q = (v / self.scale + zp)
127 .round()
128 .clamp(q_min as f32, q_max as f32);
129 (q - zp) * self.scale
130 })
131 .collect()
132 }
133
134 pub fn backward(&self, grad_output: &[f32], x: &[f32]) -> QuantResult<Vec<f32>> {
143 if grad_output.len() != x.len() {
144 return Err(QuantError::DimensionMismatch {
145 expected: x.len(),
146 got: grad_output.len(),
147 });
148 }
149 if !self.enabled {
150 return Ok(grad_output.to_vec());
151 }
152 let (x_min, x_max) = self.float_range();
153 let grad = grad_output
154 .iter()
155 .zip(x.iter())
156 .map(|(&g, &v)| if v >= x_min && v <= x_max { g } else { 0.0 })
157 .collect();
158 Ok(grad)
159 }
160
161 #[must_use]
165 pub fn quantization_noise(&self, x: &[f32]) -> f32 {
166 if x.is_empty() {
167 return 0.0;
168 }
169 let fq = self.forward(x);
170 let mse = x
171 .iter()
172 .zip(fq.iter())
173 .map(|(a, b)| (a - b).powi(2))
174 .sum::<f32>();
175 mse / x.len() as f32
176 }
177}
178
179#[cfg(test)]
182mod tests {
183 use super::*;
184 use approx::assert_abs_diff_eq;
185
186 #[test]
187 fn forward_quantize_dequantize_int8() {
188 let fq = FakeQuantize::new(8, true, 1.0 / 127.0, 0).unwrap();
189 let out = fq.forward(&[1.0_f32]);
191 assert_abs_diff_eq!(out[0], 1.0, epsilon = 0.01);
192 }
193
194 #[test]
195 fn forward_passthrough_when_disabled() {
196 let mut fq = FakeQuantize::new(8, true, 0.01, 0).unwrap();
197 fq.enabled = false;
198 let data = vec![1.5_f32, -2.3, 0.7];
199 let out = fq.forward(&data);
200 assert_eq!(out, data);
201 }
202
203 #[test]
204 fn backward_ste_passthrough() {
205 let fq = FakeQuantize::new(8, true, 1.0 / 127.0, 0).unwrap();
206 let x = vec![0.5_f32, -0.5];
207 let grad = vec![1.0_f32, -1.0];
208 let ste = fq.backward(&grad, &x).unwrap();
209 assert_abs_diff_eq!(ste[0], 1.0, epsilon = 1e-6);
211 assert_abs_diff_eq!(ste[1], -1.0, epsilon = 1e-6);
212 }
213
214 #[test]
215 fn backward_ste_zero_outside_range() {
216 let fq = FakeQuantize::new(8, true, 1.0 / 127.0, 0).unwrap();
217 let x = vec![2.0_f32, -2.0];
219 let grad = vec![1.0_f32, 1.0];
220 let ste = fq.backward(&grad, &x).unwrap();
221 assert_abs_diff_eq!(ste[0], 0.0, epsilon = 1e-6);
222 assert_abs_diff_eq!(ste[1], 0.0, epsilon = 1e-6);
223 }
224
225 #[test]
226 fn backward_dimension_mismatch_error() {
227 let fq = FakeQuantize::new(8, true, 0.01, 0).unwrap();
228 let x = vec![0.5_f32; 3];
229 let grad = vec![1.0_f32; 4];
230 assert!(matches!(
231 fq.backward(&grad, &x),
232 Err(QuantError::DimensionMismatch { .. })
233 ));
234 }
235
236 #[test]
237 fn invalid_scale_error() {
238 assert!(matches!(
239 FakeQuantize::new(8, true, -0.01, 0),
240 Err(QuantError::InvalidScale { .. })
241 ));
242 assert!(matches!(
243 FakeQuantize::new(8, true, 0.0, 0),
244 Err(QuantError::InvalidScale { .. })
245 ));
246 }
247
248 #[test]
249 fn invalid_bit_width_error() {
250 assert!(matches!(
251 FakeQuantize::new(0, true, 0.01, 0),
252 Err(QuantError::InvalidBitWidth { bits: 0 })
253 ));
254 assert!(matches!(
255 FakeQuantize::new(17, true, 0.01, 0),
256 Err(QuantError::InvalidBitWidth { bits: 17 })
257 ));
258 }
259
260 #[test]
261 fn quant_range_int8_symmetric() {
262 let fq = FakeQuantize::new(8, true, 0.01, 0).unwrap();
263 assert_eq!(fq.quant_range(), (-128, 127));
264 }
265
266 #[test]
267 fn quant_range_int4_asymmetric() {
268 let fq = FakeQuantize::new(4, false, 0.01, 0).unwrap();
269 assert_eq!(fq.quant_range(), (0, 15));
270 }
271
272 #[test]
273 fn quantization_noise_zero_for_fine_scale() {
274 let fq = FakeQuantize::new(8, true, 1.0 / 127.0, 0).unwrap();
276 let data: Vec<f32> = (0..128).map(|i| i as f32 / 128.0 - 0.5).collect();
277 let noise = fq.quantization_noise(&data);
278 assert!(noise < 1e-5, "noise too high: {noise}");
279 }
280
281 #[test]
282 fn update_params_works() {
283 let mut fq = FakeQuantize::with_defaults(8, true).unwrap();
284 fq.update_params(0.5, 0).unwrap();
285 assert_abs_diff_eq!(fq.scale, 0.5, epsilon = 1e-7);
286 }
287
288 #[test]
289 fn asymmetric_forward_with_nonzero_zp() {
290 let fq = FakeQuantize::new(4, false, 1.0 / 15.0, 0).unwrap();
292 let out = fq.forward(&[0.0_f32, 0.5, 1.0]);
293 assert_abs_diff_eq!(out[0], 0.0, epsilon = 0.001);
295 assert!(out[1] > 0.4 && out[1] < 0.6, "midpoint: {}", out[1]);
296 assert_abs_diff_eq!(out[2], 1.0, epsilon = 0.001);
297 }
298}