aprender/pruning/
wanda.rs1use super::calibration::CalibrationContext;
11use super::error::PruningError;
12use super::importance::{Importance, ImportanceScores};
13use super::mask::SparsityPattern;
14use crate::autograd::Tensor;
15use crate::nn::Module;
16
17#[derive(Debug, Clone)]
39pub struct WandaImportance {
40 layer_name: String,
42 pattern: Option<SparsityPattern>,
44 eps: f32,
46}
47
48impl WandaImportance {
49 pub fn new(layer_name: impl Into<String>) -> Self {
54 Self {
55 layer_name: layer_name.into(),
56 pattern: None,
57 eps: 1e-8,
58 }
59 }
60
61 #[must_use]
66 pub fn with_pattern(mut self, pattern: SparsityPattern) -> Self {
67 self.pattern = Some(pattern);
68 self
69 }
70
71 #[must_use]
76 pub fn with_eps(mut self, eps: f32) -> Self {
77 self.eps = eps;
78 self
79 }
80
81 #[must_use]
83 pub fn layer_name(&self) -> &str {
84 &self.layer_name
85 }
86
87 #[must_use]
89 pub fn pattern(&self) -> Option<SparsityPattern> {
90 self.pattern
91 }
92
93 pub fn compute_from_tensors(
105 &self,
106 weights: &Tensor,
107 activation_norms: &Tensor,
108 ) -> Result<Tensor, PruningError> {
109 let weight_shape = weights.shape();
110 let norm_shape = activation_norms.shape();
111
112 if weight_shape.len() < 2 {
114 return Err(PruningError::ShapeMismatch {
115 expected: vec![0, 0], got: weight_shape.to_vec(),
117 });
118 }
119
120 let out_features = weight_shape[0];
121 let in_features = weight_shape[1];
122
123 if norm_shape.is_empty() || norm_shape[0] != in_features {
124 return Err(PruningError::ShapeMismatch {
125 expected: vec![in_features],
126 got: norm_shape.to_vec(),
127 });
128 }
129
130 let weight_data = weights.data();
131 let norm_data = activation_norms.data();
132
133 for (i, &w) in weight_data.iter().enumerate() {
135 if w.is_nan() {
136 return Err(PruningError::NumericalInstability {
137 method: self.name().to_string(),
138 details: format!("NaN detected in weight at index {i}"),
139 });
140 }
141 if w.is_infinite() {
142 return Err(PruningError::NumericalInstability {
143 method: self.name().to_string(),
144 details: format!("Inf detected in weight at index {i}"),
145 });
146 }
147 }
148
149 for (i, &n) in norm_data.iter().enumerate() {
150 if n.is_nan() {
151 return Err(PruningError::NumericalInstability {
152 method: self.name().to_string(),
153 details: format!("NaN detected in activation norm at index {i}"),
154 });
155 }
156 if n.is_infinite() {
157 return Err(PruningError::NumericalInstability {
158 method: self.name().to_string(),
159 details: format!("Inf detected in activation norm at index {i}"),
160 });
161 }
162 }
163
164 let mut importance = vec![0.0f32; out_features * in_features];
166
167 for i in 0..out_features {
168 for j in 0..in_features {
169 let idx = i * in_features + j;
170 let w = weight_data[idx];
171 let norm = norm_data[j];
172
173 let sqrt_norm = if norm <= 0.0 {
175 self.eps.sqrt() } else {
177 norm.sqrt()
178 };
179
180 importance[idx] = w.abs() * sqrt_norm;
181 }
182 }
183
184 Ok(Tensor::new(&importance, weight_shape))
185 }
186}
187
188impl Importance for WandaImportance {
189 fn compute(
190 &self,
191 module: &dyn Module,
192 context: Option<&CalibrationContext>,
193 ) -> Result<ImportanceScores, PruningError> {
194 let ctx = context.ok_or(PruningError::CalibrationRequired {
196 method: self.name().to_string(),
197 })?;
198
199 let stats = ctx.require_stats(&self.layer_name)?;
201
202 let params = module.parameters();
204 if params.is_empty() {
205 return Err(PruningError::NoParameters {
206 module: self.layer_name.clone(),
207 });
208 }
209
210 let weights = params[0];
211
212 let importance = self.compute_from_tensors(weights, &stats.input_norms)?;
214
215 Ok(ImportanceScores::new(importance, "wanda".to_string()))
216 }
217
218 fn name(&self) -> &'static str {
219 "wanda"
220 }
221
222 fn requires_calibration(&self) -> bool {
223 true
224 }
225}
226
227#[cfg(test)]
228#[path = "wanda_tests.rs"]
229mod tests;