mod cubic;
mod gradual;
mod oneshot;
mod types;
pub use types::PruningSchedule;
impl PruningSchedule {
pub fn sparsity_at_step(&self, step: usize) -> f32 {
match self {
PruningSchedule::OneShot { step: prune_step } => {
Self::oneshot_sparsity_at_step(*prune_step, step)
}
PruningSchedule::Gradual {
start_step,
end_step,
initial_sparsity,
final_sparsity,
..
} => Self::gradual_sparsity_at_step(
*start_step,
*end_step,
*initial_sparsity,
*final_sparsity,
step,
),
PruningSchedule::Cubic { start_step, end_step, final_sparsity } => {
Self::cubic_sparsity_at_step(*start_step, *end_step, *final_sparsity, step)
}
}
}
pub fn should_prune_at_step(&self, step: usize) -> bool {
match self {
PruningSchedule::OneShot { step: prune_step } => {
Self::oneshot_should_prune_at_step(*prune_step, step)
}
PruningSchedule::Gradual { start_step, end_step, frequency, .. } => {
Self::gradual_should_prune_at_step(*start_step, *end_step, *frequency, step)
}
PruningSchedule::Cubic { start_step, end_step, .. } => {
Self::cubic_should_prune_at_step(*start_step, *end_step, step)
}
}
}
pub fn num_pruning_steps(&self) -> usize {
match self {
PruningSchedule::OneShot { .. } => Self::oneshot_num_pruning_steps(),
PruningSchedule::Gradual { start_step, end_step, frequency, .. } => {
Self::gradual_num_pruning_steps(*start_step, *end_step, *frequency)
}
PruningSchedule::Cubic { start_step, end_step, .. } => {
Self::cubic_num_pruning_steps(*start_step, *end_step)
}
}
}
pub fn validate(&self) -> Result<(), String> {
match self {
PruningSchedule::OneShot { .. } => Self::oneshot_validate(),
PruningSchedule::Gradual {
start_step,
end_step,
initial_sparsity,
final_sparsity,
..
} => Self::gradual_validate(*start_step, *end_step, *initial_sparsity, *final_sparsity),
PruningSchedule::Cubic { start_step, end_step, final_sparsity } => {
Self::cubic_validate(*start_step, *end_step, *final_sparsity)
}
}
}
pub fn is_complete(&self, step: usize) -> bool {
match self {
PruningSchedule::OneShot { step: prune_step } => {
Self::oneshot_is_complete(*prune_step, step)
}
PruningSchedule::Gradual { end_step, .. } => Self::gradual_is_complete(*end_step, step),
PruningSchedule::Cubic { end_step, .. } => Self::cubic_is_complete(*end_step, step),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_schedule() {
let schedule = PruningSchedule::default();
match schedule {
PruningSchedule::OneShot { step } => {
assert_eq!(step, 0, "SCHED-045 FALSIFIED: Default should be OneShot at step 0");
}
_ => panic!("SCHED-045 FALSIFIED: Default should be OneShot variant"),
}
}
}