use super::error::PruningError;
use super::importance::ImportanceScores;
use super::mask::{generate_unstructured_mask, SparsityMask, SparsityPattern};
use super::pruner::{Pruner, PruningResult};
use super::MagnitudeImportance;
use crate::autograd::Tensor;
use crate::nn::Module;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum RewindStrategy {
#[default]
Init,
Early {
iteration: usize,
},
Late {
fraction: f32,
},
None,
}
#[derive(Debug, Clone)]
pub struct LotteryTicketConfig {
pub target_sparsity: f32,
pub pruning_rounds: usize,
pub rewind_strategy: RewindStrategy,
pub prune_rate_per_round: f32,
pub global_pruning: bool,
}
impl Default for LotteryTicketConfig {
fn default() -> Self {
Self::new(0.9, 10)
}
}
impl LotteryTicketConfig {
#[must_use]
pub fn new(target_sparsity: f32, pruning_rounds: usize) -> Self {
let rounds = pruning_rounds.max(1) as f32;
let prune_rate_per_round = 1.0 - (1.0 - target_sparsity).powf(1.0 / rounds);
Self {
target_sparsity: target_sparsity.clamp(0.0, 0.99),
pruning_rounds: pruning_rounds.max(1),
rewind_strategy: RewindStrategy::Init,
prune_rate_per_round,
global_pruning: true,
}
}
#[must_use]
pub fn with_rewind_strategy(mut self, strategy: RewindStrategy) -> Self {
self.rewind_strategy = strategy;
self
}
#[must_use]
pub fn with_global_pruning(mut self, global: bool) -> Self {
self.global_pruning = global;
self
}
}
#[derive(Debug, Clone)]
pub struct WinningTicket {
pub mask: SparsityMask,
pub initial_weights: Vec<f32>,
pub shape: Vec<usize>,
pub sparsity: f32,
pub remaining_parameters: usize,
pub total_parameters: usize,
pub sparsity_history: Vec<f32>,
}
impl WinningTicket {
#[must_use]
pub fn compression_ratio(&self) -> f32 {
if self.remaining_parameters == 0 {
return f32::INFINITY;
}
self.total_parameters as f32 / self.remaining_parameters as f32
}
#[must_use]
pub fn density(&self) -> f32 {
1.0 - self.sparsity
}
}
#[derive(Debug, Clone, Default)]
pub struct LotteryTicketPrunerBuilder {
target_sparsity: Option<f32>,
pruning_rounds: Option<usize>,
rewind_strategy: Option<RewindStrategy>,
global_pruning: Option<bool>,
}
impl LotteryTicketPrunerBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn target_sparsity(mut self, sparsity: f32) -> Self {
self.target_sparsity = Some(sparsity.clamp(0.0, 0.99));
self
}
#[must_use]
pub fn pruning_rounds(mut self, rounds: usize) -> Self {
self.pruning_rounds = Some(rounds.max(1));
self
}
#[must_use]
pub fn rewind_strategy(mut self, strategy: RewindStrategy) -> Self {
self.rewind_strategy = Some(strategy);
self
}
#[must_use]
pub fn global_pruning(mut self, global: bool) -> Self {
self.global_pruning = Some(global);
self
}
#[must_use]
pub fn build(self) -> LotteryTicketPruner {
let target_sparsity = self.target_sparsity.unwrap_or(0.9);
let pruning_rounds = self.pruning_rounds.unwrap_or(10);
let mut config = LotteryTicketConfig::new(target_sparsity, pruning_rounds);
if let Some(strategy) = self.rewind_strategy {
config = config.with_rewind_strategy(strategy);
}
if let Some(global) = self.global_pruning {
config = config.with_global_pruning(global);
}
LotteryTicketPruner::with_config(config)
}
}
#[derive(Debug, Clone)]
pub struct LotteryTicketPruner {
config: LotteryTicketConfig,
importance: MagnitudeImportance,
}
impl Default for LotteryTicketPruner {
fn default() -> Self {
Self::new()
}
}
impl LotteryTicketPruner {
#[must_use]
pub fn new() -> Self {
Self::with_config(LotteryTicketConfig::default())
}
#[must_use]
pub fn with_config(config: LotteryTicketConfig) -> Self {
Self {
config,
importance: MagnitudeImportance::l2(),
}
}
#[must_use]
pub fn builder() -> LotteryTicketPrunerBuilder {
LotteryTicketPrunerBuilder::new()
}
#[must_use]
pub fn config(&self) -> &LotteryTicketConfig {
&self.config
}
pub fn find_ticket(&self, module: &dyn Module) -> Result<WinningTicket, PruningError> {
let params = module.parameters();
if params.is_empty() {
return Err(PruningError::NoParameters {
module: "module".to_string(),
});
}
let weights = params[0];
let initial_weights = weights.data().to_vec();
let shape = weights.shape().to_vec();
let total_parameters = initial_weights.len();
let mut cumulative_mask: Vec<f32> = vec![1.0; total_parameters];
let mut sparsity_history = Vec::with_capacity(self.config.pruning_rounds);
for round in 0..self.config.pruning_rounds {
let active_count = cumulative_mask.iter().filter(|&&v| v == 1.0).count();
if active_count <= 1 {
let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
let current_sparsity = zeros as f32 / total_parameters as f32;
sparsity_history.push(current_sparsity);
break;
}
let rounds_completed = (round + 1) as i32;
let remaining_fraction =
(1.0 - self.config.prune_rate_per_round).powi(rounds_completed);
let target_remaining = (total_parameters as f32 * remaining_fraction).round() as usize;
let target_remaining = target_remaining.max(1);
let to_prune = active_count.saturating_sub(target_remaining);
if to_prune == 0 {
let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
let current_sparsity = zeros as f32 / total_parameters as f32;
sparsity_history.push(current_sparsity);
continue;
}
let mut active_scores: Vec<(usize, f32)> = initial_weights
.iter()
.zip(cumulative_mask.iter())
.enumerate()
.filter(|(_, (_, &mask))| mask == 1.0)
.map(|(i, (&w, _))| (i, w.abs()))
.collect();
active_scores
.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
for (idx, _) in active_scores.iter().take(to_prune) {
cumulative_mask[*idx] = 0.0;
}
let zeros = cumulative_mask.iter().filter(|&&v| v == 0.0).count();
let current_sparsity = zeros as f32 / total_parameters as f32;
sparsity_history.push(current_sparsity);
#[cfg(debug_assertions)]
{
let _ = round; eprintln!(
"LTH Round {}/{}: sparsity = {:.2}% (pruned {} of {} active)",
round + 1,
self.config.pruning_rounds,
current_sparsity * 100.0,
to_prune,
active_count
);
}
}
let mask_tensor = Tensor::new(&cumulative_mask, &shape);
let final_mask = SparsityMask::new(mask_tensor, SparsityPattern::Unstructured)?;
let remaining = cumulative_mask.iter().filter(|&&v| v != 0.0).count();
let final_sparsity = 1.0 - (remaining as f32 / total_parameters as f32);
Ok(WinningTicket {
mask: final_mask,
initial_weights,
shape,
sparsity: final_sparsity,
remaining_parameters: remaining,
total_parameters,
sparsity_history,
})
}
pub fn apply_ticket(
&self,
module: &mut dyn Module,
ticket: &WinningTicket,
) -> Result<PruningResult, PruningError> {
let mut params = module.parameters_mut();
if params.is_empty() {
return Err(PruningError::NoParameters {
module: "module".to_string(),
});
}
let weights = &mut *params[0];
let total = weights.data().len();
ticket.mask.apply(weights)?;
if self.config.rewind_strategy != RewindStrategy::None {
let data = weights.data_mut();
let mask_data = ticket.mask.tensor().data();
for (i, (w, &m)) in data.iter_mut().zip(mask_data.iter()).enumerate() {
if m != 0.0 {
*w = ticket.initial_weights[i];
}
}
}
let zeros = weights.data().iter().filter(|&&v| v == 0.0).count();
let achieved_sparsity = zeros as f32 / total as f32;
Ok(PruningResult::new(achieved_sparsity, zeros, total))
}
}
include!("lottery_pruner_impl.rs");