entrenar/prune/schedule/
oneshot.rs1use super::PruningSchedule;
4
5impl PruningSchedule {
6 pub(super) fn oneshot_sparsity_at_step(prune_step: usize, step: usize) -> f32 {
8 if step >= prune_step {
9 1.0
10 } else {
11 0.0
12 }
13 }
14
15 pub(super) fn oneshot_should_prune_at_step(prune_step: usize, step: usize) -> bool {
17 step == prune_step
18 }
19
20 pub(super) fn oneshot_num_pruning_steps() -> usize {
22 1
23 }
24
25 pub(super) fn oneshot_validate() -> Result<(), String> {
27 Ok(())
28 }
29
30 pub(super) fn oneshot_is_complete(prune_step: usize, step: usize) -> bool {
32 step > prune_step
33 }
34}
35
36#[cfg(test)]
37mod tests {
38 use super::*;
39
40 #[test]
45 fn test_oneshot_before_step_returns_zero() {
46 let schedule = PruningSchedule::OneShot { step: 1000 };
49 assert_eq!(
50 schedule.sparsity_at_step(0),
51 0.0,
52 "SCHED-001 FALSIFIED: OneShot should return 0.0 before prune step"
53 );
54 assert_eq!(
55 schedule.sparsity_at_step(999),
56 0.0,
57 "SCHED-001 FALSIFIED: OneShot should return 0.0 at step before prune"
58 );
59 }
60
61 #[test]
62 fn test_oneshot_at_step_returns_one() {
63 let schedule = PruningSchedule::OneShot { step: 1000 };
66 assert_eq!(
67 schedule.sparsity_at_step(1000),
68 1.0,
69 "SCHED-002 FALSIFIED: OneShot should return 1.0 at prune step"
70 );
71 }
72
73 #[test]
74 fn test_oneshot_after_step_returns_one() {
75 let schedule = PruningSchedule::OneShot { step: 1000 };
78 assert_eq!(
79 schedule.sparsity_at_step(1001),
80 1.0,
81 "SCHED-003 FALSIFIED: OneShot should return 1.0 after prune step"
82 );
83 assert_eq!(
84 schedule.sparsity_at_step(10000),
85 1.0,
86 "SCHED-003 FALSIFIED: OneShot should return 1.0 long after prune step"
87 );
88 }
89
90 #[test]
91 fn test_oneshot_step_zero() {
92 let schedule = PruningSchedule::OneShot { step: 0 };
95 assert_eq!(
96 schedule.sparsity_at_step(0),
97 1.0,
98 "SCHED-004 FALSIFIED: OneShot at step 0 should return 1.0 immediately"
99 );
100 }
101
102 #[test]
103 fn test_oneshot_should_prune_only_at_step() {
104 let schedule = PruningSchedule::OneShot { step: 500 };
106 assert!(
107 !schedule.should_prune_at_step(499),
108 "SCHED-005 FALSIFIED: should_prune should be false before step"
109 );
110 assert!(
111 schedule.should_prune_at_step(500),
112 "SCHED-005 FALSIFIED: should_prune should be true at step"
113 );
114 assert!(
115 !schedule.should_prune_at_step(501),
116 "SCHED-005 FALSIFIED: should_prune should be false after step"
117 );
118 }
119
120 #[test]
121 fn test_validate_oneshot_always_valid() {
122 let schedule = PruningSchedule::OneShot { step: 0 };
124 assert!(schedule.validate().is_ok(), "SCHED-030 FALSIFIED: OneShot should always be valid");
125 }
126
127 #[test]
128 fn test_num_pruning_steps_oneshot() {
129 let schedule = PruningSchedule::OneShot { step: 1000 };
131 assert_eq!(
132 schedule.num_pruning_steps(),
133 1,
134 "SCHED-040 FALSIFIED: OneShot should have exactly 1 pruning step"
135 );
136 }
137
138 #[test]
139 fn test_is_complete_oneshot() {
140 let schedule = PruningSchedule::OneShot { step: 100 };
142 assert!(
143 !schedule.is_complete(100),
144 "SCHED-043 FALSIFIED: OneShot should not be complete at prune step"
145 );
146 assert!(
147 schedule.is_complete(101),
148 "SCHED-043 FALSIFIED: OneShot should be complete after prune step"
149 );
150 }
151
152 #[test]
153 fn test_oneshot_num_pruning_steps() {
154 let schedule = PruningSchedule::OneShot { step: 0 };
156 assert_eq!(schedule.num_pruning_steps(), 1);
157 }
158
159 #[test]
160 fn test_is_complete_oneshot_at_zero() {
161 let schedule = PruningSchedule::OneShot { step: 0 };
163 assert!(!schedule.is_complete(0));
164 assert!(schedule.is_complete(1));
165 }
166
167 #[test]
168 fn test_debug_format() {
169 let schedule = PruningSchedule::OneShot { step: 100 };
171 let debug = format!("{schedule:?}");
172 assert!(
173 debug.contains("OneShot"),
174 "SCHED-064 FALSIFIED: Debug should contain variant name"
175 );
176 assert!(debug.contains("100"), "SCHED-064 FALSIFIED: Debug should contain step value");
177 }
178
179 #[test]
180 fn test_serialize_oneshot() {
181 let schedule = PruningSchedule::OneShot { step: 1000 };
183 let json = serde_json::to_string(&schedule).expect("JSON serialization should succeed");
184 assert!(
185 json.contains("one_shot"),
186 "SCHED-050 FALSIFIED: OneShot should serialize with type=one_shot"
187 );
188 let deserialized: PruningSchedule =
189 serde_json::from_str(&json).expect("JSON deserialization should succeed");
190 assert_eq!(
191 schedule, deserialized,
192 "SCHED-050 FALSIFIED: Deserialized should match original"
193 );
194 }
195
196 #[test]
197 fn test_deserialize_oneshot_from_yaml() {
198 let yaml = "type: one_shot\nstep: 500\n";
200 let schedule: PruningSchedule =
201 serde_yaml::from_str(yaml).expect("operation should succeed");
202 match schedule {
203 PruningSchedule::OneShot { step } => assert_eq!(step, 500),
204 _ => panic!("Should deserialize to OneShot"),
205 }
206 }
207}
208
209#[cfg(test)]
210mod proptests {
211 use super::*;
212 use proptest::prelude::*;
213
214 proptest! {
215 #[test]
217 fn oneshot_idempotent(
218 prune_step in 0usize..1000,
219 test_step in 0usize..2000,
220 ) {
221 let schedule = PruningSchedule::OneShot { step: prune_step };
222 let sparsity = schedule.sparsity_at_step(test_step);
223
224 if test_step >= prune_step {
225 prop_assert_eq!(sparsity, 1.0);
226 } else {
227 prop_assert_eq!(sparsity, 0.0);
228 }
229 }
230 }
231}