use super::PruningSchedule;
impl PruningSchedule {
pub(super) fn oneshot_sparsity_at_step(prune_step: usize, step: usize) -> f32 {
if step >= prune_step {
1.0
} else {
0.0
}
}
pub(super) fn oneshot_should_prune_at_step(prune_step: usize, step: usize) -> bool {
step == prune_step
}
pub(super) fn oneshot_num_pruning_steps() -> usize {
1
}
pub(super) fn oneshot_validate() -> Result<(), String> {
Ok(())
}
pub(super) fn oneshot_is_complete(prune_step: usize, step: usize) -> bool {
step > prune_step
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oneshot_before_step_returns_zero() {
let schedule = PruningSchedule::OneShot { step: 1000 };
assert_eq!(
schedule.sparsity_at_step(0),
0.0,
"SCHED-001 FALSIFIED: OneShot should return 0.0 before prune step"
);
assert_eq!(
schedule.sparsity_at_step(999),
0.0,
"SCHED-001 FALSIFIED: OneShot should return 0.0 at step before prune"
);
}
#[test]
fn test_oneshot_at_step_returns_one() {
let schedule = PruningSchedule::OneShot { step: 1000 };
assert_eq!(
schedule.sparsity_at_step(1000),
1.0,
"SCHED-002 FALSIFIED: OneShot should return 1.0 at prune step"
);
}
#[test]
fn test_oneshot_after_step_returns_one() {
let schedule = PruningSchedule::OneShot { step: 1000 };
assert_eq!(
schedule.sparsity_at_step(1001),
1.0,
"SCHED-003 FALSIFIED: OneShot should return 1.0 after prune step"
);
assert_eq!(
schedule.sparsity_at_step(10000),
1.0,
"SCHED-003 FALSIFIED: OneShot should return 1.0 long after prune step"
);
}
#[test]
fn test_oneshot_step_zero() {
let schedule = PruningSchedule::OneShot { step: 0 };
assert_eq!(
schedule.sparsity_at_step(0),
1.0,
"SCHED-004 FALSIFIED: OneShot at step 0 should return 1.0 immediately"
);
}
#[test]
fn test_oneshot_should_prune_only_at_step() {
let schedule = PruningSchedule::OneShot { step: 500 };
assert!(
!schedule.should_prune_at_step(499),
"SCHED-005 FALSIFIED: should_prune should be false before step"
);
assert!(
schedule.should_prune_at_step(500),
"SCHED-005 FALSIFIED: should_prune should be true at step"
);
assert!(
!schedule.should_prune_at_step(501),
"SCHED-005 FALSIFIED: should_prune should be false after step"
);
}
#[test]
fn test_validate_oneshot_always_valid() {
let schedule = PruningSchedule::OneShot { step: 0 };
assert!(schedule.validate().is_ok(), "SCHED-030 FALSIFIED: OneShot should always be valid");
}
#[test]
fn test_num_pruning_steps_oneshot() {
let schedule = PruningSchedule::OneShot { step: 1000 };
assert_eq!(
schedule.num_pruning_steps(),
1,
"SCHED-040 FALSIFIED: OneShot should have exactly 1 pruning step"
);
}
#[test]
fn test_is_complete_oneshot() {
let schedule = PruningSchedule::OneShot { step: 100 };
assert!(
!schedule.is_complete(100),
"SCHED-043 FALSIFIED: OneShot should not be complete at prune step"
);
assert!(
schedule.is_complete(101),
"SCHED-043 FALSIFIED: OneShot should be complete after prune step"
);
}
#[test]
fn test_oneshot_num_pruning_steps() {
let schedule = PruningSchedule::OneShot { step: 0 };
assert_eq!(schedule.num_pruning_steps(), 1);
}
#[test]
fn test_is_complete_oneshot_at_zero() {
let schedule = PruningSchedule::OneShot { step: 0 };
assert!(!schedule.is_complete(0));
assert!(schedule.is_complete(1));
}
#[test]
fn test_debug_format() {
let schedule = PruningSchedule::OneShot { step: 100 };
let debug = format!("{schedule:?}");
assert!(
debug.contains("OneShot"),
"SCHED-064 FALSIFIED: Debug should contain variant name"
);
assert!(debug.contains("100"), "SCHED-064 FALSIFIED: Debug should contain step value");
}
#[test]
fn test_serialize_oneshot() {
let schedule = PruningSchedule::OneShot { step: 1000 };
let json = serde_json::to_string(&schedule).expect("JSON serialization should succeed");
assert!(
json.contains("one_shot"),
"SCHED-050 FALSIFIED: OneShot should serialize with type=one_shot"
);
let deserialized: PruningSchedule =
serde_json::from_str(&json).expect("JSON deserialization should succeed");
assert_eq!(
schedule, deserialized,
"SCHED-050 FALSIFIED: Deserialized should match original"
);
}
#[test]
fn test_deserialize_oneshot_from_yaml() {
let yaml = "type: one_shot\nstep: 500\n";
let schedule: PruningSchedule =
serde_yaml::from_str(yaml).expect("operation should succeed");
match schedule {
PruningSchedule::OneShot { step } => assert_eq!(step, 500),
_ => panic!("Should deserialize to OneShot"),
}
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn oneshot_idempotent(
prune_step in 0usize..1000,
test_step in 0usize..2000,
) {
let schedule = PruningSchedule::OneShot { step: prune_step };
let sparsity = schedule.sparsity_at_step(test_step);
if test_step >= prune_step {
prop_assert_eq!(sparsity, 1.0);
} else {
prop_assert_eq!(sparsity, 0.0);
}
}
}
}