pub fn generate_nm_mask(scores: &Tensor, n: usize, m: usize) -> Result<SparsityMask, PruningError> {
if n > m {
return Err(PruningError::InvalidPattern {
message: format!("N ({n}) must be <= M ({m})"),
});
}
if m == 0 {
return Err(PruningError::InvalidPattern {
message: "M must be > 0".to_string(),
});
}
let data = scores.data();
if !data.len().is_multiple_of(m) {
return Err(PruningError::InvalidPattern {
message: format!("Tensor length {} not divisible by M={}", data.len(), m),
});
}
let mut mask_data = vec![0.0f32; data.len()];
for (group_idx, chunk) in data.chunks(m).enumerate() {
let mut indexed: Vec<(usize, f32)> =
chunk.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for (local_idx, _) in indexed.into_iter().take(n) {
mask_data[group_idx * m + local_idx] = 1.0;
}
}
SparsityMask::new(
Tensor::new(&mask_data, scores.shape()),
SparsityPattern::NM { n, m },
)
}
pub fn generate_block_mask(
scores: &Tensor,
block_height: usize,
block_width: usize,
target_sparsity: f32,
) -> Result<SparsityMask, PruningError> {
let shape = scores.shape();
let (rows, cols) =
validate_block_mask_inputs(shape, block_height, block_width, target_sparsity)?;
let data = scores.data();
let block_scores = compute_sorted_block_scores(data, rows, cols, block_height, block_width);
let num_prune = (block_scores.len() as f32 * target_sparsity) as usize;
let mut mask_data = vec![1.0f32; rows * cols];
zero_out_blocks(
&mut mask_data,
&block_scores[..num_prune],
cols,
block_height,
block_width,
);
SparsityMask::new(
Tensor::new(&mask_data, shape),
SparsityPattern::Block {
height: block_height,
width: block_width,
},
)
}
fn validate_block_mask_inputs(
shape: &[usize],
block_height: usize,
block_width: usize,
target_sparsity: f32,
) -> Result<(usize, usize), PruningError> {
if shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0],
got: shape.to_vec(),
});
}
let (rows, cols) = (shape[0], shape[1]);
if rows % block_height != 0 || cols % block_width != 0 {
return Err(PruningError::InvalidPattern {
message: format!(
"Shape [{rows}, {cols}] not divisible by block size [{block_height}, {block_width}]"
),
});
}
if !(0.0..=1.0).contains(&target_sparsity) {
return Err(PruningError::InvalidSparsity {
value: target_sparsity,
constraint: "must be between 0.0 and 1.0".to_string(),
});
}
Ok((rows, cols))
}
fn compute_sorted_block_scores(
data: &[f32],
rows: usize,
cols: usize,
block_height: usize,
block_width: usize,
) -> Vec<(usize, usize, f32)> {
let num_block_rows = rows / block_height;
let num_block_cols = cols / block_width;
let mut block_scores: Vec<(usize, usize, f32)> =
Vec::with_capacity(num_block_rows * num_block_cols);
for br in 0..num_block_rows {
for bc in 0..num_block_cols {
let mut sum = 0.0f32;
for r in 0..block_height {
for c in 0..block_width {
sum += data[(br * block_height + r) * cols + bc * block_width + c];
}
}
block_scores.push((br, bc, sum));
}
}
block_scores.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
block_scores
}
fn zero_out_blocks(
mask_data: &mut [f32],
blocks_to_prune: &[(usize, usize, f32)],
cols: usize,
block_height: usize,
block_width: usize,
) {
for &(br, bc, _) in blocks_to_prune {
for r in 0..block_height {
for c in 0..block_width {
mask_data[(br * block_height + r) * cols + bc * block_width + c] = 0.0;
}
}
}
}
pub fn generate_row_mask(
scores: &Tensor,
target_sparsity: f32,
) -> Result<SparsityMask, PruningError> {
let shape = scores.shape();
if shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0], got: shape.to_vec(),
});
}
let rows = shape[0];
let cols = shape[1];
if !(0.0..=1.0).contains(&target_sparsity) {
return Err(PruningError::InvalidSparsity {
value: target_sparsity,
constraint: "must be between 0.0 and 1.0".to_string(),
});
}
let num_prune = (rows as f32 * target_sparsity) as usize;
let data = scores.data();
let mut row_scores: Vec<(usize, f32)> = Vec::with_capacity(rows);
for r in 0..rows {
let sum: f32 = (0..cols).map(|c| data[r * cols + c]).sum();
row_scores.push((r, sum));
}
row_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut mask_data = vec![1.0f32; rows * cols];
for &(row, _) in row_scores.iter().take(num_prune) {
for c in 0..cols {
mask_data[row * cols + c] = 0.0;
}
}
SparsityMask::new(Tensor::new(&mask_data, shape), SparsityPattern::Row)
}
pub fn generate_column_mask(
scores: &Tensor,
target_sparsity: f32,
) -> Result<SparsityMask, PruningError> {
let shape = scores.shape();
if shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0], got: shape.to_vec(),
});
}
let rows = shape[0];
let cols = shape[1];
if !(0.0..=1.0).contains(&target_sparsity) {
return Err(PruningError::InvalidSparsity {
value: target_sparsity,
constraint: "must be between 0.0 and 1.0".to_string(),
});
}
let num_prune = (cols as f32 * target_sparsity) as usize;
let data = scores.data();
let mut col_scores: Vec<(usize, f32)> = Vec::with_capacity(cols);
for c in 0..cols {
let sum: f32 = (0..rows).map(|r| data[r * cols + c]).sum();
col_scores.push((c, sum));
}
col_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut mask_data = vec![1.0f32; rows * cols];
for &(col, _) in col_scores.iter().take(num_prune) {
for r in 0..rows {
mask_data[r * cols + col] = 0.0;
}
}
SparsityMask::new(Tensor::new(&mask_data, shape), SparsityPattern::Column)
}
#[cfg(test)]
mod tests;