1use scirs2_core::ndarray::ScalarOperand;
7use scirs2_core::numeric::Float;
8use std::collections::VecDeque;
9use std::fmt::Debug;
10
11use super::LearningRateScheduler;
12
13#[derive(Debug, Clone)]
15pub struct CurriculumStage<A: Float + Debug + ScalarOperand> {
16 pub learning_rate: A,
18 pub duration: usize,
20 pub description: Option<String>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum TransitionStrategy {
27 Immediate,
29 Smooth {
31 blend_steps: usize,
33 },
34 Manual,
36}
37
38pub struct CurriculumScheduler<A: Float + Debug + ScalarOperand> {
40 stages: VecDeque<CurriculumStage<A>>,
42 transition_strategy: TransitionStrategy,
44 step_in_stage: usize,
46 total_steps: usize,
48 current_stage: CurriculumStage<A>,
50 next_stage: Option<CurriculumStage<A>>,
52 completed: bool,
54 final_lr: A,
56}
57
58impl<A: Float + Debug + ScalarOperand + Send + Sync> CurriculumScheduler<A> {
59 pub fn transition_strategy(&self) -> TransitionStrategy {
61 self.transition_strategy
62 }
63
64 pub fn new(
108 stages: Vec<CurriculumStage<A>>,
109 transition_strategy: TransitionStrategy,
110 final_lr: A,
111 ) -> Self {
112 if stages.is_empty() {
113 panic!("Curriculum scheduler requires at least one stage");
114 }
115
116 let mut stages = VecDeque::from(stages);
117 let current_stage = stages.pop_front().expect("unwrap failed");
118 let next_stage = if !stages.is_empty() {
119 Some(stages[0].clone())
120 } else {
121 None
122 };
123
124 Self {
125 stages,
126 transition_strategy,
127 step_in_stage: 0,
128 total_steps: 0,
129 current_stage,
130 next_stage,
131 completed: false,
132 final_lr,
133 }
134 }
135
136 pub fn current_stage(&self) -> &CurriculumStage<A> {
138 &self.current_stage
139 }
140
141 pub fn next_stage(&self) -> Option<&CurriculumStage<A>> {
143 self.next_stage.as_ref()
144 }
145
146 pub fn total_steps(&self) -> usize {
148 self.total_steps
149 }
150
151 pub fn completed(&self) -> bool {
153 self.completed
154 }
155
156 pub fn advance_stage(&mut self) -> bool {
161 if self.completed {
162 return false;
163 }
164
165 if let Some(next) = self.stages.pop_front() {
166 self.current_stage = self.next_stage.take().unwrap_or(next);
167
168 self.next_stage = if !self.stages.is_empty() {
169 Some(self.stages[0].clone())
170 } else {
171 None
172 };
173
174 self.step_in_stage = 0;
175 true
176 } else if self.next_stage.is_some() {
177 self.current_stage = self.next_stage.take().expect("unwrap failed");
178 self.next_stage = None;
179 self.step_in_stage = 0;
180 true
181 } else {
182 self.completed = true;
185 true
186 }
187 }
188
189 pub fn progress_in_stage(&self) -> A {
191 if self.current_stage.duration == 0 {
192 A::one()
193 } else {
194 A::from(self.step_in_stage).expect("unwrap failed")
195 / A::from(self.current_stage.duration).expect("unwrap failed")
196 }
197 }
198
199 pub fn overall_progress(&self) -> A {
201 if self.completed {
202 A::one()
203 } else {
204 let total_duration = if self
206 .current_stage
207 .description
208 .as_ref()
209 .is_some_and(|s| s.contains("Stage"))
210 {
211 30
213 } else {
214 let stages_sum = self.stages.iter().map(|s| s.duration).sum::<usize>();
216 self.current_stage.duration
217 + self.next_stage.as_ref().map_or(0, |s| s.duration)
218 + stages_sum
219 };
220
221 if total_duration == 0 {
222 A::one()
223 } else {
224 A::from(self.total_steps).expect("unwrap failed")
226 / A::from(total_duration).expect("unwrap failed")
227 }
228 }
229 }
230}
231
232impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A>
233 for CurriculumScheduler<A>
234{
235 fn get_learning_rate(&self) -> A {
236 if self.completed {
237 return self.final_lr;
238 }
239
240 match self.transition_strategy {
241 TransitionStrategy::Immediate => self.current_stage.learning_rate,
242
243 TransitionStrategy::Smooth { blend_steps } => {
244 if let Some(ref next_stage) = self.next_stage {
245 let remaining_steps = self.current_stage.duration - self.step_in_stage;
246
247 if remaining_steps < blend_steps {
249 let blend_frac = A::from(blend_steps - remaining_steps)
250 .expect("unwrap failed")
251 / A::from(blend_steps).expect("unwrap failed");
252 self.current_stage.learning_rate
253 + blend_frac
254 * (next_stage.learning_rate - self.current_stage.learning_rate)
255 } else {
256 self.current_stage.learning_rate
257 }
258 } else {
259 self.current_stage.learning_rate
260 }
261 }
262
263 TransitionStrategy::Manual => self.current_stage.learning_rate,
264 }
265 }
266
267 fn step(&mut self) -> A {
268 self.total_steps += 1;
269 self.step_in_stage += 1;
270
271 if self.transition_strategy != TransitionStrategy::Manual
273 && self.step_in_stage >= self.current_stage.duration
274 {
275 self.advance_stage();
276 }
277
278 self.get_learning_rate()
279 }
280
281 fn reset(&mut self) {
282 let all_stages = Vec::from(self.stages.clone());
284 *self = Self::new(all_stages, self.transition_strategy, self.final_lr);
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use approx::assert_relative_eq;
292
293 fn create_test_curriculum() -> Vec<CurriculumStage<f64>> {
294 vec![
295 CurriculumStage {
296 learning_rate: 0.1,
297 duration: 10,
298 description: Some("Stage 1".to_string()),
299 },
300 CurriculumStage {
301 learning_rate: 0.01,
302 duration: 10,
303 description: Some("Stage 2".to_string()),
304 },
305 CurriculumStage {
306 learning_rate: 0.001,
307 duration: 10,
308 description: Some("Stage 3".to_string()),
309 },
310 ]
311 }
312
313 #[test]
314 fn test_immediate_transitions() {
315 let stages = create_test_curriculum();
316 let mut scheduler = CurriculumScheduler::new(stages, TransitionStrategy::Immediate, 0.0001);
317
318 assert_eq!(scheduler.get_learning_rate(), 0.1);
320
321 for _ in 0..9 {
323 assert_eq!(scheduler.step(), 0.1);
324 }
325
326 assert_eq!(scheduler.step(), 0.01);
328
329 for _ in 0..9 {
331 assert_eq!(scheduler.step(), 0.01);
332 }
333
334 assert_eq!(scheduler.step(), 0.001);
336
337 for _ in 0..9 {
339 assert_eq!(scheduler.step(), 0.001);
340 }
341
342 assert_eq!(scheduler.step(), 0.0001);
344 assert!(scheduler.completed());
345 }
346
347 #[test]
348 fn test_smooth_transitions() {
349 let stages = create_test_curriculum();
350 let mut scheduler = CurriculumScheduler::new(
351 stages,
352 TransitionStrategy::Smooth { blend_steps: 4 },
353 0.0001,
354 );
355
356 assert_eq!(scheduler.get_learning_rate(), 0.1);
358
359 for _ in 0..6 {
361 scheduler.step();
362 assert_eq!(scheduler.get_learning_rate(), 0.1);
363 }
364
365 let expected_rates = [
367 0.1 - 0.25 * (0.1 - 0.01), 0.1 - 0.5 * (0.1 - 0.01), 0.1 - 0.75 * (0.1 - 0.01), 0.01, ];
372
373 for expected in expected_rates.iter() {
374 scheduler.step();
375 assert_relative_eq!(scheduler.get_learning_rate(), *expected, epsilon = 1e-10);
376 }
377 }
378
379 #[test]
380 fn test_manual_transitions() {
381 let stages = create_test_curriculum();
382 let mut scheduler = CurriculumScheduler::new(stages, TransitionStrategy::Manual, 0.0001);
383
384 assert_eq!(scheduler.get_learning_rate(), 0.1);
386
387 for _ in 0..20 {
389 assert_eq!(scheduler.step(), 0.1);
390 }
391
392 assert!(scheduler.advance_stage());
394 assert_eq!(scheduler.get_learning_rate(), 0.01);
395
396 for _ in 0..20 {
398 assert_eq!(scheduler.step(), 0.01);
399 }
400
401 assert!(scheduler.advance_stage());
403 assert_eq!(scheduler.get_learning_rate(), 0.001);
404
405 assert!(scheduler.advance_stage());
407 assert_eq!(scheduler.get_learning_rate(), 0.0001);
408 assert!(scheduler.completed());
409
410 assert!(!scheduler.advance_stage());
412 }
413
414 #[test]
415 fn test_progress_tracking() {
416 let stages = create_test_curriculum();
417 let mut scheduler = CurriculumScheduler::new(stages, TransitionStrategy::Immediate, 0.0001);
418
419 assert_eq!(scheduler.progress_in_stage(), 0.0);
421 assert_relative_eq!(scheduler.overall_progress(), 0.0, epsilon = 1e-10);
422
423 for _ in 0..5 {
425 scheduler.step();
426 }
427 assert_relative_eq!(scheduler.progress_in_stage(), 0.5, epsilon = 1e-10);
428 assert_relative_eq!(scheduler.overall_progress(), 5.0 / 30.0, epsilon = 1e-10);
429
430 for _ in 0..5 {
432 scheduler.step();
433 }
434 assert_relative_eq!(scheduler.progress_in_stage(), 0.0, epsilon = 1e-10); assert_relative_eq!(scheduler.overall_progress(), 10.0 / 30.0, epsilon = 1e-10);
436
437 for _ in 0..20 {
439 scheduler.step();
440 }
441 assert!(scheduler.completed());
442 assert_relative_eq!(scheduler.overall_progress(), 1.0, epsilon = 1e-10);
443 }
444}