use super::calibration::CalibrationContext;
use super::error::PruningError;
use crate::autograd::Tensor;
use crate::nn::Module;
#[derive(Debug, Clone)]
pub struct ImportanceStats {
pub min: f32,
pub max: f32,
pub mean: f32,
pub std: f32,
pub sparsity_at_threshold: Vec<(f32, f32)>,
}
impl ImportanceStats {
#[must_use]
pub fn from_tensor(values: &Tensor) -> Self {
let data = values.data();
if data.is_empty() {
return Self {
min: 0.0,
max: 0.0,
mean: 0.0,
std: 0.0,
sparsity_at_threshold: vec![],
};
}
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut sum = 0.0f32;
for &v in data {
if v < min {
min = v;
}
if v > max {
max = v;
}
sum += v;
}
let mean = sum / data.len() as f32;
let variance: f32 =
data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
let std = variance.sqrt();
Self {
min,
max,
mean,
std,
sparsity_at_threshold: vec![], }
}
#[must_use]
pub fn sparsity_at(&self, values: &Tensor, threshold: f32) -> f32 {
let data = values.data();
if data.is_empty() {
return 0.0;
}
let below = data.iter().filter(|&&v| v < threshold).count();
below as f32 / data.len() as f32
}
}
impl Default for ImportanceStats {
fn default() -> Self {
Self {
min: 0.0,
max: 0.0,
mean: 0.0,
std: 0.0,
sparsity_at_threshold: vec![],
}
}
}
#[derive(Debug, Clone)]
pub struct ImportanceScores {
pub values: Tensor,
pub stats: ImportanceStats,
pub method: String,
}
impl ImportanceScores {
#[must_use]
pub fn new(values: Tensor, method: String) -> Self {
let stats = ImportanceStats::from_tensor(&values);
Self {
values,
stats,
method,
}
}
#[must_use]
pub fn shape(&self) -> &[usize] {
self.values.shape()
}
#[must_use]
pub fn len(&self) -> usize {
self.values.data().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.values.data().is_empty()
}
}
pub trait Importance: Send + Sync {
fn compute(
&self,
module: &dyn Module,
context: Option<&CalibrationContext>,
) -> Result<ImportanceScores, PruningError>;
fn name(&self) -> &'static str;
fn requires_calibration(&self) -> bool;
}
#[cfg(test)]
#[path = "importance_tests.rs"]
mod tests;