use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::{creation::*, Tensor};
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use scirs2_core::slice_random::shuffle;
#[derive(Debug, Clone, Copy)]
pub enum SparsityPattern {
Random { sparsity: f32 },
Blocked { block_size: usize, sparsity: f32 },
Structured { channels_to_prune: usize },
MagnitudeBased { threshold: f32 },
}
#[derive(Debug, Clone)]
pub struct SparseMask {
mask: Tensor,
sparsity: f32,
nnz: usize,
}
impl SparseMask {
pub fn random(shape: &[usize], sparsity: f32) -> Result<Self> {
if !(0.0..=1.0).contains(&sparsity) {
return Err(TorshError::InvalidArgument(format!(
"Sparsity must be in [0, 1], got {}",
sparsity
)));
}
let total_elements: usize = shape.iter().product();
let num_zeros = (total_elements as f32 * sparsity) as usize;
let mut mask_data = vec![1.0_f32; total_elements];
let mut indices: Vec<usize> = (0..total_elements).collect();
shuffle(&mut indices);
for &idx in indices.iter().take(num_zeros) {
mask_data[idx] = 0.0;
}
let mask = Tensor::from_vec(mask_data, shape)?;
let nnz = total_elements - num_zeros;
Ok(Self {
mask,
sparsity,
nnz,
})
}
pub fn from_magnitude(weights: &Tensor, threshold: f32) -> Result<Self> {
let shape = weights.shape().dims().to_vec();
let weight_data = weights.to_vec()?;
let mask_data: Vec<f32> = weight_data
.iter()
.map(|&w| if w.abs() >= threshold { 1.0 } else { 0.0 })
.collect();
let nnz = mask_data.iter().filter(|&&m| m > 0.0).count();
let total = mask_data.len();
let sparsity = 1.0 - (nnz as f32 / total as f32);
Ok(Self {
mask: Tensor::from_vec(mask_data, &shape)?,
sparsity,
nnz,
})
}
pub fn blocked(shape: &[usize], block_size: usize, sparsity: f32) -> Result<Self> {
if shape.len() != 2 {
return Err(TorshError::InvalidArgument(
"Block sparsity only supported for 2D tensors".to_string(),
));
}
let rows = shape[0];
let cols = shape[1];
if rows % block_size != 0 || cols % block_size != 0 {
return Err(TorshError::InvalidArgument(format!(
"Shape {:?} must be divisible by block_size {}",
shape, block_size
)));
}
let num_blocks_row = rows / block_size;
let num_blocks_col = cols / block_size;
let total_blocks = num_blocks_row * num_blocks_col;
let blocks_to_zero = (total_blocks as f32 * sparsity) as usize;
let mut mask_data = vec![1.0_f32; rows * cols];
let mut block_indices: Vec<usize> = (0..total_blocks).collect();
shuffle(&mut block_indices);
for &block_idx in block_indices.iter().take(blocks_to_zero) {
let block_row = block_idx / num_blocks_col;
let block_col = block_idx % num_blocks_col;
for r in 0..block_size {
for c in 0..block_size {
let row = block_row * block_size + r;
let col = block_col * block_size + c;
let idx = row * cols + col;
mask_data[idx] = 0.0;
}
}
}
let nnz = mask_data.iter().filter(|&&m| m > 0.0).count();
let actual_sparsity = 1.0 - (nnz as f32 / mask_data.len() as f32);
Ok(Self {
mask: Tensor::from_vec(mask_data, shape)?,
sparsity: actual_sparsity,
nnz,
})
}
pub fn apply(&self, weights: &Tensor) -> Result<Tensor> {
weights.mul(&self.mask)
}
pub fn nnz(&self) -> usize {
self.nnz
}
pub fn sparsity(&self) -> f32 {
self.sparsity
}
pub fn mask(&self) -> &Tensor {
&self.mask
}
}
pub struct SparseLinear {
base: ModuleBase,
mask: SparseMask,
in_features: usize,
out_features: usize,
use_bias: bool,
}
impl SparseLinear {
pub fn new(
in_features: usize,
out_features: usize,
pattern: SparsityPattern,
bias: bool,
) -> Self {
let mut base = ModuleBase::new();
let weight = crate::init::kaiming_uniform(&[in_features, out_features], "fan_in")
.expect("Failed to initialize sparse linear weight");
let mask = match pattern {
SparsityPattern::Random { sparsity } => {
SparseMask::random(&[in_features, out_features], sparsity)
.expect("Failed to create random sparsity mask")
}
SparsityPattern::Blocked {
block_size,
sparsity,
} => SparseMask::blocked(&[in_features, out_features], block_size, sparsity)
.expect("Failed to create blocked sparsity mask"),
SparsityPattern::MagnitudeBased { threshold } => {
SparseMask::from_magnitude(&weight, threshold)
.expect("Failed to create magnitude-based mask")
}
SparsityPattern::Structured { channels_to_prune } => {
let sparsity = channels_to_prune as f32 / out_features as f32;
SparseMask::random(&[in_features, out_features], sparsity)
.expect("Failed to create structured sparsity mask")
}
};
let masked_weight = mask.apply(&weight).expect("Failed to apply mask");
base.register_parameter("weight".to_string(), Parameter::new(masked_weight));
if bias {
let bias_tensor = zeros(&[out_features]).expect("Failed to create bias tensor");
base.register_parameter("bias".to_string(), Parameter::new(bias_tensor));
}
Self {
base,
mask,
in_features,
out_features,
use_bias: bias,
}
}
pub fn sparsity(&self) -> f32 {
self.mask.sparsity()
}
pub fn nnz(&self) -> usize {
self.mask.nnz()
}
pub fn prune_by_magnitude(&mut self, threshold: f32) -> Result<()> {
let weight = self.base.parameters["weight"].tensor().read().clone();
self.mask = SparseMask::from_magnitude(&weight, threshold)?;
let masked = self.mask.apply(&weight)?;
self.base
.register_parameter("weight".to_string(), Parameter::new(masked));
Ok(())
}
pub fn increase_sparsity(&mut self, target_sparsity: f32) -> Result<()> {
if target_sparsity <= self.mask.sparsity() {
return Ok(()); }
let weight = self.base.parameters["weight"].tensor().read().clone();
let weight_data = weight.to_vec()?;
let mut abs_weights: Vec<f32> = weight_data.iter().map(|&w| w.abs()).collect();
abs_weights.sort_by(|a, b| {
a.partial_cmp(b)
.expect("weight comparison should not involve NaN")
});
let num_to_prune = (abs_weights.len() as f32 * target_sparsity) as usize
- (abs_weights.len() - self.mask.nnz());
let threshold = abs_weights[num_to_prune];
self.prune_by_magnitude(threshold)
}
}
impl Module for SparseLinear {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let weight = self.base.parameters["weight"].tensor().read().clone();
let sparse_weight = self.mask.apply(&weight)?;
let output = input.matmul(&sparse_weight)?;
if self.use_bias {
let bias = self.base.parameters["bias"].tensor().read().clone();
Ok(output.add(&bias)?)
} else {
Ok(output)
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
}
impl core::fmt::Debug for SparseLinear {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SparseLinear")
.field("in_features", &self.in_features)
.field("out_features", &self.out_features)
.field("sparsity", &self.mask.sparsity())
.field("nnz", &self.mask.nnz())
.finish()
}
}
pub struct SparseConv2d {
base: ModuleBase,
mask: SparseMask,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
}
impl SparseConv2d {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
pattern: SparsityPattern,
bias: bool,
) -> Self {
let mut base = ModuleBase::new();
let weight_shape = [out_channels, in_channels, kernel_size, kernel_size];
let weight = crate::init::kaiming_uniform(&weight_shape, "fan_in")
.expect("Failed to initialize sparse conv2d weight");
let mask = match pattern {
SparsityPattern::Random { sparsity } => SparseMask::random(&weight_shape, sparsity)
.expect("Failed to create random sparsity mask"),
SparsityPattern::Blocked {
block_size: _block_size,
sparsity,
} => {
SparseMask::random(&weight_shape, sparsity)
.expect("Failed to create blocked sparsity mask")
}
SparsityPattern::MagnitudeBased { threshold } => {
SparseMask::from_magnitude(&weight, threshold)
.expect("Failed to create magnitude-based mask")
}
SparsityPattern::Structured { channels_to_prune } => {
let sparsity = channels_to_prune as f32 / out_channels as f32;
SparseMask::random(&weight_shape, sparsity)
.expect("Failed to create structured sparsity mask")
}
};
let masked_weight = mask.apply(&weight).expect("Failed to apply mask");
base.register_parameter("weight".to_string(), Parameter::new(masked_weight));
if bias {
let bias_tensor = zeros(&[out_channels]).expect("Failed to create bias tensor");
base.register_parameter("bias".to_string(), Parameter::new(bias_tensor));
}
Self {
base,
mask,
in_channels,
out_channels,
kernel_size,
stride,
padding,
}
}
pub fn sparsity(&self) -> f32 {
self.mask.sparsity()
}
pub fn nnz(&self) -> usize {
self.mask.nnz()
}
}
impl Module for SparseConv2d {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
use crate::functional as F;
let weight = self.base.parameters["weight"].tensor().read().clone();
let sparse_weight = self.mask.apply(&weight)?;
let bias = if self.base.parameters.contains_key("bias") {
Some(self.base.parameters["bias"].tensor().read().clone())
} else {
None
};
F::conv2d(
input,
&sparse_weight,
bias.as_ref(),
(self.stride, self.stride),
(self.padding, self.padding),
(1, 1), 1, )
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
}
impl core::fmt::Debug for SparseConv2d {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SparseConv2d")
.field("in_channels", &self.in_channels)
.field("out_channels", &self.out_channels)
.field("kernel_size", &self.kernel_size)
.field("sparsity", &self.mask.sparsity())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct SparseTrainingConfig {
pub initial_sparsity: f32,
pub target_sparsity: f32,
pub pruning_steps: usize,
pub pruning_start_step: usize,
pub pruning_frequency: usize,
}
impl Default for SparseTrainingConfig {
fn default() -> Self {
Self {
initial_sparsity: 0.0,
target_sparsity: 0.9,
pruning_steps: 1000,
pruning_start_step: 0,
pruning_frequency: 100,
}
}
}
pub struct GradualPruningScheduler {
config: SparseTrainingConfig,
current_step: usize,
}
impl GradualPruningScheduler {
pub fn new(config: SparseTrainingConfig) -> Self {
Self {
config,
current_step: 0,
}
}
pub fn get_sparsity(&self) -> f32 {
if self.current_step < self.config.pruning_start_step {
return self.config.initial_sparsity;
}
let steps_since_start = self.current_step - self.config.pruning_start_step;
if steps_since_start >= self.config.pruning_steps {
return self.config.target_sparsity;
}
let progress = steps_since_start as f32 / self.config.pruning_steps as f32;
self.config.initial_sparsity
+ (self.config.target_sparsity - self.config.initial_sparsity) * progress
}
pub fn should_prune(&self) -> bool {
if self.current_step < self.config.pruning_start_step {
return false;
}
(self.current_step - self.config.pruning_start_step) % self.config.pruning_frequency == 0
}
pub fn step(&mut self) {
self.current_step += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_mask_random() {
let mask = SparseMask::random(&[10, 10], 0.5).unwrap();
assert!((mask.sparsity() - 0.5).abs() < 0.1); assert_eq!(mask.nnz() + (100.0 * mask.sparsity()) as usize, 100);
}
#[test]
fn test_sparse_mask_blocked() {
let mask = SparseMask::blocked(&[8, 8], 2, 0.5).unwrap();
assert!(mask.sparsity() >= 0.4 && mask.sparsity() <= 0.6);
}
#[test]
fn test_sparse_linear() {
let layer = SparseLinear::new(10, 5, SparsityPattern::Random { sparsity: 0.8 }, true);
assert_eq!(layer.in_features, 10);
assert_eq!(layer.out_features, 5);
assert!((layer.sparsity() - 0.8).abs() < 0.1);
let input = randn(&[2, 10]).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[2, 5]);
}
#[test]
fn test_sparse_conv2d() {
let layer = SparseConv2d::new(
3,
16,
3,
1,
1,
SparsityPattern::Random { sparsity: 0.7 },
true,
);
assert!((layer.sparsity() - 0.7).abs() < 0.1);
let input = randn(&[2, 3, 32, 32]).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[2, 16, 32, 32]);
}
#[test]
fn test_gradual_pruning_scheduler() {
let config = SparseTrainingConfig {
initial_sparsity: 0.0,
target_sparsity: 0.9,
pruning_steps: 100,
pruning_start_step: 10,
pruning_frequency: 10,
};
let mut scheduler = GradualPruningScheduler::new(config);
assert_eq!(scheduler.get_sparsity(), 0.0);
assert!(!scheduler.should_prune());
for _ in 0..10 {
scheduler.step();
}
assert!(scheduler.should_prune());
assert!(scheduler.get_sparsity() < 0.9);
for _ in 0..100 {
scheduler.step();
}
assert_eq!(scheduler.get_sparsity(), 0.9);
}
}