Skip to main content

entrenar/prune/schedule/cubic/
core.rs

1//! Core cubic pruning schedule implementation.
2
3use crate::prune::schedule::PruningSchedule;
4
5impl PruningSchedule {
6    /// Compute the target sparsity at a given training step for Cubic schedule.
7    /// Formula: s_t = s_f * (1 - (1 - t/T)^3)
8    pub(in crate::prune::schedule) fn cubic_sparsity_at_step(
9        start_step: usize,
10        end_step: usize,
11        final_sparsity: f32,
12        step: usize,
13    ) -> f32 {
14        if step < start_step {
15            0.0
16        } else if step >= end_step {
17            final_sparsity
18        } else {
19            let t = (step - start_step) as f32;
20            let total = (end_step - start_step) as f32;
21            let ratio = 1.0 - t / total;
22            final_sparsity * (1.0 - ratio.powi(3))
23        }
24    }
25
26    /// Check if pruning should be applied at this step for Cubic schedule.
27    /// Cubic prunes at every step in the window.
28    pub(in crate::prune::schedule) fn cubic_should_prune_at_step(
29        start_step: usize,
30        end_step: usize,
31        step: usize,
32    ) -> bool {
33        step >= start_step && step <= end_step
34    }
35
36    /// Get the total number of pruning operations for Cubic schedule.
37    pub(in crate::prune::schedule) fn cubic_num_pruning_steps(
38        start_step: usize,
39        end_step: usize,
40    ) -> usize {
41        end_step - start_step + 1
42    }
43
44    /// Validate Cubic schedule.
45    pub(in crate::prune::schedule) fn cubic_validate(
46        start_step: usize,
47        end_step: usize,
48        final_sparsity: f32,
49    ) -> Result<(), String> {
50        if end_step <= start_step {
51            return Err(format!(
52                "end_step ({end_step}) must be greater than start_step ({start_step})"
53            ));
54        }
55        if !(0.0..=1.0).contains(&final_sparsity) {
56            return Err(format!("final_sparsity ({final_sparsity}) must be between 0.0 and 1.0"));
57        }
58        Ok(())
59    }
60
61    /// Check if Cubic pruning has completed.
62    pub(in crate::prune::schedule) fn cubic_is_complete(end_step: usize, step: usize) -> bool {
63        step > end_step
64    }
65}