use crate::analysis::sensitivity::LayerSensitivity;
use crate::error::{QuantError, QuantResult};
#[derive(Debug, Clone)]
pub struct MixedPrecisionPolicy {
pub layer_bits: Vec<u32>,
pub layer_names: Vec<String>,
pub target_avg_bits: f32,
}
impl MixedPrecisionPolicy {
pub fn from_sensitivity(
sensitivities: &[LayerSensitivity],
target_avg_bits: f32,
) -> QuantResult<Self> {
if sensitivities.is_empty() {
return Err(QuantError::EmptyInput(
"MixedPrecisionPolicy::from_sensitivity",
));
}
let max_bits = sensitivities
.iter()
.map(|s| s.bits_range.iter().copied().max().unwrap_or(0))
.max()
.unwrap_or(0) as f32;
let min_bits = sensitivities
.iter()
.map(|s| s.bits_range.iter().copied().min().unwrap_or(32))
.min()
.unwrap_or(32) as f32;
if target_avg_bits > max_bits {
return Err(QuantError::InfeasibleCompressionTarget {
target: target_avg_bits,
});
}
let n = sensitivities.len();
let mut bits: Vec<u32> = sensitivities
.iter()
.map(|s| s.bits_range.iter().copied().min().unwrap_or(4))
.collect();
loop {
let avg = bits.iter().sum::<u32>() as f32 / n as f32;
if avg >= target_avg_bits {
break;
}
let mut best_layer = None;
let mut best_gain = f32::NEG_INFINITY;
for i in 0..n {
let sens = &sensitivities[i];
let cur_bits = bits[i];
let next = sens
.bits_range
.iter()
.copied()
.filter(|&b| b > cur_bits)
.min();
let Some(next_bits) = next else { continue };
let mse_cur = sens.mse_at(cur_bits).unwrap_or(0.0);
let mse_next = sens.mse_at(next_bits).unwrap_or(0.0);
let delta_mse = mse_cur - mse_next; let delta_bits = (next_bits - cur_bits) as f32;
let gain = delta_mse / delta_bits.max(1.0);
if gain > best_gain {
best_gain = gain;
best_layer = Some((i, next_bits));
}
}
match best_layer {
Some((i, b)) => bits[i] = b,
None => break, }
}
let actual_avg = bits.iter().sum::<u32>() as f32 / n as f32;
if actual_avg < target_avg_bits - min_bits && target_avg_bits > min_bits {
return Err(QuantError::InfeasibleCompressionTarget {
target: target_avg_bits,
});
}
let layer_names = sensitivities.iter().map(|s| s.name.clone()).collect();
Ok(Self {
layer_bits: bits,
layer_names,
target_avg_bits,
})
}
#[must_use]
pub fn effective_average_bits(&self) -> f32 {
if self.layer_bits.is_empty() {
return 0.0;
}
self.layer_bits.iter().sum::<u32>() as f32 / self.layer_bits.len() as f32
}
#[must_use]
pub fn bits_for_layer(&self, name: &str) -> Option<u32> {
self.layer_names
.iter()
.position(|n| n == name)
.map(|i| self.layer_bits[i])
}
#[must_use]
pub fn n_layers(&self) -> usize {
self.layer_bits.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::sensitivity::LayerSensitivity;
use approx::assert_abs_diff_eq;
fn make_sensitivity(name: &str, bits: &[u32], mse: &[f32]) -> LayerSensitivity {
LayerSensitivity {
bits_range: bits.to_vec(),
mse_per_bits: mse.to_vec(),
name: name.to_string(),
}
}
#[test]
fn greedy_assigns_more_bits_to_sensitive_layer() {
let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.01, 0.005, 0.001]);
let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 5.0).unwrap();
assert!(
policy.bits_for_layer("l0").unwrap() >= policy.bits_for_layer("l1").unwrap(),
"l0 (sensitive) should get >= bits than l1"
);
}
#[test]
fn target_average_bits_met() {
let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.5, 0.05, 0.001]);
let target = 4.0_f32;
let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], target).unwrap();
let avg = policy.effective_average_bits();
assert!(
avg >= target,
"average bits {avg} should be >= target {target}"
);
}
#[test]
fn single_layer_policy() {
let s = make_sensitivity("only", &[2, 4, 8], &[0.3, 0.02, 0.001]);
let policy = MixedPrecisionPolicy::from_sensitivity(&[s], 4.0).unwrap();
assert_eq!(policy.n_layers(), 1);
assert_abs_diff_eq!(policy.effective_average_bits(), 4.0, epsilon = 1.0);
}
#[test]
fn infeasible_target_error() {
let s = make_sensitivity("l", &[2, 4], &[0.5, 0.01]);
assert!(matches!(
MixedPrecisionPolicy::from_sensitivity(&[s], 16.0),
Err(QuantError::InfeasibleCompressionTarget { .. })
));
}
#[test]
fn empty_sensitivities_error() {
assert!(matches!(
MixedPrecisionPolicy::from_sensitivity(&[], 4.0),
Err(QuantError::EmptyInput(_))
));
}
#[test]
fn bits_for_layer_lookup() {
let s0 = make_sensitivity("attn", &[2, 4, 8], &[0.5, 0.05, 0.001]);
let s1 = make_sensitivity("ffn", &[2, 4, 8], &[0.1, 0.01, 0.001]);
let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 4.0).unwrap();
assert!(policy.bits_for_layer("attn").is_some());
assert!(policy.bits_for_layer("ffn").is_some());
assert!(policy.bits_for_layer("unknown").is_none());
}
#[test]
fn all_layers_get_minimum_at_low_target() {
let s0 = make_sensitivity("l0", &[2, 4, 8], &[0.5, 0.05, 0.001]);
let s1 = make_sensitivity("l1", &[2, 4, 8], &[0.4, 0.04, 0.001]);
let policy = MixedPrecisionPolicy::from_sensitivity(&[s0, s1], 2.0).unwrap();
for &b in &policy.layer_bits {
assert!(b >= 2, "all layers should be at minimum bits");
}
}
}