#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum BitWidth {
Int2 = 2,
Int4 = 4,
Int8 = 8,
Fp16 = 16,
Fp32 = 32,
}
impl BitWidth {
pub fn bytes_per_weight(&self) -> f64 {
match self {
Self::Int2 => 0.25, Self::Int4 => 0.5, Self::Int8 => 1.0,
Self::Fp16 => 2.0,
Self::Fp32 => 4.0,
}
}
pub fn compression_ratio_vs_fp32(&self) -> f64 {
4.0 / self.bytes_per_weight()
}
}
impl std::fmt::Display for BitWidth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Int2 => write!(f, "INT2"),
Self::Int4 => write!(f, "INT4"),
Self::Int8 => write!(f, "INT8"),
Self::Fp16 => write!(f, "FP16"),
Self::Fp32 => write!(f, "FP32"),
}
}
}
#[derive(Debug, Clone)]
pub struct LayerSensitivity {
pub layer_name: String,
pub gradient_norm: f64,
pub weight_variance: f64,
pub activation_range: f64,
pub output_sensitivity: f64,
pub is_embedding: bool,
pub is_final_layer: bool,
}
impl LayerSensitivity {
pub fn sensitivity_score(&self) -> f64 {
let base = self.gradient_norm * 0.4
+ self.weight_variance * 0.3
+ self.activation_range * 0.2
+ self.output_sensitivity * 0.1;
let embed_bonus = if self.is_embedding { 1.0 } else { 0.0 };
let final_bonus = if self.is_final_layer { 0.5 } else { 0.0 };
base + embed_bonus + final_bonus
}
}
#[derive(Debug, Clone)]
pub enum BitWidthStrategy {
Uniform(BitWidth),
SensitivityBased {
high_threshold: f64,
medium_threshold: f64,
low_threshold: f64,
},
BudgetOptimal,
}
#[derive(Debug, Clone)]
pub struct QuantizationPolicy {
pub strategy: BitWidthStrategy,
pub budget_bytes: Option<usize>,
pub min_bit_width: BitWidth,
pub max_bit_width: BitWidth,
}
impl Default for QuantizationPolicy {
fn default() -> Self {
Self {
strategy: BitWidthStrategy::Uniform(BitWidth::Int8),
budget_bytes: None,
min_bit_width: BitWidth::Int4,
max_bit_width: BitWidth::Fp32,
}
}
}
#[derive(Debug, Clone)]
pub struct LayerBitWidthAssignment {
pub layer_name: String,
pub bit_width: BitWidth,
pub param_count: usize,
pub memory_bytes: usize,
pub sensitivity_score: f64,
}
pub struct QuantizationSummary {
pub total_params: usize,
pub total_memory_bytes: usize,
pub compression_ratio: f64,
pub bit_width_distribution: Vec<(BitWidth, usize)>,
pub avg_sensitivity_score: f64,
}
impl std::fmt::Display for QuantizationSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "QuantizationSummary {{")?;
writeln!(f, " total_params: {}", self.total_params)?;
writeln!(f, " total_memory_bytes: {}", self.total_memory_bytes)?;
writeln!(f, " compression_ratio: {:.2}x", self.compression_ratio)?;
writeln!(
f,
" avg_sensitivity_score: {:.4}",
self.avg_sensitivity_score
)?;
write!(f, " bit_width_distribution: [")?;
for (i, (bw, count)) in self.bit_width_distribution.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}×{}", bw, count)?;
}
writeln!(f, "]")?;
write!(f, "}}")
}
}
#[derive(Debug, thiserror::Error)]
pub enum QuantSelectionError {
#[error("Budget exceeded: required {required} bytes, budget {budget} bytes")]
BudgetExceeded { required: usize, budget: usize },
#[error("Empty layer list")]
EmptyLayers,
#[error("Layer count mismatch")]
LengthMismatch,
}
pub struct PerLayerQuantSelector {
policy: QuantizationPolicy,
}
impl PerLayerQuantSelector {
pub fn new(policy: QuantizationPolicy) -> Self {
Self { policy }
}
pub fn assign_bit_widths(
&self,
layers: &[LayerSensitivity],
layer_param_counts: &[usize],
) -> Result<Vec<LayerBitWidthAssignment>, QuantSelectionError> {
if layers.is_empty() {
return Err(QuantSelectionError::EmptyLayers);
}
if layers.len() != layer_param_counts.len() {
return Err(QuantSelectionError::LengthMismatch);
}
let assignments = match &self.policy.strategy {
BitWidthStrategy::Uniform(bw) => self.assign_uniform(*bw, layers, layer_param_counts),
BitWidthStrategy::SensitivityBased {
high_threshold,
medium_threshold,
low_threshold,
} => self.assign_sensitivity_based(
*high_threshold,
*medium_threshold,
*low_threshold,
layers,
layer_param_counts,
),
BitWidthStrategy::BudgetOptimal => {
self.assign_budget_optimal(layers, layer_param_counts)?
},
};
if let Some(budget) = self.policy.budget_bytes {
let required = Self::total_memory_bytes(&assignments);
if required > budget {
return Err(QuantSelectionError::BudgetExceeded { required, budget });
}
}
Ok(assignments)
}
fn clamp_bit_width(&self, bw: BitWidth) -> BitWidth {
if bw < self.policy.min_bit_width {
self.policy.min_bit_width
} else if bw > self.policy.max_bit_width {
self.policy.max_bit_width
} else {
bw
}
}
fn make_assignment(
&self,
layer: &LayerSensitivity,
param_count: usize,
bit_width: BitWidth,
) -> LayerBitWidthAssignment {
let bw = self.clamp_bit_width(bit_width);
let memory_bytes = (param_count as f64 * bw.bytes_per_weight()).ceil() as usize;
LayerBitWidthAssignment {
layer_name: layer.layer_name.clone(),
bit_width: bw,
param_count,
memory_bytes,
sensitivity_score: layer.sensitivity_score(),
}
}
fn assign_uniform(
&self,
bw: BitWidth,
layers: &[LayerSensitivity],
counts: &[usize],
) -> Vec<LayerBitWidthAssignment> {
layers
.iter()
.zip(counts.iter())
.map(|(l, &c)| self.make_assignment(l, c, bw))
.collect()
}
fn assign_sensitivity_based(
&self,
high_threshold: f64,
medium_threshold: f64,
low_threshold: f64,
layers: &[LayerSensitivity],
counts: &[usize],
) -> Vec<LayerBitWidthAssignment> {
layers
.iter()
.zip(counts.iter())
.map(|(l, &c)| {
let score = l.sensitivity_score();
let bw = if score > high_threshold {
BitWidth::Fp16
} else if score > medium_threshold {
BitWidth::Int8
} else if score < low_threshold {
BitWidth::Int2
} else {
BitWidth::Int4
};
self.make_assignment(l, c, bw)
})
.collect()
}
fn assign_budget_optimal(
&self,
layers: &[LayerSensitivity],
counts: &[usize],
) -> Result<Vec<LayerBitWidthAssignment>, QuantSelectionError> {
let mut bit_widths: Vec<BitWidth> = vec![self.policy.min_bit_width; layers.len()];
let budget = match self.policy.budget_bytes {
Some(b) => b,
None => usize::MAX, };
let mut sorted_indices: Vec<usize> = (0..layers.len()).collect();
sorted_indices.sort_by(|&a, &b| {
layers[b]
.sensitivity_score()
.partial_cmp(&layers[a].sensitivity_score())
.unwrap_or(std::cmp::Ordering::Equal)
});
let precision_ladder = [
BitWidth::Int2,
BitWidth::Int4,
BitWidth::Int8,
BitWidth::Fp16,
BitWidth::Fp32,
];
loop {
let maybe_upgrade = sorted_indices.iter().find_map(|&idx| {
let current = bit_widths[idx];
let next_bw = precision_ladder
.iter()
.find(|&&bw| bw > current && bw <= self.policy.max_bit_width)
.copied()?;
let old_mem = (counts[idx] as f64 * current.bytes_per_weight()).ceil() as usize;
let new_mem = (counts[idx] as f64 * next_bw.bytes_per_weight()).ceil() as usize;
let current_total: usize = bit_widths
.iter()
.zip(counts.iter())
.map(|(&bw, &c)| (c as f64 * bw.bytes_per_weight()).ceil() as usize)
.sum();
let proposed_total = current_total - old_mem + new_mem;
if proposed_total <= budget {
Some((idx, next_bw))
} else {
None
}
});
match maybe_upgrade {
Some((idx, next_bw)) => bit_widths[idx] = next_bw,
None => break,
}
}
Ok(layers
.iter()
.zip(counts.iter())
.enumerate()
.map(|(i, (l, &c))| self.make_assignment(l, c, bit_widths[i]))
.collect())
}
pub fn total_memory_bytes(assignments: &[LayerBitWidthAssignment]) -> usize {
assignments.iter().map(|a| a.memory_bytes).sum()
}
pub fn summary_report(assignments: &[LayerBitWidthAssignment]) -> QuantizationSummary {
let total_params: usize = assignments.iter().map(|a| a.param_count).sum();
let total_memory_bytes: usize = assignments.iter().map(|a| a.memory_bytes).sum();
let fp32_bytes = total_params * 4; let compression_ratio = if total_memory_bytes == 0 {
1.0
} else {
fp32_bytes as f64 / total_memory_bytes as f64
};
let avg_sensitivity_score = if assignments.is_empty() {
0.0
} else {
assignments.iter().map(|a| a.sensitivity_score).sum::<f64>() / assignments.len() as f64
};
let mut dist_map: std::collections::HashMap<u8, usize> = std::collections::HashMap::new();
for a in assignments.iter() {
*dist_map.entry(a.bit_width as u8).or_insert(0) += 1;
}
let mut bit_width_distribution: Vec<(BitWidth, usize)> = dist_map
.into_iter()
.filter_map(|(bits, count)| {
let bw = match bits {
2 => Some(BitWidth::Int2),
4 => Some(BitWidth::Int4),
8 => Some(BitWidth::Int8),
16 => Some(BitWidth::Fp16),
32 => Some(BitWidth::Fp32),
_ => None,
};
bw.map(|b| (b, count))
})
.collect();
bit_width_distribution.sort_by_key(|(bw, _)| *bw as u8);
QuantizationSummary {
total_params,
total_memory_bytes,
compression_ratio,
bit_width_distribution,
avg_sensitivity_score,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_layer(
name: &str,
gradient_norm: f64,
weight_variance: f64,
activation_range: f64,
output_sensitivity: f64,
is_embedding: bool,
is_final_layer: bool,
) -> LayerSensitivity {
LayerSensitivity {
layer_name: name.to_string(),
gradient_norm,
weight_variance,
activation_range,
output_sensitivity,
is_embedding,
is_final_layer,
}
}
fn simple_layers() -> (Vec<LayerSensitivity>, Vec<usize>) {
let layers = vec![
make_layer("embed", 2.0, 1.0, 1.0, 1.0, true, false),
make_layer("attn", 1.5, 0.8, 0.5, 0.5, false, false),
make_layer("ffn", 0.5, 0.3, 0.2, 0.1, false, false),
make_layer("head", 1.0, 0.6, 0.4, 0.4, false, true),
];
let counts = vec![10_000, 8_000, 16_000, 4_000];
(layers, counts)
}
#[test]
fn test_bytes_per_weight() {
assert!((BitWidth::Int2.bytes_per_weight() - 0.25).abs() < 1e-10);
assert!((BitWidth::Int4.bytes_per_weight() - 0.5).abs() < 1e-10);
assert!((BitWidth::Int8.bytes_per_weight() - 1.0).abs() < 1e-10);
assert!((BitWidth::Fp16.bytes_per_weight() - 2.0).abs() < 1e-10);
assert!((BitWidth::Fp32.bytes_per_weight() - 4.0).abs() < 1e-10);
}
#[test]
fn test_compression_ratio() {
assert!((BitWidth::Fp32.compression_ratio_vs_fp32() - 1.0).abs() < 1e-10);
assert!((BitWidth::Fp16.compression_ratio_vs_fp32() - 2.0).abs() < 1e-10);
assert!((BitWidth::Int8.compression_ratio_vs_fp32() - 4.0).abs() < 1e-10);
assert!((BitWidth::Int4.compression_ratio_vs_fp32() - 8.0).abs() < 1e-10);
assert!((BitWidth::Int2.compression_ratio_vs_fp32() - 16.0).abs() < 1e-10);
}
#[test]
fn test_sensitivity_score_basic() {
let layer = make_layer("l", 1.0, 1.0, 1.0, 1.0, false, false);
let score = layer.sensitivity_score();
assert!((score - 1.0).abs() < 1e-10, "expected 1.0 got {}", score);
}
#[test]
fn test_sensitivity_score_embedding_bonus() {
let no_embed = make_layer("l", 1.0, 1.0, 1.0, 1.0, false, false);
let embed = make_layer("e", 1.0, 1.0, 1.0, 1.0, true, false);
let diff = embed.sensitivity_score() - no_embed.sensitivity_score();
assert!((diff - 1.0).abs() < 1e-10, "embedding bonus should be +1.0");
}
#[test]
fn test_sensitivity_score_final_layer_bonus() {
let normal = make_layer("l", 1.0, 1.0, 1.0, 1.0, false, false);
let final_l = make_layer("f", 1.0, 1.0, 1.0, 1.0, false, true);
let diff = final_l.sensitivity_score() - normal.sensitivity_score();
assert!(
(diff - 0.5).abs() < 1e-10,
"final layer bonus should be +0.5"
);
}
#[test]
fn test_sensitivity_score_both_bonuses() {
let layer = make_layer("l", 1.0, 1.0, 1.0, 1.0, true, true);
assert!((layer.sensitivity_score() - 2.5).abs() < 1e-10);
}
#[test]
fn test_uniform_strategy() {
let policy = QuantizationPolicy {
strategy: BitWidthStrategy::Uniform(BitWidth::Int8),
budget_bytes: None,
min_bit_width: BitWidth::Int2,
max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let (layers, counts) = simple_layers();
let assignments = selector.assign_bit_widths(&layers, &counts).expect("assign");
assert!(assignments.iter().all(|a| a.bit_width == BitWidth::Int8));
assert_eq!(assignments.len(), 4);
}
#[test]
fn test_uniform_strategy_clamped_by_min() {
let policy = QuantizationPolicy {
strategy: BitWidthStrategy::Uniform(BitWidth::Int2),
budget_bytes: None,
min_bit_width: BitWidth::Int4, max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let layers = vec![make_layer("l", 0.1, 0.1, 0.1, 0.1, false, false)];
let counts = vec![100];
let assignments = selector.assign_bit_widths(&layers, &counts).expect("assign");
assert_eq!(assignments[0].bit_width, BitWidth::Int4);
}
#[test]
fn test_sensitivity_based_high_score() {
let policy = QuantizationPolicy {
strategy: BitWidthStrategy::SensitivityBased {
high_threshold: 2.0,
medium_threshold: 1.0,
low_threshold: 0.3,
},
budget_bytes: None,
min_bit_width: BitWidth::Int2,
max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let layer = make_layer("l", 3.0, 2.0, 2.0, 2.0, false, false);
let assignments = selector.assign_bit_widths(&[layer], &[1000]).expect("assign");
assert_eq!(assignments[0].bit_width, BitWidth::Fp16);
}
#[test]
fn test_sensitivity_based_medium_score() {
let policy = QuantizationPolicy {
strategy: BitWidthStrategy::SensitivityBased {
high_threshold: 2.0,
medium_threshold: 1.0,
low_threshold: 0.3,
},
budget_bytes: None,
min_bit_width: BitWidth::Int2,
max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let layer = make_layer("l", 1.5, 1.5, 1.0, 1.0, false, false);
let assignments = selector.assign_bit_widths(&[layer], &[500]).expect("assign");
assert_eq!(assignments[0].bit_width, BitWidth::Int8);
}
#[test]
fn test_sensitivity_based_low_score() {
let policy = QuantizationPolicy {
strategy: BitWidthStrategy::SensitivityBased {
high_threshold: 2.0,
medium_threshold: 1.0,
low_threshold: 0.3,
},
budget_bytes: None,
min_bit_width: BitWidth::Int2,
max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let layer = make_layer("l", 0.1, 0.1, 0.1, 0.1, false, false);
let assignments = selector.assign_bit_widths(&[layer], &[200]).expect("assign");
assert_eq!(assignments[0].bit_width, BitWidth::Int2);
}
#[test]
fn test_sensitivity_based_between_medium_and_low() {
let policy = QuantizationPolicy {
strategy: BitWidthStrategy::SensitivityBased {
high_threshold: 2.0,
medium_threshold: 1.0,
low_threshold: 0.3,
},
budget_bytes: None,
min_bit_width: BitWidth::Int2,
max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let layer = make_layer("l", 0.8, 0.8, 0.6, 0.6, false, false);
let assignments = selector.assign_bit_widths(&[layer], &[300]).expect("assign");
assert_eq!(assignments[0].bit_width, BitWidth::Int4);
}
#[test]
fn test_budget_optimal_fits() {
let large_budget = 1_000_000;
let policy = QuantizationPolicy {
strategy: BitWidthStrategy::BudgetOptimal,
budget_bytes: Some(large_budget),
min_bit_width: BitWidth::Int2,
max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let (layers, counts) = simple_layers();
let assignments = selector.assign_bit_widths(&layers, &counts).expect("assign");
assert_eq!(assignments.len(), 4);
let total = PerLayerQuantSelector::total_memory_bytes(&assignments);
assert!(total <= large_budget);
}
#[test]
fn test_empty_layers_error() {
let policy = QuantizationPolicy::default();
let selector = PerLayerQuantSelector::new(policy);
let err = selector.assign_bit_widths(&[], &[]).expect_err("should error on empty");
assert!(matches!(err, QuantSelectionError::EmptyLayers));
}
#[test]
fn test_length_mismatch_error() {
let policy = QuantizationPolicy::default();
let selector = PerLayerQuantSelector::new(policy);
let layers = vec![make_layer("l", 1.0, 1.0, 1.0, 1.0, false, false)];
let counts = vec![100, 200]; let err = selector
.assign_bit_widths(&layers, &counts)
.expect_err("should error on mismatch");
assert!(matches!(err, QuantSelectionError::LengthMismatch));
}
#[test]
fn test_budget_exceeded_error() {
let tiny_budget = 1; let policy = QuantizationPolicy {
strategy: BitWidthStrategy::Uniform(BitWidth::Int8),
budget_bytes: Some(tiny_budget),
min_bit_width: BitWidth::Int2,
max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let layers = vec![make_layer("l", 1.0, 1.0, 1.0, 1.0, false, false)];
let counts = vec![10_000];
let err = selector.assign_bit_widths(&layers, &counts).expect_err("should exceed budget");
assert!(matches!(err, QuantSelectionError::BudgetExceeded { .. }));
}
#[test]
fn test_total_memory_bytes() {
let assignments = vec![
LayerBitWidthAssignment {
layer_name: "a".to_string(),
bit_width: BitWidth::Int8,
param_count: 100,
memory_bytes: 100,
sensitivity_score: 1.0,
},
LayerBitWidthAssignment {
layer_name: "b".to_string(),
bit_width: BitWidth::Fp16,
param_count: 50,
memory_bytes: 100,
sensitivity_score: 2.0,
},
];
assert_eq!(PerLayerQuantSelector::total_memory_bytes(&assignments), 200);
}
#[test]
fn test_summary_report_display() {
let policy = QuantizationPolicy {
strategy: BitWidthStrategy::Uniform(BitWidth::Int8),
budget_bytes: None,
min_bit_width: BitWidth::Int2,
max_bit_width: BitWidth::Fp32,
};
let selector = PerLayerQuantSelector::new(policy);
let (layers, counts) = simple_layers();
let assignments = selector.assign_bit_widths(&layers, &counts).expect("assign");
let summary = PerLayerQuantSelector::summary_report(&assignments);
let s = format!("{}", summary);
assert!(s.contains("total_params"));
assert!(s.contains("compression_ratio"));
assert!(s.contains("INT8"));
}
#[test]
fn test_summary_report_compression_ratio() {
let assignments = vec![LayerBitWidthAssignment {
layer_name: "l".to_string(),
bit_width: BitWidth::Int8,
param_count: 1000,
memory_bytes: 1000, sensitivity_score: 0.5,
}];
let summary = PerLayerQuantSelector::summary_report(&assignments);
assert!((summary.compression_ratio - 4.0).abs() < 1e-6);
}
#[test]
fn test_quant_selection_error_display() {
let e1 = QuantSelectionError::EmptyLayers;
assert!(format!("{}", e1).contains("Empty"));
let e2 = QuantSelectionError::LengthMismatch;
assert!(format!("{}", e2).contains("mismatch"));
let e3 = QuantSelectionError::BudgetExceeded {
required: 500,
budget: 100,
};
let s = format!("{}", e3);
assert!(s.contains("500"));
assert!(s.contains("100"));
}
}