use super::calibration::CalibrationContext;
use super::error::PruningError;
use super::importance::{Importance, ImportanceScores};
use super::mask::SparsityPattern;
use crate::autograd::Tensor;
use crate::nn::Module;
#[derive(Debug, Clone)]
pub struct WandaImportance {
layer_name: String,
pattern: Option<SparsityPattern>,
eps: f32,
}
impl WandaImportance {
pub fn new(layer_name: impl Into<String>) -> Self {
Self {
layer_name: layer_name.into(),
pattern: None,
eps: 1e-8,
}
}
#[must_use]
pub fn with_pattern(mut self, pattern: SparsityPattern) -> Self {
self.pattern = Some(pattern);
self
}
#[must_use]
pub fn with_eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
#[must_use]
pub fn layer_name(&self) -> &str {
&self.layer_name
}
#[must_use]
pub fn pattern(&self) -> Option<SparsityPattern> {
self.pattern
}
pub fn compute_from_tensors(
&self,
weights: &Tensor,
activation_norms: &Tensor,
) -> Result<Tensor, PruningError> {
let weight_shape = weights.shape();
let norm_shape = activation_norms.shape();
if weight_shape.len() < 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0], got: weight_shape.to_vec(),
});
}
let out_features = weight_shape[0];
let in_features = weight_shape[1];
if norm_shape.is_empty() || norm_shape[0] != in_features {
return Err(PruningError::ShapeMismatch {
expected: vec![in_features],
got: norm_shape.to_vec(),
});
}
let weight_data = weights.data();
let norm_data = activation_norms.data();
for (i, &w) in weight_data.iter().enumerate() {
if w.is_nan() {
return Err(PruningError::NumericalInstability {
method: self.name().to_string(),
details: format!("NaN detected in weight at index {i}"),
});
}
if w.is_infinite() {
return Err(PruningError::NumericalInstability {
method: self.name().to_string(),
details: format!("Inf detected in weight at index {i}"),
});
}
}
for (i, &n) in norm_data.iter().enumerate() {
if n.is_nan() {
return Err(PruningError::NumericalInstability {
method: self.name().to_string(),
details: format!("NaN detected in activation norm at index {i}"),
});
}
if n.is_infinite() {
return Err(PruningError::NumericalInstability {
method: self.name().to_string(),
details: format!("Inf detected in activation norm at index {i}"),
});
}
}
let mut importance = vec![0.0f32; out_features * in_features];
for i in 0..out_features {
for j in 0..in_features {
let idx = i * in_features + j;
let w = weight_data[idx];
let norm = norm_data[j];
let sqrt_norm = if norm <= 0.0 {
self.eps.sqrt() } else {
norm.sqrt()
};
importance[idx] = w.abs() * sqrt_norm;
}
}
Ok(Tensor::new(&importance, weight_shape))
}
}
impl Importance for WandaImportance {
fn compute(
&self,
module: &dyn Module,
context: Option<&CalibrationContext>,
) -> Result<ImportanceScores, PruningError> {
let ctx = context.ok_or(PruningError::CalibrationRequired {
method: self.name().to_string(),
})?;
let stats = ctx.require_stats(&self.layer_name)?;
let params = module.parameters();
if params.is_empty() {
return Err(PruningError::NoParameters {
module: self.layer_name.clone(),
});
}
let weights = params[0];
let importance = self.compute_from_tensors(weights, &stats.input_norms)?;
Ok(ImportanceScores::new(importance, "wanda".to_string()))
}
fn name(&self) -> &'static str {
"wanda"
}
fn requires_calibration(&self) -> bool {
true
}
}
#[cfg(test)]
#[path = "wanda_tests.rs"]
mod tests;