Skip to main content

aprender/pruning/
wanda.rs

1//! Wanda (Weights and Activations) importance scoring.
2//!
3//! # Toyota Way: Genchi Genbutsu
4//! Uses real activation patterns from calibration data, not estimates.
5//!
6//! # References
7//! - Sun, M., et al. (2023). A simple and effective pruning approach for large language models.
8//!   arXiv:2306.11695.
9
10use super::calibration::CalibrationContext;
11use super::error::PruningError;
12use super::importance::{Importance, ImportanceScores};
13use super::mask::SparsityPattern;
14use crate::autograd::Tensor;
15use crate::nn::Module;
16
17/// Wanda (Weights and Activations) importance estimator.
18///
19/// Combines weight magnitudes with input activation norms to identify
20/// important weights. This method is from Sun et al. (2023) and
21/// achieves strong results with no retraining needed.
22///
23/// # Formula
24/// `importance = |w| * sqrt(activation_norm)`
25///
26/// Where `activation_norm` is the L2 norm of input activations across
27/// calibration samples for each input channel.
28///
29/// # Advantages
30/// - No gradient computation needed
31/// - No retraining required after pruning
32/// - Works well at moderate sparsity (50%)
33/// - Very fast (single forward pass for calibration)
34///
35/// # Requirements
36/// - Calibration data (128 samples typically sufficient)
37/// - Activation statistics for target layer
38#[derive(Debug, Clone)]
39pub struct WandaImportance {
40    /// Layer name to look up in calibration context
41    layer_name: String,
42    /// Optional pattern constraint for N:M pruning
43    pattern: Option<SparsityPattern>,
44    /// Small epsilon to prevent division by zero
45    eps: f32,
46}
47
48impl WandaImportance {
49    /// Create Wanda importance estimator for a specific layer.
50    ///
51    /// # Arguments
52    /// * `layer_name` - Layer identifier to look up in `CalibrationContext`
53    pub fn new(layer_name: impl Into<String>) -> Self {
54        Self {
55            layer_name: layer_name.into(),
56            pattern: None,
57            eps: 1e-8,
58        }
59    }
60
61    /// Set sparsity pattern constraint.
62    ///
63    /// # Arguments
64    /// * `pattern` - N:M pattern or other structural constraint
65    #[must_use]
66    pub fn with_pattern(mut self, pattern: SparsityPattern) -> Self {
67        self.pattern = Some(pattern);
68        self
69    }
70
71    /// Set epsilon for numerical stability.
72    ///
73    /// # Arguments
74    /// * `eps` - Small value to prevent division by zero (default: 1e-8)
75    #[must_use]
76    pub fn with_eps(mut self, eps: f32) -> Self {
77        self.eps = eps;
78        self
79    }
80
81    /// Get the layer name.
82    #[must_use]
83    pub fn layer_name(&self) -> &str {
84        &self.layer_name
85    }
86
87    /// Get the pattern if set.
88    #[must_use]
89    pub fn pattern(&self) -> Option<SparsityPattern> {
90        self.pattern
91    }
92
93    /// Compute Wanda importance scores.
94    ///
95    /// # Arguments
96    /// * `weights` - Weight tensor of shape \[`out_features`, `in_features`\]
97    /// * `activation_norms` - L2 norms of input activations \[`in_features`\]
98    ///
99    /// # Returns
100    /// Importance scores with same shape as weights.
101    ///
102    /// # Formula
103    /// `importance[i,j] = |weights[i,j]| * sqrt(activation_norms[j])`
104    pub fn compute_from_tensors(
105        &self,
106        weights: &Tensor,
107        activation_norms: &Tensor,
108    ) -> Result<Tensor, PruningError> {
109        let weight_shape = weights.shape();
110        let norm_shape = activation_norms.shape();
111
112        // Validate shapes
113        if weight_shape.len() < 2 {
114            return Err(PruningError::ShapeMismatch {
115                expected: vec![0, 0], // Indicates 2D expected
116                got: weight_shape.to_vec(),
117            });
118        }
119
120        let out_features = weight_shape[0];
121        let in_features = weight_shape[1];
122
123        if norm_shape.is_empty() || norm_shape[0] != in_features {
124            return Err(PruningError::ShapeMismatch {
125                expected: vec![in_features],
126                got: norm_shape.to_vec(),
127            });
128        }
129
130        let weight_data = weights.data();
131        let norm_data = activation_norms.data();
132
133        // Jidoka: Check for NaN/Inf
134        for (i, &w) in weight_data.iter().enumerate() {
135            if w.is_nan() {
136                return Err(PruningError::NumericalInstability {
137                    method: self.name().to_string(),
138                    details: format!("NaN detected in weight at index {i}"),
139                });
140            }
141            if w.is_infinite() {
142                return Err(PruningError::NumericalInstability {
143                    method: self.name().to_string(),
144                    details: format!("Inf detected in weight at index {i}"),
145                });
146            }
147        }
148
149        for (i, &n) in norm_data.iter().enumerate() {
150            if n.is_nan() {
151                return Err(PruningError::NumericalInstability {
152                    method: self.name().to_string(),
153                    details: format!("NaN detected in activation norm at index {i}"),
154                });
155            }
156            if n.is_infinite() {
157                return Err(PruningError::NumericalInstability {
158                    method: self.name().to_string(),
159                    details: format!("Inf detected in activation norm at index {i}"),
160                });
161            }
162        }
163
164        // Compute importance: |w| * sqrt(activation_norm)
165        let mut importance = vec![0.0f32; out_features * in_features];
166
167        for i in 0..out_features {
168            for j in 0..in_features {
169                let idx = i * in_features + j;
170                let w = weight_data[idx];
171                let norm = norm_data[j];
172
173                // Handle zero/negative norms gracefully
174                let sqrt_norm = if norm <= 0.0 {
175                    self.eps.sqrt() // Use epsilon for zero activations
176                } else {
177                    norm.sqrt()
178                };
179
180                importance[idx] = w.abs() * sqrt_norm;
181            }
182        }
183
184        Ok(Tensor::new(&importance, weight_shape))
185    }
186}
187
188impl Importance for WandaImportance {
189    fn compute(
190        &self,
191        module: &dyn Module,
192        context: Option<&CalibrationContext>,
193    ) -> Result<ImportanceScores, PruningError> {
194        // Require calibration context
195        let ctx = context.ok_or(PruningError::CalibrationRequired {
196            method: self.name().to_string(),
197        })?;
198
199        // Get activation stats for this layer
200        let stats = ctx.require_stats(&self.layer_name)?;
201
202        // Get module parameters
203        let params = module.parameters();
204        if params.is_empty() {
205            return Err(PruningError::NoParameters {
206                module: self.layer_name.clone(),
207            });
208        }
209
210        let weights = params[0];
211
212        // Compute importance using activation norms
213        let importance = self.compute_from_tensors(weights, &stats.input_norms)?;
214
215        Ok(ImportanceScores::new(importance, "wanda".to_string()))
216    }
217
218    fn name(&self) -> &'static str {
219        "wanda"
220    }
221
222    fn requires_calibration(&self) -> bool {
223        true
224    }
225}
226
227#[cfg(test)]
228#[path = "wanda_tests.rs"]
229mod tests;