use crate::error::{ModelError, ModelResult};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PruneMethod {
MagnitudeUnstructured,
StructuredMagnitude,
RandomUnstructured,
}
#[derive(Debug, Clone)]
pub struct PruneConfig {
pub method: PruneMethod,
pub sparsity: f32,
pub include_prefixes: Vec<String>,
pub exclude_prefixes: Vec<String>,
pub min_tensor_size: usize,
}
impl Default for PruneConfig {
fn default() -> Self {
Self {
method: PruneMethod::MagnitudeUnstructured,
sparsity: 0.5,
include_prefixes: vec![],
exclude_prefixes: vec!["bias".to_string(), "norm".to_string()],
min_tensor_size: 16,
}
}
}
#[derive(Debug, Clone)]
pub struct TensorPruneResult {
pub name: String,
pub original_nonzero: usize,
pub pruned_nonzero: usize,
pub actual_sparsity: f32,
}
#[derive(Debug, Clone)]
pub struct PruneResult {
pub tensor_results: Vec<TensorPruneResult>,
pub total_params: usize,
pub pruned_params: usize,
pub overall_sparsity: f32,
}
impl PruneResult {
pub fn compression_ratio(&self) -> f32 {
let remaining = self.total_params.saturating_sub(self.pruned_params);
if remaining == 0 {
return f32::INFINITY;
}
self.total_params as f32 / remaining as f32
}
}
pub type PruneMask = Vec<bool>;
pub struct ModelPruner {
config: PruneConfig,
}
impl ModelPruner {
pub fn new(config: PruneConfig) -> Self {
Self { config }
}
pub fn compute_mask(&self, values: &[f32]) -> PruneMask {
let n = values.len();
let n_prune = (n as f32 * self.config.sparsity) as usize;
if n_prune == 0 {
return vec![true; n];
}
match self.config.method {
PruneMethod::MagnitudeUnstructured | PruneMethod::StructuredMagnitude => {
magnitude_mask(values, n_prune)
}
PruneMethod::RandomUnstructured => random_mask(n, n_prune),
}
}
pub fn apply_mask(values: &[f32], mask: &PruneMask) -> Vec<f32> {
values
.iter()
.zip(mask.iter())
.map(|(&v, &keep)| if keep { v } else { 0.0 })
.collect()
}
pub fn prune_weights(
&self,
weights: &HashMap<String, Vec<f32>>,
) -> ModelResult<(HashMap<String, Vec<f32>>, PruneResult)> {
let mut pruned_map: HashMap<String, Vec<f32>> = HashMap::with_capacity(weights.len());
let mut tensor_results: Vec<TensorPruneResult> = Vec::new();
let mut total_params: usize = 0;
let mut pruned_params: usize = 0;
for (name, values) in weights {
total_params += values.len();
if !self.should_prune(name) || values.len() < self.config.min_tensor_size {
pruned_map.insert(name.clone(), values.clone());
continue;
}
let original_nonzero = values.iter().filter(|&&v| v != 0.0).count();
let mask = self.compute_mask(values);
let pruned_values = Self::apply_mask(values, &mask);
let pruned_nonzero = pruned_values.iter().filter(|&&v| v != 0.0).count();
let zeroed = original_nonzero.saturating_sub(pruned_nonzero);
pruned_params += zeroed;
let actual_sparsity = if values.is_empty() {
0.0
} else {
zeroed as f32 / values.len() as f32
};
tensor_results.push(TensorPruneResult {
name: name.clone(),
original_nonzero,
pruned_nonzero,
actual_sparsity,
});
pruned_map.insert(name.clone(), pruned_values);
}
let overall_sparsity = if total_params == 0 {
0.0
} else {
pruned_params as f32 / total_params as f32
};
let result = PruneResult {
tensor_results,
total_params,
pruned_params,
overall_sparsity,
};
Ok((pruned_map, result))
}
pub fn schedule_sparsity(
initial_sparsity: f32,
final_sparsity: f32,
step: usize,
total_steps: usize,
start_step: usize,
prune_freq: usize,
) -> f32 {
if step < start_step {
return initial_sparsity;
}
let freq = prune_freq.max(1);
let t = ((step - start_step) / freq) as f32;
let t_total = ((total_steps.saturating_sub(start_step)) / freq) as f32;
if t_total == 0.0 || t >= t_total {
return final_sparsity;
}
let frac = 1.0 - t / t_total;
final_sparsity + (initial_sparsity - final_sparsity) * frac * frac * frac
}
fn should_prune(&self, name: &str) -> bool {
for prefix in &self.config.exclude_prefixes {
if name.starts_with(prefix.as_str()) || name.contains(prefix.as_str()) {
return false;
}
}
if self.config.include_prefixes.is_empty() {
return true;
}
self.config
.include_prefixes
.iter()
.any(|p| name.starts_with(p.as_str()))
}
}
pub fn prune_magnitude(
weights: &HashMap<String, Vec<f32>>,
sparsity: f32,
) -> ModelResult<(HashMap<String, Vec<f32>>, PruneResult)> {
let config = PruneConfig {
sparsity,
min_tensor_size: 0,
..Default::default()
};
ModelPruner::new(config).prune_weights(weights)
}
pub fn prune_structured_channels(
weight: &[f32],
shape: &[usize],
sparsity: f32,
) -> ModelResult<(Vec<f32>, Vec<usize>)> {
if shape.len() < 2 {
return Err(ModelError::invalid_config(
"prune_structured_channels: shape must have at least 2 dimensions [out, in, ...]",
));
}
let out_channels = shape[0];
let row_size: usize = shape[1..].iter().product();
if row_size == 0 {
return Err(ModelError::invalid_config(
"prune_structured_channels: inner dimensions must not be zero",
));
}
if weight.len() != out_channels * row_size {
return Err(ModelError::invalid_config(format!(
"prune_structured_channels: weight length {} does not match shape product {}",
weight.len(),
out_channels * row_size
)));
}
let n_prune = (out_channels as f32 * sparsity) as usize;
let mut norms: Vec<(usize, f32)> = (0..out_channels)
.map(|ch| {
let start = ch * row_size;
let end = start + row_size;
let norm_sq: f32 = weight[start..end].iter().map(|v| v * v).sum();
(ch, norm_sq.sqrt())
})
.collect();
norms.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let pruned_set: std::collections::HashSet<usize> = norms[..n_prune.min(out_channels)]
.iter()
.map(|&(i, _)| i)
.collect();
let kept_indices: Vec<usize> = (0..out_channels)
.filter(|ch| !pruned_set.contains(ch))
.collect();
let mut pruned_weight: Vec<f32> = Vec::with_capacity(kept_indices.len() * row_size);
for &ch in &kept_indices {
let start = ch * row_size;
let end = start + row_size;
pruned_weight.extend_from_slice(&weight[start..end]);
}
Ok((pruned_weight, kept_indices))
}
fn magnitude_mask(values: &[f32], n_prune: usize) -> PruneMask {
let n = values.len();
let mut indexed: Vec<(f32, usize)> = values
.iter()
.enumerate()
.map(|(i, v)| (v.abs(), i))
.collect();
indexed.sort_by(|a, b| {
a.0.partial_cmp(&b.0)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.1.cmp(&b.1))
});
let mut mask = vec![true; n];
for &(_, idx) in &indexed[..n_prune.min(n)] {
mask[idx] = false;
}
mask
}
fn random_mask(n: usize, n_prune: usize) -> PruneMask {
let mut indices: Vec<usize> = (0..n).collect();
let mut state: u64 = 0x9e3779b97f4a7c15u64;
for i in (1..n).rev() {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let j = (state >> 33) as usize % (i + 1);
indices.swap(i, j);
}
let mut mask = vec![true; n];
for &idx in &indices[..n_prune] {
mask[idx] = false;
}
mask
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prune_config_default() {
let cfg = PruneConfig::default();
assert_eq!(cfg.sparsity, 0.5);
assert!(cfg.exclude_prefixes.contains(&"bias".to_string()));
assert!(cfg.exclude_prefixes.contains(&"norm".to_string()));
assert_eq!(cfg.method, PruneMethod::MagnitudeUnstructured);
}
#[test]
fn test_magnitude_mask_correct_sparsity() {
let pruner = ModelPruner::new(PruneConfig {
sparsity: 0.5,
..Default::default()
});
let vals = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mask = pruner.compute_mask(&vals);
let kept = mask.iter().filter(|&&b| b).count();
assert!((3..=5).contains(&kept), "kept={kept}");
}
#[test]
fn test_apply_mask_zeros_out() {
let vals = vec![1.0f32, 2.0, 3.0, 4.0];
let mask = vec![true, false, true, false];
let result = ModelPruner::apply_mask(&vals, &mask);
assert_eq!(result, vec![1.0, 0.0, 3.0, 0.0]);
}
#[test]
fn test_prune_weights_excludes_bias() {
let mut weights = HashMap::new();
weights.insert("layer.weight".to_string(), vec![1.0f32; 32]);
weights.insert("layer.bias".to_string(), vec![1.0f32; 8]);
let pruner = ModelPruner::new(PruneConfig {
sparsity: 0.5,
..Default::default()
});
let (pruned, result) = pruner.prune_weights(&weights).unwrap();
assert_eq!(pruned["layer.bias"], weights["layer.bias"]);
let zeros = pruned["layer.weight"].iter().filter(|&&v| v == 0.0).count();
assert!((12..=20).contains(&zeros), "zeros={zeros}");
let _ = result;
}
#[test]
fn test_prune_magnitude_convenience() {
let mut weights = HashMap::new();
weights.insert(
"proj".to_string(),
vec![0.1f32, 0.5, 0.01, 0.9, 0.2, 0.8, 0.05, 0.7],
);
let (pruned, result) = prune_magnitude(&weights, 0.5).unwrap();
assert!(
result.overall_sparsity >= 0.4 && result.overall_sparsity <= 0.6,
"sparsity={}",
result.overall_sparsity
);
assert!(
result.compression_ratio() > 1.0,
"ratio={}",
result.compression_ratio()
);
let _ = pruned;
}
#[test]
fn test_prune_result_overall_sparsity() {
let mut weights = HashMap::new();
weights.insert(
"w".to_string(),
(0..100).map(|i| i as f32 * 0.01).collect::<Vec<_>>(),
);
let pruner = ModelPruner::new(PruneConfig {
sparsity: 0.7,
..Default::default()
});
let (_, result) = pruner.prune_weights(&weights).unwrap();
assert!(
result.overall_sparsity >= 0.65 && result.overall_sparsity <= 0.75,
"sparsity={}",
result.overall_sparsity
);
}
#[test]
fn test_schedule_sparsity_bounds() {
let s0 = ModelPruner::schedule_sparsity(0.0, 0.9, 0, 100, 0, 1);
assert!(
(s0 - 0.0).abs() < 0.01,
"s0 should be initial=0.0, got {s0}"
);
let s_end = ModelPruner::schedule_sparsity(0.0, 0.9, 100, 100, 0, 1);
assert!(
(s_end - 0.9).abs() < 0.01,
"s_end should be 0.9, got {s_end}"
);
}
#[test]
fn test_structured_channel_pruning() {
let weight: Vec<f32> = (0..12).map(|i| i as f32).collect();
let shape = vec![4, 3];
let (pruned, kept) = prune_structured_channels(&weight, &shape, 0.5).unwrap();
assert_eq!(kept.len(), 2, "kept={kept:?}");
assert_eq!(pruned.len(), 6, "pruned.len={}", pruned.len()); }
#[test]
fn test_prune_zero_sparsity_noop() {
let mut weights = HashMap::new();
weights.insert("w".to_string(), vec![1.0f32, 2.0, 3.0]);
let (pruned, result) = prune_magnitude(&weights, 0.0).unwrap();
assert_eq!(pruned["w"], weights["w"]); assert_eq!(result.pruned_params, 0);
}
#[test]
fn test_schedule_sparsity_before_start() {
let s = ModelPruner::schedule_sparsity(0.1, 0.9, 5, 100, 10, 1);
assert!((s - 0.1).abs() < 1e-6, "s={s}");
}
#[test]
fn test_prune_method_random() {
let pruner = ModelPruner::new(PruneConfig {
method: PruneMethod::RandomUnstructured,
sparsity: 0.5,
..Default::default()
});
let vals: Vec<f32> = (0..100).map(|i| i as f32).collect();
let mask = pruner.compute_mask(&vals);
let kept = mask.iter().filter(|&&b| b).count();
assert_eq!(kept, 50, "kept={kept}");
}
#[test]
fn test_compression_ratio_no_pruning() {
let result = PruneResult {
tensor_results: vec![],
total_params: 100,
pruned_params: 0,
overall_sparsity: 0.0,
};
assert!((result.compression_ratio() - 1.0).abs() < 1e-6);
}
#[test]
fn test_prune_structured_bad_shape() {
let weight = vec![1.0f32; 12];
let result = prune_structured_channels(&weight, &[4], 0.5);
assert!(result.is_err());
}
#[test]
fn test_should_prune_include_exclude() {
let pruner = ModelPruner::new(PruneConfig {
include_prefixes: vec!["proj".to_string()],
exclude_prefixes: vec!["proj.bias".to_string()],
..Default::default()
});
assert!(pruner.should_prune("proj.weight"));
assert!(!pruner.should_prune("proj.bias"));
assert!(!pruner.should_prune("embed.weight")); }
}