use crate::model_merge::WeightTensor;
use crate::pruning::{prune_tensor, PruningConfig, PruningError};
#[derive(Debug, Clone)]
pub enum CompressionStage {
Prune(PruningConfig),
QuantizeInt8,
Clip {
percentile: f32,
},
}
impl CompressionStage {
pub fn name(&self) -> &'static str {
match self {
CompressionStage::Prune(_) => "prune",
CompressionStage::QuantizeInt8 => "quantize_int8",
CompressionStage::Clip { .. } => "clip",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CompressionConfig {
pub stages: Vec<CompressionStage>,
pub skip_embedding_layers: bool,
}
impl CompressionConfig {
pub fn new() -> Self {
Self {
stages: Vec::new(),
skip_embedding_layers: false,
}
}
pub fn add_stage(mut self, stage: CompressionStage) -> Self {
self.stages.push(stage);
self
}
pub fn prune_then_quantize(sparsity: f32) -> Self {
let prune_cfg = PruningConfig::unstructured_l1(sparsity);
Self::new()
.add_stage(CompressionStage::Prune(prune_cfg))
.add_stage(CompressionStage::QuantizeInt8)
}
pub fn quantize_only() -> Self {
Self::new().add_stage(CompressionStage::QuantizeInt8)
}
pub fn prune_only(sparsity: f32) -> Self {
let prune_cfg = PruningConfig::unstructured_l1(sparsity);
Self::new().add_stage(CompressionStage::Prune(prune_cfg))
}
}
#[derive(Debug, Clone)]
pub struct StageStats {
pub stage_name: String,
pub tensors_processed: usize,
pub tensors_skipped: usize,
pub params_before: usize,
pub nonzero_params_after: usize,
pub memory_before_bytes: usize,
pub memory_after_bytes: usize,
}
impl StageStats {
pub fn compression_ratio(&self) -> f32 {
if self.memory_after_bytes == 0 {
return 1.0;
}
self.memory_before_bytes as f32 / self.memory_after_bytes as f32
}
pub fn sparsity(&self) -> f32 {
if self.params_before == 0 {
return 0.0;
}
let zeros = self.params_before.saturating_sub(self.nonzero_params_after);
zeros as f32 / self.params_before as f32
}
}
#[derive(Debug, Clone)]
pub struct CompressionResult {
pub compressed_tensors: Vec<WeightTensor>,
pub stage_stats: Vec<StageStats>,
}
impl CompressionResult {
pub fn total_params(&self) -> usize {
self.compressed_tensors.iter().map(|t| t.data.len()).sum()
}
pub fn total_nonzero(&self) -> usize {
self.compressed_tensors
.iter()
.map(|t| t.data.iter().filter(|&&x| x != 0.0).count())
.sum()
}
pub fn overall_sparsity(&self) -> f32 {
let total = self.total_params();
if total == 0 {
return 0.0;
}
let nonzero = self.total_nonzero();
let zeros = total.saturating_sub(nonzero);
zeros as f32 / total as f32
}
pub fn total_compression_ratio(&self) -> f32 {
if self.stage_stats.is_empty() {
return 1.0;
}
let before = self.memory_before_bytes();
let after = self.memory_after_bytes();
if after == 0 {
return 1.0;
}
before as f32 / after as f32
}
pub fn memory_before_bytes(&self) -> usize {
self.stage_stats
.first()
.map(|s| s.memory_before_bytes)
.unwrap_or(0)
}
pub fn memory_after_bytes(&self) -> usize {
self.stage_stats
.last()
.map(|s| s.memory_after_bytes)
.unwrap_or(0)
}
pub fn summary(&self) -> String {
let mut lines: Vec<String> = Vec::new();
lines.push(format!(
"=== Compression Summary ({} stage(s)) ===",
self.stage_stats.len()
));
for (i, stats) in self.stage_stats.iter().enumerate() {
lines.push(format!(
" Stage {}: [{}] processed={} skipped={} sparsity={:.4} ratio={:.3}x \
memory={}B->{}B",
i + 1,
stats.stage_name,
stats.tensors_processed,
stats.tensors_skipped,
stats.sparsity(),
stats.compression_ratio(),
stats.memory_before_bytes,
stats.memory_after_bytes,
));
}
lines.push(format!(
" Overall: tensors={} total_params={} nonzero={} sparsity={:.4} \
compression_ratio={:.3}x memory={}B->{}B",
self.compressed_tensors.len(),
self.total_params(),
self.total_nonzero(),
self.overall_sparsity(),
self.total_compression_ratio(),
self.memory_before_bytes(),
self.memory_after_bytes(),
));
lines.join("\n")
}
}
#[derive(Debug, thiserror::Error)]
pub enum CompressionError {
#[error("pruning error: {0}")]
Pruning(#[from] PruningError),
#[error("empty model: no tensors")]
EmptyModel,
#[error("empty pipeline: no stages")]
EmptyPipeline,
#[error("invalid clip percentile {0}: must be in (0, 1]")]
InvalidPercentile(f32),
}
#[inline]
fn is_embedding_layer(name: &str) -> bool {
let lower = name.to_ascii_lowercase();
lower.starts_with("embed") || lower.starts_with("token")
}
#[inline]
fn tensor_bytes(tensor: &WeightTensor) -> usize {
tensor.data.len() * core::mem::size_of::<f32>()
}
#[inline]
fn count_nonzero(tensor: &WeightTensor) -> usize {
tensor.data.iter().filter(|&&x| x != 0.0).count()
}
fn apply_quantize_int8_inplace(tensor: &mut WeightTensor) {
let data = &mut tensor.data;
if data.is_empty() {
return;
}
let max_abs = data.iter().map(|w| w.abs()).fold(0.0_f32, f32::max);
if max_abs == 0.0 {
return; }
let scale = max_abs / 127.0_f32;
for w in data.iter_mut() {
let q = (*w / scale).round().clamp(-127.0_f32, 127.0_f32) as i8;
*w = q as f32 * scale;
}
}
fn apply_clip_inplace(tensor: &mut WeightTensor, percentile: f32) -> Result<(), CompressionError> {
if percentile <= 0.0 || percentile > 1.0 {
return Err(CompressionError::InvalidPercentile(percentile));
}
let data = &mut tensor.data;
if data.is_empty() {
return Ok(());
}
let mut abs_vals: Vec<f32> = data.iter().map(|w| w.abs()).collect();
abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
let n = abs_vals.len();
let idx = ((percentile * n as f32).ceil() as usize)
.saturating_sub(1)
.min(n - 1);
let threshold = abs_vals[idx];
for w in data.iter_mut() {
if w.abs() <= threshold {
*w = 0.0;
}
}
Ok(())
}
#[inline]
fn quantize_int8_memory_after(memory_before: usize) -> usize {
(memory_before as f32 * 0.25).round() as usize
}
pub fn compress_model(
tensors: &[WeightTensor],
config: &CompressionConfig,
) -> Result<CompressionResult, CompressionError> {
if tensors.is_empty() {
return Err(CompressionError::EmptyModel);
}
if config.stages.is_empty() {
return Err(CompressionError::EmptyPipeline);
}
for stage in &config.stages {
if let CompressionStage::Clip { percentile } = stage {
if *percentile <= 0.0 || *percentile > 1.0 {
return Err(CompressionError::InvalidPercentile(*percentile));
}
}
}
let mut working: Vec<WeightTensor> = tensors.to_vec();
let mut stage_stats: Vec<StageStats> = Vec::with_capacity(config.stages.len());
for stage in &config.stages {
let stage_name = stage.name().to_string();
let mut tensors_processed = 0usize;
let mut tensors_skipped = 0usize;
let mut params_before = 0usize;
let mut nonzero_after = 0usize;
let mut memory_before = 0usize;
let mut memory_after = 0usize;
for tensor in working.iter_mut() {
let should_skip = config.skip_embedding_layers && is_embedding_layer(&tensor.name);
let tb = tensor_bytes(tensor);
params_before += tensor.data.len();
memory_before += tb;
if should_skip {
tensors_skipped += 1;
nonzero_after += count_nonzero(tensor);
memory_after += tb;
continue;
}
tensors_processed += 1;
match stage {
CompressionStage::Prune(prune_cfg) => {
let (pruned, _mask) = prune_tensor(tensor, prune_cfg)?;
*tensor = pruned;
nonzero_after += count_nonzero(tensor);
memory_after += tensor_bytes(tensor);
}
CompressionStage::QuantizeInt8 => {
apply_quantize_int8_inplace(tensor);
nonzero_after += count_nonzero(tensor);
memory_after += quantize_int8_memory_after(tb);
}
CompressionStage::Clip { percentile } => {
apply_clip_inplace(tensor, *percentile)?;
nonzero_after += count_nonzero(tensor);
memory_after += tensor_bytes(tensor);
}
}
}
stage_stats.push(StageStats {
stage_name,
tensors_processed,
tensors_skipped,
params_before,
nonzero_params_after: nonzero_after,
memory_before_bytes: memory_before,
memory_after_bytes: memory_after,
});
}
Ok(CompressionResult {
compressed_tensors: working,
stage_stats,
})
}
pub fn estimate_compressed_size(tensors: &[WeightTensor], config: &CompressionConfig) -> usize {
if tensors.is_empty() || config.stages.is_empty() {
return 0;
}
let total_f32_bytes: usize = tensors.iter().map(tensor_bytes).sum();
let embedding_bytes: usize = if config.skip_embedding_layers {
tensors
.iter()
.filter(|t| is_embedding_layer(&t.name))
.map(tensor_bytes)
.sum()
} else {
0
};
let compressible_bytes = total_f32_bytes.saturating_sub(embedding_bytes);
let mut size = compressible_bytes as f64;
for stage in &config.stages {
match stage {
CompressionStage::Prune(_) => {
}
CompressionStage::QuantizeInt8 => {
size *= 0.25;
}
CompressionStage::Clip { .. } => {
}
}
}
embedding_bytes + size.round() as usize
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tensor(name: &str, data: Vec<f32>, shape: Vec<usize>) -> WeightTensor {
WeightTensor::new(name, data, shape)
}
fn linear_data(n: usize) -> Vec<f32> {
(1..=n).map(|i| i as f32).collect()
}
#[test]
fn is_embedding_layer_matches_embed_prefix() {
assert!(is_embedding_layer("embed.weight"));
assert!(is_embedding_layer("Embed.weight"));
assert!(is_embedding_layer("embedding_layer"));
assert!(is_embedding_layer("token_embedding"));
assert!(!is_embedding_layer("linear.weight"));
assert!(!is_embedding_layer("layer_norm"));
}
#[test]
fn apply_quantize_int8_preserves_sign() {
let mut t = make_tensor("w", vec![1.0, -2.0, 0.5, -0.25], vec![4]);
apply_quantize_int8_inplace(&mut t);
assert!(t.data[0] > 0.0);
assert!(t.data[1] < 0.0);
assert!(t.data[2] > 0.0);
assert!(t.data[3] < 0.0);
}
#[test]
fn apply_clip_zeros_small_values() {
let mut t = make_tensor("w", linear_data(10), vec![10]);
apply_clip_inplace(&mut t, 0.3).expect("clip ok");
assert_eq!(t.data[0], 0.0);
assert_eq!(t.data[1], 0.0);
assert_eq!(t.data[2], 0.0);
assert!(t.data[9] != 0.0);
}
#[test]
fn apply_clip_invalid_percentile_returns_error() {
let mut t = make_tensor("w", vec![1.0; 4], vec![4]);
assert!(apply_clip_inplace(&mut t, 0.0).is_err());
assert!(apply_clip_inplace(&mut t, 1.1).is_err());
assert!(apply_clip_inplace(&mut t, -0.5).is_err());
assert!(apply_clip_inplace(&mut t, 1.0).is_ok()); }
#[test]
fn stage_stats_compression_ratio_equals_before_over_after() {
let stats = StageStats {
stage_name: "prune".to_string(),
tensors_processed: 1,
tensors_skipped: 0,
params_before: 100,
nonzero_params_after: 50,
memory_before_bytes: 400,
memory_after_bytes: 400,
};
let ratio = stats.compression_ratio();
assert!((ratio - 1.0).abs() < 1e-6);
}
#[test]
fn stage_stats_sparsity_half() {
let stats = StageStats {
stage_name: "prune".to_string(),
tensors_processed: 2,
tensors_skipped: 0,
params_before: 100,
nonzero_params_after: 50,
memory_before_bytes: 400,
memory_after_bytes: 400,
};
assert!((stats.sparsity() - 0.5).abs() < 1e-6);
}
#[test]
fn compression_result_memory_helpers() {
let result = CompressionResult {
compressed_tensors: vec![],
stage_stats: vec![
StageStats {
stage_name: "prune".to_string(),
tensors_processed: 1,
tensors_skipped: 0,
params_before: 10,
nonzero_params_after: 5,
memory_before_bytes: 40,
memory_after_bytes: 40,
},
StageStats {
stage_name: "quantize_int8".to_string(),
tensors_processed: 1,
tensors_skipped: 0,
params_before: 10,
nonzero_params_after: 5,
memory_before_bytes: 40,
memory_after_bytes: 10,
},
],
};
assert_eq!(result.memory_before_bytes(), 40);
assert_eq!(result.memory_after_bytes(), 10);
assert!((result.total_compression_ratio() - 4.0).abs() < 1e-4);
}
#[test]
fn compress_model_returns_same_tensor_count() {
let tensors = vec![
make_tensor("layer1.weight", linear_data(8), vec![2, 4]),
make_tensor("layer2.weight", linear_data(4), vec![2, 2]),
];
let config = CompressionConfig::quantize_only();
let result = compress_model(&tensors, &config).expect("compress ok");
assert_eq!(result.compressed_tensors.len(), 2);
}
#[test]
fn compress_model_prune_reduces_nonzero() {
let tensors = vec![make_tensor("layer.weight", linear_data(10), vec![10])];
let config = CompressionConfig::prune_only(0.5);
let result = compress_model(&tensors, &config).expect("compress ok");
assert!(result.total_nonzero() < 10);
}
}