entrenar/prune/schedule/gradual/
mod.rs1#[cfg(test)]
4mod proptests;
5#[cfg(test)]
6mod tests;
7
8use super::PruningSchedule;
9
10impl PruningSchedule {
11 pub(super) fn gradual_sparsity_at_step(
13 start_step: usize,
14 end_step: usize,
15 initial_sparsity: f32,
16 final_sparsity: f32,
17 step: usize,
18 ) -> f32 {
19 if step < start_step {
20 initial_sparsity
21 } else if step >= end_step {
22 final_sparsity
23 } else {
24 let progress = (step - start_step) as f32 / (end_step - start_step) as f32;
25 initial_sparsity + progress * (final_sparsity - initial_sparsity)
26 }
27 }
28
29 pub(super) fn gradual_should_prune_at_step(
31 start_step: usize,
32 end_step: usize,
33 frequency: usize,
34 step: usize,
35 ) -> bool {
36 if step < start_step || step > end_step {
37 return false;
38 }
39 if frequency == 0 {
40 return step == start_step;
41 }
42 (step - start_step).is_multiple_of(frequency)
43 }
44
45 pub(super) fn gradual_num_pruning_steps(
47 start_step: usize,
48 end_step: usize,
49 frequency: usize,
50 ) -> usize {
51 if frequency == 0 {
52 1
53 } else {
54 (end_step - start_step) / frequency + 1
55 }
56 }
57
58 pub(super) fn gradual_validate(
60 start_step: usize,
61 end_step: usize,
62 initial_sparsity: f32,
63 final_sparsity: f32,
64 ) -> Result<(), String> {
65 if end_step <= start_step {
66 return Err(format!(
67 "end_step ({end_step}) must be greater than start_step ({start_step})"
68 ));
69 }
70 if !(0.0..=1.0).contains(&initial_sparsity) {
71 return Err(format!(
72 "initial_sparsity ({initial_sparsity}) must be between 0.0 and 1.0"
73 ));
74 }
75 if !(0.0..=1.0).contains(&final_sparsity) {
76 return Err(format!("final_sparsity ({final_sparsity}) must be between 0.0 and 1.0"));
77 }
78 Ok(())
79 }
80
81 pub(super) fn gradual_is_complete(end_step: usize, step: usize) -> bool {
83 step > end_step
84 }
85}