Skip to main content

entrenar/quant/
error_analysis.rs

1//! Quantization Error Analysis and Property Tests
2//!
3//! Provides comprehensive error analysis for quantization:
4//! - Error bounds validation
5//! - Error distribution analysis
6//! - Outlier impact measurement
7//! - Scale sensitivity analysis
8//! - Numerical stability tests
9
10use super::granularity::{
11    calibrate_per_tensor, dequantize_with_params, quantization_mse, quantize_with_params,
12    QuantMode, QuantParams,
13};
14use serde::{Deserialize, Serialize};
15
16/// Error statistics for quantization analysis
17#[derive(Clone, Debug, Default, Serialize, Deserialize)]
18pub struct QuantErrorStats {
19    /// Mean Squared Error
20    pub mse: f32,
21    /// Mean Absolute Error
22    pub mae: f32,
23    /// Maximum absolute error
24    pub max_error: f32,
25    /// Signal-to-Quantization-Noise Ratio (SQNR) in dB
26    pub sqnr_db: f32,
27    /// Percentage of values with error > threshold
28    pub outlier_rate: f32,
29    /// Number of samples
30    pub num_samples: usize,
31}
32
33impl QuantErrorStats {
34    /// Root Mean Squared Error
35    pub fn rmse(&self) -> f32 {
36        contract_pre_rmse!();
37        self.mse.sqrt()
38    }
39}
40
41/// Analyze quantization error for given values and parameters
42///
43/// # Arguments
44/// * `original` - Original f32 values
45/// * `params` - Quantization parameters
46/// * `outlier_threshold` - Error threshold for outlier detection
47pub fn analyze_error(
48    original: &[f32],
49    params: &QuantParams,
50    outlier_threshold: f32,
51) -> QuantErrorStats {
52    if original.is_empty() {
53        return QuantErrorStats::default();
54    }
55
56    let quantized = quantize_with_params(original, params);
57    let dequantized = dequantize_with_params(&quantized, params);
58
59    let errors: Vec<f32> =
60        original.iter().zip(dequantized.iter()).map(|(o, d)| (o - d).abs()).collect();
61
62    let mse = quantization_mse(original, &dequantized);
63    let mae = errors.iter().sum::<f32>() / errors.len().max(1) as f32;
64    let max_error = errors.iter().copied().fold(0.0f32, f32::max);
65
66    let outlier_count = errors.iter().filter(|&&e| e > outlier_threshold).count();
67    let outlier_rate = outlier_count as f32 / errors.len().max(1) as f32;
68
69    // SQNR = 10 * log10(signal_power / noise_power)
70    let signal_power: f32 =
71        original.iter().map(|x| x * x).sum::<f32>() / original.len().max(1) as f32;
72    let noise_power = mse;
73    let sqnr_db = if noise_power > 1e-10 {
74        10.0 * (signal_power / noise_power).max(f32::MIN_POSITIVE).log10()
75    } else {
76        f32::INFINITY
77    };
78
79    QuantErrorStats { mse, mae, max_error, sqnr_db, outlier_rate, num_samples: original.len() }
80}
81
82/// Calculate theoretical maximum error for given quantization parameters
83///
84/// For symmetric quantization: max_error = scale / 2 (half quantization step)
85/// For asymmetric: max_error = scale / 2
86pub fn theoretical_max_error(params: &QuantParams) -> f32 {
87    let max_scale = params.scales.iter().copied().fold(0.0f32, f32::max);
88    max_scale / 2.0
89}
90
91/// Calculate expected SQNR for uniform quantization
92///
93/// Theoretical SQNR for b-bit quantization: 6.02 * b + 1.76 dB
94/// This assumes uniform distribution of input values
95pub fn theoretical_sqnr(bits: u8) -> f32 {
96    6.02 * f32::from(bits) + 1.76
97}
98
99/// Check if error is within expected bounds
100pub fn error_within_bounds(stats: &QuantErrorStats, params: &QuantParams, tolerance: f32) -> bool {
101    let theoretical_max = theoretical_max_error(params);
102    stats.max_error <= theoretical_max * (1.0 + tolerance)
103}
104
105/// Analyze sensitivity of error to scale perturbation
106///
107/// Returns (original_mse, perturbed_mse, sensitivity)
108pub fn scale_sensitivity(
109    values: &[f32],
110    params: &QuantParams,
111    perturbation: f32,
112) -> (f32, f32, f32) {
113    // Original error
114    let quantized = quantize_with_params(values, params);
115    let dequantized = dequantize_with_params(&quantized, params);
116    let original_mse = quantization_mse(values, &dequantized);
117
118    // Perturbed scales
119    let perturbed_scales: Vec<f32> =
120        params.scales.iter().map(|s| s * (1.0 + perturbation)).collect();
121
122    let perturbed_params = QuantParams {
123        scales: perturbed_scales,
124        zero_points: params.zero_points.clone(),
125        granularity: params.granularity,
126        mode: params.mode,
127        bits: params.bits,
128    };
129
130    let perturbed_quantized = quantize_with_params(values, &perturbed_params);
131    let perturbed_dequantized = dequantize_with_params(&perturbed_quantized, &perturbed_params);
132    let perturbed_mse = quantization_mse(values, &perturbed_dequantized);
133
134    let sensitivity = if perturbation.abs() > 1e-10 {
135        (perturbed_mse - original_mse).abs() / (perturbation.abs() * original_mse.max(1e-10))
136    } else {
137        0.0
138    };
139
140    (original_mse, perturbed_mse, sensitivity)
141}
142
143/// Compare error between different bit widths
144///
145/// Returns (mse_4bit, mse_8bit, improvement_ratio)
146pub fn compare_bit_widths(values: &[f32]) -> (f32, f32, f32) {
147    let params_4bit = calibrate_per_tensor(values, 4, QuantMode::Symmetric);
148    let params_8bit = calibrate_per_tensor(values, 8, QuantMode::Symmetric);
149
150    let q4 = quantize_with_params(values, &params_4bit);
151    let q8 = quantize_with_params(values, &params_8bit);
152
153    let d4 = dequantize_with_params(&q4, &params_4bit);
154    let d8 = dequantize_with_params(&q8, &params_8bit);
155
156    let mse_4bit = quantization_mse(values, &d4);
157    let mse_8bit = quantization_mse(values, &d8);
158
159    let improvement = if mse_8bit > 1e-10 {
160        mse_4bit / mse_8bit
161    } else if mse_4bit > 1e-10 {
162        f32::INFINITY
163    } else {
164        1.0
165    };
166
167    (mse_4bit, mse_8bit, improvement)
168}
169
170/// Analyze impact of outliers on quantization error
171///
172/// Returns (original_mse, clipped_mse, outlier_impact)
173pub fn analyze_outlier_impact(values: &[f32], percentile: f32) -> (f32, f32, f32) {
174    if values.is_empty() || percentile <= 0.0 || percentile >= 100.0 {
175        return (0.0, 0.0, 0.0);
176    }
177
178    // Sort values to find percentile thresholds
179    let mut sorted: Vec<f32> = values.iter().map(|v| v.abs()).collect();
180    sorted.sort_by(f32::total_cmp);
181
182    let upper_idx = (percentile / 100.0 * sorted.len() as f32) as usize;
183    let threshold = *sorted.get(upper_idx.min(sorted.len() - 1)).unwrap_or(&0.0);
184
185    let lower_threshold = -threshold;
186    let upper_threshold = threshold;
187
188    // Clipped values
189    let clipped: Vec<f32> =
190        values.iter().map(|&v| v.clamp(lower_threshold, upper_threshold)).collect();
191
192    // Quantize both
193    let params_original = calibrate_per_tensor(values, 8, QuantMode::Symmetric);
194    let params_clipped = calibrate_per_tensor(&clipped, 8, QuantMode::Symmetric);
195
196    let q_orig = quantize_with_params(values, &params_original);
197    let q_clip = quantize_with_params(&clipped, &params_clipped);
198
199    let d_orig = dequantize_with_params(&q_orig, &params_original);
200    let d_clip = dequantize_with_params(&q_clip, &params_clipped);
201
202    let mse_original = quantization_mse(values, &d_orig);
203    let mse_clipped = quantization_mse(&clipped, &d_clip);
204
205    let outlier_impact = if mse_clipped > 1e-10 {
206        mse_original / mse_clipped
207    } else if mse_original > 1e-10 {
208        f32::INFINITY
209    } else {
210        1.0
211    };
212
213    (mse_original, mse_clipped, outlier_impact)
214}
215
216#[cfg(test)]
217mod tests {
218    use super::super::granularity::{calibrate_per_channel, QuantGranularity};
219    use super::*;
220    use approx::assert_abs_diff_eq;
221    use proptest::prelude::*;
222
223    #[test]
224    fn test_error_stats_basic() {
225        let values: Vec<f32> = (0..100).map(|i| (i as f32 * 0.1).sin()).collect();
226        let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
227        let stats = analyze_error(&values, &params, 0.01);
228
229        assert!(stats.mse >= 0.0);
230        assert!(stats.mae >= 0.0);
231        assert!(stats.max_error >= 0.0);
232        assert!(stats.sqnr_db > 0.0);
233        assert_eq!(stats.num_samples, 100);
234    }
235
236    #[test]
237    fn test_rmse_calculation() {
238        let stats = QuantErrorStats { mse: 4.0, ..Default::default() };
239        assert_abs_diff_eq!(stats.rmse(), 2.0, epsilon = 1e-6);
240    }
241
242    #[test]
243    fn test_theoretical_max_error() {
244        let params = QuantParams {
245            scales: vec![0.1, 0.2],
246            zero_points: vec![],
247            granularity: QuantGranularity::PerChannel,
248            mode: QuantMode::Symmetric,
249            bits: 8,
250        };
251
252        let max_err = theoretical_max_error(&params);
253        assert_abs_diff_eq!(max_err, 0.1, epsilon = 1e-6); // max scale / 2
254    }
255
256    #[test]
257    fn test_theoretical_sqnr() {
258        // 8-bit: 6.02 * 8 + 1.76 = 49.92 dB
259        let sqnr_8bit = theoretical_sqnr(8);
260        assert_abs_diff_eq!(sqnr_8bit, 49.92, epsilon = 0.01);
261
262        // 4-bit: 6.02 * 4 + 1.76 = 25.84 dB
263        let sqnr_4bit = theoretical_sqnr(4);
264        assert_abs_diff_eq!(sqnr_4bit, 25.84, epsilon = 0.01);
265    }
266
267    #[test]
268    fn test_error_within_bounds() {
269        let values: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
270        let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
271        let stats = analyze_error(&values, &params, 0.1);
272
273        // Error should be within bounds with some tolerance
274        assert!(error_within_bounds(&stats, &params, 0.1));
275    }
276
277    #[test]
278    fn test_scale_sensitivity() {
279        let values: Vec<f32> = (0..100).map(|i| (i as f32 * 0.1).sin()).collect();
280        let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
281
282        let (orig_mse, pert_mse, sensitivity) = scale_sensitivity(&values, &params, 0.1);
283
284        assert!(orig_mse >= 0.0);
285        assert!(pert_mse >= 0.0);
286        assert!(sensitivity >= 0.0);
287    }
288
289    #[test]
290    fn test_compare_bit_widths() {
291        let values: Vec<f32> = (0..100).map(|i| (i as f32 * 0.1).sin()).collect();
292
293        let (mse_4bit, mse_8bit, improvement) = compare_bit_widths(&values);
294
295        // 8-bit should be better than 4-bit
296        assert!(mse_8bit <= mse_4bit);
297        assert!(improvement >= 1.0);
298    }
299
300    #[test]
301    fn test_outlier_impact() {
302        // Values with outliers
303        let mut values: Vec<f32> = (0..100).map(|i| (i as f32 * 0.01).sin()).collect();
304        values.push(100.0); // Add outlier
305        values.push(-100.0); // Add outlier
306
307        let (mse_orig, mse_clip, impact) = analyze_outlier_impact(&values, 99.0);
308
309        // Clipping should generally help when there are outliers
310        assert!(mse_orig >= 0.0);
311        assert!(mse_clip >= 0.0);
312        assert!(impact >= 0.0);
313    }
314
315    #[test]
316    fn test_empty_values() {
317        let values: Vec<f32> = vec![];
318        let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
319        let stats = analyze_error(&values, &params, 0.1);
320
321        assert_eq!(stats.num_samples, 0);
322    }
323
324    #[test]
325    fn test_zeros_error() {
326        let values = vec![0.0; 100];
327        let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
328        let stats = analyze_error(&values, &params, 0.001);
329
330        // Zeros should quantize perfectly
331        assert!(stats.mse < 1e-10);
332        assert!(stats.mae < 1e-10);
333    }
334
335    // Property tests
336
337    proptest! {
338        #![proptest_config(ProptestConfig::with_cases(200))]
339
340        #[test]
341        fn prop_mse_non_negative(values in proptest::collection::vec(-100.0f32..100.0, 10..100)) {
342            let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
343            let stats = analyze_error(&values, &params, 0.1);
344
345            prop_assert!(stats.mse >= 0.0, "MSE must be non-negative");
346            prop_assert!(stats.mae >= 0.0, "MAE must be non-negative");
347            prop_assert!(stats.max_error >= 0.0, "Max error must be non-negative");
348        }
349
350        #[test]
351        fn prop_8bit_better_than_4bit(values in proptest::collection::vec(-100.0f32..100.0, 10..100)) {
352            let (mse_4bit, mse_8bit, _) = compare_bit_widths(&values);
353
354            prop_assert!(
355                mse_8bit <= mse_4bit * 1.01, // Small tolerance for edge cases
356                "8-bit MSE ({}) should be <= 4-bit MSE ({})",
357                mse_8bit,
358                mse_4bit
359            );
360        }
361
362        #[test]
363        fn prop_error_bounded(values in proptest::collection::vec(-100.0f32..100.0, 10..100)) {
364            let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
365            let stats = analyze_error(&values, &params, 0.1);
366
367            // Error should be bounded by theoretical max (with tolerance)
368            let theoretical_max = theoretical_max_error(&params);
369            prop_assert!(
370                stats.max_error <= theoretical_max * 1.5,
371                "Max error ({}) should be <= theoretical max * 1.5 ({})",
372                stats.max_error,
373                theoretical_max * 1.5
374            );
375        }
376
377        #[test]
378        fn prop_sqnr_positive_for_nonzero_signal(
379            values in proptest::collection::vec(
380                prop_oneof![
381                    -100.0f32..-1.0,
382                    1.0f32..100.0,
383                ],
384                10..100
385            )
386        ) {
387            let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
388            let stats = analyze_error(&values, &params, 0.1);
389
390            // SQNR should be positive for non-zero signal
391            prop_assert!(stats.sqnr_db > 0.0, "SQNR must be positive for non-zero signal");
392        }
393
394        #[test]
395        fn prop_outlier_rate_bounded(
396            values in proptest::collection::vec(-100.0f32..100.0, 10..100),
397            threshold in 0.001f32..10.0
398        ) {
399            let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
400            let stats = analyze_error(&values, &params, threshold);
401
402            prop_assert!(
403                stats.outlier_rate >= 0.0 && stats.outlier_rate <= 1.0,
404                "Outlier rate must be in [0, 1], got {}",
405                stats.outlier_rate
406            );
407        }
408
409        #[test]
410        fn prop_per_channel_lower_error(
411            num_channels in 2usize..5,
412            features_per_channel in 5usize..20,
413            scale_multiplier in 2.0f32..20.0
414        ) {
415            // Create values where channels have very different scales
416            let values: Vec<f32> = (0..num_channels)
417                .flat_map(|ch| {
418                    let scale = (ch as f32 + 1.0) * scale_multiplier;
419                    (0..features_per_channel).map(move |i| {
420                        (i as f32 / features_per_channel as f32 - 0.5) * scale
421                    })
422                })
423                .collect();
424
425            let params_pt = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
426            let params_pc = calibrate_per_channel(&values, num_channels, 8, QuantMode::Symmetric);
427
428            let stats_pt = analyze_error(&values, &params_pt, 0.1);
429            let stats_pc = analyze_error(&values, &params_pc, 0.1);
430
431            prop_assert!(
432                stats_pc.mse <= stats_pt.mse * 1.01,
433                "Per-channel MSE ({}) should be <= per-tensor MSE ({})",
434                stats_pc.mse,
435                stats_pt.mse
436            );
437        }
438
439        #[test]
440        fn prop_scale_sensitivity_finite(
441            values in proptest::collection::vec(-100.0f32..100.0, 10..100),
442            perturbation in 0.01f32..0.5
443        ) {
444            let params = calibrate_per_tensor(&values, 8, QuantMode::Symmetric);
445            let (orig, pert, sens) = scale_sensitivity(&values, &params, perturbation);
446
447            prop_assert!(orig.is_finite(), "Original MSE must be finite");
448            prop_assert!(pert.is_finite(), "Perturbed MSE must be finite");
449            prop_assert!(sens.is_finite(), "Sensitivity must be finite");
450        }
451    }
452}