1use super::granularity::{
11 calibrate_per_tensor, dequantize_with_params, quantization_mse, quantize_with_params,
12 QuantMode, QuantParams,
13};
14use serde::{Deserialize, Serialize};
15
16#[derive(Clone, Debug, Default, Serialize, Deserialize)]
18pub struct QuantErrorStats {
19 pub mse: f32,
21 pub mae: f32,
23 pub max_error: f32,
25 pub sqnr_db: f32,
27 pub outlier_rate: f32,
29 pub num_samples: usize,
31}
32
33impl QuantErrorStats {
34 pub fn rmse(&self) -> f32 {
36 contract_pre_rmse!();
37 self.mse.sqrt()
38 }
39}
40
41pub fn analyze_error(
48 original: &[f32],
49 params: &QuantParams,
50 outlier_threshold: f32,
51) -> QuantErrorStats {
52 if original.is_empty() {
53 return QuantErrorStats::default();
54 }
55
56 let quantized = quantize_with_params(original, params);
57 let dequantized = dequantize_with_params(&quantized, params);
58
59 let errors: Vec<f32> =
60 original.iter().zip(dequantized.iter()).map(|(o, d)| (o - d).abs()).collect();
61
62 let mse = quantization_mse(original, &dequantized);
63 let mae = errors.iter().sum::<f32>() / errors.len().max(1) as f32;
64 let max_error = errors.iter().copied().fold(0.0f32, f32::max);
65
66 let outlier_count = errors.iter().filter(|&&e| e > outlier_threshold).count();
67 let outlier_rate = outlier_count as f32 / errors.len().max(1) as f32;
68
69 let signal_power: f32 =
71 original.iter().map(|x| x * x).sum::<f32>() / original.len().max(1) as f32;
72 let noise_power = mse;
73 let sqnr_db = if noise_power > 1e-10 {
74 10.0 * (signal_power / noise_power).max(f32::MIN_POSITIVE).log10()
75 } else {
76 f32::INFINITY
77 };
78
79 QuantErrorStats { mse, mae, max_error, sqnr_db, outlier_rate, num_samples: original.len() }
80}
81
82pub fn theoretical_max_error(params: &QuantParams) -> f32 {
87 let max_scale = params.scales.iter().copied().fold(0.0f32, f32::max);
88 max_scale / 2.0
89}
90
91pub fn theoretical_sqnr(bits: u8) -> f32 {
96 6.02 * f32::from(bits) + 1.76
97}
98
99pub fn error_within_bounds(stats: &QuantErrorStats, params: &QuantParams, tolerance: f32) -> bool {
101 let theoretical_max = theoretical_max_error(params);
102 stats.max_error <= theoretical_max * (1.0 + tolerance)
103}
104
105pub fn scale_sensitivity(
109 values: &[f32],
110 params: &QuantParams,
111 perturbation: f32,
112) -> (f32, f32, f32) {
113 let quantized = quantize_with_params(values, params);
115 let dequantized = dequantize_with_params(&quantized, params);
116 let original_mse = quantization_mse(values, &dequantized);
117
118 let perturbed_scales: Vec<f32> =
120 params.scales.iter().map(|s| s * (1.0 + perturbation)).collect();
121
122 let perturbed_params = QuantParams {
123 scales: perturbed_scales,
124 zero_points: params.zero_points.clone(),
125 granularity: params.granularity,
126 mode: params.mode,
127 bits: params.bits,
128 };
129
130 let perturbed_quantized = quantize_with_params(values, &perturbed_params);
131 let perturbed_dequantized = dequantize_with_params(&perturbed_quantized, &perturbed_params);
132 let perturbed_mse = quantization_mse(values, &perturbed_dequantized);
133
134 let sensitivity = if perturbation.abs() > 1e-10 {
135 (perturbed_mse - original_mse).abs() / (perturbation.abs() * original_mse.max(1e-10))
136 } else {
137 0.0
138 };
139
140 (original_mse, perturbed_mse, sensitivity)
141}
142
143pub fn compare_bit_widths(values: &[f32]) -> (f32, f32, f32) {
147 let params_4bit = calibrate_per_tensor(values, 4, QuantMode::Symmetric);
148 let params_8bit = calibrate_per_tensor(values, 8, QuantMode::Symmetric);
149
150 let q4 = quantize_with_params(values, ¶ms_4bit);
151 let q8 = quantize_with_params(values, ¶ms_8bit);
152
153 let d4 = dequantize_with_params(&q4, ¶ms_4bit);
154 let d8 = dequantize_with_params(&q8, ¶ms_8bit);
155
156 let mse_4bit = quantization_mse(values, &d4);
157 let mse_8bit = quantization_mse(values, &d8);
158
159 let improvement = if mse_8bit > 1e-10 {
160 mse_4bit / mse_8bit
161 } else if mse_4bit > 1e-10 {
162 f32::INFINITY
163 } else {
164 1.0
165 };
166
167 (mse_4bit, mse_8bit, improvement)
168}
169
170pub fn analyze_outlier_impact(values: &[f32], percentile: f32) -> (f32, f32, f32) {
174 if values.is_empty() || percentile <= 0.0 || percentile >= 100.0 {
175 return (0.0, 0.0, 0.0);
176 }
177
178 let mut sorted: Vec<f32> = values.iter().map(|v| v.abs()).collect();
180 sorted.sort_by(f32::total_cmp);
181
182 let upper_idx = (percentile / 100.0 * sorted.len() as f32) as usize;
183 let threshold = *sorted.get(upper_idx.min(sorted.len() - 1)).unwrap_or(&0.0);
184
185 let lower_threshold = -threshold;
186 let upper_threshold = threshold;
187
188 let clipped: Vec<f32> =
190 values.iter().map(|&v| v.clamp(lower_threshold, upper_threshold)).collect();
191
192 let params_original = calibrate_per_tensor(values, 8, QuantMode::Symmetric);
194 let params_clipped = calibrate_per_tensor(&clipped, 8, QuantMode::Symmetric);
195
196 let q_orig = quantize_with_params(values, ¶ms_original);
197 let q_clip = quantize_with_params(&clipped, ¶ms_clipped);
198
199 let d_orig = dequantize_with_params(&q_orig, ¶ms_original);
200 let d_clip = dequantize_with_params(&q_clip, ¶ms_clipped);
201
202 let mse_original = quantization_mse(values, &d_orig);
203 let mse_clipped = quantization_mse(&clipped, &d_clip);
204
205 let outlier_impact = if mse_clipped > 1e-10 {
206 mse_original / mse_clipped
207 } else if mse_original > 1e-10 {
208 f32::INFINITY
209 } else {
210 1.0
211 };
212
213 (mse_original, mse_clipped, outlier_impact)
214}
215
216#[cfg(test)]
217mod tests {
218 use super::super::granularity::{calibrate_per_channel, QuantGranularity};
219 use super::*;
220 use approx::assert_abs_diff_eq;
221 use proptest::prelude::*;
222
223 #[test]
224 fn test_error_stats_basic() {
225 let values: Vec<f32> = (0..100).map(|i| (i as f32 * 0.1).sin()).collect();
226 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
227 let stats = analyze_error(&values, ¶ms, 0.01);
228
229 assert!(stats.mse >= 0.0);
230 assert!(stats.mae >= 0.0);
231 assert!(stats.max_error >= 0.0);
232 assert!(stats.sqnr_db > 0.0);
233 assert_eq!(stats.num_samples, 100);
234 }
235
236 #[test]
237 fn test_rmse_calculation() {
238 let stats = QuantErrorStats { mse: 4.0, ..Default::default() };
239 assert_abs_diff_eq!(stats.rmse(), 2.0, epsilon = 1e-6);
240 }
241
242 #[test]
243 fn test_theoretical_max_error() {
244 let params = QuantParams {
245 scales: vec![0.1, 0.2],
246 zero_points: vec![],
247 granularity: QuantGranularity::PerChannel,
248 mode: QuantMode::Symmetric,
249 bits: 8,
250 };
251
252 let max_err = theoretical_max_error(¶ms);
253 assert_abs_diff_eq!(max_err, 0.1, epsilon = 1e-6); }
255
256 #[test]
257 fn test_theoretical_sqnr() {
258 let sqnr_8bit = theoretical_sqnr(8);
260 assert_abs_diff_eq!(sqnr_8bit, 49.92, epsilon = 0.01);
261
262 let sqnr_4bit = theoretical_sqnr(4);
264 assert_abs_diff_eq!(sqnr_4bit, 25.84, epsilon = 0.01);
265 }
266
267 #[test]
268 fn test_error_within_bounds() {
269 let values: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
270 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
271 let stats = analyze_error(&values, ¶ms, 0.1);
272
273 assert!(error_within_bounds(&stats, ¶ms, 0.1));
275 }
276
277 #[test]
278 fn test_scale_sensitivity() {
279 let values: Vec<f32> = (0..100).map(|i| (i as f32 * 0.1).sin()).collect();
280 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
281
282 let (orig_mse, pert_mse, sensitivity) = scale_sensitivity(&values, ¶ms, 0.1);
283
284 assert!(orig_mse >= 0.0);
285 assert!(pert_mse >= 0.0);
286 assert!(sensitivity >= 0.0);
287 }
288
289 #[test]
290 fn test_compare_bit_widths() {
291 let values: Vec<f32> = (0..100).map(|i| (i as f32 * 0.1).sin()).collect();
292
293 let (mse_4bit, mse_8bit, improvement) = compare_bit_widths(&values);
294
295 assert!(mse_8bit <= mse_4bit);
297 assert!(improvement >= 1.0);
298 }
299
300 #[test]
301 fn test_outlier_impact() {
302 let mut values: Vec<f32> = (0..100).map(|i| (i as f32 * 0.01).sin()).collect();
304 values.push(100.0); values.push(-100.0); let (mse_orig, mse_clip, impact) = analyze_outlier_impact(&values, 99.0);
308
309 assert!(mse_orig >= 0.0);
311 assert!(mse_clip >= 0.0);
312 assert!(impact >= 0.0);
313 }
314
315 #[test]
316 fn test_empty_values() {
317 let values: Vec<f32> = vec![];
318 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
319 let stats = analyze_error(&values, ¶ms, 0.1);
320
321 assert_eq!(stats.num_samples, 0);
322 }
323
324 #[test]
325 fn test_zeros_error() {
326 let values = vec![0.0; 100];
327 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
328 let stats = analyze_error(&values, ¶ms, 0.001);
329
330 assert!(stats.mse < 1e-10);
332 assert!(stats.mae < 1e-10);
333 }
334
335 proptest! {
338 #![proptest_config(ProptestConfig::with_cases(200))]
339
340 #[test]
341 fn prop_mse_non_negative(values in proptest::collection::vec(-100.0f32..100.0, 10..100)) {
342 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
343 let stats = analyze_error(&values, ¶ms, 0.1);
344
345 prop_assert!(stats.mse >= 0.0, "MSE must be non-negative");
346 prop_assert!(stats.mae >= 0.0, "MAE must be non-negative");
347 prop_assert!(stats.max_error >= 0.0, "Max error must be non-negative");
348 }
349
350 #[test]
351 fn prop_8bit_better_than_4bit(values in proptest::collection::vec(-100.0f32..100.0, 10..100)) {
352 let (mse_4bit, mse_8bit, _) = compare_bit_widths(&values);
353
354 prop_assert!(
355 mse_8bit <= mse_4bit * 1.01, "8-bit MSE ({}) should be <= 4-bit MSE ({})",
357 mse_8bit,
358 mse_4bit
359 );
360 }
361
362 #[test]
363 fn prop_error_bounded(values in proptest::collection::vec(-100.0f32..100.0, 10..100)) {
364 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
365 let stats = analyze_error(&values, ¶ms, 0.1);
366
367 let theoretical_max = theoretical_max_error(¶ms);
369 prop_assert!(
370 stats.max_error <= theoretical_max * 1.5,
371 "Max error ({}) should be <= theoretical max * 1.5 ({})",
372 stats.max_error,
373 theoretical_max * 1.5
374 );
375 }
376
377 #[test]
378 fn prop_sqnr_positive_for_nonzero_signal(
379 values in proptest::collection::vec(
380 prop_oneof![
381 -100.0f32..-1.0,
382 1.0f32..100.0,
383 ],
384 10..100
385 )
386 ) {
387 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
388 let stats = analyze_error(&values, ¶ms, 0.1);
389
390 prop_assert!(stats.sqnr_db > 0.0, "SQNR must be positive for non-zero signal");
392 }
393
394 #[test]
395 fn prop_outlier_rate_bounded(
396 values in proptest::collection::vec(-100.0f32..100.0, 10..100),
397 threshold in 0.001f32..10.0
398 ) {
399 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
400 let stats = analyze_error(&values, ¶ms, threshold);
401
402 prop_assert!(
403 stats.outlier_rate >= 0.0 && stats.outlier_rate <= 1.0,
404 "Outlier rate must be in [0, 1], got {}",
405 stats.outlier_rate
406 );
407 }
408
409 #[test]
410 fn prop_per_channel_lower_error(
411 num_channels in 2usize..5,
412 features_per_channel in 5usize..20,
413 scale_multiplier in 2.0f32..20.0
414 ) {
415 let values: Vec<f32> = (0..num_channels)
417 .flat_map(|ch| {
418 let scale = (ch as f32 + 1.0) * scale_multiplier;
419 (0..features_per_channel).map(move |i| {
420 (i as f32 / features_per_channel as f32 - 0.5) * scale
421 })
422 })
423 .collect();
424
425 let params_pt = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
426 let params_pc = calibrate_per_channel(&values, num_channels, 8, QuantMode::Symmetric);
427
428 let stats_pt = analyze_error(&values, ¶ms_pt, 0.1);
429 let stats_pc = analyze_error(&values, ¶ms_pc, 0.1);
430
431 prop_assert!(
432 stats_pc.mse <= stats_pt.mse * 1.01,
433 "Per-channel MSE ({}) should be <= per-tensor MSE ({})",
434 stats_pc.mse,
435 stats_pt.mse
436 );
437 }
438
439 #[test]
440 fn prop_scale_sensitivity_finite(
441 values in proptest::collection::vec(-100.0f32..100.0, 10..100),
442 perturbation in 0.01f32..0.5
443 ) {
444 let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
445 let (orig, pert, sens) = scale_sensitivity(&values, ¶ms, perturbation);
446
447 prop_assert!(orig.is_finite(), "Original MSE must be finite");
448 prop_assert!(pert.is_finite(), "Perturbed MSE must be finite");
449 prop_assert!(sens.is_finite(), "Sensitivity must be finite");
450 }
451 }
452}