use super::calibration::CalibrationContext;
use super::error::PruningError;
use super::importance::{Importance, ImportanceScores};
use crate::autograd::Tensor;
use crate::nn::Module;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormType {
L1,
L2,
}
#[derive(Debug, Clone)]
pub struct MagnitudeImportance {
norm: NormType,
}
impl MagnitudeImportance {
#[must_use]
pub fn l1() -> Self {
Self { norm: NormType::L1 }
}
#[must_use]
pub fn l2() -> Self {
Self { norm: NormType::L2 }
}
#[must_use]
pub fn with_norm(norm: NormType) -> Self {
Self { norm }
}
#[must_use]
pub fn norm(&self) -> NormType {
self.norm
}
pub fn compute_for_weights(&self, weights: &Tensor) -> Result<Tensor, PruningError> {
let data = weights.data();
for (i, &w) in data.iter().enumerate() {
if w.is_nan() {
return Err(PruningError::NumericalInstability {
method: self.name().to_string(),
details: format!("NaN detected in weight at index {i}"),
});
}
if w.is_infinite() {
return Err(PruningError::NumericalInstability {
method: self.name().to_string(),
details: format!("Inf detected in weight at index {i}"),
});
}
}
let importance: Vec<f32> = match self.norm {
NormType::L1 => data.iter().map(|&w| w.abs()).collect(),
NormType::L2 => data.iter().map(|&w| w * w).collect(),
};
Ok(Tensor::new(&importance, weights.shape()))
}
}
impl Importance for MagnitudeImportance {
fn compute(
&self,
module: &dyn Module,
_context: Option<&CalibrationContext>,
) -> Result<ImportanceScores, PruningError> {
let params = module.parameters();
if params.is_empty() {
return Err(PruningError::NoParameters {
module: "unknown".to_string(),
});
}
let weights = params[0];
let importance = self.compute_for_weights(weights)?;
Ok(ImportanceScores::new(
importance,
format!(
"magnitude_{}",
match self.norm {
NormType::L1 => "l1",
NormType::L2 => "l2",
}
),
))
}
fn name(&self) -> &'static str {
match self.norm {
NormType::L1 => "magnitude_l1",
NormType::L2 => "magnitude_l2",
}
}
fn requires_calibration(&self) -> bool {
false
}
}
#[cfg(test)]
#[path = "magnitude_tests.rs"]
mod tests;