1use crate::analysis::sensitivity::LayerSensitivity;
17use crate::error::{QuantError, QuantResult};
18
19#[derive(Debug, Clone)]
23pub struct MixedPrecisionPolicy {
24 pub layer_bits: Vec<u32>,
26 pub layer_names: Vec<String>,
28 pub target_avg_bits: f32,
30}
31
32impl MixedPrecisionPolicy {
33 pub fn from_sensitivity(
45 sensitivities: &[LayerSensitivity],
46 target_avg_bits: f32,
47 ) -> QuantResult<Self> {
48 if sensitivities.is_empty() {
49 return Err(QuantError::EmptyInput(
50 "MixedPrecisionPolicy::from_sensitivity",
51 ));
52 }
53
54 let max_bits = sensitivities
56 .iter()
57 .map(|s| s.bits_range.iter().copied().max().unwrap_or(0))
58 .max()
59 .unwrap_or(0) as f32;
60 let min_bits = sensitivities
61 .iter()
62 .map(|s| s.bits_range.iter().copied().min().unwrap_or(32))
63 .min()
64 .unwrap_or(32) as f32;
65
66 if target_avg_bits > max_bits {
67 return Err(QuantError::InfeasibleCompressionTarget {
68 target: target_avg_bits,
69 });
70 }
71
72 let n = sensitivities.len();
73 let mut bits: Vec<u32> = sensitivities
75 .iter()
76 .map(|s| s.bits_range.iter().copied().min().unwrap_or(4))
77 .collect();
78
79 loop {
81 let avg = bits.iter().sum::<u32>() as f32 / n as f32;
82 if avg >= target_avg_bits {
83 break;
84 }
85
86 let mut best_layer = None;
88 let mut best_gain = f32::NEG_INFINITY;
89
90 for i in 0..n {
91 let sens = &sensitivities[i];
92 let cur_bits = bits[i];
93 let next = sens
95 .bits_range
96 .iter()
97 .copied()
98 .filter(|&b| b > cur_bits)
99 .min();
100 let Some(next_bits) = next else { continue };
101
102 let mse_cur = sens.mse_at(cur_bits).unwrap_or(0.0);
104 let mse_next = sens.mse_at(next_bits).unwrap_or(0.0);
105 let delta_mse = mse_cur - mse_next; let delta_bits = (next_bits - cur_bits) as f32;
107 let gain = delta_mse / delta_bits.max(1.0);
108
109 if gain > best_gain {
110 best_gain = gain;
111 best_layer = Some((i, next_bits));
112 }
113 }
114
115 match best_layer {
116 Some((i, b)) => bits[i] = b,
117 None => break, }
119 }
120
121 let actual_avg = bits.iter().sum::<u32>() as f32 / n as f32;
123 if actual_avg < target_avg_bits - min_bits && target_avg_bits > min_bits {
124 return Err(QuantError::InfeasibleCompressionTarget {
126 target: target_avg_bits,
127 });
128 }
129
130 let layer_names = sensitivities.iter().map(|s| s.name.clone()).collect();
131 Ok(Self {
132 layer_bits: bits,
133 layer_names,
134 target_avg_bits,
135 })
136 }
137
138 #[must_use]
140 pub fn effective_average_bits(&self) -> f32 {
141 if self.layer_bits.is_empty() {
142 return 0.0;
143 }
144 self.layer_bits.iter().sum::<u32>() as f32 / self.layer_bits.len() as f32
145 }
146
147 #[must_use]
151 pub fn bits_for_layer(&self, name: &str) -> Option<u32> {
152 self.layer_names
153 .iter()
154 .position(|n| n == name)
155 .map(|i| self.layer_bits[i])
156 }
157
158 #[must_use]
160 pub fn n_layers(&self) -> usize {
161 self.layer_bits.len()
162 }
163}
164
165#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::analysis::sensitivity::LayerSensitivity;
171 use approx::assert_abs_diff_eq;
172
173 fn make_sensitivity(name: &str, bits: &[u32], mse: &[f32]) -> LayerSensitivity {
174 LayerSensitivity {
175 bits_range: bits.to_vec(),
176 mse_per_bits: mse.to_vec(),
177 name: name.to_string(),
178 }
179 }
180
181 #[test]
182 fn greedy_assigns_more_bits_to_sensitive_layer() {
183 let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
185 let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.01, 0.005, 0.001]);
186 let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 5.0).unwrap();
187 assert!(
189 policy.bits_for_layer("l0").unwrap() >= policy.bits_for_layer("l1").unwrap(),
190 "l0 (sensitive) should get >= bits than l1"
191 );
192 }
193
194 #[test]
195 fn target_average_bits_met() {
196 let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
197 let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.5, 0.05, 0.001]);
198 let target = 4.0_f32;
199 let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], target).unwrap();
200 let avg = policy.effective_average_bits();
201 assert!(
202 avg >= target,
203 "average bits {avg} should be >= target {target}"
204 );
205 }
206
207 #[test]
208 fn single_layer_policy() {
209 let s = make_sensitivity("only", &[2, 4, 8], &[0.3, 0.02, 0.001]);
210 let policy = MixedPrecisionPolicy::from_sensitivity(&[s], 4.0).unwrap();
211 assert_eq!(policy.n_layers(), 1);
212 assert_abs_diff_eq!(policy.effective_average_bits(), 4.0, epsilon = 1.0);
213 }
214
215 #[test]
216 fn infeasible_target_error() {
217 let s = make_sensitivity("l", &[2, 4], &[0.5, 0.01]);
218 assert!(matches!(
220 MixedPrecisionPolicy::from_sensitivity(&[s], 16.0),
221 Err(QuantError::InfeasibleCompressionTarget { .. })
222 ));
223 }
224
225 #[test]
226 fn empty_sensitivities_error() {
227 assert!(matches!(
228 MixedPrecisionPolicy::from_sensitivity(&[], 4.0),
229 Err(QuantError::EmptyInput(_))
230 ));
231 }
232
233 #[test]
234 fn bits_for_layer_lookup() {
235 let s0 = make_sensitivity("attn", &[2, 4, 8], &[0.5, 0.05, 0.001]);
236 let s1 = make_sensitivity("ffn", &[2, 4, 8], &[0.1, 0.01, 0.001]);
237 let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 4.0).unwrap();
238 assert!(policy.bits_for_layer("attn").is_some());
239 assert!(policy.bits_for_layer("ffn").is_some());
240 assert!(policy.bits_for_layer("unknown").is_none());
241 }
242
243 #[test]
244 fn all_layers_get_minimum_at_low_target() {
245 let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
247 let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.4, 0.04, 0.001]);
248 let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 2.0).unwrap();
249 for &b in &policy.layer_bits {
250 assert!(b >= 2, "all layers should be at minimum bits");
251 }
252 }
253}