entrenar/prune/schedule/
mod.rs1mod cubic;
16mod gradual;
17mod oneshot;
18mod types;
19
20pub use types::PruningSchedule;
21
22impl PruningSchedule {
23 pub fn sparsity_at_step(&self, step: usize) -> f32 {
37 match self {
38 PruningSchedule::OneShot { step: prune_step } => {
39 Self::oneshot_sparsity_at_step(*prune_step, step)
40 }
41 PruningSchedule::Gradual {
42 start_step,
43 end_step,
44 initial_sparsity,
45 final_sparsity,
46 ..
47 } => Self::gradual_sparsity_at_step(
48 *start_step,
49 *end_step,
50 *initial_sparsity,
51 *final_sparsity,
52 step,
53 ),
54 PruningSchedule::Cubic { start_step, end_step, final_sparsity } => {
55 Self::cubic_sparsity_at_step(*start_step, *end_step, *final_sparsity, step)
56 }
57 }
58 }
59
60 pub fn should_prune_at_step(&self, step: usize) -> bool {
70 match self {
71 PruningSchedule::OneShot { step: prune_step } => {
72 Self::oneshot_should_prune_at_step(*prune_step, step)
73 }
74 PruningSchedule::Gradual { start_step, end_step, frequency, .. } => {
75 Self::gradual_should_prune_at_step(*start_step, *end_step, *frequency, step)
76 }
77 PruningSchedule::Cubic { start_step, end_step, .. } => {
78 Self::cubic_should_prune_at_step(*start_step, *end_step, step)
79 }
80 }
81 }
82
83 pub fn num_pruning_steps(&self) -> usize {
89 match self {
90 PruningSchedule::OneShot { .. } => Self::oneshot_num_pruning_steps(),
91 PruningSchedule::Gradual { start_step, end_step, frequency, .. } => {
92 Self::gradual_num_pruning_steps(*start_step, *end_step, *frequency)
93 }
94 PruningSchedule::Cubic { start_step, end_step, .. } => {
95 Self::cubic_num_pruning_steps(*start_step, *end_step)
96 }
97 }
98 }
99
100 pub fn validate(&self) -> Result<(), String> {
106 match self {
107 PruningSchedule::OneShot { .. } => Self::oneshot_validate(),
108 PruningSchedule::Gradual {
109 start_step,
110 end_step,
111 initial_sparsity,
112 final_sparsity,
113 ..
114 } => Self::gradual_validate(*start_step, *end_step, *initial_sparsity, *final_sparsity),
115 PruningSchedule::Cubic { start_step, end_step, final_sparsity } => {
116 Self::cubic_validate(*start_step, *end_step, *final_sparsity)
117 }
118 }
119 }
120
121 pub fn is_complete(&self, step: usize) -> bool {
123 match self {
124 PruningSchedule::OneShot { step: prune_step } => {
125 Self::oneshot_is_complete(*prune_step, step)
126 }
127 PruningSchedule::Gradual { end_step, .. } => Self::gradual_is_complete(*end_step, step),
128 PruningSchedule::Cubic { end_step, .. } => Self::cubic_is_complete(*end_step, step),
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_default_schedule() {
139 let schedule = PruningSchedule::default();
141 match schedule {
142 PruningSchedule::OneShot { step } => {
143 assert_eq!(step, 0, "SCHED-045 FALSIFIED: Default should be OneShot at step 0");
144 }
145 _ => panic!("SCHED-045 FALSIFIED: Default should be OneShot variant"),
146 }
147 }
148}