Skip to main content

aprender/pruning/
pruner.rs

1//! High-level pruning interface.
2//!
3//! # Toyota Way: Genchi Genbutsu
4//! Pruners operate on actual model weights, not abstractions.
5//!
6//! # References
7//! - Han, S., et al. (2015). Learning both weights and connections. `NeurIPS`.
8
9use super::calibration::CalibrationContext;
10use super::error::PruningError;
11use super::importance::{Importance, ImportanceScores};
12use super::mask::{SparsityMask, SparsityPattern};
13use crate::nn::Module;
14use std::collections::HashMap;
15
16/// Result of a pruning operation with diagnostics.
17///
18/// Contains statistics about the pruning operation including
19/// achieved sparsity, parameter counts, and per-layer breakdown.
20#[derive(Debug, Clone)]
21pub struct PruningResult {
22    /// Actual achieved sparsity (may differ from target for structured pruning).
23    pub achieved_sparsity: f32,
24    /// Number of parameters pruned (set to zero).
25    pub parameters_pruned: usize,
26    /// Total parameters in module.
27    pub total_parameters: usize,
28    /// Per-layer sparsity breakdown.
29    pub layer_sparsity: HashMap<String, f32>,
30    /// Estimated memory savings in bytes (assumes FP32).
31    pub memory_savings_bytes: usize,
32}
33
34impl PruningResult {
35    /// Create a new pruning result.
36    #[must_use]
37    pub fn new(achieved_sparsity: f32, parameters_pruned: usize, total_parameters: usize) -> Self {
38        Self {
39            achieved_sparsity,
40            parameters_pruned,
41            total_parameters,
42            layer_sparsity: HashMap::new(),
43            memory_savings_bytes: parameters_pruned * 4, // FP32 = 4 bytes
44        }
45    }
46
47    /// Add layer sparsity information.
48    #[must_use]
49    pub fn with_layer_sparsity(mut self, layer_name: String, sparsity: f32) -> Self {
50        self.layer_sparsity.insert(layer_name, sparsity);
51        self
52    }
53
54    /// Get compression ratio (original / pruned size).
55    #[must_use]
56    pub fn compression_ratio(&self) -> f32 {
57        if self.total_parameters == 0 || self.achieved_sparsity >= 1.0 {
58            return f32::INFINITY;
59        }
60        1.0 / (1.0 - self.achieved_sparsity)
61    }
62}
63
64impl Default for PruningResult {
65    fn default() -> Self {
66        Self::new(0.0, 0, 0)
67    }
68}
69
70/// High-level pruning interface.
71///
72/// # Toyota Way: Genchi Genbutsu
73/// Pruners must operate on actual model weights, not abstractions.
74///
75/// # Object Safety
76/// This trait is object-safe and can be used with `dyn Pruner`.
77pub trait Pruner: Send + Sync {
78    /// Generate a sparsity mask based on importance scores.
79    ///
80    /// # Arguments
81    /// * `scores` - Pre-computed importance scores
82    /// * `target_sparsity` - Desired fraction of weights to prune (0.0 to 1.0)
83    /// * `pattern` - Sparsity pattern constraint (unstructured, N:M, block)
84    ///
85    /// # Returns
86    /// * `Ok(SparsityMask)` - Generated mask
87    /// * `Err(PruningError)` - If mask generation fails
88    fn generate_mask(
89        &self,
90        scores: &ImportanceScores,
91        target_sparsity: f32,
92        pattern: SparsityPattern,
93    ) -> Result<SparsityMask, PruningError>;
94
95    /// Apply a sparsity mask to a module, zeroing pruned weights.
96    ///
97    /// # Arguments
98    /// * `module` - The module to prune (modified in-place)
99    /// * `mask` - The sparsity mask to apply
100    ///
101    /// # Returns
102    /// * `Ok(PruningResult)` - Statistics about the pruning operation
103    /// * `Err(PruningError)` - If mask application fails
104    ///
105    /// # Safety
106    /// This operation modifies weights in-place. The mask must match
107    /// the module's parameter shapes exactly.
108    fn apply_mask(
109        &self,
110        module: &mut dyn Module,
111        mask: &SparsityMask,
112    ) -> Result<PruningResult, PruningError>;
113
114    /// Get the importance estimator used by this pruner.
115    fn importance(&self) -> &dyn Importance;
116
117    /// Name of this pruner for logging.
118    fn name(&self) -> &'static str;
119}
120
121/// Simple magnitude-based pruner.
122///
123/// Uses weight magnitude as importance and generates masks to achieve
124/// the target sparsity by pruning the smallest weights.
125#[derive(Debug, Clone)]
126pub struct MagnitudePruner {
127    importance: super::magnitude::MagnitudeImportance,
128}
129
130impl MagnitudePruner {
131    /// Create a new magnitude pruner with L2 norm.
132    #[must_use]
133    pub fn new() -> Self {
134        Self {
135            importance: super::magnitude::MagnitudeImportance::l2(),
136        }
137    }
138
139    /// Create a magnitude pruner with L1 norm.
140    #[must_use]
141    pub fn l1() -> Self {
142        Self {
143            importance: super::magnitude::MagnitudeImportance::l1(),
144        }
145    }
146
147    /// Create a magnitude pruner with L2 norm.
148    #[must_use]
149    pub fn l2() -> Self {
150        Self {
151            importance: super::magnitude::MagnitudeImportance::l2(),
152        }
153    }
154}
155
156impl Default for MagnitudePruner {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162impl Pruner for MagnitudePruner {
163    fn generate_mask(
164        &self,
165        scores: &ImportanceScores,
166        target_sparsity: f32,
167        pattern: SparsityPattern,
168    ) -> Result<SparsityMask, PruningError> {
169        match pattern {
170            SparsityPattern::Unstructured => {
171                super::mask::generate_unstructured_mask(&scores.values, target_sparsity)
172            }
173            SparsityPattern::NM { n, m } => super::mask::generate_nm_mask(&scores.values, n, m),
174            SparsityPattern::Block { height, width } => {
175                super::mask::generate_block_mask(&scores.values, height, width, target_sparsity)
176            }
177            SparsityPattern::Row => super::mask::generate_row_mask(&scores.values, target_sparsity),
178            SparsityPattern::Column => {
179                super::mask::generate_column_mask(&scores.values, target_sparsity)
180            }
181        }
182    }
183
184    fn apply_mask(
185        &self,
186        module: &mut dyn Module,
187        mask: &SparsityMask,
188    ) -> Result<PruningResult, PruningError> {
189        let mut params = module.parameters_mut();
190        if params.is_empty() {
191            return Err(PruningError::NoParameters {
192                module: "unknown".to_string(),
193            });
194        }
195
196        // Apply mask to first parameter (weight matrix)
197        let weights = &mut *params[0];
198        let total = weights.data().len();
199
200        mask.apply(weights)?;
201
202        let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
203        let achieved_sparsity = zeros as f32 / total as f32;
204
205        Ok(PruningResult::new(achieved_sparsity, zeros, total))
206    }
207
208    fn importance(&self) -> &dyn Importance {
209        &self.importance
210    }
211
212    fn name(&self) -> &'static str {
213        "magnitude_pruner"
214    }
215}
216
217/// Wanda-based pruner.
218///
219/// Uses activation-weighted importance (Wanda) and generates masks
220/// to achieve the target sparsity. Requires calibration data.
221#[derive(Debug, Clone)]
222pub struct WandaPruner {
223    importance: super::wanda::WandaImportance,
224}
225
226impl WandaPruner {
227    /// Create a new Wanda pruner for a specific layer.
228    ///
229    /// # Arguments
230    /// * `layer_name` - Layer identifier to look up in `CalibrationContext`
231    pub fn new(layer_name: impl Into<String>) -> Self {
232        Self {
233            importance: super::wanda::WandaImportance::new(layer_name),
234        }
235    }
236}
237
238impl Pruner for WandaPruner {
239    fn generate_mask(
240        &self,
241        scores: &ImportanceScores,
242        target_sparsity: f32,
243        pattern: SparsityPattern,
244    ) -> Result<SparsityMask, PruningError> {
245        match pattern {
246            SparsityPattern::Unstructured => {
247                super::mask::generate_unstructured_mask(&scores.values, target_sparsity)
248            }
249            SparsityPattern::NM { n, m } => super::mask::generate_nm_mask(&scores.values, n, m),
250            SparsityPattern::Block { height, width } => {
251                super::mask::generate_block_mask(&scores.values, height, width, target_sparsity)
252            }
253            SparsityPattern::Row => super::mask::generate_row_mask(&scores.values, target_sparsity),
254            SparsityPattern::Column => {
255                super::mask::generate_column_mask(&scores.values, target_sparsity)
256            }
257        }
258    }
259
260    fn apply_mask(
261        &self,
262        module: &mut dyn Module,
263        mask: &SparsityMask,
264    ) -> Result<PruningResult, PruningError> {
265        let mut params = module.parameters_mut();
266        if params.is_empty() {
267            return Err(PruningError::NoParameters {
268                module: "unknown".to_string(),
269            });
270        }
271
272        let weights = &mut *params[0];
273        let total = weights.data().len();
274
275        mask.apply(weights)?;
276
277        let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
278        let achieved_sparsity = zeros as f32 / total as f32;
279
280        Ok(PruningResult::new(achieved_sparsity, zeros, total))
281    }
282
283    fn importance(&self) -> &dyn Importance {
284        &self.importance
285    }
286
287    fn name(&self) -> &'static str {
288        "wanda_pruner"
289    }
290}
291
292/// Convenience function to prune a module with a single call.
293///
294/// # Arguments
295/// * `module` - Module to prune
296/// * `pruner` - Pruner to use
297/// * `target_sparsity` - Desired sparsity ratio
298/// * `pattern` - Sparsity pattern
299/// * `context` - Optional calibration context
300///
301/// # Returns
302/// Pruning result with statistics.
303pub fn prune_module(
304    module: &mut dyn Module,
305    pruner: &dyn Pruner,
306    target_sparsity: f32,
307    pattern: SparsityPattern,
308    context: Option<&CalibrationContext>,
309) -> Result<PruningResult, PruningError> {
310    // Compute importance scores
311    let scores = pruner.importance().compute(module, context)?;
312
313    // Generate mask
314    let mask = pruner.generate_mask(&scores, target_sparsity, pattern)?;
315
316    // Apply mask
317    pruner.apply_mask(module, &mask)
318}
319
320#[cfg(test)]
321#[path = "pruner_tests.rs"]
322mod tests;