entrenar/prune/schedule/cubic/
core.rs1use crate::prune::schedule::PruningSchedule;
4
5impl PruningSchedule {
6 pub(in crate::prune::schedule) fn cubic_sparsity_at_step(
9 start_step: usize,
10 end_step: usize,
11 final_sparsity: f32,
12 step: usize,
13 ) -> f32 {
14 if step < start_step {
15 0.0
16 } else if step >= end_step {
17 final_sparsity
18 } else {
19 let t = (step - start_step) as f32;
20 let total = (end_step - start_step) as f32;
21 let ratio = 1.0 - t / total;
22 final_sparsity * (1.0 - ratio.powi(3))
23 }
24 }
25
26 pub(in crate::prune::schedule) fn cubic_should_prune_at_step(
29 start_step: usize,
30 end_step: usize,
31 step: usize,
32 ) -> bool {
33 step >= start_step && step <= end_step
34 }
35
36 pub(in crate::prune::schedule) fn cubic_num_pruning_steps(
38 start_step: usize,
39 end_step: usize,
40 ) -> usize {
41 end_step - start_step + 1
42 }
43
44 pub(in crate::prune::schedule) fn cubic_validate(
46 start_step: usize,
47 end_step: usize,
48 final_sparsity: f32,
49 ) -> Result<(), String> {
50 if end_step <= start_step {
51 return Err(format!(
52 "end_step ({end_step}) must be greater than start_step ({start_step})"
53 ));
54 }
55 if !(0.0..=1.0).contains(&final_sparsity) {
56 return Err(format!("final_sparsity ({final_sparsity}) must be between 0.0 and 1.0"));
57 }
58 Ok(())
59 }
60
61 pub(in crate::prune::schedule) fn cubic_is_complete(end_step: usize, step: usize) -> bool {
63 step > end_step
64 }
65}