use crate::{CoreError, CoreResult};
use scirs2_core::ndarray::Array2;
#[allow(unused_imports)]
use scirs2_core::ndarray::Axis; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PruningStrategy {
Magnitude,
L1Norm,
L2Norm,
Gradient,
Random,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PruningGranularity {
Unstructured,
Channel,
Filter,
Head,
Block,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruningConfig {
pub strategy: PruningStrategy,
pub granularity: PruningGranularity,
pub target_sparsity: f32,
pub global_threshold: bool,
pub num_iterations: usize,
pub keep_pruned_weights: bool,
}
impl Default for PruningConfig {
fn default() -> Self {
Self {
strategy: PruningStrategy::Magnitude,
granularity: PruningGranularity::Unstructured,
target_sparsity: 0.5,
global_threshold: false,
num_iterations: 1,
keep_pruned_weights: false,
}
}
}
impl PruningConfig {
pub fn new(strategy: PruningStrategy, target_sparsity: f32) -> Self {
Self {
strategy,
target_sparsity,
..Default::default()
}
}
pub fn with_granularity(mut self, granularity: PruningGranularity) -> Self {
self.granularity = granularity;
self
}
pub fn with_global_threshold(mut self) -> Self {
self.global_threshold = true;
self
}
pub fn with_iterations(mut self, num_iterations: usize) -> Self {
self.num_iterations = num_iterations;
self
}
pub fn with_keep_weights(mut self) -> Self {
self.keep_pruned_weights = true;
self
}
pub fn validate(&self) -> CoreResult<()> {
if self.target_sparsity < 0.0 || self.target_sparsity >= 1.0 {
return Err(CoreError::InvalidConfig(
"target_sparsity must be in [0, 1)".into(),
));
}
if self.num_iterations == 0 {
return Err(CoreError::InvalidConfig(
"num_iterations must be > 0".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PruningMask {
pub mask: Array2<f32>,
pub pruned_weights: Option<Array2<f32>>,
pub sparsity: f32,
}
impl PruningMask {
pub fn new(mask: Array2<f32>) -> Self {
let total = mask.len();
let zeros = mask.iter().filter(|&&x| x == 0.0).count();
let sparsity = zeros as f32 / total as f32;
Self {
mask,
pruned_weights: None,
sparsity,
}
}
pub fn apply(&self, weights: &Array2<f32>) -> Array2<f32> {
weights * &self.mask
}
pub fn num_parameters(&self) -> usize {
self.mask.iter().filter(|&&x| x != 0.0).count()
}
pub fn compression_ratio(&self) -> f32 {
1.0 / (1.0 - self.sparsity).max(1e-6)
}
}
pub struct StructuredPruner {
config: PruningConfig,
masks: HashMap<String, PruningMask>,
}
impl StructuredPruner {
pub fn new(config: PruningConfig) -> CoreResult<Self> {
config.validate()?;
Ok(Self {
config,
masks: HashMap::new(),
})
}
pub fn prune(&mut self, name: &str, weights: &Array2<f32>) -> CoreResult<PruningMask> {
let mask = match self.config.granularity {
PruningGranularity::Unstructured => self.prune_unstructured(weights)?,
PruningGranularity::Channel => self.prune_channels(weights)?,
PruningGranularity::Filter => self.prune_filters(weights)?,
_ => {
return Err(CoreError::InvalidConfig(format!(
"Granularity {:?} not yet implemented for 2D tensors",
self.config.granularity
)))
}
};
self.masks.insert(name.to_string(), mask.clone());
Ok(mask)
}
fn prune_unstructured(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
let importance = self.compute_importance(weights)?;
let threshold = self.compute_threshold(&importance)?;
let mask = importance.mapv(|v| if v.abs() >= threshold { 1.0 } else { 0.0 });
Ok(PruningMask::new(mask))
}
fn prune_channels(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
let (out_channels, _in_features) = weights.dim();
let mut channel_importance = Vec::with_capacity(out_channels);
for channel_idx in 0..out_channels {
let channel = weights.row(channel_idx);
let importance = match self.config.strategy {
PruningStrategy::L1Norm => channel.iter().map(|x| x.abs()).sum::<f32>(),
PruningStrategy::L2Norm => channel.iter().map(|x| x.powi(2)).sum::<f32>().sqrt(),
PruningStrategy::Magnitude => {
channel.iter().map(|x| x.abs()).sum::<f32>() / channel.len() as f32
}
_ => channel.iter().map(|x| x.abs()).sum::<f32>(),
};
channel_importance.push((channel_idx, importance));
}
channel_importance.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let num_to_prune = (out_channels as f32 * self.config.target_sparsity) as usize;
let mut mask = Array2::ones(weights.dim());
for &(channel_idx, _) in channel_importance.iter().take(num_to_prune) {
mask.row_mut(channel_idx).fill(0.0);
}
Ok(PruningMask::new(mask))
}
fn prune_filters(&self, weights: &Array2<f32>) -> CoreResult<PruningMask> {
self.prune_channels(weights)
}
fn compute_importance(&self, weights: &Array2<f32>) -> CoreResult<Array2<f32>> {
let importance = match self.config.strategy {
PruningStrategy::Magnitude => weights.mapv(|x| x.abs()),
PruningStrategy::L1Norm => weights.mapv(|x| x.abs()),
PruningStrategy::L2Norm => weights.mapv(|x| x.powi(2)),
PruningStrategy::Random => {
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
Array2::from_shape_fn(weights.dim(), |_| rng.random::<f32>())
}
PruningStrategy::Gradient => {
weights.mapv(|x| x.abs())
}
};
Ok(importance)
}
fn compute_threshold(&self, importance: &Array2<f32>) -> CoreResult<f32> {
let mut values: Vec<f32> = importance.iter().copied().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let threshold_idx = (values.len() as f32 * self.config.target_sparsity) as usize;
let threshold = values.get(threshold_idx).copied().unwrap_or(0.0);
Ok(threshold)
}
pub fn prune_progressive(
&mut self,
name: &str,
weights: &Array2<f32>,
) -> CoreResult<Vec<PruningMask>> {
let mut masks = Vec::with_capacity(self.config.num_iterations);
let sparsity_per_iter = self.config.target_sparsity / self.config.num_iterations as f32;
let mut current_weights = weights.clone();
for iter in 0..self.config.num_iterations {
let iter_config = PruningConfig {
target_sparsity: sparsity_per_iter,
..self.config.clone()
};
let mut iter_pruner = StructuredPruner::new(iter_config)?;
let mask = iter_pruner.prune(&format!("{}_{}", name, iter), ¤t_weights)?;
current_weights = mask.apply(¤t_weights);
masks.push(mask);
}
if let Some(final_mask) = masks.last() {
self.masks.insert(name.to_string(), final_mask.clone());
}
Ok(masks)
}
pub fn get_mask(&self, name: &str) -> Option<&PruningMask> {
self.masks.get(name)
}
pub fn masks(&self) -> &HashMap<String, PruningMask> {
&self.masks
}
pub fn global_sparsity(&self) -> f32 {
if self.masks.is_empty() {
return 0.0;
}
let total_params: usize = self.masks.values().map(|m| m.mask.len()).sum();
let pruned_params: usize = self
.masks
.values()
.map(|m| m.mask.iter().filter(|&&x| x == 0.0).count())
.sum();
pruned_params as f32 / total_params as f32
}
pub fn global_compression_ratio(&self) -> f32 {
let sparsity = self.global_sparsity();
1.0 / (1.0 - sparsity).max(1e-6)
}
}
pub struct GradientPruner {
pruner: StructuredPruner,
gradient_accumulator: HashMap<String, Array2<f32>>,
}
impl GradientPruner {
pub fn new(config: PruningConfig) -> CoreResult<Self> {
Ok(Self {
pruner: StructuredPruner::new(config)?,
gradient_accumulator: HashMap::new(),
})
}
pub fn accumulate_gradient(&mut self, name: &str, gradient: &Array2<f32>) {
let acc = self
.gradient_accumulator
.entry(name.to_string())
.or_insert_with(|| Array2::zeros(gradient.dim()));
*acc = &*acc + gradient;
}
pub fn prune_with_gradients(
&mut self,
name: &str,
weights: &Array2<f32>,
) -> CoreResult<PruningMask> {
let gradients = self
.gradient_accumulator
.get(name)
.ok_or_else(|| CoreError::InvalidConfig("No gradients accumulated".into()))?;
let importance = weights * gradients;
let importance = importance.mapv(|x| x.abs());
let threshold = self.compute_gradient_threshold(&importance)?;
let mask = importance.mapv(|v| if v >= threshold { 1.0 } else { 0.0 });
let pruning_mask = PruningMask::new(mask);
self.pruner
.masks
.insert(name.to_string(), pruning_mask.clone());
Ok(pruning_mask)
}
fn compute_gradient_threshold(&self, importance: &Array2<f32>) -> CoreResult<f32> {
let mut values: Vec<f32> = importance.iter().copied().collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let threshold_idx = (values.len() as f32 * self.pruner.config.target_sparsity) as usize;
Ok(values.get(threshold_idx).copied().unwrap_or(0.0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pruning_config() {
let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
assert_eq!(config.strategy, PruningStrategy::Magnitude);
assert_eq!(config.target_sparsity, 0.5);
assert!(config.validate().is_ok());
}
#[test]
fn test_pruning_config_validation() {
let mut config = PruningConfig::new(PruningStrategy::Magnitude, 1.5);
assert!(config.validate().is_err());
config.target_sparsity = -0.1;
assert!(config.validate().is_err());
config.target_sparsity = 0.5;
config.num_iterations = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_unstructured_pruning() {
let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
let mut pruner = StructuredPruner::new(config).unwrap();
let weights = Array2::from_shape_fn((10, 10), |(i, j)| ((i * 10 + j) as f32) * 0.01);
let mask = pruner.prune("layer1", &weights).unwrap();
assert!(
mask.sparsity >= 0.45 && mask.sparsity <= 0.55,
"Expected sparsity ~0.5, got {}",
mask.sparsity
);
}
#[test]
fn test_channel_pruning() {
let config = PruningConfig::new(PruningStrategy::L2Norm, 0.5)
.with_granularity(PruningGranularity::Channel);
let mut pruner = StructuredPruner::new(config).unwrap();
let weights = Array2::from_shape_fn((8, 16), |(i, _j)| {
if i < 4 {
1.0
} else {
0.1
} });
let mask = pruner.prune("layer1", &weights).unwrap();
for row in mask.mask.axis_iter(Axis(0)) {
let sum: f32 = row.sum();
assert!(sum == 0.0 || sum == row.len() as f32);
}
}
#[test]
fn test_pruning_mask_apply() {
let mask_data = Array2::from_shape_fn((4, 4), |(i, j)| if i == j { 1.0 } else { 0.0 });
let mask = PruningMask::new(mask_data);
let weights = Array2::ones((4, 4));
let pruned = mask.apply(&weights);
for i in 0..4 {
for j in 0..4 {
if i == j {
assert_eq!(pruned[[i, j]], 1.0);
} else {
assert_eq!(pruned[[i, j]], 0.0);
}
}
}
}
#[test]
fn test_progressive_pruning() {
let config = PruningConfig::new(PruningStrategy::Magnitude, 0.6).with_iterations(3);
let mut pruner = StructuredPruner::new(config).unwrap();
let weights = Array2::from_shape_fn((8, 8), |(i, j)| (i as f32 + j as f32) * 0.1);
let masks = pruner.prune_progressive("layer1", &weights).unwrap();
assert_eq!(masks.len(), 3);
for i in 1..masks.len() {
assert!(masks[i].sparsity >= masks[i - 1].sparsity);
}
}
#[test]
fn test_compression_ratio() {
let mask = PruningMask::new(Array2::from_shape_fn((10, 10), |(i, j)| {
if i + j < 5 {
1.0
} else {
0.0
}
}));
let ratio = mask.compression_ratio();
assert!(ratio > 1.0); assert!(ratio < 10.0); }
#[test]
fn test_global_sparsity() {
let config = PruningConfig::new(PruningStrategy::Magnitude, 0.5);
let mut pruner = StructuredPruner::new(config).unwrap();
let weights1 = Array2::from_shape_fn((4, 4), |(i, j)| (i + j) as f32);
let weights2 = Array2::from_shape_fn((4, 4), |(i, j)| (i * j) as f32);
pruner.prune("layer1", &weights1).unwrap();
pruner.prune("layer2", &weights2).unwrap();
let global_sparsity = pruner.global_sparsity();
assert!((0.4..=0.6).contains(&global_sparsity));
}
#[test]
fn test_gradient_pruner_accumulation() {
let config = PruningConfig::new(PruningStrategy::Gradient, 0.5);
let mut pruner = GradientPruner::new(config).unwrap();
let gradient1 = Array2::ones((4, 4));
let gradient2 = Array2::ones((4, 4)) * 2.0;
pruner.accumulate_gradient("layer1", &gradient1);
pruner.accumulate_gradient("layer1", &gradient2);
let accumulated = &pruner.gradient_accumulator["layer1"];
assert_eq!(accumulated[[0, 0]], 3.0);
}
#[test]
fn test_random_pruning() {
let config = PruningConfig::new(PruningStrategy::Random, 0.5);
let mut pruner = StructuredPruner::new(config).unwrap();
let weights = Array2::ones((10, 10));
let mask = pruner.prune("layer1", &weights).unwrap();
assert!(mask.sparsity >= 0.4 && mask.sparsity <= 0.6);
}
}