use crate::error::{QuantError, QuantResult};
use crate::pruning::mask::SparseMask;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MagnitudeNorm {
L1,
L2,
}
#[derive(Debug, Clone)]
pub struct MagnitudePruner {
pub target_sparsity: f32,
pub norm: MagnitudeNorm,
}
impl MagnitudePruner {
#[must_use]
pub fn new(target_sparsity: f32, norm: MagnitudeNorm) -> Self {
assert!(
(0.0..1.0).contains(&target_sparsity),
"target_sparsity must be in [0, 1), got {target_sparsity}"
);
Self {
target_sparsity,
norm,
}
}
pub fn compute_mask(&self, weights: &[f32]) -> QuantResult<SparseMask> {
if weights.is_empty() {
return Err(QuantError::EmptyInput("MagnitudePruner::compute_mask"));
}
let n = weights.len();
let n_prune = ((n as f32) * self.target_sparsity).ceil() as usize;
if n_prune >= n {
return Err(QuantError::AllZeroPruning {
threshold: self.target_sparsity,
n,
});
}
let scores: Vec<f32> = weights
.iter()
.map(|&w| match self.norm {
MagnitudeNorm::L1 => w.abs(),
MagnitudeNorm::L2 => w * w,
})
.collect();
let mut sorted = scores.clone();
sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let threshold = sorted[n_prune.saturating_sub(1)];
let mut n_pruned = 0_usize;
let mask: Vec<bool> = scores
.iter()
.map(|&s| {
if n_pruned < n_prune && s <= threshold {
n_pruned += 1;
false
} else {
true
}
})
.collect();
Ok(SparseMask { mask })
}
pub fn prune(&self, weights: &mut [f32]) -> QuantResult<SparseMask> {
let mask = self.compute_mask(weights)?;
mask.apply_in_place(weights);
Ok(mask)
}
pub fn compute_grouped_mask(
&self,
weights: &[f32],
group_size: usize,
) -> QuantResult<SparseMask> {
if weights.is_empty() {
return Err(QuantError::EmptyInput(
"MagnitudePruner::compute_grouped_mask",
));
}
if weights.len() % group_size != 0 {
return Err(QuantError::GroupSizeMismatch {
len: weights.len(),
group: group_size,
});
}
let mut combined = Vec::with_capacity(weights.len());
for chunk in weights.chunks_exact(group_size) {
let chunk_mask = self.compute_mask(chunk)?;
combined.extend_from_slice(&chunk_mask.mask);
}
Ok(SparseMask { mask: combined })
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn l1_prune_50_percent() {
let p = MagnitudePruner::new(0.5, MagnitudeNorm::L1);
let weights = vec![0.1_f32, 0.5, 0.3, 0.9, 0.2, 0.8, 0.4, 0.6];
let mask = p.compute_mask(&weights).unwrap();
assert_abs_diff_eq!(mask.sparsity(), 0.5, epsilon = 0.01);
assert!(!mask.mask[0], "0.1 should be pruned");
assert!(mask.mask[1], "0.5 should be active");
}
#[test]
fn l2_prune_25_percent() {
let p = MagnitudePruner::new(0.25, MagnitudeNorm::L2);
let weights = vec![0.1_f32, 0.5, 0.3, 0.9];
let mask = p.compute_mask(&weights).unwrap();
assert_eq!(mask.count_pruned(), 1);
assert!(!mask.mask[0]);
}
#[test]
fn prune_in_place_zeroes_weights() {
let p = MagnitudePruner::new(0.5, MagnitudeNorm::L1);
let mut w = vec![1.0_f32, 0.01, 2.0, 0.02];
let mask = p.prune(&mut w).unwrap();
assert!(mask.count_pruned() >= 1);
assert_abs_diff_eq!(w[1], 0.0, epsilon = 1e-7);
assert_abs_diff_eq!(w[3], 0.0, epsilon = 1e-7);
}
#[test]
fn empty_input_error() {
let p = MagnitudePruner::new(0.5, MagnitudeNorm::L1);
assert!(matches!(
p.compute_mask(&[]),
Err(QuantError::EmptyInput(_))
));
}
#[test]
fn all_zero_pruning_error() {
let p = MagnitudePruner {
target_sparsity: 0.99,
norm: MagnitudeNorm::L1,
};
let w = vec![0.1_f32; 4];
assert!(matches!(
p.compute_mask(&w),
Err(QuantError::AllZeroPruning { .. })
));
}
#[test]
fn zero_sparsity_all_active() {
let p = MagnitudePruner::new(0.0, MagnitudeNorm::L1);
let w = vec![0.5_f32, 1.0, 2.0];
let mask = p.compute_mask(&w).unwrap();
assert_eq!(mask.count_pruned(), 0);
}
#[test]
fn grouped_mask_per_channel() {
let p = MagnitudePruner::new(0.5, MagnitudeNorm::L1);
let w = vec![
0.1_f32, 0.5, 0.2, 0.8, 0.9_f32, 0.3, 0.7, 0.1,
]; let mask = p.compute_grouped_mask(&w, 4).unwrap();
assert_eq!(mask.len(), 8);
assert_abs_diff_eq!(mask.sparsity(), 0.5, epsilon = 0.01);
}
#[test]
fn grouped_mask_mismatch_error() {
let p = MagnitudePruner::new(0.5, MagnitudeNorm::L1);
let w = vec![0.5_f32; 5];
assert!(matches!(
p.compute_grouped_mask(&w, 4),
Err(QuantError::GroupSizeMismatch { .. })
));
}
}