Skip to main content

aprender/pruning/
lottery.rs

1//! Lottery Ticket Hypothesis implementation.
2//!
3//! This module implements the Lottery Ticket Hypothesis (LTH) from:
4//! Frankle, J., & Carbin, M. (2018). The Lottery Ticket Hypothesis:
5//! Finding Sparse, Trainable Neural Networks. arXiv:1803.03635.
6//!
7//! # Key Insight
8//! Dense neural networks contain sparse subnetworks ("winning tickets") that
9//! can achieve comparable test accuracy when trained from scratch with their
10//! original initialization.
11//!
12//! # Algorithm: Iterative Magnitude Pruning (IMP)
13//! 1. Initialize network with weights W₀
14//! 2. Train network to convergence → W_T
15//! 3. Prune p% of smallest magnitude weights → mask M
16//! 4. Reset remaining weights to W₀ (or W_k for late rewinding)
17//! 5. Repeat from step 2 with masked network
18//!
19//! # Toyota Way Principles
20//! - **Jidoka**: Validate weight tensors at each pruning round
21//! - **Poka-Yoke**: Type-safe configuration prevents invalid settings
22//! - **Genchi Genbutsu**: Uses actual trained weights for importance
23//!
24//! # Example
25//!
26//! ```ignore
27//! use aprender::pruning::{LotteryTicketPruner, RewindStrategy};
28//!
29//! let pruner = LotteryTicketPruner::builder()
30//!     .target_sparsity(0.9)
31//!     .pruning_rounds(10)
32//!     .rewind_strategy(RewindStrategy::Init)
33//!     .build();
34//!
35//! let ticket = pruner.find_ticket(&model, &train_fn)?;
36//! ```
37
38use super::error::PruningError;
39use super::importance::ImportanceScores;
40use super::mask::{generate_unstructured_mask, SparsityMask, SparsityPattern};
41use super::pruner::{Pruner, PruningResult};
42use super::MagnitudeImportance;
43use crate::autograd::Tensor;
44use crate::nn::Module;
45
46/// Strategy for rewinding weights after pruning.
47///
48/// # References
49/// - Init rewinding: Original LTH paper (Frankle & Carbin, 2018)
50/// - Late rewinding: "Stabilizing the Lottery Ticket Hypothesis" (Frankle et al., 2019)
51#[derive(Debug, Clone, Copy, PartialEq, Default)]
52pub enum RewindStrategy {
53    /// Rewind to initialization (W₀).
54    /// Original LTH approach - works for small networks.
55    #[default]
56    Init,
57
58    /// Rewind to early training iteration k (W_k).
59    /// More stable for deeper networks. k is typically 0.1-1% of training.
60    Early {
61        /// Iteration to rewind to (e.g., 500 for 50k total iterations).
62        iteration: usize,
63    },
64
65    /// Rewind to late training iteration.
66    /// Used when early rewinding still fails.
67    Late {
68        /// Fraction of training to complete before capturing rewind point (0.0-1.0).
69        fraction: f32,
70    },
71
72    /// No rewinding - just apply mask to current weights.
73    /// Useful for one-shot pruning comparison.
74    None,
75}
76
77/// Configuration for Lottery Ticket pruning.
78#[derive(Debug, Clone)]
79pub struct LotteryTicketConfig {
80    /// Target sparsity (fraction of weights to prune, 0.0-1.0).
81    pub target_sparsity: f32,
82
83    /// Number of iterative pruning rounds.
84    /// Each round prunes a fraction, accumulating to target_sparsity.
85    pub pruning_rounds: usize,
86
87    /// Strategy for rewinding weights after each pruning round.
88    pub rewind_strategy: RewindStrategy,
89
90    /// Pruning rate per round (computed from target_sparsity and rounds).
91    /// p_per_round = 1 - (1 - target_sparsity)^(1/rounds)
92    pub prune_rate_per_round: f32,
93
94    /// Whether to use global pruning (across all layers) or per-layer.
95    pub global_pruning: bool,
96}
97
98impl Default for LotteryTicketConfig {
99    fn default() -> Self {
100        Self::new(0.9, 10)
101    }
102}
103
104impl LotteryTicketConfig {
105    /// Create a new configuration.
106    ///
107    /// # Arguments
108    /// * `target_sparsity` - Final sparsity (0.0-1.0), e.g., 0.9 = 90% pruned
109    /// * `pruning_rounds` - Number of iterative pruning rounds
110    #[must_use]
111    pub fn new(target_sparsity: f32, pruning_rounds: usize) -> Self {
112        let rounds = pruning_rounds.max(1) as f32;
113        // Compute per-round pruning rate to achieve target after all rounds
114        // After n rounds: remaining = (1 - p)^n = 1 - target_sparsity
115        // So: p = 1 - (1 - target_sparsity)^(1/n)
116        let prune_rate_per_round = 1.0 - (1.0 - target_sparsity).powf(1.0 / rounds);
117
118        Self {
119            target_sparsity: target_sparsity.clamp(0.0, 0.99),
120            pruning_rounds: pruning_rounds.max(1),
121            rewind_strategy: RewindStrategy::Init,
122            prune_rate_per_round,
123            global_pruning: true,
124        }
125    }
126
127    /// Set the rewind strategy.
128    #[must_use]
129    pub fn with_rewind_strategy(mut self, strategy: RewindStrategy) -> Self {
130        self.rewind_strategy = strategy;
131        self
132    }
133
134    /// Enable or disable global pruning.
135    #[must_use]
136    pub fn with_global_pruning(mut self, global: bool) -> Self {
137        self.global_pruning = global;
138        self
139    }
140}
141
142/// A "winning ticket" - the sparse subnetwork found by LTH.
143///
144/// Contains the pruning mask and the initial weights to use for retraining.
145#[derive(Debug, Clone)]
146pub struct WinningTicket {
147    /// The sparsity mask identifying which weights to keep.
148    pub mask: SparsityMask,
149
150    /// The initial weights to rewind to (W₀ or W_k).
151    pub initial_weights: Vec<f32>,
152
153    /// Shape of the weight tensor.
154    pub shape: Vec<usize>,
155
156    /// Final sparsity achieved.
157    pub sparsity: f32,
158
159    /// Number of parameters remaining (non-zero).
160    pub remaining_parameters: usize,
161
162    /// Total parameters in original network.
163    pub total_parameters: usize,
164
165    /// History of sparsity at each pruning round.
166    pub sparsity_history: Vec<f32>,
167}
168
169impl WinningTicket {
170    /// Get compression ratio (original size / pruned size).
171    #[must_use]
172    pub fn compression_ratio(&self) -> f32 {
173        if self.remaining_parameters == 0 {
174            return f32::INFINITY;
175        }
176        self.total_parameters as f32 / self.remaining_parameters as f32
177    }
178
179    /// Get the fraction of weights remaining.
180    #[must_use]
181    pub fn density(&self) -> f32 {
182        1.0 - self.sparsity
183    }
184}
185
186/// Builder for `LotteryTicketPruner`.
187#[derive(Debug, Clone, Default)]
188pub struct LotteryTicketPrunerBuilder {
189    target_sparsity: Option<f32>,
190    pruning_rounds: Option<usize>,
191    rewind_strategy: Option<RewindStrategy>,
192    global_pruning: Option<bool>,
193}
194
195impl LotteryTicketPrunerBuilder {
196    /// Create a new builder.
197    #[must_use]
198    pub fn new() -> Self {
199        Self::default()
200    }
201
202    /// Set target sparsity (0.0-1.0).
203    #[must_use]
204    pub fn target_sparsity(mut self, sparsity: f32) -> Self {
205        self.target_sparsity = Some(sparsity.clamp(0.0, 0.99));
206        self
207    }
208
209    /// Set number of pruning rounds.
210    #[must_use]
211    pub fn pruning_rounds(mut self, rounds: usize) -> Self {
212        self.pruning_rounds = Some(rounds.max(1));
213        self
214    }
215
216    /// Set rewind strategy.
217    #[must_use]
218    pub fn rewind_strategy(mut self, strategy: RewindStrategy) -> Self {
219        self.rewind_strategy = Some(strategy);
220        self
221    }
222
223    /// Enable global pruning across all layers.
224    #[must_use]
225    pub fn global_pruning(mut self, global: bool) -> Self {
226        self.global_pruning = Some(global);
227        self
228    }
229
230    /// Build the pruner.
231    #[must_use]
232    pub fn build(self) -> LotteryTicketPruner {
233        let target_sparsity = self.target_sparsity.unwrap_or(0.9);
234        let pruning_rounds = self.pruning_rounds.unwrap_or(10);
235
236        let mut config = LotteryTicketConfig::new(target_sparsity, pruning_rounds);
237
238        if let Some(strategy) = self.rewind_strategy {
239            config = config.with_rewind_strategy(strategy);
240        }
241        if let Some(global) = self.global_pruning {
242            config = config.with_global_pruning(global);
243        }
244
245        LotteryTicketPruner::with_config(config)
246    }
247}
248
249/// Lottery Ticket Hypothesis pruner.
250///
251/// Implements Iterative Magnitude Pruning (IMP) with weight rewinding
252/// to find sparse, trainable subnetworks.
253#[derive(Debug, Clone)]
254pub struct LotteryTicketPruner {
255    config: LotteryTicketConfig,
256    importance: MagnitudeImportance,
257}
258
259impl Default for LotteryTicketPruner {
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265impl LotteryTicketPruner {
266    /// Create a new LTH pruner with default configuration.
267    /// Default: 90% sparsity over 10 rounds with init rewinding.
268    #[must_use]
269    pub fn new() -> Self {
270        Self::with_config(LotteryTicketConfig::default())
271    }
272
273    /// Create a pruner with custom configuration.
274    #[must_use]
275    pub fn with_config(config: LotteryTicketConfig) -> Self {
276        Self {
277            config,
278            importance: MagnitudeImportance::l2(),
279        }
280    }
281
282    /// Get a builder for configuring the pruner.
283    #[must_use]
284    pub fn builder() -> LotteryTicketPrunerBuilder {
285        LotteryTicketPrunerBuilder::new()
286    }
287
288    /// Get the configuration.
289    #[must_use]
290    pub fn config(&self) -> &LotteryTicketConfig {
291        &self.config
292    }
293
294    /// Find a winning ticket from the given module.
295    ///
296    /// This performs iterative magnitude pruning:
297    /// 1. Compute importance scores
298    /// 2. Generate mask for current round's pruning rate
299    /// 3. Accumulate into overall mask
300    /// 4. Record weight state for rewinding
301    ///
302    /// Note: This is a simplified version that doesn't require training.
303    /// For full LTH, use `find_ticket_with_training`.
304    pub fn find_ticket(&self, module: &dyn Module) -> Result<WinningTicket, PruningError> {
305        let params = module.parameters();
306        if params.is_empty() {
307            return Err(PruningError::NoParameters {
308                module: "module".to_string(),
309            });
310        }
311
312        // Get initial weights (for rewinding)
313        let weights = params[0];
314        let initial_weights = weights.data().to_vec();
315        let shape = weights.shape().to_vec();
316        let total_parameters = initial_weights.len();
317
318        // Initialize cumulative mask (all ones = keep all)
319        let mut cumulative_mask: Vec<f32> = vec![1.0; total_parameters];
320        let mut sparsity_history = Vec::with_capacity(self.config.pruning_rounds);
321
322        // Iterative pruning rounds
323        for round in 0..self.config.pruning_rounds {
324            // Count active weights (not yet pruned)
325            let active_count = cumulative_mask.iter().filter(|&&v| v == 1.0).count();
326            if active_count <= 1 {
327                // Keep at least 1 weight - the "winning ticket" must have at least 1 parameter
328                let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
329                let current_sparsity = zeros as f32 / total_parameters as f32;
330                sparsity_history.push(current_sparsity);
331                break;
332            }
333
334            // Compute target remaining weights after this round
335            // Using the LTH formula: remaining_fraction = (1 - p)^k where k is rounds completed
336            let rounds_completed = (round + 1) as i32;
337            let remaining_fraction =
338                (1.0 - self.config.prune_rate_per_round).powi(rounds_completed);
339            let target_remaining = (total_parameters as f32 * remaining_fraction).round() as usize;
340            // Ensure at least 1 weight remains
341            let target_remaining = target_remaining.max(1);
342
343            // How many to prune this round
344            let to_prune = active_count.saturating_sub(target_remaining);
345
346            if to_prune == 0 {
347                // Calculate current sparsity
348                let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
349                let current_sparsity = zeros as f32 / total_parameters as f32;
350                sparsity_history.push(current_sparsity);
351                continue;
352            }
353
354            // Compute importance scores for active weights only
355            // Collect (index, importance) pairs for active weights
356            let mut active_scores: Vec<(usize, f32)> = initial_weights
357                .iter()
358                .zip(cumulative_mask.iter())
359                .enumerate()
360                .filter(|(_, (_, &mask))| mask == 1.0)
361                .map(|(i, (&w, _))| (i, w.abs()))
362                .collect();
363
364            // Sort by importance (ascending - lowest first to prune)
365            active_scores
366                .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
367
368            // Prune the lowest `to_prune` weights
369            for (idx, _) in active_scores.iter().take(to_prune) {
370                cumulative_mask[*idx] = 0.0;
371            }
372
373            // Calculate current sparsity
374            let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
375            let current_sparsity = zeros as f32 / total_parameters as f32;
376            sparsity_history.push(current_sparsity);
377
378            // Log progress (in debug mode)
379            #[cfg(debug_assertions)]
380            {
381                let _ = round; // Silence unused warning in release
382                eprintln!(
383                    "LTH Round {}/{}: sparsity = {:.2}% (pruned {} of {} active)",
384                    round + 1,
385                    self.config.pruning_rounds,
386                    current_sparsity * 100.0,
387                    to_prune,
388                    active_count
389                );
390            }
391        }
392
393        // Create final mask
394        let mask_tensor = Tensor::new(&cumulative_mask, &shape);
395        let final_mask = SparsityMask::new(mask_tensor, SparsityPattern::Unstructured)?;
396
397        let remaining = cumulative_mask.iter().filter(|&&v| v != 0.0).count();
398        let final_sparsity = 1.0 - (remaining as f32 / total_parameters as f32);
399
400        Ok(WinningTicket {
401            mask: final_mask,
402            initial_weights,
403            shape,
404            sparsity: final_sparsity,
405            remaining_parameters: remaining,
406            total_parameters,
407            sparsity_history,
408        })
409    }
410
411    /// Apply a winning ticket mask to a module.
412    ///
413    /// This zeros out the pruned weights according to the mask.
414    pub fn apply_ticket(
415        &self,
416        module: &mut dyn Module,
417        ticket: &WinningTicket,
418    ) -> Result<PruningResult, PruningError> {
419        let mut params = module.parameters_mut();
420        if params.is_empty() {
421            return Err(PruningError::NoParameters {
422                module: "module".to_string(),
423            });
424        }
425
426        let weights = &mut *params[0];
427        let total = weights.data().len();
428
429        // Apply mask
430        ticket.mask.apply(weights)?;
431
432        // If rewinding is enabled, also reset to initial weights
433        if self.config.rewind_strategy != RewindStrategy::None {
434            let data = weights.data_mut();
435            let mask_data = ticket.mask.tensor().data();
436
437            for (i, (w, &m)) in data.iter_mut().zip(mask_data.iter()).enumerate() {
438                if m != 0.0 {
439                    *w = ticket.initial_weights[i];
440                }
441            }
442        }
443
444        let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
445        let achieved_sparsity = zeros as f32 / total as f32;
446
447        Ok(PruningResult::new(achieved_sparsity, zeros, total))
448    }
449}
450
451include!("lottery_part_02.rs");