use super::calibration::CalibrationContext;
use super::error::PruningError;
use super::importance::{Importance, ImportanceScores};
use crate::autograd::Tensor;
use crate::nn::Module;
#[derive(Debug, Clone)]
pub struct SparseGPTImportance {
layer_name: String,
block_size: usize,
damp: f32,
damp_relative: bool,
}
impl SparseGPTImportance {
pub fn new(layer_name: impl Into<String>) -> Self {
Self {
layer_name: layer_name.into(),
block_size: 128,
damp: 0.01,
damp_relative: true,
}
}
#[must_use]
pub fn with_block_size(mut self, block_size: usize) -> Self {
self.block_size = block_size;
self
}
#[must_use]
pub fn with_damp(mut self, damp: f32) -> Self {
self.damp = damp;
self.damp_relative = false;
self
}
#[must_use]
pub fn with_relative_damp(mut self, damp: f32) -> Self {
self.damp = damp;
self.damp_relative = true;
self
}
#[must_use]
pub fn layer_name(&self) -> &str {
&self.layer_name
}
#[must_use]
pub fn block_size(&self) -> usize {
self.block_size
}
#[must_use]
pub fn damp(&self) -> f32 {
self.damp
}
pub fn compute_hessian(&self, activations: &Tensor) -> Result<Tensor, PruningError> {
let shape = activations.shape();
if shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0], got: shape.to_vec(),
});
}
let num_samples = shape[0];
let in_features = shape[1];
if num_samples == 0 {
return Err(PruningError::InvalidSparsity {
value: 0.0,
constraint: "calibration must have at least 1 sample".to_string(),
});
}
let act_data = activations.data();
for (i, &v) in act_data.iter().enumerate() {
if v.is_nan() {
return Err(PruningError::NumericalInstability {
method: "SparseGPT".to_string(),
details: format!("NaN in activation at index {i}"),
});
}
if v.is_infinite() {
return Err(PruningError::NumericalInstability {
method: "SparseGPT".to_string(),
details: format!("Inf in activation at index {i}"),
});
}
}
let mut hessian = vec![0.0f32; in_features * in_features];
for sample in 0..num_samples {
for i in 0..in_features {
let xi = act_data[sample * in_features + i];
for j in 0..in_features {
let xj = act_data[sample * in_features + j];
hessian[i * in_features + j] += xi * xj;
}
}
}
let n = num_samples as f32;
for v in &mut hessian {
*v /= n;
}
let damp_value = if self.damp_relative {
let mut diag_sum = 0.0f32;
for i in 0..in_features {
diag_sum += hessian[i * in_features + i];
}
let mean_diag = diag_sum / in_features as f32;
self.damp * mean_diag
} else {
self.damp
};
for i in 0..in_features {
hessian[i * in_features + i] += damp_value;
}
Ok(Tensor::new(&hessian, &[in_features, in_features]))
}
pub fn compute_hessian_inverse(&self, hessian: &Tensor) -> Result<Tensor, PruningError> {
let shape = hessian.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(PruningError::ShapeMismatch {
expected: vec![shape[0], shape[0]],
got: shape.to_vec(),
});
}
let n = shape[0];
let h_data = hessian.data();
let mut l = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..=i {
let mut sum = h_data[i * n + j];
for k in 0..j {
sum -= l[i * n + k] * l[j * n + k];
}
if i == j {
if sum <= 0.0 {
return Err(PruningError::NumericalInstability {
method: "SparseGPT".to_string(),
details: format!(
"Hessian not positive definite at index {i}. Consider increasing damping."
),
});
}
l[i * n + j] = sum.sqrt();
} else {
l[i * n + j] = sum / l[j * n + j];
}
}
}
let mut l_inv = vec![0.0f32; n * n];
for i in 0..n {
l_inv[i * n + i] = 1.0 / l[i * n + i];
for j in 0..i {
let mut sum = 0.0f32;
for k in j..i {
sum -= l[i * n + k] * l_inv[k * n + j];
}
l_inv[i * n + j] = sum / l[i * n + i];
}
}
let mut h_inv = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..n {
let mut sum = 0.0f32;
for k in i.max(j)..n {
sum += l_inv[k * n + i] * l_inv[k * n + j];
}
h_inv[i * n + j] = sum;
}
}
Ok(Tensor::new(&h_inv, &[n, n]))
}
pub fn compute_saliency(
&self,
weights: &Tensor,
hessian_inv: &Tensor,
) -> Result<Tensor, PruningError> {
let w_shape = weights.shape();
let h_shape = hessian_inv.shape();
if w_shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0],
got: w_shape.to_vec(),
});
}
let out_features = w_shape[0];
let in_features = w_shape[1];
if h_shape[0] != in_features {
return Err(PruningError::ShapeMismatch {
expected: vec![in_features, in_features],
got: h_shape.to_vec(),
});
}
let w_data = weights.data();
let h_data = hessian_inv.data();
let mut h_diag = vec![0.0f32; in_features];
for j in 0..in_features {
h_diag[j] = h_data[j * in_features + j];
if h_diag[j] <= 0.0 {
return Err(PruningError::NumericalInstability {
method: "SparseGPT".to_string(),
details: format!(
"Non-positive Hessian inverse diagonal at index {}: {}",
j, h_diag[j]
),
});
}
}
let mut saliency = vec![0.0f32; out_features * in_features];
for i in 0..out_features {
for j in 0..in_features {
let w = w_data[i * in_features + j];
saliency[i * in_features + j] = (w * w) / h_diag[j];
}
}
Ok(Tensor::new(&saliency, w_shape))
}
pub fn compute_from_activations(
&self,
weights: &Tensor,
activations: &Tensor,
) -> Result<ImportanceScores, PruningError> {
let hessian = self.compute_hessian(activations)?;
let hessian_inv = self.compute_hessian_inverse(&hessian)?;
let saliency = self.compute_saliency(weights, &hessian_inv)?;
Ok(ImportanceScores::new(
saliency,
"sparsegpt_saliency".to_string(),
))
}
}
impl Importance for SparseGPTImportance {
fn compute(
&self,
module: &dyn Module,
context: Option<&CalibrationContext>,
) -> Result<ImportanceScores, PruningError> {
let ctx = context.ok_or(PruningError::CalibrationRequired {
method: self.name().to_string(),
})?;
let stats = ctx.require_stats(&self.layer_name)?;
let params = module.parameters();
if params.is_empty() {
return Err(PruningError::NoParameters {
module: self.layer_name.clone(),
});
}
let weights = params[0];
let in_features = stats.input_features();
let sq_mean = stats.squared_mean.data();
let mut activations = vec![0.0f32; in_features];
for i in 0..in_features {
activations[i] = sq_mean[i].sqrt();
}
let act_tensor = Tensor::new(&activations, &[1, in_features]);
self.compute_from_activations(weights, &act_tensor)
}
fn name(&self) -> &'static str {
"sparsegpt"
}
fn requires_calibration(&self) -> bool {
true
}
}
#[cfg(test)]
#[path = "sparsegpt_tests.rs"]
mod tests;