Skip to main content

oxicuda_quant/analysis/
sensitivity.rs

1//! # Quantization Sensitivity Analysis
2//!
3//! Measures how sensitive each layer is to quantization at different bit-widths.
4//! More sensitive layers should be assigned higher bit-widths in a mixed-precision
5//! quantization scheme.
6//!
7//! ## Sensitivity metric
8//!
9//! For each layer and each candidate bit-width, we quantize the weights with a
10//! MinMax symmetric scheme and compute the mean squared error between the
11//! original and dequantized weights:
12//!
13//! ```text
14//! sensitivity(layer, bits) = MSE(W, dequant(quant(W, bits)))
15//! ```
16
17use crate::error::{QuantError, QuantResult};
18use crate::scheme::minmax::{MinMaxQuantizer, QuantGranularity, QuantScheme};
19
20// ─── LayerSensitivity ────────────────────────────────────────────────────────
21
22/// Sensitivity scores for one layer across multiple bit-widths.
23#[derive(Debug, Clone)]
24pub struct LayerSensitivity {
25    /// Candidate bit-widths tested (sorted ascending).
26    pub bits_range: Vec<u32>,
27    /// Quantization MSE at each bit-width (same order as `bits_range`).
28    pub mse_per_bits: Vec<f32>,
29    /// Layer name or identifier (optional).
30    pub name: String,
31}
32
33impl LayerSensitivity {
34    /// Return the sensitivity (MSE) for a specific bit-width.
35    ///
36    /// Returns `None` if the bit-width was not tested.
37    #[must_use]
38    pub fn mse_at(&self, bits: u32) -> Option<f32> {
39        self.bits_range
40            .iter()
41            .position(|&b| b == bits)
42            .map(|i| self.mse_per_bits[i])
43    }
44
45    /// Mean sensitivity across all tested bit-widths.
46    #[must_use]
47    pub fn mean_sensitivity(&self) -> f32 {
48        if self.mse_per_bits.is_empty() {
49            return 0.0;
50        }
51        self.mse_per_bits.iter().sum::<f32>() / self.mse_per_bits.len() as f32
52    }
53
54    /// Returns `true` if higher bit-widths give lower MSE (monotone sensitivity).
55    #[must_use]
56    pub fn is_monotone(&self) -> bool {
57        self.mse_per_bits.windows(2).all(|w| w[0] >= w[1])
58    }
59}
60
61// ─── SensitivityAnalyzer ─────────────────────────────────────────────────────
62
63/// Analyses per-layer quantization sensitivity.
64#[derive(Debug, Clone, Default)]
65pub struct SensitivityAnalyzer;
66
67impl SensitivityAnalyzer {
68    /// Create a new sensitivity analyser.
69    #[must_use]
70    pub fn new() -> Self {
71        Self
72    }
73
74    /// Compute quantization sensitivity for one layer across `bits_range`.
75    ///
76    /// # Parameters
77    ///
78    /// * `weights`    — flat weight tensor (any layout).
79    /// * `bits_range` — candidate bit-widths (e.g., `&[2, 3, 4, 8]`).
80    /// * `name`       — optional layer label.
81    ///
82    /// # Errors
83    ///
84    /// * [`QuantError::EmptyInput`] — `weights` is empty.
85    /// * [`QuantError::InvalidBitWidth`] — any bit-width in `bits_range` is 0 or > 16.
86    pub fn analyze_layer(
87        &self,
88        weights: &[f32],
89        bits_range: &[u32],
90        name: impl Into<String>,
91    ) -> QuantResult<LayerSensitivity> {
92        if weights.is_empty() {
93            return Err(QuantError::EmptyInput("SensitivityAnalyzer::analyze_layer"));
94        }
95        for &b in bits_range {
96            if b == 0 || b > 16 {
97                return Err(QuantError::InvalidBitWidth { bits: b });
98            }
99        }
100
101        let mut mse_per_bits = Vec::with_capacity(bits_range.len());
102        for &bits in bits_range {
103            let q = MinMaxQuantizer::new(bits, QuantScheme::Symmetric, QuantGranularity::PerTensor);
104            let p = q.calibrate(weights)?;
105            let qw = q.quantize(weights, &p)?;
106            let dqw = q.dequantize(&qw, &p);
107            let mse = weights
108                .iter()
109                .zip(dqw.iter())
110                .map(|(a, b)| (a - b).powi(2))
111                .sum::<f32>()
112                / weights.len() as f32;
113            mse_per_bits.push(mse);
114        }
115
116        Ok(LayerSensitivity {
117            bits_range: bits_range.to_vec(),
118            mse_per_bits,
119            name: name.into(),
120        })
121    }
122
123    /// Analyse multiple layers and return their sensitivity profiles.
124    ///
125    /// # Parameters
126    ///
127    /// * `layers`     — list of `(name, weights)` pairs.
128    /// * `bits_range` — candidate bit-widths to test for each layer.
129    ///
130    /// # Errors
131    ///
132    /// Propagates errors from [`analyze_layer`](Self::analyze_layer).
133    pub fn analyze_multiple<'a>(
134        &self,
135        layers: &[(&'a str, &'a [f32])],
136        bits_range: &[u32],
137    ) -> QuantResult<Vec<LayerSensitivity>> {
138        layers
139            .iter()
140            .map(|(name, weights)| self.analyze_layer(weights, bits_range, *name))
141            .collect()
142    }
143}
144
145// ─── Tests ───────────────────────────────────────────────────────────────────
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use approx::assert_abs_diff_eq;
151
152    fn make_weights(n: usize) -> Vec<f32> {
153        (0..n).map(|i| (i as f32 / n as f32) * 2.0 - 1.0).collect()
154    }
155
156    #[test]
157    fn higher_bits_lower_mse() {
158        let a = SensitivityAnalyzer::new();
159        let w = make_weights(64);
160        let sens = a.analyze_layer(&w, &[2, 4, 8], "test_layer").unwrap();
161        assert!(
162            sens.mse_per_bits[0] >= sens.mse_per_bits[2],
163            "MSE at 2 bits ({}) should be >= MSE at 8 bits ({})",
164            sens.mse_per_bits[0],
165            sens.mse_per_bits[2]
166        );
167    }
168
169    #[test]
170    fn int8_very_low_mse() {
171        let a = SensitivityAnalyzer::new();
172        let w = make_weights(128);
173        let sens = a.analyze_layer(&w, &[8], "layer0").unwrap();
174        assert!(sens.mse_at(8).unwrap() < 1e-4, "INT8 MSE should be tiny");
175    }
176
177    #[test]
178    fn mse_at_missing_bits_returns_none() {
179        let a = SensitivityAnalyzer::new();
180        let w = make_weights(16);
181        let sens = a.analyze_layer(&w, &[4, 8], "l").unwrap();
182        assert!(sens.mse_at(2).is_none());
183        assert!(sens.mse_at(4).is_some());
184    }
185
186    #[test]
187    fn monotone_sensitivity() {
188        let a = SensitivityAnalyzer::new();
189        let w = make_weights(64);
190        let sens = a.analyze_layer(&w, &[2, 4, 8], "l").unwrap();
191        assert!(
192            sens.is_monotone(),
193            "MSE should decrease with increasing bits"
194        );
195    }
196
197    #[test]
198    fn analyze_multiple_layers() {
199        let a = SensitivityAnalyzer::new();
200        let w0 = make_weights(32);
201        let w1 = make_weights(64);
202        let result = a
203            .analyze_multiple(&[("layer0", &w0), ("layer1", &w1)], &[4, 8])
204            .unwrap();
205        assert_eq!(result.len(), 2);
206        assert_eq!(result[0].name, "layer0");
207        assert_eq!(result[1].name, "layer1");
208    }
209
210    #[test]
211    fn empty_input_error() {
212        let a = SensitivityAnalyzer::new();
213        assert!(matches!(
214            a.analyze_layer(&[], &[4], "l"),
215            Err(QuantError::EmptyInput(_))
216        ));
217    }
218
219    #[test]
220    fn invalid_bit_width_error() {
221        let a = SensitivityAnalyzer::new();
222        let w = make_weights(16);
223        assert!(matches!(
224            a.analyze_layer(&w, &[0], "l"),
225            Err(QuantError::InvalidBitWidth { bits: 0 })
226        ));
227    }
228
229    #[test]
230    fn mean_sensitivity_nonzero() {
231        let a = SensitivityAnalyzer::new();
232        let w = make_weights(32);
233        let sens = a.analyze_layer(&w, &[2, 4], "l").unwrap();
234        assert!(sens.mean_sensitivity() > 0.0);
235        assert_abs_diff_eq!(
236            sens.mean_sensitivity(),
237            (sens.mse_per_bits[0] + sens.mse_per_bits[1]) / 2.0,
238            epsilon = 1e-6
239        );
240    }
241}