entrenar/train/curriculum/
linear.rs1use super::CurriculumScheduler;
4
5#[derive(Debug, Clone)]
27pub struct LinearCurriculum {
28 start_difficulty: f32,
29 end_difficulty: f32,
30 ramp_epochs: usize,
31 current_epoch: usize,
32}
33
34impl LinearCurriculum {
35 pub fn new(start_difficulty: f32, end_difficulty: f32, ramp_epochs: usize) -> Self {
43 Self {
44 start_difficulty: start_difficulty.clamp(0.0, 1.0),
45 end_difficulty: end_difficulty.clamp(0.0, 1.0),
46 ramp_epochs: ramp_epochs.max(1),
47 current_epoch: 0,
48 }
49 }
50}
51
52impl CurriculumScheduler for LinearCurriculum {
53 fn difficulty(&self) -> f32 {
54 let progress = (self.current_epoch as f32 / self.ramp_epochs as f32).min(1.0);
55 let difficulty =
56 self.start_difficulty + progress * (self.end_difficulty - self.start_difficulty);
57 let (min, max) = if self.start_difficulty <= self.end_difficulty {
58 (self.start_difficulty, self.end_difficulty)
59 } else {
60 (self.end_difficulty, self.start_difficulty)
61 };
62 difficulty.clamp(min, max)
63 }
64
65 fn tier(&self) -> usize {
66 let d = self.difficulty();
68 if d < 0.25 {
69 1
70 } else if d < 0.5 {
71 2
72 } else if d < 0.75 {
73 3
74 } else {
75 4
76 }
77 }
78
79 fn step(&mut self, _epoch: usize, _accuracy: f32) {
80 self.current_epoch += 1;
81 }
82
83 fn reset(&mut self) {
84 self.current_epoch = 0;
85 }
86
87 fn name(&self) -> &'static str {
88 "LinearCurriculum"
89 }
90}