use super::error::PruningError;
use crate::autograd::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SparsityPattern {
Unstructured,
NM {
n: usize,
m: usize,
},
Block {
height: usize,
width: usize,
},
Row,
Column,
}
impl SparsityPattern {
#[must_use]
pub fn is_valid(&self) -> bool {
match self {
SparsityPattern::NM { n, m } => *n <= *m && *m > 0,
SparsityPattern::Block { height, width } => *height > 0 && *width > 0,
_ => true,
}
}
#[must_use]
pub fn theoretical_sparsity(&self) -> Option<f32> {
match self {
SparsityPattern::NM { n, m } => Some(1.0 - (*n as f32 / *m as f32)),
_ => None, }
}
pub fn validate(&self, mask: &Tensor) -> Result<(), PruningError> {
match self {
SparsityPattern::Unstructured => Ok(()),
SparsityPattern::NM { n, m } => validate_nm(mask, *n, *m),
SparsityPattern::Block { height, width } => validate_block(mask, *height, *width),
SparsityPattern::Row => validate_row(mask),
SparsityPattern::Column => validate_column(mask),
}
}
}
impl Default for SparsityPattern {
fn default() -> Self {
SparsityPattern::Unstructured
}
}
#[derive(Debug, Clone)]
pub struct SparsityMask {
mask: Tensor,
pattern: SparsityPattern,
sparsity: f32,
}
impl SparsityMask {
pub fn new(mask: Tensor, pattern: SparsityPattern) -> Result<Self, PruningError> {
for &v in mask.data() {
if (v - 0.0).abs() > 1e-6 && (v - 1.0).abs() > 1e-6 {
return Err(PruningError::InvalidMask {
reason: format!("Mask contains non-binary value: {v}"),
});
}
}
pattern.validate(&mask)?;
let data = mask.data();
let sparsity = if data.is_empty() {
0.0
} else {
let zeros = data.iter().filter(|&&v| v < 0.5).count();
zeros as f32 / data.len() as f32
};
Ok(Self {
mask,
pattern,
sparsity,
})
}
#[must_use]
pub fn dense(shape: &[usize]) -> Self {
let mask = Tensor::ones(shape);
Self {
mask,
pattern: SparsityPattern::Unstructured,
sparsity: 0.0,
}
}
#[must_use]
pub fn sparsity(&self) -> f32 {
self.sparsity
}
#[must_use]
pub fn pattern(&self) -> SparsityPattern {
self.pattern
}
#[must_use]
pub fn tensor(&self) -> &Tensor {
&self.mask
}
#[must_use]
pub fn shape(&self) -> &[usize] {
self.mask.shape()
}
pub fn apply(&self, weights: &mut Tensor) -> Result<(), PruningError> {
if weights.shape() != self.mask.shape() {
return Err(PruningError::ShapeMismatch {
expected: self.mask.shape().to_vec(),
got: weights.shape().to_vec(),
});
}
let mask_data = self.mask.data();
let weight_data = weights.data_mut();
for (w, &m) in weight_data.iter_mut().zip(mask_data.iter()) {
*w *= m;
}
Ok(())
}
#[must_use]
pub fn nnz(&self) -> usize {
self.mask.data().iter().filter(|&&v| v > 0.5).count()
}
#[must_use]
pub fn num_zeros(&self) -> usize {
self.mask.data().iter().filter(|&&v| v < 0.5).count()
}
}
pub fn generate_unstructured_mask(
scores: &Tensor,
target_sparsity: f32,
) -> Result<SparsityMask, PruningError> {
if !(0.0..=1.0).contains(&target_sparsity) {
return Err(PruningError::InvalidSparsity {
value: target_sparsity,
constraint: "must be between 0.0 and 1.0".to_string(),
});
}
let data = scores.data();
if data.is_empty() {
return SparsityMask::new(Tensor::new(&[], &[0]), SparsityPattern::Unstructured);
}
let num_prune = (data.len() as f32 * target_sparsity) as usize;
let mut sorted: Vec<f32> = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let threshold = if num_prune == 0 {
f32::NEG_INFINITY
} else if num_prune >= sorted.len() {
f32::INFINITY
} else {
sorted[num_prune - 1]
};
let mask_data: Vec<f32> = data
.iter()
.map(|&v| if v > threshold { 1.0 } else { 0.0 })
.collect();
SparsityMask::new(
Tensor::new(&mask_data, scores.shape()),
SparsityPattern::Unstructured,
)
}
fn validate_nm(mask: &Tensor, n: usize, m: usize) -> Result<(), PruningError> {
let data = mask.data();
if !data.len().is_multiple_of(m) {
return Err(PruningError::InvalidPattern {
message: format!("Tensor length {} not divisible by M={}", data.len(), m),
});
}
for (i, chunk) in data.chunks(m).enumerate() {
let nnz = chunk.iter().filter(|&&v| v > 0.5).count();
if nnz != n {
return Err(PruningError::InvalidPattern {
message: format!(
"Group {} has {} non-zeros, expected {} (N:M = {}:{})",
i, nnz, n, n, m
),
});
}
}
Ok(())
}
fn require_2d(mask: &Tensor, pattern_name: &str) -> Result<(usize, usize), PruningError> {
let shape = mask.shape();
if shape.len() != 2 {
return Err(PruningError::InvalidPattern {
message: format!(
"{pattern_name} pattern requires 2D tensor, got {}D",
shape.len()
),
});
}
Ok((shape[0], shape[1]))
}
fn check_block_uniform(
data: &[f32],
br: usize,
bc: usize,
height: usize,
width: usize,
cols: usize,
) -> Result<(), PruningError> {
let first = data[br * height * cols + bc * width];
for r in 0..height {
for c in 0..width {
let val = data[(br * height + r) * cols + bc * width + c];
if (val - first).abs() > 1e-6 {
return Err(PruningError::InvalidPattern {
message: format!("Block ({br}, {bc}) is not uniform: found {val} and {first}"),
});
}
}
}
Ok(())
}
fn validate_block(mask: &Tensor, height: usize, width: usize) -> Result<(), PruningError> {
let (rows, cols) = require_2d(mask, "Block")?;
if rows % height != 0 || cols % width != 0 {
return Err(PruningError::InvalidPattern {
message: format!("Shape [{rows}, {cols}] not divisible by block [{height}, {width}]"),
});
}
let data = mask.data();
for br in 0..(rows / height) {
for bc in 0..(cols / width) {
check_block_uniform(data, br, bc, height, width, cols)?;
}
}
Ok(())
}
fn validate_row(mask: &Tensor) -> Result<(), PruningError> {
let (rows, cols) = require_2d(mask, "Row")?;
let data = mask.data();
for r in 0..rows {
let first = data[r * cols];
for c in 1..cols {
if (data[r * cols + c] - first).abs() > 1e-6 {
return Err(PruningError::InvalidPattern {
message: format!("Row {r} is not uniform"),
});
}
}
}
Ok(())
}
fn validate_column(mask: &Tensor) -> Result<(), PruningError> {
let (rows, cols) = require_2d(mask, "Column")?;
let data = mask.data();
for c in 0..cols {
let first = data[c];
for r in 1..rows {
if (data[r * cols + c] - first).abs() > 1e-6 {
return Err(PruningError::InvalidPattern {
message: format!("Column {c} is not uniform"),
});
}
}
}
Ok(())
}
include!("mask.rs");