Skip to main content

entrenar/train/curriculum/
linear.rs

1//! Linear curriculum scheduler
2
3use super::CurriculumScheduler;
4
5/// Linear curriculum that increases difficulty over epochs
6///
7/// Difficulty increases linearly from `start_difficulty` to `end_difficulty`
8/// over `ramp_epochs` epochs.
9///
10/// # Example
11///
12/// ```
13/// use entrenar::train::{LinearCurriculum, CurriculumScheduler};
14///
15/// let mut curriculum = LinearCurriculum::new(0.3, 1.0, 10);
16///
17/// // Initially at start difficulty
18/// assert!((curriculum.difficulty() - 0.3).abs() < 1e-5);
19///
20/// // After 5 epochs at 100% accuracy
21/// for _ in 0..5 {
22///     curriculum.step(0, 1.0);
23/// }
24/// assert!(curriculum.difficulty() > 0.5);
25/// ```
26#[derive(Debug, Clone)]
27pub struct LinearCurriculum {
28    start_difficulty: f32,
29    end_difficulty: f32,
30    ramp_epochs: usize,
31    current_epoch: usize,
32}
33
34impl LinearCurriculum {
35    /// Create a new linear curriculum
36    ///
37    /// # Arguments
38    ///
39    /// * `start_difficulty` - Initial difficulty (0.0-1.0)
40    /// * `end_difficulty` - Final difficulty (0.0-1.0)
41    /// * `ramp_epochs` - Epochs to reach full difficulty
42    pub fn new(start_difficulty: f32, end_difficulty: f32, ramp_epochs: usize) -> Self {
43        Self {
44            start_difficulty: start_difficulty.clamp(0.0, 1.0),
45            end_difficulty: end_difficulty.clamp(0.0, 1.0),
46            ramp_epochs: ramp_epochs.max(1),
47            current_epoch: 0,
48        }
49    }
50}
51
52impl CurriculumScheduler for LinearCurriculum {
53    fn difficulty(&self) -> f32 {
54        let progress = (self.current_epoch as f32 / self.ramp_epochs as f32).min(1.0);
55        let difficulty =
56            self.start_difficulty + progress * (self.end_difficulty - self.start_difficulty);
57        let (min, max) = if self.start_difficulty <= self.end_difficulty {
58            (self.start_difficulty, self.end_difficulty)
59        } else {
60            (self.end_difficulty, self.start_difficulty)
61        };
62        difficulty.clamp(min, max)
63    }
64
65    fn tier(&self) -> usize {
66        // Map difficulty to 4 tiers (1-4)
67        let d = self.difficulty();
68        if d < 0.25 {
69            1
70        } else if d < 0.5 {
71            2
72        } else if d < 0.75 {
73            3
74        } else {
75            4
76        }
77    }
78
79    fn step(&mut self, _epoch: usize, _accuracy: f32) {
80        self.current_epoch += 1;
81    }
82
83    fn reset(&mut self) {
84        self.current_epoch = 0;
85    }
86
87    fn name(&self) -> &'static str {
88        "LinearCurriculum"
89    }
90}