use super::SparseTensor;
use crate::autograd::Variable;
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use ndarray::{s, Array1, Array2, ArrayD};
use num_traits::{Float, FromPrimitive, One, Zero};
use std::cmp::Ordering;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PruningStrategy {
Magnitude,
Random,
Structured,
GradientBased,
SNIP,
}
#[derive(Debug, Clone)]
pub struct PruningConfig {
pub target_sparsity: f32,
pub strategy: PruningStrategy,
pub structured: bool,
pub schedule: Option<PruningSchedule>,
}
#[derive(Debug, Clone)]
pub struct PruningSchedule {
pub initial_sparsity: f32,
pub final_sparsity: f32,
pub num_steps: usize,
pub current_step: usize,
}
impl PruningSchedule {
pub fn new(initial_sparsity: f32, final_sparsity: f32, num_steps: usize) -> Self {
Self {
initial_sparsity,
final_sparsity,
num_steps,
current_step: 0,
}
}
pub fn current_sparsity(&self) -> f32 {
if self.current_step >= self.num_steps {
return self.final_sparsity;
}
let progress = self.current_step as f32 / self.num_steps as f32;
self.initial_sparsity + progress * (self.final_sparsity - self.initial_sparsity)
}
pub fn step(&mut self) {
self.current_step = (self.current_step + 1).min(self.num_steps);
}
}
pub struct ModelPruner<T: Float> {
pub config: PruningConfig,
pub importance_scores: HashMap<String, Array1<T>>,
}
use std::collections::HashMap;
impl<
T: Float
+ PartialOrd
+ Copy
+ Send
+ Sync
+ ndarray::ScalarOperand
+ FromPrimitive
+ std::ops::AddAssign,
> ModelPruner<T>
{
pub fn new(config: PruningConfig) -> Self {
Self {
config,
importance_scores: HashMap::new(),
}
}
pub fn prune_tensor(&self, tensor: &ArrayD<T>) -> RusTorchResult<SparseTensor<T>> {
match self.config.strategy {
PruningStrategy::Magnitude => self.magnitude_pruning(tensor),
PruningStrategy::Random => self.random_pruning(tensor),
PruningStrategy::Structured => self.structured_pruning(tensor),
PruningStrategy::GradientBased => self.gradient_based_pruning(tensor),
PruningStrategy::SNIP => self.snip_pruning(tensor),
}
}
fn magnitude_pruning(&self, tensor: &ArrayD<T>) -> RusTorchResult<SparseTensor<T>> {
let target_sparsity = self.get_current_sparsity();
let total_elements = tensor.len();
let elements_to_keep = ((1.0 - target_sparsity) * total_elements as f32) as usize;
let mut magnitude_indices: Vec<(usize, T)> = tensor
.iter()
.enumerate()
.map(|(i, &val)| (i, val.abs()))
.collect();
magnitude_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
let kept_indices: Vec<usize> = magnitude_indices
.iter()
.take(elements_to_keep)
.map(|(idx, _)| *idx)
.collect();
self.create_sparse_from_indices(tensor, &kept_indices)
}
fn random_pruning(&self, tensor: &ArrayD<T>) -> RusTorchResult<SparseTensor<T>> {
let target_sparsity = self.get_current_sparsity();
let total_elements = tensor.len();
let elements_to_keep = ((1.0 - target_sparsity) * total_elements as f32) as usize;
use rand::seq::SliceRandom;
let mut rng = rand::thread_rng();
let mut all_indices: Vec<usize> = (0..total_elements).collect();
all_indices.shuffle(&mut rng);
let kept_indices = &all_indices[..elements_to_keep];
self.create_sparse_from_indices(tensor, kept_indices)
}
fn structured_pruning(&self, tensor: &ArrayD<T>) -> RusTorchResult<SparseTensor<T>> {
if tensor.ndim() < 2 {
return Err(RusTorchError::InvalidOperation {
operation: "structured_pruning".to_string(),
message: "Structured pruning requires at least 2D tensors".to_string(),
});
}
let target_sparsity = self.get_current_sparsity();
if tensor.ndim() == 2 {
let rows = tensor.shape()[0];
let rows_to_keep = ((1.0 - target_sparsity) * rows as f32) as usize;
let mut row_norms: Vec<(usize, T)> = (0..rows)
.map(|i| {
let row = tensor.slice(s![i, ..]);
let norm_sq = row.iter().map(|&x| x * x).fold(T::zero(), |a, b| a + b);
(i, norm_sq.sqrt())
})
.collect();
row_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
let kept_rows: Vec<usize> = row_norms
.iter()
.take(rows_to_keep)
.map(|(idx, _)| *idx)
.collect();
let mut kept_indices = Vec::new();
let cols = tensor.shape()[1];
for &row in &kept_rows {
for col in 0..cols {
kept_indices.push(row * cols + col);
}
}
return self.create_sparse_from_indices(tensor, &kept_indices);
}
self.magnitude_pruning(tensor)
}
fn gradient_based_pruning(&self, tensor: &ArrayD<T>) -> RusTorchResult<SparseTensor<T>> {
self.magnitude_pruning(tensor)
}
fn snip_pruning(&self, tensor: &ArrayD<T>) -> RusTorchResult<SparseTensor<T>> {
self.magnitude_pruning(tensor)
}
fn create_sparse_from_indices(
&self,
tensor: &ArrayD<T>,
kept_indices: &[usize],
) -> RusTorchResult<SparseTensor<T>> {
let shape = tensor.shape().to_vec();
let mut indices_per_dim = vec![Vec::new(); shape.len()];
let mut values = Vec::new();
let flat_tensor = tensor.as_slice().unwrap();
for &flat_idx in kept_indices {
if flat_idx >= flat_tensor.len() {
continue;
}
let value = flat_tensor[flat_idx];
if !value.is_zero() {
values.push(value);
let mut remaining_idx = flat_idx;
for (dim, &dim_size) in shape.iter().enumerate().rev() {
let coord = remaining_idx % dim_size;
indices_per_dim[shape.len() - 1 - dim].push(coord);
remaining_idx /= dim_size;
}
}
}
let indices: Vec<Array1<usize>> = indices_per_dim
.into_iter()
.map(|v| Array1::from_vec(v))
.collect();
let values_array = Array1::from_vec(values);
SparseTensor::from_coo(indices, values_array, shape)
}
fn get_current_sparsity(&self) -> f32 {
match &self.config.schedule {
Some(schedule) => schedule.current_sparsity(),
None => self.config.target_sparsity,
}
}
pub fn update_importance_scores(&mut self, param_name: &str, gradients: &ArrayD<T>) {
let importance: Array1<T> = gradients.iter().map(|&grad| grad.abs()).collect();
self.importance_scores
.insert(param_name.to_string(), importance);
}
pub fn prune_model(
&mut self,
parameters: &HashMap<String, Variable<T>>,
) -> RusTorchResult<HashMap<String, SparseTensor<T>>> {
let mut pruned_params = HashMap::new();
for (name, param) in parameters.iter() {
let param_tensor = param.data();
let param_guard = param_tensor.read().unwrap();
let sparse_param = self.prune_tensor(¶m_guard.data)?;
pruned_params.insert(name.clone(), sparse_param);
}
if let Some(ref mut schedule) = self.config.schedule {
schedule.step();
}
Ok(pruned_params)
}
}
pub struct MagnitudePruner<T: Float> {
pub global: bool,
pub sparsity: f32,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Float + PartialOrd + Copy> MagnitudePruner<T> {
pub fn new(sparsity: f32, global: bool) -> Self {
Self {
global,
sparsity: sparsity.clamp(0.0, 1.0),
_phantom: std::marker::PhantomData,
}
}
pub fn prune(&self, tensor: &ArrayD<T>) -> RusTorchResult<ArrayD<T>> {
let flat_tensor = tensor.as_slice().unwrap();
let total_elements = flat_tensor.len();
let elements_to_zero = (self.sparsity * total_elements as f32) as usize;
let mut magnitude_indices: Vec<(usize, T)> = flat_tensor
.iter()
.enumerate()
.map(|(i, &val)| (i, val.abs()))
.collect();
magnitude_indices.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
let mut pruned = tensor.clone();
let pruned_flat = pruned.as_slice_mut().unwrap();
for i in 0..elements_to_zero.min(magnitude_indices.len()) {
let idx = magnitude_indices[i].0;
pruned_flat[idx] = T::zero();
}
Ok(pruned)
}
}
pub struct StructuredPruner<T: Float> {
pub granularity: StructuredGranularity,
pub ratio: f32,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum StructuredGranularity {
Neuron,
Channel,
Filter,
}
impl<T: Float + PartialOrd + Copy> StructuredPruner<T> {
pub fn new(granularity: StructuredGranularity, ratio: f32) -> Self {
Self {
granularity,
ratio: ratio.clamp(0.0, 1.0),
_phantom: std::marker::PhantomData,
}
}
pub fn prune_linear_weights(&self, weights: &Array2<T>) -> RusTorchResult<Array2<T>> {
let (rows, cols) = weights.dim();
match self.granularity {
StructuredGranularity::Neuron => {
let neurons_to_prune = (self.ratio * rows as f32) as usize;
let mut neuron_norms: Vec<(usize, T)> = (0..rows)
.map(|i| {
let row = weights.row(i);
let norm_sq = row.iter().map(|&x| x * x).fold(T::zero(), |a, b| a + b);
(i, norm_sq.sqrt())
})
.collect();
neuron_norms.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
let mut pruned_weights = weights.clone();
for i in 0..neurons_to_prune.min(neuron_norms.len()) {
let neuron_idx = neuron_norms[i].0;
for j in 0..cols {
pruned_weights[[neuron_idx, j]] = T::zero();
}
}
Ok(pruned_weights)
}
_ => {
let flattened = weights.clone().into_dyn();
let magnitude_pruner = MagnitudePruner::new(self.ratio, false);
let pruned_flat = magnitude_pruner.prune(&flattened)?;
Ok(Array2::from_shape_vec(
(rows, cols),
pruned_flat.into_raw_vec_and_offset().0,
)?)
}
}
}
}
pub struct FisherPruner<T: Float> {
pub fisher_info: HashMap<String, ArrayD<T>>,
pub n_samples: usize,
}
impl<
T: Float + std::ops::AddAssign + Copy + ndarray::ScalarOperand + Send + Sync + FromPrimitive,
> FisherPruner<T>
{
pub fn new() -> Self {
Self {
fisher_info: HashMap::new(),
n_samples: 0,
}
}
pub fn update_fisher(&mut self, param_name: &str, gradients: &ArrayD<T>) {
let squared_grads = gradients.mapv(|g| g * g);
match self.fisher_info.get_mut(param_name) {
Some(existing) => {
let alpha = T::one() / T::from(self.n_samples + 1).unwrap();
*existing = &*existing * (T::one() - alpha) + &squared_grads * alpha;
}
None => {
self.fisher_info
.insert(param_name.to_string(), squared_grads);
}
}
self.n_samples += 1;
}
pub fn prune_with_fisher(
&self,
param_name: &str,
tensor: &ArrayD<T>,
target_sparsity: f32,
) -> RusTorchResult<SparseTensor<T>> {
let fisher_scores =
self.fisher_info
.get(param_name)
.ok_or_else(|| RusTorchError::InvalidParameters {
operation: "fisher_pruning".to_string(),
message: format!(
"No Fisher information available for parameter: {}",
param_name
),
})?;
if fisher_scores.shape() != tensor.shape() {
return Err(RusTorchError::ShapeMismatch {
expected: tensor.shape().to_vec(),
actual: fisher_scores.shape().to_vec(),
});
}
let total_elements = tensor.len();
let elements_to_keep = ((1.0 - target_sparsity) * total_elements as f32) as usize;
let mut fisher_indices: Vec<(usize, T)> = fisher_scores
.iter()
.enumerate()
.map(|(i, &score)| (i, score))
.collect();
fisher_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
let kept_indices: Vec<usize> = fisher_indices
.iter()
.take(elements_to_keep)
.map(|(idx, _)| *idx)
.collect();
let model_pruner = ModelPruner::new(PruningConfig {
target_sparsity,
strategy: PruningStrategy::GradientBased,
structured: false,
schedule: None,
});
model_pruner.create_sparse_from_indices(tensor, &kept_indices)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn test_magnitude_pruning() {
let config = PruningConfig {
target_sparsity: 0.5,
strategy: PruningStrategy::Magnitude,
structured: false,
schedule: None,
};
let pruner = ModelPruner::new(config);
let tensor = Array2::from_shape_vec((2, 3), vec![1.0f32, -2.0, 0.5, -4.0, 3.0, 0.1])
.unwrap()
.into_dyn();
let sparse_result = pruner.prune_tensor(&tensor).unwrap();
assert!(sparse_result.nnz <= 3);
assert!(sparse_result.sparsity() >= 0.4);
}
#[test]
fn test_structured_pruning() {
let structured_pruner = StructuredPruner::new(StructuredGranularity::Neuron, 0.5);
let weights = Array2::from_shape_vec(
(4, 3),
vec![
1.0f32, 2.0, 3.0, 0.1, 0.1, 0.1, -2.0, 1.5, -1.0, 0.05, 0.02, 0.03, ],
)
.unwrap();
let pruned = structured_pruner.prune_linear_weights(&weights).unwrap();
let zero_rows = (0..4)
.filter(|&i| pruned.row(i).iter().all(|&x| x == 0.0))
.count();
assert_eq!(zero_rows, 2);
}
#[test]
fn test_pruning_schedule() {
let mut schedule = PruningSchedule::new(0.0, 0.9, 10);
assert_eq!(schedule.current_sparsity(), 0.0);
schedule.step();
assert!(schedule.current_sparsity() > 0.0 && schedule.current_sparsity() < 0.9);
for _ in 0..10 {
schedule.step();
}
assert_eq!(schedule.current_sparsity(), 0.9);
}
#[test]
fn test_fisher_pruner() {
let mut fisher_pruner = FisherPruner::new();
let gradients = Array2::from_shape_vec((2, 2), vec![0.1f32, 0.9, 0.3, 0.7])
.unwrap()
.into_dyn();
fisher_pruner.update_fisher("layer1", &gradients);
let weights = Array2::from_shape_vec((2, 2), vec![1.0f32, 2.0, 3.0, 4.0])
.unwrap()
.into_dyn();
let sparse_result = fisher_pruner
.prune_with_fisher("layer1", &weights, 0.5)
.unwrap();
assert!(sparse_result.nnz == 2);
}
}