1use super::calibration::CalibrationContext;
10use super::error::PruningError;
11use super::importance::{Importance, ImportanceScores};
12use super::mask::{SparsityMask, SparsityPattern};
13use crate::nn::Module;
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
21pub struct PruningResult {
22 pub achieved_sparsity: f32,
24 pub parameters_pruned: usize,
26 pub total_parameters: usize,
28 pub layer_sparsity: HashMap<String, f32>,
30 pub memory_savings_bytes: usize,
32}
33
34impl PruningResult {
35 #[must_use]
37 pub fn new(achieved_sparsity: f32, parameters_pruned: usize, total_parameters: usize) -> Self {
38 Self {
39 achieved_sparsity,
40 parameters_pruned,
41 total_parameters,
42 layer_sparsity: HashMap::new(),
43 memory_savings_bytes: parameters_pruned * 4, }
45 }
46
47 #[must_use]
49 pub fn with_layer_sparsity(mut self, layer_name: String, sparsity: f32) -> Self {
50 self.layer_sparsity.insert(layer_name, sparsity);
51 self
52 }
53
54 #[must_use]
56 pub fn compression_ratio(&self) -> f32 {
57 if self.total_parameters == 0 || self.achieved_sparsity >= 1.0 {
58 return f32::INFINITY;
59 }
60 1.0 / (1.0 - self.achieved_sparsity)
61 }
62}
63
64impl Default for PruningResult {
65 fn default() -> Self {
66 Self::new(0.0, 0, 0)
67 }
68}
69
70pub trait Pruner: Send + Sync {
78 fn generate_mask(
89 &self,
90 scores: &ImportanceScores,
91 target_sparsity: f32,
92 pattern: SparsityPattern,
93 ) -> Result<SparsityMask, PruningError>;
94
95 fn apply_mask(
109 &self,
110 module: &mut dyn Module,
111 mask: &SparsityMask,
112 ) -> Result<PruningResult, PruningError>;
113
114 fn importance(&self) -> &dyn Importance;
116
117 fn name(&self) -> &'static str;
119}
120
121#[derive(Debug, Clone)]
126pub struct MagnitudePruner {
127 importance: super::magnitude::MagnitudeImportance,
128}
129
130impl MagnitudePruner {
131 #[must_use]
133 pub fn new() -> Self {
134 Self {
135 importance: super::magnitude::MagnitudeImportance::l2(),
136 }
137 }
138
139 #[must_use]
141 pub fn l1() -> Self {
142 Self {
143 importance: super::magnitude::MagnitudeImportance::l1(),
144 }
145 }
146
147 #[must_use]
149 pub fn l2() -> Self {
150 Self {
151 importance: super::magnitude::MagnitudeImportance::l2(),
152 }
153 }
154}
155
156impl Default for MagnitudePruner {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162impl Pruner for MagnitudePruner {
163 fn generate_mask(
164 &self,
165 scores: &ImportanceScores,
166 target_sparsity: f32,
167 pattern: SparsityPattern,
168 ) -> Result<SparsityMask, PruningError> {
169 match pattern {
170 SparsityPattern::Unstructured => {
171 super::mask::generate_unstructured_mask(&scores.values, target_sparsity)
172 }
173 SparsityPattern::NM { n, m } => super::mask::generate_nm_mask(&scores.values, n, m),
174 SparsityPattern::Block { height, width } => {
175 super::mask::generate_block_mask(&scores.values, height, width, target_sparsity)
176 }
177 SparsityPattern::Row => super::mask::generate_row_mask(&scores.values, target_sparsity),
178 SparsityPattern::Column => {
179 super::mask::generate_column_mask(&scores.values, target_sparsity)
180 }
181 }
182 }
183
184 fn apply_mask(
185 &self,
186 module: &mut dyn Module,
187 mask: &SparsityMask,
188 ) -> Result<PruningResult, PruningError> {
189 let mut params = module.parameters_mut();
190 if params.is_empty() {
191 return Err(PruningError::NoParameters {
192 module: "unknown".to_string(),
193 });
194 }
195
196 let weights = &mut *params[0];
198 let total = weights.data().len();
199
200 mask.apply(weights)?;
201
202 let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
203 let achieved_sparsity = zeros as f32 / total as f32;
204
205 Ok(PruningResult::new(achieved_sparsity, zeros, total))
206 }
207
208 fn importance(&self) -> &dyn Importance {
209 &self.importance
210 }
211
212 fn name(&self) -> &'static str {
213 "magnitude_pruner"
214 }
215}
216
217#[derive(Debug, Clone)]
222pub struct WandaPruner {
223 importance: super::wanda::WandaImportance,
224}
225
226impl WandaPruner {
227 pub fn new(layer_name: impl Into<String>) -> Self {
232 Self {
233 importance: super::wanda::WandaImportance::new(layer_name),
234 }
235 }
236}
237
238impl Pruner for WandaPruner {
239 fn generate_mask(
240 &self,
241 scores: &ImportanceScores,
242 target_sparsity: f32,
243 pattern: SparsityPattern,
244 ) -> Result<SparsityMask, PruningError> {
245 match pattern {
246 SparsityPattern::Unstructured => {
247 super::mask::generate_unstructured_mask(&scores.values, target_sparsity)
248 }
249 SparsityPattern::NM { n, m } => super::mask::generate_nm_mask(&scores.values, n, m),
250 SparsityPattern::Block { height, width } => {
251 super::mask::generate_block_mask(&scores.values, height, width, target_sparsity)
252 }
253 SparsityPattern::Row => super::mask::generate_row_mask(&scores.values, target_sparsity),
254 SparsityPattern::Column => {
255 super::mask::generate_column_mask(&scores.values, target_sparsity)
256 }
257 }
258 }
259
260 fn apply_mask(
261 &self,
262 module: &mut dyn Module,
263 mask: &SparsityMask,
264 ) -> Result<PruningResult, PruningError> {
265 let mut params = module.parameters_mut();
266 if params.is_empty() {
267 return Err(PruningError::NoParameters {
268 module: "unknown".to_string(),
269 });
270 }
271
272 let weights = &mut *params[0];
273 let total = weights.data().len();
274
275 mask.apply(weights)?;
276
277 let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
278 let achieved_sparsity = zeros as f32 / total as f32;
279
280 Ok(PruningResult::new(achieved_sparsity, zeros, total))
281 }
282
283 fn importance(&self) -> &dyn Importance {
284 &self.importance
285 }
286
287 fn name(&self) -> &'static str {
288 "wanda_pruner"
289 }
290}
291
292pub fn prune_module(
304 module: &mut dyn Module,
305 pruner: &dyn Pruner,
306 target_sparsity: f32,
307 pattern: SparsityPattern,
308 context: Option<&CalibrationContext>,
309) -> Result<PruningResult, PruningError> {
310 let scores = pruner.importance().compute(module, context)?;
312
313 let mask = pruner.generate_mask(&scores, target_sparsity, pattern)?;
315
316 pruner.apply_mask(module, &mask)
318}
319
320#[cfg(test)]
321#[path = "pruner_tests.rs"]
322mod tests;