oxicuda_quant/analysis/
sensitivity.rs1use crate::error::{QuantError, QuantResult};
18use crate::scheme::minmax::{MinMaxQuantizer, QuantGranularity, QuantScheme};
19
20#[derive(Debug, Clone)]
24pub struct LayerSensitivity {
25 pub bits_range: Vec<u32>,
27 pub mse_per_bits: Vec<f32>,
29 pub name: String,
31}
32
33impl LayerSensitivity {
34 #[must_use]
38 pub fn mse_at(&self, bits: u32) -> Option<f32> {
39 self.bits_range
40 .iter()
41 .position(|&b| b == bits)
42 .map(|i| self.mse_per_bits[i])
43 }
44
45 #[must_use]
47 pub fn mean_sensitivity(&self) -> f32 {
48 if self.mse_per_bits.is_empty() {
49 return 0.0;
50 }
51 self.mse_per_bits.iter().sum::<f32>() / self.mse_per_bits.len() as f32
52 }
53
54 #[must_use]
56 pub fn is_monotone(&self) -> bool {
57 self.mse_per_bits.windows(2).all(|w| w[0] >= w[1])
58 }
59}
60
61#[derive(Debug, Clone, Default)]
65pub struct SensitivityAnalyzer;
66
67impl SensitivityAnalyzer {
68 #[must_use]
70 pub fn new() -> Self {
71 Self
72 }
73
74 pub fn analyze_layer(
87 &self,
88 weights: &[f32],
89 bits_range: &[u32],
90 name: impl Into<String>,
91 ) -> QuantResult<LayerSensitivity> {
92 if weights.is_empty() {
93 return Err(QuantError::EmptyInput("SensitivityAnalyzer::analyze_layer"));
94 }
95 for &b in bits_range {
96 if b == 0 || b > 16 {
97 return Err(QuantError::InvalidBitWidth { bits: b });
98 }
99 }
100
101 let mut mse_per_bits = Vec::with_capacity(bits_range.len());
102 for &bits in bits_range {
103 let q = MinMaxQuantizer::new(bits, QuantScheme::Symmetric, QuantGranularity::PerTensor);
104 let p = q.calibrate(weights)?;
105 let qw = q.quantize(weights, &p)?;
106 let dqw = q.dequantize(&qw, &p);
107 let mse = weights
108 .iter()
109 .zip(dqw.iter())
110 .map(|(a, b)| (a - b).powi(2))
111 .sum::<f32>()
112 / weights.len() as f32;
113 mse_per_bits.push(mse);
114 }
115
116 Ok(LayerSensitivity {
117 bits_range: bits_range.to_vec(),
118 mse_per_bits,
119 name: name.into(),
120 })
121 }
122
123 pub fn analyze_multiple<'a>(
134 &self,
135 layers: &[(&'a str, &'a [f32])],
136 bits_range: &[u32],
137 ) -> QuantResult<Vec<LayerSensitivity>> {
138 layers
139 .iter()
140 .map(|(name, weights)| self.analyze_layer(weights, bits_range, *name))
141 .collect()
142 }
143}
144
145#[cfg(test)]
148mod tests {
149 use super::*;
150 use approx::assert_abs_diff_eq;
151
152 fn make_weights(n: usize) -> Vec<f32> {
153 (0..n).map(|i| (i as f32 / n as f32) * 2.0 - 1.0).collect()
154 }
155
156 #[test]
157 fn higher_bits_lower_mse() {
158 let a = SensitivityAnalyzer::new();
159 let w = make_weights(64);
160 let sens = a.analyze_layer(&w, &[2, 4, 8], "test_layer").unwrap();
161 assert!(
162 sens.mse_per_bits[0] >= sens.mse_per_bits[2],
163 "MSE at 2 bits ({}) should be >= MSE at 8 bits ({})",
164 sens.mse_per_bits[0],
165 sens.mse_per_bits[2]
166 );
167 }
168
169 #[test]
170 fn int8_very_low_mse() {
171 let a = SensitivityAnalyzer::new();
172 let w = make_weights(128);
173 let sens = a.analyze_layer(&w, &[8], "layer0").unwrap();
174 assert!(sens.mse_at(8).unwrap() < 1e-4, "INT8 MSE should be tiny");
175 }
176
177 #[test]
178 fn mse_at_missing_bits_returns_none() {
179 let a = SensitivityAnalyzer::new();
180 let w = make_weights(16);
181 let sens = a.analyze_layer(&w, &[4, 8], "l").unwrap();
182 assert!(sens.mse_at(2).is_none());
183 assert!(sens.mse_at(4).is_some());
184 }
185
186 #[test]
187 fn monotone_sensitivity() {
188 let a = SensitivityAnalyzer::new();
189 let w = make_weights(64);
190 let sens = a.analyze_layer(&w, &[2, 4, 8], "l").unwrap();
191 assert!(
192 sens.is_monotone(),
193 "MSE should decrease with increasing bits"
194 );
195 }
196
197 #[test]
198 fn analyze_multiple_layers() {
199 let a = SensitivityAnalyzer::new();
200 let w0 = make_weights(32);
201 let w1 = make_weights(64);
202 let result = a
203 .analyze_multiple(&[("layer0", &w0), ("layer1", &w1)], &[4, 8])
204 .unwrap();
205 assert_eq!(result.len(), 2);
206 assert_eq!(result[0].name, "layer0");
207 assert_eq!(result[1].name, "layer1");
208 }
209
210 #[test]
211 fn empty_input_error() {
212 let a = SensitivityAnalyzer::new();
213 assert!(matches!(
214 a.analyze_layer(&[], &[4], "l"),
215 Err(QuantError::EmptyInput(_))
216 ));
217 }
218
219 #[test]
220 fn invalid_bit_width_error() {
221 let a = SensitivityAnalyzer::new();
222 let w = make_weights(16);
223 assert!(matches!(
224 a.analyze_layer(&w, &[0], "l"),
225 Err(QuantError::InvalidBitWidth { bits: 0 })
226 ));
227 }
228
229 #[test]
230 fn mean_sensitivity_nonzero() {
231 let a = SensitivityAnalyzer::new();
232 let w = make_weights(32);
233 let sens = a.analyze_layer(&w, &[2, 4], "l").unwrap();
234 assert!(sens.mean_sensitivity() > 0.0);
235 assert_abs_diff_eq!(
236 sens.mean_sensitivity(),
237 (sens.mse_per_bits[0] + sens.mse_per_bits[1]) / 2.0,
238 epsilon = 1e-6
239 );
240 }
241}