Skip to main content

entrenar/prune/schedule/gradual/
mod.rs

1//! Gradual pruning schedule methods.
2
3#[cfg(test)]
4mod proptests;
5#[cfg(test)]
6mod tests;
7
8use super::PruningSchedule;
9
10impl PruningSchedule {
11    /// Compute the target sparsity at a given training step for Gradual schedule.
12    pub(super) fn gradual_sparsity_at_step(
13        start_step: usize,
14        end_step: usize,
15        initial_sparsity: f32,
16        final_sparsity: f32,
17        step: usize,
18    ) -> f32 {
19        if step < start_step {
20            initial_sparsity
21        } else if step >= end_step {
22            final_sparsity
23        } else {
24            let progress = (step - start_step) as f32 / (end_step - start_step) as f32;
25            initial_sparsity + progress * (final_sparsity - initial_sparsity)
26        }
27    }
28
29    /// Check if pruning should be applied at this step for Gradual schedule.
30    pub(super) fn gradual_should_prune_at_step(
31        start_step: usize,
32        end_step: usize,
33        frequency: usize,
34        step: usize,
35    ) -> bool {
36        if step < start_step || step > end_step {
37            return false;
38        }
39        if frequency == 0 {
40            return step == start_step;
41        }
42        (step - start_step).is_multiple_of(frequency)
43    }
44
45    /// Get the total number of pruning operations for Gradual schedule.
46    pub(super) fn gradual_num_pruning_steps(
47        start_step: usize,
48        end_step: usize,
49        frequency: usize,
50    ) -> usize {
51        if frequency == 0 {
52            1
53        } else {
54            (end_step - start_step) / frequency + 1
55        }
56    }
57
58    /// Validate Gradual schedule.
59    pub(super) fn gradual_validate(
60        start_step: usize,
61        end_step: usize,
62        initial_sparsity: f32,
63        final_sparsity: f32,
64    ) -> Result<(), String> {
65        if end_step <= start_step {
66            return Err(format!(
67                "end_step ({end_step}) must be greater than start_step ({start_step})"
68            ));
69        }
70        if !(0.0..=1.0).contains(&initial_sparsity) {
71            return Err(format!(
72                "initial_sparsity ({initial_sparsity}) must be between 0.0 and 1.0"
73            ));
74        }
75        if !(0.0..=1.0).contains(&final_sparsity) {
76            return Err(format!("final_sparsity ({final_sparsity}) must be between 0.0 and 1.0"));
77        }
78        Ok(())
79    }
80
81    /// Check if Gradual pruning has completed.
82    pub(super) fn gradual_is_complete(end_step: usize, step: usize) -> bool {
83        step > end_step
84    }
85}