Skip to main content

oxicuda_quant/analysis/
policy.rs

1//! # Mixed-Precision Policy
2//!
3//! Assigns per-layer bit-widths to meet a target average bits-per-parameter
4//! budget, while respecting per-layer sensitivity.  More sensitive layers
5//! receive higher bit-widths.
6//!
7//! ## Greedy algorithm
8//!
9//! 1. Initialise every layer to the minimum bit-width.
10//! 2. While the current average bits < target:
11//!    a. Find the layer whose upgrade (to the next bit-width) yields the
12//!    largest marginal sensitivity reduction per extra bit spent.
13//!    b. Upgrade that layer.
14//! 3. Return the final assignment.
15
16use crate::analysis::sensitivity::LayerSensitivity;
17use crate::error::{QuantError, QuantResult};
18
19// ─── MixedPrecisionPolicy ────────────────────────────────────────────────────
20
21/// Per-layer bit-width assignment produced by the greedy sensitivity policy.
22#[derive(Debug, Clone)]
23pub struct MixedPrecisionPolicy {
24    /// Bit-width assigned to each layer (same order as the input sensitivity list).
25    pub layer_bits: Vec<u32>,
26    /// Layer names (mirrors the order of `layer_bits`).
27    pub layer_names: Vec<String>,
28    /// Target average bits-per-parameter (budget constraint).
29    pub target_avg_bits: f32,
30}
31
32impl MixedPrecisionPolicy {
33    /// Compute a mixed-precision policy from layer sensitivity profiles.
34    ///
35    /// # Parameters
36    ///
37    /// * `sensitivities`   — per-layer [`LayerSensitivity`] from `SensitivityAnalyzer`.
38    /// * `target_avg_bits` — target average bit-width (e.g., `4.0` for ~4-bit).
39    ///
40    /// # Errors
41    ///
42    /// * [`QuantError::EmptyInput`]               — `sensitivities` is empty.
43    /// * [`QuantError::InfeasibleCompressionTarget`] — target cannot be met even at max bits.
44    pub fn from_sensitivity(
45        sensitivities: &[LayerSensitivity],
46        target_avg_bits: f32,
47    ) -> QuantResult<Self> {
48        if sensitivities.is_empty() {
49            return Err(QuantError::EmptyInput(
50                "MixedPrecisionPolicy::from_sensitivity",
51            ));
52        }
53
54        // Verify target feasibility.
55        let max_bits = sensitivities
56            .iter()
57            .map(|s| s.bits_range.iter().copied().max().unwrap_or(0))
58            .max()
59            .unwrap_or(0) as f32;
60        let min_bits = sensitivities
61            .iter()
62            .map(|s| s.bits_range.iter().copied().min().unwrap_or(32))
63            .min()
64            .unwrap_or(32) as f32;
65
66        if target_avg_bits > max_bits {
67            return Err(QuantError::InfeasibleCompressionTarget {
68                target: target_avg_bits,
69            });
70        }
71
72        let n = sensitivities.len();
73        // Start with minimum bit-width for each layer.
74        let mut bits: Vec<u32> = sensitivities
75            .iter()
76            .map(|s| s.bits_range.iter().copied().min().unwrap_or(4))
77            .collect();
78
79        // Greedy upgrade loop.
80        loop {
81            let avg = bits.iter().sum::<u32>() as f32 / n as f32;
82            if avg >= target_avg_bits {
83                break;
84            }
85
86            // Find the layer that benefits most from upgrading one step.
87            let mut best_layer = None;
88            let mut best_gain = f32::NEG_INFINITY;
89
90            for i in 0..n {
91                let sens = &sensitivities[i];
92                let cur_bits = bits[i];
93                // Find next higher bit-width in this layer's range.
94                let next = sens
95                    .bits_range
96                    .iter()
97                    .copied()
98                    .filter(|&b| b > cur_bits)
99                    .min();
100                let Some(next_bits) = next else { continue };
101
102                // Sensitivity gain = reduction in MSE per extra bit used.
103                let mse_cur = sens.mse_at(cur_bits).unwrap_or(0.0);
104                let mse_next = sens.mse_at(next_bits).unwrap_or(0.0);
105                let delta_mse = mse_cur - mse_next; // positive = improvement
106                let delta_bits = (next_bits - cur_bits) as f32;
107                let gain = delta_mse / delta_bits.max(1.0);
108
109                if gain > best_gain {
110                    best_gain = gain;
111                    best_layer = Some((i, next_bits));
112                }
113            }
114
115            match best_layer {
116                Some((i, b)) => bits[i] = b,
117                None => break, // All layers at maximum bits.
118            }
119        }
120
121        // Check that minimum is actually achievable (edge case: target < min_bits).
122        let actual_avg = bits.iter().sum::<u32>() as f32 / n as f32;
123        if actual_avg < target_avg_bits - min_bits && target_avg_bits > min_bits {
124            // Unable to reach target even at maximum.
125            return Err(QuantError::InfeasibleCompressionTarget {
126                target: target_avg_bits,
127            });
128        }
129
130        let layer_names = sensitivities.iter().map(|s| s.name.clone()).collect();
131        Ok(Self {
132            layer_bits: bits,
133            layer_names,
134            target_avg_bits,
135        })
136    }
137
138    /// Effective average bits per parameter across all layers.
139    #[must_use]
140    pub fn effective_average_bits(&self) -> f32 {
141        if self.layer_bits.is_empty() {
142            return 0.0;
143        }
144        self.layer_bits.iter().sum::<u32>() as f32 / self.layer_bits.len() as f32
145    }
146
147    /// Return the bit-width assigned to a layer by name.
148    ///
149    /// Returns `None` if the name is not found.
150    #[must_use]
151    pub fn bits_for_layer(&self, name: &str) -> Option<u32> {
152        self.layer_names
153            .iter()
154            .position(|n| n == name)
155            .map(|i| self.layer_bits[i])
156    }
157
158    /// Number of layers in the policy.
159    #[must_use]
160    pub fn n_layers(&self) -> usize {
161        self.layer_bits.len()
162    }
163}
164
165// ─── Tests ───────────────────────────────────────────────────────────────────
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::analysis::sensitivity::LayerSensitivity;
171    use approx::assert_abs_diff_eq;
172
173    fn make_sensitivity(name: &str, bits: &[u32], mse: &[f32]) -> LayerSensitivity {
174        LayerSensitivity {
175            bits_range: bits.to_vec(),
176            mse_per_bits: mse.to_vec(),
177            name: name.to_string(),
178        }
179    }
180
181    #[test]
182    fn greedy_assigns_more_bits_to_sensitive_layer() {
183        // Layer 0 is very sensitive (high MSE at low bits), layer 1 is not.
184        let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
185        let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.01, 0.005, 0.001]);
186        let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 5.0).unwrap();
187        // l0 should get more bits than l1.
188        assert!(
189            policy.bits_for_layer("l0").unwrap() >= policy.bits_for_layer("l1").unwrap(),
190            "l0 (sensitive) should get >= bits than l1"
191        );
192    }
193
194    #[test]
195    fn target_average_bits_met() {
196        let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
197        let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.5, 0.05, 0.001]);
198        let target = 4.0_f32;
199        let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], target).unwrap();
200        let avg = policy.effective_average_bits();
201        assert!(
202            avg >= target,
203            "average bits {avg} should be >= target {target}"
204        );
205    }
206
207    #[test]
208    fn single_layer_policy() {
209        let s = make_sensitivity("only", &[2, 4, 8], &[0.3, 0.02, 0.001]);
210        let policy = MixedPrecisionPolicy::from_sensitivity(&[s], 4.0).unwrap();
211        assert_eq!(policy.n_layers(), 1);
212        assert_abs_diff_eq!(policy.effective_average_bits(), 4.0, epsilon = 1.0);
213    }
214
215    #[test]
216    fn infeasible_target_error() {
217        let s = make_sensitivity("l", &[2, 4], &[0.5, 0.01]);
218        // Target 16 bits but max is 4 → infeasible.
219        assert!(matches!(
220            MixedPrecisionPolicy::from_sensitivity(&[s], 16.0),
221            Err(QuantError::InfeasibleCompressionTarget { .. })
222        ));
223    }
224
225    #[test]
226    fn empty_sensitivities_error() {
227        assert!(matches!(
228            MixedPrecisionPolicy::from_sensitivity(&[], 4.0),
229            Err(QuantError::EmptyInput(_))
230        ));
231    }
232
233    #[test]
234    fn bits_for_layer_lookup() {
235        let s0 = make_sensitivity("attn", &[2, 4, 8], &[0.5, 0.05, 0.001]);
236        let s1 = make_sensitivity("ffn", &[2, 4, 8], &[0.1, 0.01, 0.001]);
237        let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 4.0).unwrap();
238        assert!(policy.bits_for_layer("attn").is_some());
239        assert!(policy.bits_for_layer("ffn").is_some());
240        assert!(policy.bits_for_layer("unknown").is_none());
241    }
242
243    #[test]
244    fn all_layers_get_minimum_at_low_target() {
245        // target = 2.0 = minimum → all layers should stay at 2 bits.
246        let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
247        let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.4, 0.04, 0.001]);
248        let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 2.0).unwrap();
249        for &b in &policy.layer_bits {
250            assert!(b >= 2, "all layers should be at minimum bits");
251        }
252    }
253}