use super::CurriculumScheduler;
#[derive(Debug, Clone)]
pub struct TieredCurriculum {
tier_thresholds: Vec<f32>,
patience: usize,
current_tier: usize,
epochs_at_threshold: usize,
}
impl TieredCurriculum {
pub fn new(tier_thresholds: Vec<f32>, patience: usize) -> Self {
Self { tier_thresholds, patience: patience.max(1), current_tier: 1, epochs_at_threshold: 0 }
}
pub fn citl_default() -> Self {
Self::new(vec![0.6, 0.7, 0.8], 3)
}
pub fn current_threshold(&self) -> Option<f32> {
if self.current_tier <= self.tier_thresholds.len() {
Some(self.tier_thresholds[self.current_tier - 1])
} else {
None
}
}
}
impl CurriculumScheduler for TieredCurriculum {
fn difficulty(&self) -> f32 {
(self.current_tier as f32 - 1.0) / 3.0
}
fn tier(&self) -> usize {
self.current_tier
}
fn step(&mut self, _epoch: usize, accuracy: f32) {
if let Some(threshold) = self.current_threshold() {
if accuracy >= threshold {
self.epochs_at_threshold += 1;
if self.epochs_at_threshold >= self.patience {
self.current_tier = (self.current_tier + 1).min(4);
self.epochs_at_threshold = 0;
}
} else {
self.epochs_at_threshold = 0;
}
}
}
fn reset(&mut self) {
self.current_tier = 1;
self.epochs_at_threshold = 0;
}
fn name(&self) -> &'static str {
"TieredCurriculum"
}
}