Skip to main content

entrenar/prune/schedule/
mod.rs

1//! Pruning schedule definitions
2//!
3//! Defines when and how sparsity increases during training:
4//! - OneShot: Prune all at once at a specific step
5//! - Gradual: Linear interpolation from initial to final sparsity
6//! - Cubic: Cubic schedule (Zhu & Gupta, 2017) for smoother transitions
7//!
8//! # Toyota Way: Kaizen (Continuous Improvement)
9//! Gradual and cubic schedules allow incremental model adaptation.
10//!
11//! # References
12//! - Zhu, M., & Gupta, S. (2017). To prune, or not to prune: exploring the
13//!   efficacy of pruning for model compression. arXiv:1710.01878.
14
15mod cubic;
16mod gradual;
17mod oneshot;
18mod types;
19
20pub use types::PruningSchedule;
21
22impl PruningSchedule {
23    /// Compute the target sparsity at a given training step.
24    ///
25    /// # Arguments
26    ///
27    /// * `step` - Current training step
28    ///
29    /// # Returns
30    ///
31    /// Target sparsity as a value between 0.0 and 1.0.
32    ///
33    /// # Panics
34    ///
35    /// Does not panic. Returns bounded values for all inputs.
36    pub fn sparsity_at_step(&self, step: usize) -> f32 {
37        match self {
38            PruningSchedule::OneShot { step: prune_step } => {
39                Self::oneshot_sparsity_at_step(*prune_step, step)
40            }
41            PruningSchedule::Gradual {
42                start_step,
43                end_step,
44                initial_sparsity,
45                final_sparsity,
46                ..
47            } => Self::gradual_sparsity_at_step(
48                *start_step,
49                *end_step,
50                *initial_sparsity,
51                *final_sparsity,
52                step,
53            ),
54            PruningSchedule::Cubic { start_step, end_step, final_sparsity } => {
55                Self::cubic_sparsity_at_step(*start_step, *end_step, *final_sparsity, step)
56            }
57        }
58    }
59
60    /// Check if pruning should be applied at this step based on frequency.
61    ///
62    /// For `OneShot`, returns true only at the prune step.
63    /// For `Gradual`, returns true at steps matching the frequency.
64    /// For `Cubic`, returns true at every step in the pruning window.
65    ///
66    /// # Arguments
67    ///
68    /// * `step` - Current training step
69    pub fn should_prune_at_step(&self, step: usize) -> bool {
70        match self {
71            PruningSchedule::OneShot { step: prune_step } => {
72                Self::oneshot_should_prune_at_step(*prune_step, step)
73            }
74            PruningSchedule::Gradual { start_step, end_step, frequency, .. } => {
75                Self::gradual_should_prune_at_step(*start_step, *end_step, *frequency, step)
76            }
77            PruningSchedule::Cubic { start_step, end_step, .. } => {
78                Self::cubic_should_prune_at_step(*start_step, *end_step, step)
79            }
80        }
81    }
82
83    /// Get the total number of pruning operations for this schedule.
84    ///
85    /// # Returns
86    ///
87    /// Expected number of times pruning will be applied.
88    pub fn num_pruning_steps(&self) -> usize {
89        match self {
90            PruningSchedule::OneShot { .. } => Self::oneshot_num_pruning_steps(),
91            PruningSchedule::Gradual { start_step, end_step, frequency, .. } => {
92                Self::gradual_num_pruning_steps(*start_step, *end_step, *frequency)
93            }
94            PruningSchedule::Cubic { start_step, end_step, .. } => {
95                Self::cubic_num_pruning_steps(*start_step, *end_step)
96            }
97        }
98    }
99
100    /// Check if the schedule is valid.
101    ///
102    /// # Errors
103    ///
104    /// Returns an error message if the schedule is invalid.
105    pub fn validate(&self) -> Result<(), String> {
106        match self {
107            PruningSchedule::OneShot { .. } => Self::oneshot_validate(),
108            PruningSchedule::Gradual {
109                start_step,
110                end_step,
111                initial_sparsity,
112                final_sparsity,
113                ..
114            } => Self::gradual_validate(*start_step, *end_step, *initial_sparsity, *final_sparsity),
115            PruningSchedule::Cubic { start_step, end_step, final_sparsity } => {
116                Self::cubic_validate(*start_step, *end_step, *final_sparsity)
117            }
118        }
119    }
120
121    /// Check if pruning has completed (current step is past the schedule).
122    pub fn is_complete(&self, step: usize) -> bool {
123        match self {
124            PruningSchedule::OneShot { step: prune_step } => {
125                Self::oneshot_is_complete(*prune_step, step)
126            }
127            PruningSchedule::Gradual { end_step, .. } => Self::gradual_is_complete(*end_step, step),
128            PruningSchedule::Cubic { end_step, .. } => Self::cubic_is_complete(*end_step, step),
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_default_schedule() {
139        // TEST_ID: SCHED-045
140        let schedule = PruningSchedule::default();
141        match schedule {
142            PruningSchedule::OneShot { step } => {
143                assert_eq!(step, 0, "SCHED-045 FALSIFIED: Default should be OneShot at step 0");
144            }
145            _ => panic!("SCHED-045 FALSIFIED: Default should be OneShot variant"),
146        }
147    }
148}