use super::calibration::CalibrationContext;
use super::error::PruningError;
use super::importance::{Importance, ImportanceScores};
use super::mask::{SparsityMask, SparsityPattern};
use crate::nn::Module;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct PruningResult {
pub achieved_sparsity: f32,
pub parameters_pruned: usize,
pub total_parameters: usize,
pub layer_sparsity: HashMap<String, f32>,
pub memory_savings_bytes: usize,
}
impl PruningResult {
#[must_use]
pub fn new(achieved_sparsity: f32, parameters_pruned: usize, total_parameters: usize) -> Self {
Self {
achieved_sparsity,
parameters_pruned,
total_parameters,
layer_sparsity: HashMap::new(),
memory_savings_bytes: parameters_pruned * 4, }
}
#[must_use]
pub fn with_layer_sparsity(mut self, layer_name: String, sparsity: f32) -> Self {
self.layer_sparsity.insert(layer_name, sparsity);
self
}
#[must_use]
pub fn compression_ratio(&self) -> f32 {
if self.total_parameters == 0 || self.achieved_sparsity >= 1.0 {
return f32::INFINITY;
}
1.0 / (1.0 - self.achieved_sparsity)
}
}
impl Default for PruningResult {
fn default() -> Self {
Self::new(0.0, 0, 0)
}
}
pub trait Pruner: Send + Sync {
fn generate_mask(
&self,
scores: &ImportanceScores,
target_sparsity: f32,
pattern: SparsityPattern,
) -> Result<SparsityMask, PruningError>;
fn apply_mask(
&self,
module: &mut dyn Module,
mask: &SparsityMask,
) -> Result<PruningResult, PruningError>;
fn importance(&self) -> &dyn Importance;
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone)]
pub struct MagnitudePruner {
importance: super::magnitude::MagnitudeImportance,
}
impl MagnitudePruner {
#[must_use]
pub fn new() -> Self {
Self {
importance: super::magnitude::MagnitudeImportance::l2(),
}
}
#[must_use]
pub fn l1() -> Self {
Self {
importance: super::magnitude::MagnitudeImportance::l1(),
}
}
#[must_use]
pub fn l2() -> Self {
Self {
importance: super::magnitude::MagnitudeImportance::l2(),
}
}
}
impl Default for MagnitudePruner {
fn default() -> Self {
Self::new()
}
}
impl Pruner for MagnitudePruner {
fn generate_mask(
&self,
scores: &ImportanceScores,
target_sparsity: f32,
pattern: SparsityPattern,
) -> Result<SparsityMask, PruningError> {
match pattern {
SparsityPattern::Unstructured => {
super::mask::generate_unstructured_mask(&scores.values, target_sparsity)
}
SparsityPattern::NM { n, m } => super::mask::generate_nm_mask(&scores.values, n, m),
SparsityPattern::Block { height, width } => {
super::mask::generate_block_mask(&scores.values, height, width, target_sparsity)
}
SparsityPattern::Row => super::mask::generate_row_mask(&scores.values, target_sparsity),
SparsityPattern::Column => {
super::mask::generate_column_mask(&scores.values, target_sparsity)
}
}
}
fn apply_mask(
&self,
module: &mut dyn Module,
mask: &SparsityMask,
) -> Result<PruningResult, PruningError> {
let mut params = module.parameters_mut();
if params.is_empty() {
return Err(PruningError::NoParameters {
module: "unknown".to_string(),
});
}
let weights = &mut *params[0];
let total = weights.data().len();
mask.apply(weights)?;
let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
let achieved_sparsity = zeros as f32 / total as f32;
Ok(PruningResult::new(achieved_sparsity, zeros, total))
}
fn importance(&self) -> &dyn Importance {
&self.importance
}
fn name(&self) -> &'static str {
"magnitude_pruner"
}
}
#[derive(Debug, Clone)]
pub struct WandaPruner {
importance: super::wanda::WandaImportance,
}
impl WandaPruner {
pub fn new(layer_name: impl Into<String>) -> Self {
Self {
importance: super::wanda::WandaImportance::new(layer_name),
}
}
}
impl Pruner for WandaPruner {
fn generate_mask(
&self,
scores: &ImportanceScores,
target_sparsity: f32,
pattern: SparsityPattern,
) -> Result<SparsityMask, PruningError> {
match pattern {
SparsityPattern::Unstructured => {
super::mask::generate_unstructured_mask(&scores.values, target_sparsity)
}
SparsityPattern::NM { n, m } => super::mask::generate_nm_mask(&scores.values, n, m),
SparsityPattern::Block { height, width } => {
super::mask::generate_block_mask(&scores.values, height, width, target_sparsity)
}
SparsityPattern::Row => super::mask::generate_row_mask(&scores.values, target_sparsity),
SparsityPattern::Column => {
super::mask::generate_column_mask(&scores.values, target_sparsity)
}
}
}
fn apply_mask(
&self,
module: &mut dyn Module,
mask: &SparsityMask,
) -> Result<PruningResult, PruningError> {
let mut params = module.parameters_mut();
if params.is_empty() {
return Err(PruningError::NoParameters {
module: "unknown".to_string(),
});
}
let weights = &mut *params[0];
let total = weights.data().len();
mask.apply(weights)?;
let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
let achieved_sparsity = zeros as f32 / total as f32;
Ok(PruningResult::new(achieved_sparsity, zeros, total))
}
fn importance(&self) -> &dyn Importance {
&self.importance
}
fn name(&self) -> &'static str {
"wanda_pruner"
}
}
pub fn prune_module(
module: &mut dyn Module,
pruner: &dyn Pruner,
target_sparsity: f32,
pattern: SparsityPattern,
context: Option<&CalibrationContext>,
) -> Result<PruningResult, PruningError> {
let scores = pruner.importance().compute(module, context)?;
let mask = pruner.generate_mask(&scores, target_sparsity, pattern)?;
pruner.apply_mask(module, &mask)
}
#[cfg(test)]
#[path = "pruner_tests.rs"]
mod tests;