1use super::error::PruningError;
39use super::importance::ImportanceScores;
40use super::mask::{generate_unstructured_mask, SparsityMask, SparsityPattern};
41use super::pruner::{Pruner, PruningResult};
42use super::MagnitudeImportance;
43use crate::autograd::Tensor;
44use crate::nn::Module;
45
46#[derive(Debug, Clone, Copy, PartialEq, Default)]
52pub enum RewindStrategy {
53 #[default]
56 Init,
57
58 Early {
61 iteration: usize,
63 },
64
65 Late {
68 fraction: f32,
70 },
71
72 None,
75}
76
77#[derive(Debug, Clone)]
79pub struct LotteryTicketConfig {
80 pub target_sparsity: f32,
82
83 pub pruning_rounds: usize,
86
87 pub rewind_strategy: RewindStrategy,
89
90 pub prune_rate_per_round: f32,
93
94 pub global_pruning: bool,
96}
97
98impl Default for LotteryTicketConfig {
99 fn default() -> Self {
100 Self::new(0.9, 10)
101 }
102}
103
104impl LotteryTicketConfig {
105 #[must_use]
111 pub fn new(target_sparsity: f32, pruning_rounds: usize) -> Self {
112 let rounds = pruning_rounds.max(1) as f32;
113 let prune_rate_per_round = 1.0 - (1.0 - target_sparsity).powf(1.0 / rounds);
117
118 Self {
119 target_sparsity: target_sparsity.clamp(0.0, 0.99),
120 pruning_rounds: pruning_rounds.max(1),
121 rewind_strategy: RewindStrategy::Init,
122 prune_rate_per_round,
123 global_pruning: true,
124 }
125 }
126
127 #[must_use]
129 pub fn with_rewind_strategy(mut self, strategy: RewindStrategy) -> Self {
130 self.rewind_strategy = strategy;
131 self
132 }
133
134 #[must_use]
136 pub fn with_global_pruning(mut self, global: bool) -> Self {
137 self.global_pruning = global;
138 self
139 }
140}
141
142#[derive(Debug, Clone)]
146pub struct WinningTicket {
147 pub mask: SparsityMask,
149
150 pub initial_weights: Vec<f32>,
152
153 pub shape: Vec<usize>,
155
156 pub sparsity: f32,
158
159 pub remaining_parameters: usize,
161
162 pub total_parameters: usize,
164
165 pub sparsity_history: Vec<f32>,
167}
168
169impl WinningTicket {
170 #[must_use]
172 pub fn compression_ratio(&self) -> f32 {
173 if self.remaining_parameters == 0 {
174 return f32::INFINITY;
175 }
176 self.total_parameters as f32 / self.remaining_parameters as f32
177 }
178
179 #[must_use]
181 pub fn density(&self) -> f32 {
182 1.0 - self.sparsity
183 }
184}
185
186#[derive(Debug, Clone, Default)]
188pub struct LotteryTicketPrunerBuilder {
189 target_sparsity: Option<f32>,
190 pruning_rounds: Option<usize>,
191 rewind_strategy: Option<RewindStrategy>,
192 global_pruning: Option<bool>,
193}
194
195impl LotteryTicketPrunerBuilder {
196 #[must_use]
198 pub fn new() -> Self {
199 Self::default()
200 }
201
202 #[must_use]
204 pub fn target_sparsity(mut self, sparsity: f32) -> Self {
205 self.target_sparsity = Some(sparsity.clamp(0.0, 0.99));
206 self
207 }
208
209 #[must_use]
211 pub fn pruning_rounds(mut self, rounds: usize) -> Self {
212 self.pruning_rounds = Some(rounds.max(1));
213 self
214 }
215
216 #[must_use]
218 pub fn rewind_strategy(mut self, strategy: RewindStrategy) -> Self {
219 self.rewind_strategy = Some(strategy);
220 self
221 }
222
223 #[must_use]
225 pub fn global_pruning(mut self, global: bool) -> Self {
226 self.global_pruning = Some(global);
227 self
228 }
229
230 #[must_use]
232 pub fn build(self) -> LotteryTicketPruner {
233 let target_sparsity = self.target_sparsity.unwrap_or(0.9);
234 let pruning_rounds = self.pruning_rounds.unwrap_or(10);
235
236 let mut config = LotteryTicketConfig::new(target_sparsity, pruning_rounds);
237
238 if let Some(strategy) = self.rewind_strategy {
239 config = config.with_rewind_strategy(strategy);
240 }
241 if let Some(global) = self.global_pruning {
242 config = config.with_global_pruning(global);
243 }
244
245 LotteryTicketPruner::with_config(config)
246 }
247}
248
249#[derive(Debug, Clone)]
254pub struct LotteryTicketPruner {
255 config: LotteryTicketConfig,
256 importance: MagnitudeImportance,
257}
258
259impl Default for LotteryTicketPruner {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265impl LotteryTicketPruner {
266 #[must_use]
269 pub fn new() -> Self {
270 Self::with_config(LotteryTicketConfig::default())
271 }
272
273 #[must_use]
275 pub fn with_config(config: LotteryTicketConfig) -> Self {
276 Self {
277 config,
278 importance: MagnitudeImportance::l2(),
279 }
280 }
281
282 #[must_use]
284 pub fn builder() -> LotteryTicketPrunerBuilder {
285 LotteryTicketPrunerBuilder::new()
286 }
287
288 #[must_use]
290 pub fn config(&self) -> &LotteryTicketConfig {
291 &self.config
292 }
293
294 pub fn find_ticket(&self, module: &dyn Module) -> Result<WinningTicket, PruningError> {
305 let params = module.parameters();
306 if params.is_empty() {
307 return Err(PruningError::NoParameters {
308 module: "module".to_string(),
309 });
310 }
311
312 let weights = params[0];
314 let initial_weights = weights.data().to_vec();
315 let shape = weights.shape().to_vec();
316 let total_parameters = initial_weights.len();
317
318 let mut cumulative_mask: Vec<f32> = vec![1.0; total_parameters];
320 let mut sparsity_history = Vec::with_capacity(self.config.pruning_rounds);
321
322 for round in 0..self.config.pruning_rounds {
324 let active_count = cumulative_mask.iter().filter(|&&v| v == 1.0).count();
326 if active_count <= 1 {
327 let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
329 let current_sparsity = zeros as f32 / total_parameters as f32;
330 sparsity_history.push(current_sparsity);
331 break;
332 }
333
334 let rounds_completed = (round + 1) as i32;
337 let remaining_fraction =
338 (1.0 - self.config.prune_rate_per_round).powi(rounds_completed);
339 let target_remaining = (total_parameters as f32 * remaining_fraction).round() as usize;
340 let target_remaining = target_remaining.max(1);
342
343 let to_prune = active_count.saturating_sub(target_remaining);
345
346 if to_prune == 0 {
347 let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
349 let current_sparsity = zeros as f32 / total_parameters as f32;
350 sparsity_history.push(current_sparsity);
351 continue;
352 }
353
354 let mut active_scores: Vec<(usize, f32)> = initial_weights
357 .iter()
358 .zip(cumulative_mask.iter())
359 .enumerate()
360 .filter(|(_, (_, &mask))| mask == 1.0)
361 .map(|(i, (&w, _))| (i, w.abs()))
362 .collect();
363
364 active_scores
366 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
367
368 for (idx, _) in active_scores.iter().take(to_prune) {
370 cumulative_mask[*idx] = 0.0;
371 }
372
373 let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
375 let current_sparsity = zeros as f32 / total_parameters as f32;
376 sparsity_history.push(current_sparsity);
377
378 #[cfg(debug_assertions)]
380 {
381 let _ = round; eprintln!(
383 "LTH Round {}/{}: sparsity = {:.2}% (pruned {} of {} active)",
384 round + 1,
385 self.config.pruning_rounds,
386 current_sparsity * 100.0,
387 to_prune,
388 active_count
389 );
390 }
391 }
392
393 let mask_tensor = Tensor::new(&cumulative_mask, &shape);
395 let final_mask = SparsityMask::new(mask_tensor, SparsityPattern::Unstructured)?;
396
397 let remaining = cumulative_mask.iter().filter(|&&v| v != 0.0).count();
398 let final_sparsity = 1.0 - (remaining as f32 / total_parameters as f32);
399
400 Ok(WinningTicket {
401 mask: final_mask,
402 initial_weights,
403 shape,
404 sparsity: final_sparsity,
405 remaining_parameters: remaining,
406 total_parameters,
407 sparsity_history,
408 })
409 }
410
411 pub fn apply_ticket(
415 &self,
416 module: &mut dyn Module,
417 ticket: &WinningTicket,
418 ) -> Result<PruningResult, PruningError> {
419 let mut params = module.parameters_mut();
420 if params.is_empty() {
421 return Err(PruningError::NoParameters {
422 module: "module".to_string(),
423 });
424 }
425
426 let weights = &mut *params[0];
427 let total = weights.data().len();
428
429 ticket.mask.apply(weights)?;
431
432 if self.config.rewind_strategy != RewindStrategy::None {
434 let data = weights.data_mut();
435 let mask_data = ticket.mask.tensor().data();
436
437 for (i, (w, &m)) in data.iter_mut().zip(mask_data.iter()).enumerate() {
438 if m != 0.0 {
439 *w = ticket.initial_weights[i];
440 }
441 }
442 }
443
444 let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
445 let achieved_sparsity = zeros as f32 / total as f32;
446
447 Ok(PruningResult::new(achieved_sparsity, zeros, total))
448 }
449}
450
451include!("lottery_part_02.rs");