use crate::Module;
use crate::Parameter;
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::{boxed::Box, collections::HashMap, string::String, vec::Vec};
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[derive(Debug, Clone)]
pub enum PruningStrategy {
MagnitudeBased,
Structured,
Gradual {
initial_sparsity: f32,
final_sparsity: f32,
begin_step: usize,
end_step: usize,
},
LotteryTicket,
}
#[derive(Debug, Clone)]
pub enum PruningScope {
Global,
LayerSpecific(Vec<String>),
LayerType(String),
}
#[derive(Debug, Clone)]
pub struct PruningConfig {
pub strategy: PruningStrategy,
pub scope: PruningScope,
pub sparsity: f32,
pub structured: bool,
}
#[derive(Debug, Clone)]
pub struct PruningMask {
pub mask: Tensor<f32>,
pub parameter_name: String,
pub sparsity: f32,
}
impl PruningMask {
pub fn new(mask: Tensor<f32>, parameter_name: String) -> Self {
let total_elements = mask.numel();
let mask_data = mask.data().unwrap_or_else(|_| vec![]);
let zero_elements = mask_data.iter().filter(|&&x| x == 0.0).count();
let sparsity = zero_elements as f32 / total_elements as f32;
Self {
mask,
parameter_name,
sparsity,
}
}
pub fn apply(&self, parameter: &Parameter) -> Result<Parameter, Box<dyn std::error::Error>> {
let data = parameter.tensor().read().clone();
let masked_data = data.mul_op(&self.mask)?;
Ok(Parameter::new(masked_data))
}
pub fn pruned_count(&self) -> usize {
self.mask
.data()
.unwrap_or_else(|_| vec![])
.iter()
.filter(|&&x| x == 0.0)
.count()
}
pub fn total_count(&self) -> usize {
self.mask.numel()
}
}
pub struct Pruner {
config: PruningConfig,
masks: HashMap<String, PruningMask>,
current_step: usize,
}
impl Pruner {
pub fn new(config: PruningConfig) -> Self {
Self {
config,
masks: HashMap::new(),
current_step: 0,
}
}
pub fn prune_module<M: Module>(
&mut self,
module: &M,
) -> Result<(), Box<dyn std::error::Error>> {
let named_params = module.named_parameters();
for (name, param) in named_params.iter() {
if self.should_prune_parameter(name) {
let mask = self.create_mask(param)?;
self.masks.insert(name.to_string(), mask);
}
}
Ok(())
}
pub fn apply_masks<M: Module>(
&self,
_module: &mut M,
) -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
pub fn update_masks(&mut self) {
match &self.config.strategy {
PruningStrategy::Gradual {
initial_sparsity,
final_sparsity,
begin_step,
end_step,
} => {
if self.current_step >= *begin_step && self.current_step <= *end_step {
let progress =
(self.current_step - begin_step) as f32 / (end_step - begin_step) as f32;
let current_sparsity =
initial_sparsity + progress * (final_sparsity - initial_sparsity);
for (_, mask) in self.masks.iter_mut() {
if mask.sparsity < current_sparsity {
}
}
}
}
_ => {}
}
self.current_step += 1;
}
pub fn get_sparsity_stats(&self) -> HashMap<String, f32> {
self.masks
.iter()
.map(|(name, mask)| (name.clone(), mask.sparsity))
.collect()
}
pub fn get_total_sparsity(&self) -> f32 {
if self.masks.is_empty() {
return 0.0;
}
let total_pruned: usize = self.masks.values().map(|m| m.pruned_count()).sum();
let total_params: usize = self.masks.values().map(|m| m.total_count()).sum();
total_pruned as f32 / total_params as f32
}
fn should_prune_parameter(&self, param_name: &str) -> bool {
match &self.config.scope {
PruningScope::Global => true,
PruningScope::LayerSpecific(names) => {
names.iter().any(|name| param_name.contains(name))
}
PruningScope::LayerType(layer_type) => param_name.contains(layer_type),
}
}
fn create_mask(
&self,
parameter: &Parameter,
) -> Result<PruningMask, Box<dyn std::error::Error>> {
let data = parameter.tensor().read().clone();
let mask = match &self.config.strategy {
PruningStrategy::MagnitudeBased => self.create_magnitude_mask(&data)?,
PruningStrategy::Structured => self.create_structured_mask(&data)?,
PruningStrategy::Gradual {
initial_sparsity, ..
} => self.create_magnitude_mask_with_sparsity(&data, *initial_sparsity)?,
PruningStrategy::LotteryTicket => self.create_magnitude_mask(&data)?,
};
Ok(PruningMask::new(mask, "parameter".to_string()))
}
fn create_magnitude_mask(
&self,
data: &Tensor<f32>,
) -> Result<Tensor<f32>, Box<dyn std::error::Error>> {
let abs_values = data.abs()?;
let mut sorted_values: Vec<f32> = abs_values.data().unwrap_or_else(|_| vec![]);
sorted_values.sort_by(|a, b| {
a.partial_cmp(b)
.expect("value comparison should not involve NaN")
});
let threshold_idx = (sorted_values.len() as f32 * self.config.sparsity) as usize;
let threshold = sorted_values[threshold_idx.min(sorted_values.len() - 1)];
let mask = abs_values.gt_scalar(threshold)?;
let mask_data = mask.data()?;
let f32_mask_data: Vec<f32> = mask_data
.iter()
.map(|&b| if b { 1.0 } else { 0.0 })
.collect();
let f32_mask =
Tensor::from_data(f32_mask_data, mask.shape().dims().to_vec(), mask.device())
.expect("tensor creation should succeed");
Ok(f32_mask)
}
fn create_structured_mask(
&self,
data: &Tensor<f32>,
) -> Result<Tensor<f32>, Box<dyn std::error::Error>> {
let shape = data.shape();
let dims = shape.dims();
if dims.len() == 4 {
self.create_channel_mask(data)
} else if dims.len() == 2 {
self.create_magnitude_mask(data)
} else {
self.create_magnitude_mask(data)
}
}
fn create_channel_mask(
&self,
data: &Tensor<f32>,
) -> Result<Tensor<f32>, Box<dyn std::error::Error>> {
let binding = data.shape();
let dims = binding.dims();
let out_channels = dims[0];
let mut channel_norms = Vec::new();
for i in 0..out_channels {
let channel_data = data.slice(0, i, i + 1)?;
let channel_tensor = channel_data.to_tensor()?;
let squared = channel_tensor.mul_op(&channel_tensor)?;
let sum_squared = squared.sum()?;
let norm = sum_squared.sqrt()?;
channel_norms.push((i, norm.item()?));
}
channel_norms.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.expect("norm comparison should not involve NaN")
});
let channels_to_prune = (out_channels as f32 * self.config.sparsity) as usize;
let pruned_channels: std::collections::HashSet<usize> = channel_norms
.iter()
.take(channels_to_prune)
.map(|(idx, _)| *idx)
.collect();
let mask_data: Vec<f32> = (0..data.numel())
.map(|i| {
let channel_idx = i / (data.numel() / out_channels);
if pruned_channels.contains(&channel_idx) {
0.0
} else {
1.0
}
})
.collect();
Ok(Tensor::from_data(mask_data, dims.to_vec(), data.device())
.expect("mask tensor creation should succeed"))
}
fn create_magnitude_mask_with_sparsity(
&self,
data: &Tensor<f32>,
sparsity: f32,
) -> Result<Tensor<f32>, Box<dyn std::error::Error>> {
let abs_values = data.abs()?;
let mut sorted_values: Vec<f32> = abs_values.to_vec()?;
sorted_values.sort_by(|a, b| {
a.partial_cmp(b)
.expect("value comparison should not involve NaN")
});
let threshold_idx = (sorted_values.len() as f32 * sparsity) as usize;
let threshold = sorted_values[threshold_idx.min(sorted_values.len() - 1)];
let mask = abs_values.gt_scalar(threshold)?;
let mask_data = mask.to_vec()?;
let mask_f32: Vec<f32> = mask_data
.iter()
.map(|&b| if b { 1.0 } else { 0.0 })
.collect();
Ok(Tensor::from_vec(mask_f32, mask.shape().dims())?)
}
}
impl Pruner {
pub fn magnitude_based(sparsity: f32) -> Self {
Self::new(PruningConfig {
strategy: PruningStrategy::MagnitudeBased,
scope: PruningScope::Global,
sparsity,
structured: false,
})
}
pub fn structured(sparsity: f32) -> Self {
Self::new(PruningConfig {
strategy: PruningStrategy::Structured,
scope: PruningScope::Global,
sparsity,
structured: true,
})
}
pub fn gradual(
initial_sparsity: f32,
final_sparsity: f32,
begin_step: usize,
end_step: usize,
) -> Self {
Self::new(PruningConfig {
strategy: PruningStrategy::Gradual {
initial_sparsity,
final_sparsity,
begin_step,
end_step,
},
scope: PruningScope::Global,
sparsity: final_sparsity,
structured: false,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
#[test]
fn test_pruning_mask_creation() {
let mask_data = vec![1.0, 0.0, 1.0, 0.0];
let mask_tensor = Tensor::from_data(mask_data, vec![2, 2], DeviceType::Cpu).unwrap();
let mask = PruningMask::new(mask_tensor, "test_param".to_string());
assert_eq!(mask.sparsity, 0.5);
assert_eq!(mask.pruned_count(), 2);
assert_eq!(mask.total_count(), 4);
}
#[test]
fn test_pruner_creation() {
let pruner = Pruner::magnitude_based(0.5);
assert_eq!(pruner.config.sparsity, 0.5);
assert!(matches!(
pruner.config.strategy,
PruningStrategy::MagnitudeBased
));
}
#[test]
fn test_structured_pruner_creation() {
let pruner = Pruner::structured(0.3);
assert_eq!(pruner.config.sparsity, 0.3);
assert!(matches!(
pruner.config.strategy,
PruningStrategy::Structured
));
assert!(pruner.config.structured);
}
#[test]
fn test_gradual_pruner_creation() {
let pruner = Pruner::gradual(0.1, 0.9, 1000, 5000);
assert_eq!(pruner.config.sparsity, 0.9);
if let PruningStrategy::Gradual {
initial_sparsity,
final_sparsity,
begin_step,
end_step,
} = pruner.config.strategy
{
assert_eq!(initial_sparsity, 0.1);
assert_eq!(final_sparsity, 0.9);
assert_eq!(begin_step, 1000);
assert_eq!(end_step, 5000);
} else {
panic!("Expected gradual strategy");
}
}
#[test]
fn test_sparsity_stats_empty() {
let pruner = Pruner::magnitude_based(0.5);
let stats = pruner.get_sparsity_stats();
assert!(stats.is_empty());
assert_eq!(pruner.get_total_sparsity(), 0.0);
}
}