Skip to main content

entrenar/train/curriculum/
adaptive.rs

1//! Adaptive curriculum scheduler for error-specific training
2
3use std::collections::HashMap;
4
5use super::CurriculumScheduler;
6
7/// Adaptive curriculum that adjusts based on error class performance
8///
9/// Tracks accuracy per error class and increases difficulty for
10/// well-learned classes while maintaining focus on struggling classes.
11///
12/// Supports the CITL adaptive tier selection pattern.
13#[derive(Debug, Clone)]
14pub struct AdaptiveCurriculum {
15    /// Accuracy per error class
16    pub(crate) class_accuracy: HashMap<String, f32>,
17    /// Attempts per error class
18    pub(crate) class_attempts: HashMap<String, usize>,
19    /// Default tier for unknown errors
20    default_tier: usize,
21    /// Overall difficulty based on mean accuracy
22    overall_difficulty: f32,
23}
24
25impl AdaptiveCurriculum {
26    /// Create new adaptive curriculum
27    pub fn new() -> Self {
28        Self {
29            class_accuracy: HashMap::new(),
30            class_attempts: HashMap::new(),
31            default_tier: 1,
32            overall_difficulty: 0.0,
33        }
34    }
35
36    /// Get recommended tier for an error class
37    ///
38    /// Based on the CITL `select_tier()` pattern
39    pub fn tier_for_error(&self, error_code: &str, attempt: usize) -> usize {
40        // Special cases
41        if error_code.starts_with("ICE") {
42            return 4; // ICEs always need full debug
43        }
44
45        // Type/trait errors benefit from traces
46        if matches!(error_code, "E0308" | "E0277" | "E0382") && attempt >= 1 {
47            return 3;
48        }
49
50        // Name resolution needs verbose
51        if matches!(error_code, "E0425" | "E0433") && attempt >= 2 {
52            return 3;
53        }
54
55        // Default escalation pattern
56        match attempt {
57            0 => self.default_tier,
58            1 => 2,
59            2.. => 3,
60        }
61    }
62
63    /// Update accuracy for an error class
64    pub fn update_class(&mut self, error_code: &str, correct: bool) {
65        let attempts = self.class_attempts.entry(error_code.to_string()).or_insert(0);
66        *attempts += 1;
67
68        let acc = self.class_accuracy.entry(error_code.to_string()).or_insert(0.0);
69        // Exponential moving average
70        let alpha = 0.1;
71        *acc = *acc * (1.0 - alpha) + if correct { alpha } else { 0.0 };
72
73        // Update overall difficulty
74        if !self.class_accuracy.is_empty() {
75            self.overall_difficulty =
76                self.class_accuracy.values().sum::<f32>() / self.class_accuracy.len() as f32;
77        }
78    }
79
80    /// Get sample weight based on class rarity and accuracy
81    ///
82    /// Long-tail (rare) errors get higher weights per Feldman (2020)
83    pub fn weight_for_class(&self, error_code: &str) -> f32 {
84        let attempts = *self.class_attempts.get(error_code).unwrap_or(&0);
85        let accuracy = *self.class_accuracy.get(error_code).unwrap_or(&0.0);
86
87        // Rare classes get higher weight
88        let rarity_weight = 1.0 / (attempts as f32 + 1.0).sqrt();
89
90        // Low accuracy classes get higher weight
91        let difficulty_weight = 1.0 - accuracy;
92
93        // Combine weights (normalize to reasonable range)
94        (1.0 + rarity_weight + difficulty_weight).min(3.0)
95    }
96}
97
98impl Default for AdaptiveCurriculum {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104impl CurriculumScheduler for AdaptiveCurriculum {
105    fn difficulty(&self) -> f32 {
106        self.overall_difficulty
107    }
108
109    fn tier(&self) -> usize {
110        if self.overall_difficulty < 0.25 {
111            1
112        } else if self.overall_difficulty < 0.5 {
113            2
114        } else if self.overall_difficulty < 0.75 {
115            3
116        } else {
117            4
118        }
119    }
120
121    fn step(&mut self, _epoch: usize, accuracy: f32) {
122        // Update overall difficulty based on recent accuracy
123        let alpha = 0.1;
124        self.overall_difficulty = self.overall_difficulty * (1.0 - alpha) + accuracy * alpha;
125    }
126
127    fn reset(&mut self) {
128        self.class_accuracy.clear();
129        self.class_attempts.clear();
130        self.overall_difficulty = 0.0;
131    }
132
133    fn name(&self) -> &'static str {
134        "AdaptiveCurriculum"
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_adaptive_curriculum_new() {
144        let curriculum = AdaptiveCurriculum::new();
145        assert!(curriculum.class_accuracy.is_empty());
146        assert!(curriculum.class_attempts.is_empty());
147        assert_eq!(curriculum.overall_difficulty, 0.0);
148    }
149
150    #[test]
151    fn test_adaptive_curriculum_default() {
152        let curriculum = AdaptiveCurriculum::default();
153        assert_eq!(curriculum.difficulty(), 0.0);
154    }
155
156    #[test]
157    fn test_tier_for_error_ice() {
158        let curriculum = AdaptiveCurriculum::new();
159        assert_eq!(curriculum.tier_for_error("ICE001", 0), 4);
160        assert_eq!(curriculum.tier_for_error("ICE-crash", 5), 4);
161    }
162
163    #[test]
164    fn test_tier_for_error_type_errors() {
165        let curriculum = AdaptiveCurriculum::new();
166        // E0308 on attempt 0 uses default
167        assert_eq!(curriculum.tier_for_error("E0308", 0), 1);
168        // E0308 on attempt 1+ gets tier 3
169        assert_eq!(curriculum.tier_for_error("E0308", 1), 3);
170        assert_eq!(curriculum.tier_for_error("E0277", 2), 3);
171        assert_eq!(curriculum.tier_for_error("E0382", 1), 3);
172    }
173
174    #[test]
175    fn test_tier_for_error_name_resolution() {
176        let curriculum = AdaptiveCurriculum::new();
177        // E0425 needs attempt >= 2 for tier 3
178        assert_eq!(curriculum.tier_for_error("E0425", 0), 1);
179        assert_eq!(curriculum.tier_for_error("E0425", 1), 2);
180        assert_eq!(curriculum.tier_for_error("E0425", 2), 3);
181        assert_eq!(curriculum.tier_for_error("E0433", 3), 3);
182    }
183
184    #[test]
185    fn test_tier_for_error_default_escalation() {
186        let curriculum = AdaptiveCurriculum::new();
187        // Generic error escalation
188        assert_eq!(curriculum.tier_for_error("E0001", 0), 1);
189        assert_eq!(curriculum.tier_for_error("E0001", 1), 2);
190        assert_eq!(curriculum.tier_for_error("E0001", 2), 3);
191        assert_eq!(curriculum.tier_for_error("E0001", 5), 3);
192    }
193
194    #[test]
195    fn test_update_class() {
196        let mut curriculum = AdaptiveCurriculum::new();
197
198        curriculum.update_class("E0308", true);
199        assert_eq!(curriculum.class_attempts.get("E0308"), Some(&1));
200        // First correct: 0.0 * 0.9 + 0.1 = 0.1
201        assert!(
202            (curriculum.class_accuracy.get("E0308").expect("key should exist") - 0.1).abs() < 0.001
203        );
204
205        curriculum.update_class("E0308", false);
206        assert_eq!(curriculum.class_attempts.get("E0308"), Some(&2));
207        // 0.1 * 0.9 + 0 = 0.09
208        assert!(
209            (curriculum.class_accuracy.get("E0308").expect("key should exist") - 0.09).abs()
210                < 0.001
211        );
212    }
213
214    #[test]
215    fn test_weight_for_class_unknown() {
216        let curriculum = AdaptiveCurriculum::new();
217        let weight = curriculum.weight_for_class("unknown");
218        // rarity = 1/sqrt(1) = 1.0, difficulty = 1.0 - 0 = 1.0
219        // weight = 1.0 + 1.0 + 1.0 = 3.0
220        assert!((weight - 3.0).abs() < 0.001);
221    }
222
223    #[test]
224    fn test_weight_for_class_known() {
225        let mut curriculum = AdaptiveCurriculum::new();
226
227        // Add some attempts
228        for _ in 0..10 {
229            curriculum.update_class("E0308", true);
230        }
231
232        let weight = curriculum.weight_for_class("E0308");
233        // rarity = 1/sqrt(11) ≈ 0.3, difficulty = 1 - accuracy
234        // Should be lower than unknown
235        assert!(weight < 3.0);
236        assert!(weight >= 1.0);
237    }
238
239    #[test]
240    fn test_curriculum_scheduler_difficulty() {
241        let mut curriculum = AdaptiveCurriculum::new();
242        assert_eq!(curriculum.difficulty(), 0.0);
243
244        curriculum.step(0, 0.5);
245        assert!(curriculum.difficulty() > 0.0);
246    }
247
248    #[test]
249    fn test_curriculum_scheduler_tier() {
250        let mut curriculum = AdaptiveCurriculum::new();
251
252        // Tier 1 for low difficulty
253        assert_eq!(curriculum.tier(), 1);
254
255        // Tier 2 for 0.25-0.5
256        curriculum.overall_difficulty = 0.3;
257        assert_eq!(curriculum.tier(), 2);
258
259        // Tier 3 for 0.5-0.75
260        curriculum.overall_difficulty = 0.6;
261        assert_eq!(curriculum.tier(), 3);
262
263        // Tier 4 for >= 0.75
264        curriculum.overall_difficulty = 0.8;
265        assert_eq!(curriculum.tier(), 4);
266    }
267
268    #[test]
269    fn test_curriculum_scheduler_step() {
270        let mut curriculum = AdaptiveCurriculum::new();
271
272        curriculum.step(0, 1.0);
273        assert!((curriculum.difficulty() - 0.1).abs() < 0.001);
274
275        curriculum.step(1, 1.0);
276        // 0.1 * 0.9 + 1.0 * 0.1 = 0.19
277        assert!((curriculum.difficulty() - 0.19).abs() < 0.001);
278    }
279
280    #[test]
281    fn test_curriculum_scheduler_reset() {
282        let mut curriculum = AdaptiveCurriculum::new();
283        curriculum.update_class("E0308", true);
284        curriculum.step(0, 0.5);
285
286        assert!(!curriculum.class_accuracy.is_empty());
287        assert!(curriculum.difficulty() > 0.0);
288
289        curriculum.reset();
290
291        assert!(curriculum.class_accuracy.is_empty());
292        assert!(curriculum.class_attempts.is_empty());
293        assert_eq!(curriculum.difficulty(), 0.0);
294    }
295
296    #[test]
297    fn test_curriculum_scheduler_name() {
298        let curriculum = AdaptiveCurriculum::new();
299        assert_eq!(curriculum.name(), "AdaptiveCurriculum");
300    }
301
302    #[test]
303    fn test_adaptive_curriculum_clone() {
304        let mut curriculum = AdaptiveCurriculum::new();
305        curriculum.update_class("E0308", true);
306
307        let cloned = curriculum.clone();
308        assert_eq!(curriculum.class_attempts, cloned.class_attempts);
309        assert_eq!(curriculum.class_accuracy, cloned.class_accuracy);
310    }
311
312    #[test]
313    fn test_adaptive_curriculum_debug() {
314        let curriculum = AdaptiveCurriculum::new();
315        let debug = format!("{curriculum:?}");
316        assert!(debug.contains("AdaptiveCurriculum"));
317    }
318}