Skip to main content

entrenar/train/curriculum/
tiered.rs

1//! Tiered curriculum scheduler for CITL
2
3use super::CurriculumScheduler;
4
5/// Tiered curriculum for diagnostic verbosity levels
6///
7/// Designed for CITL training with four diagnostic tiers:
8/// - Tier 1: JSON diagnostics + clippy (baseline)
9/// - Tier 2: + verbose build output
10/// - Tier 3: + RUSTC_LOG traces
11/// - Tier 4: + full debug output
12///
13/// Tier advancement based on accuracy thresholds.
14///
15/// # Example
16///
17/// ```
18/// use entrenar::train::{TieredCurriculum, CurriculumScheduler};
19///
20/// let mut curriculum = TieredCurriculum::new(vec![0.6, 0.7, 0.8], 3);
21///
22/// assert_eq!(curriculum.tier(), 1);
23///
24/// // Advance to tier 2 after achieving 60% accuracy for 3 epochs
25/// for _ in 0..3 {
26///     curriculum.step(0, 0.65);
27/// }
28/// assert_eq!(curriculum.tier(), 2);
29/// ```
30#[derive(Debug, Clone)]
31pub struct TieredCurriculum {
32    /// Accuracy thresholds to advance to next tier
33    tier_thresholds: Vec<f32>,
34    /// Consecutive epochs at threshold before advancing
35    patience: usize,
36    /// Current tier (1-4)
37    current_tier: usize,
38    /// Epochs at current tier meeting threshold
39    epochs_at_threshold: usize,
40}
41
42impl TieredCurriculum {
43    /// Create new tiered curriculum
44    ///
45    /// # Arguments
46    ///
47    /// * `tier_thresholds` - Accuracy thresholds for each tier advancement
48    /// * `patience` - Epochs at threshold before advancing
49    pub fn new(tier_thresholds: Vec<f32>, patience: usize) -> Self {
50        Self { tier_thresholds, patience: patience.max(1), current_tier: 1, epochs_at_threshold: 0 }
51    }
52
53    /// Create with default CITL thresholds
54    ///
55    /// - Tier 1 -> 2: 60% accuracy
56    /// - Tier 2 -> 3: 70% accuracy
57    /// - Tier 3 -> 4: 80% accuracy
58    pub fn citl_default() -> Self {
59        Self::new(vec![0.6, 0.7, 0.8], 3)
60    }
61
62    /// Get the threshold for current tier advancement
63    pub fn current_threshold(&self) -> Option<f32> {
64        if self.current_tier <= self.tier_thresholds.len() {
65            Some(self.tier_thresholds[self.current_tier - 1])
66        } else {
67            None
68        }
69    }
70}
71
72impl CurriculumScheduler for TieredCurriculum {
73    fn difficulty(&self) -> f32 {
74        (self.current_tier as f32 - 1.0) / 3.0
75    }
76
77    fn tier(&self) -> usize {
78        self.current_tier
79    }
80
81    fn step(&mut self, _epoch: usize, accuracy: f32) {
82        if let Some(threshold) = self.current_threshold() {
83            if accuracy >= threshold {
84                self.epochs_at_threshold += 1;
85                if self.epochs_at_threshold >= self.patience {
86                    // Advance to next tier
87                    self.current_tier = (self.current_tier + 1).min(4);
88                    self.epochs_at_threshold = 0;
89                }
90            } else {
91                // Reset counter if below threshold
92                self.epochs_at_threshold = 0;
93            }
94        }
95    }
96
97    fn reset(&mut self) {
98        self.current_tier = 1;
99        self.epochs_at_threshold = 0;
100    }
101
102    fn name(&self) -> &'static str {
103        "TieredCurriculum"
104    }
105}