Skip to main content

entrenar/prune/schedule/
oneshot.rs

1//! OneShot pruning schedule methods.
2
3use super::PruningSchedule;
4
5impl PruningSchedule {
6    /// Compute the target sparsity at a given training step for OneShot schedule.
7    pub(super) fn oneshot_sparsity_at_step(prune_step: usize, step: usize) -> f32 {
8        if step >= prune_step {
9            1.0
10        } else {
11            0.0
12        }
13    }
14
15    /// Check if pruning should be applied at this step for OneShot schedule.
16    pub(super) fn oneshot_should_prune_at_step(prune_step: usize, step: usize) -> bool {
17        step == prune_step
18    }
19
20    /// Get the total number of pruning operations for OneShot schedule.
21    pub(super) fn oneshot_num_pruning_steps() -> usize {
22        1
23    }
24
25    /// Validate OneShot schedule (always valid).
26    pub(super) fn oneshot_validate() -> Result<(), String> {
27        Ok(())
28    }
29
30    /// Check if OneShot pruning has completed.
31    pub(super) fn oneshot_is_complete(prune_step: usize, step: usize) -> bool {
32        step > prune_step
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39
40    // =========================================================================
41    // OneShot Schedule Tests
42    // =========================================================================
43
44    #[test]
45    fn test_oneshot_before_step_returns_zero() {
46        // TEST_ID: SCHED-001
47        // FALSIFIES: OneShot returns non-zero before prune step
48        let schedule = PruningSchedule::OneShot { step: 1000 };
49        assert_eq!(
50            schedule.sparsity_at_step(0),
51            0.0,
52            "SCHED-001 FALSIFIED: OneShot should return 0.0 before prune step"
53        );
54        assert_eq!(
55            schedule.sparsity_at_step(999),
56            0.0,
57            "SCHED-001 FALSIFIED: OneShot should return 0.0 at step before prune"
58        );
59    }
60
61    #[test]
62    fn test_oneshot_at_step_returns_one() {
63        // TEST_ID: SCHED-002
64        // FALSIFIES: OneShot returns wrong value at prune step
65        let schedule = PruningSchedule::OneShot { step: 1000 };
66        assert_eq!(
67            schedule.sparsity_at_step(1000),
68            1.0,
69            "SCHED-002 FALSIFIED: OneShot should return 1.0 at prune step"
70        );
71    }
72
73    #[test]
74    fn test_oneshot_after_step_returns_one() {
75        // TEST_ID: SCHED-003
76        // FALSIFIES: OneShot returns wrong value after prune step
77        let schedule = PruningSchedule::OneShot { step: 1000 };
78        assert_eq!(
79            schedule.sparsity_at_step(1001),
80            1.0,
81            "SCHED-003 FALSIFIED: OneShot should return 1.0 after prune step"
82        );
83        assert_eq!(
84            schedule.sparsity_at_step(10000),
85            1.0,
86            "SCHED-003 FALSIFIED: OneShot should return 1.0 long after prune step"
87        );
88    }
89
90    #[test]
91    fn test_oneshot_step_zero() {
92        // TEST_ID: SCHED-004
93        // Edge case: prune at step 0
94        let schedule = PruningSchedule::OneShot { step: 0 };
95        assert_eq!(
96            schedule.sparsity_at_step(0),
97            1.0,
98            "SCHED-004 FALSIFIED: OneShot at step 0 should return 1.0 immediately"
99        );
100    }
101
102    #[test]
103    fn test_oneshot_should_prune_only_at_step() {
104        // TEST_ID: SCHED-005
105        let schedule = PruningSchedule::OneShot { step: 500 };
106        assert!(
107            !schedule.should_prune_at_step(499),
108            "SCHED-005 FALSIFIED: should_prune should be false before step"
109        );
110        assert!(
111            schedule.should_prune_at_step(500),
112            "SCHED-005 FALSIFIED: should_prune should be true at step"
113        );
114        assert!(
115            !schedule.should_prune_at_step(501),
116            "SCHED-005 FALSIFIED: should_prune should be false after step"
117        );
118    }
119
120    #[test]
121    fn test_validate_oneshot_always_valid() {
122        // TEST_ID: SCHED-030
123        let schedule = PruningSchedule::OneShot { step: 0 };
124        assert!(schedule.validate().is_ok(), "SCHED-030 FALSIFIED: OneShot should always be valid");
125    }
126
127    #[test]
128    fn test_num_pruning_steps_oneshot() {
129        // TEST_ID: SCHED-040
130        let schedule = PruningSchedule::OneShot { step: 1000 };
131        assert_eq!(
132            schedule.num_pruning_steps(),
133            1,
134            "SCHED-040 FALSIFIED: OneShot should have exactly 1 pruning step"
135        );
136    }
137
138    #[test]
139    fn test_is_complete_oneshot() {
140        // TEST_ID: SCHED-043
141        let schedule = PruningSchedule::OneShot { step: 100 };
142        assert!(
143            !schedule.is_complete(100),
144            "SCHED-043 FALSIFIED: OneShot should not be complete at prune step"
145        );
146        assert!(
147            schedule.is_complete(101),
148            "SCHED-043 FALSIFIED: OneShot should be complete after prune step"
149        );
150    }
151
152    #[test]
153    fn test_oneshot_num_pruning_steps() {
154        // TEST_ID: SCHED-072
155        let schedule = PruningSchedule::OneShot { step: 0 };
156        assert_eq!(schedule.num_pruning_steps(), 1);
157    }
158
159    #[test]
160    fn test_is_complete_oneshot_at_zero() {
161        // TEST_ID: SCHED-074
162        let schedule = PruningSchedule::OneShot { step: 0 };
163        assert!(!schedule.is_complete(0));
164        assert!(schedule.is_complete(1));
165    }
166
167    #[test]
168    fn test_debug_format() {
169        // TEST_ID: SCHED-064
170        let schedule = PruningSchedule::OneShot { step: 100 };
171        let debug = format!("{schedule:?}");
172        assert!(
173            debug.contains("OneShot"),
174            "SCHED-064 FALSIFIED: Debug should contain variant name"
175        );
176        assert!(debug.contains("100"), "SCHED-064 FALSIFIED: Debug should contain step value");
177    }
178
179    #[test]
180    fn test_serialize_oneshot() {
181        // TEST_ID: SCHED-050
182        let schedule = PruningSchedule::OneShot { step: 1000 };
183        let json = serde_json::to_string(&schedule).expect("JSON serialization should succeed");
184        assert!(
185            json.contains("one_shot"),
186            "SCHED-050 FALSIFIED: OneShot should serialize with type=one_shot"
187        );
188        let deserialized: PruningSchedule =
189            serde_json::from_str(&json).expect("JSON deserialization should succeed");
190        assert_eq!(
191            schedule, deserialized,
192            "SCHED-050 FALSIFIED: Deserialized should match original"
193        );
194    }
195
196    #[test]
197    fn test_deserialize_oneshot_from_yaml() {
198        // TEST_ID: SCHED-084
199        let yaml = "type: one_shot\nstep: 500\n";
200        let schedule: PruningSchedule =
201            serde_yaml::from_str(yaml).expect("operation should succeed");
202        match schedule {
203            PruningSchedule::OneShot { step } => assert_eq!(step, 500),
204            _ => panic!("Should deserialize to OneShot"),
205        }
206    }
207}
208
209#[cfg(test)]
210mod proptests {
211    use super::*;
212    use proptest::prelude::*;
213
214    proptest! {
215        /// OneShot is idempotent after prune step
216        #[test]
217        fn oneshot_idempotent(
218            prune_step in 0usize..1000,
219            test_step in 0usize..2000,
220        ) {
221            let schedule = PruningSchedule::OneShot { step: prune_step };
222            let sparsity = schedule.sparsity_at_step(test_step);
223
224            if test_step >= prune_step {
225                prop_assert_eq!(sparsity, 1.0);
226            } else {
227                prop_assert_eq!(sparsity, 0.0);
228            }
229        }
230    }
231}