entrenar/train/curriculum/
adaptive.rs1use std::collections::HashMap;
4
5use super::CurriculumScheduler;
6
7#[derive(Debug, Clone)]
14pub struct AdaptiveCurriculum {
15 pub(crate) class_accuracy: HashMap<String, f32>,
17 pub(crate) class_attempts: HashMap<String, usize>,
19 default_tier: usize,
21 overall_difficulty: f32,
23}
24
25impl AdaptiveCurriculum {
26 pub fn new() -> Self {
28 Self {
29 class_accuracy: HashMap::new(),
30 class_attempts: HashMap::new(),
31 default_tier: 1,
32 overall_difficulty: 0.0,
33 }
34 }
35
36 pub fn tier_for_error(&self, error_code: &str, attempt: usize) -> usize {
40 if error_code.starts_with("ICE") {
42 return 4; }
44
45 if matches!(error_code, "E0308" | "E0277" | "E0382") && attempt >= 1 {
47 return 3;
48 }
49
50 if matches!(error_code, "E0425" | "E0433") && attempt >= 2 {
52 return 3;
53 }
54
55 match attempt {
57 0 => self.default_tier,
58 1 => 2,
59 2.. => 3,
60 }
61 }
62
63 pub fn update_class(&mut self, error_code: &str, correct: bool) {
65 let attempts = self.class_attempts.entry(error_code.to_string()).or_insert(0);
66 *attempts += 1;
67
68 let acc = self.class_accuracy.entry(error_code.to_string()).or_insert(0.0);
69 let alpha = 0.1;
71 *acc = *acc * (1.0 - alpha) + if correct { alpha } else { 0.0 };
72
73 if !self.class_accuracy.is_empty() {
75 self.overall_difficulty =
76 self.class_accuracy.values().sum::<f32>() / self.class_accuracy.len() as f32;
77 }
78 }
79
80 pub fn weight_for_class(&self, error_code: &str) -> f32 {
84 let attempts = *self.class_attempts.get(error_code).unwrap_or(&0);
85 let accuracy = *self.class_accuracy.get(error_code).unwrap_or(&0.0);
86
87 let rarity_weight = 1.0 / (attempts as f32 + 1.0).sqrt();
89
90 let difficulty_weight = 1.0 - accuracy;
92
93 (1.0 + rarity_weight + difficulty_weight).min(3.0)
95 }
96}
97
98impl Default for AdaptiveCurriculum {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl CurriculumScheduler for AdaptiveCurriculum {
105 fn difficulty(&self) -> f32 {
106 self.overall_difficulty
107 }
108
109 fn tier(&self) -> usize {
110 if self.overall_difficulty < 0.25 {
111 1
112 } else if self.overall_difficulty < 0.5 {
113 2
114 } else if self.overall_difficulty < 0.75 {
115 3
116 } else {
117 4
118 }
119 }
120
121 fn step(&mut self, _epoch: usize, accuracy: f32) {
122 let alpha = 0.1;
124 self.overall_difficulty = self.overall_difficulty * (1.0 - alpha) + accuracy * alpha;
125 }
126
127 fn reset(&mut self) {
128 self.class_accuracy.clear();
129 self.class_attempts.clear();
130 self.overall_difficulty = 0.0;
131 }
132
133 fn name(&self) -> &'static str {
134 "AdaptiveCurriculum"
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_adaptive_curriculum_new() {
144 let curriculum = AdaptiveCurriculum::new();
145 assert!(curriculum.class_accuracy.is_empty());
146 assert!(curriculum.class_attempts.is_empty());
147 assert_eq!(curriculum.overall_difficulty, 0.0);
148 }
149
150 #[test]
151 fn test_adaptive_curriculum_default() {
152 let curriculum = AdaptiveCurriculum::default();
153 assert_eq!(curriculum.difficulty(), 0.0);
154 }
155
156 #[test]
157 fn test_tier_for_error_ice() {
158 let curriculum = AdaptiveCurriculum::new();
159 assert_eq!(curriculum.tier_for_error("ICE001", 0), 4);
160 assert_eq!(curriculum.tier_for_error("ICE-crash", 5), 4);
161 }
162
163 #[test]
164 fn test_tier_for_error_type_errors() {
165 let curriculum = AdaptiveCurriculum::new();
166 assert_eq!(curriculum.tier_for_error("E0308", 0), 1);
168 assert_eq!(curriculum.tier_for_error("E0308", 1), 3);
170 assert_eq!(curriculum.tier_for_error("E0277", 2), 3);
171 assert_eq!(curriculum.tier_for_error("E0382", 1), 3);
172 }
173
174 #[test]
175 fn test_tier_for_error_name_resolution() {
176 let curriculum = AdaptiveCurriculum::new();
177 assert_eq!(curriculum.tier_for_error("E0425", 0), 1);
179 assert_eq!(curriculum.tier_for_error("E0425", 1), 2);
180 assert_eq!(curriculum.tier_for_error("E0425", 2), 3);
181 assert_eq!(curriculum.tier_for_error("E0433", 3), 3);
182 }
183
184 #[test]
185 fn test_tier_for_error_default_escalation() {
186 let curriculum = AdaptiveCurriculum::new();
187 assert_eq!(curriculum.tier_for_error("E0001", 0), 1);
189 assert_eq!(curriculum.tier_for_error("E0001", 1), 2);
190 assert_eq!(curriculum.tier_for_error("E0001", 2), 3);
191 assert_eq!(curriculum.tier_for_error("E0001", 5), 3);
192 }
193
194 #[test]
195 fn test_update_class() {
196 let mut curriculum = AdaptiveCurriculum::new();
197
198 curriculum.update_class("E0308", true);
199 assert_eq!(curriculum.class_attempts.get("E0308"), Some(&1));
200 assert!(
202 (curriculum.class_accuracy.get("E0308").expect("key should exist") - 0.1).abs() < 0.001
203 );
204
205 curriculum.update_class("E0308", false);
206 assert_eq!(curriculum.class_attempts.get("E0308"), Some(&2));
207 assert!(
209 (curriculum.class_accuracy.get("E0308").expect("key should exist") - 0.09).abs()
210 < 0.001
211 );
212 }
213
214 #[test]
215 fn test_weight_for_class_unknown() {
216 let curriculum = AdaptiveCurriculum::new();
217 let weight = curriculum.weight_for_class("unknown");
218 assert!((weight - 3.0).abs() < 0.001);
221 }
222
223 #[test]
224 fn test_weight_for_class_known() {
225 let mut curriculum = AdaptiveCurriculum::new();
226
227 for _ in 0..10 {
229 curriculum.update_class("E0308", true);
230 }
231
232 let weight = curriculum.weight_for_class("E0308");
233 assert!(weight < 3.0);
236 assert!(weight >= 1.0);
237 }
238
239 #[test]
240 fn test_curriculum_scheduler_difficulty() {
241 let mut curriculum = AdaptiveCurriculum::new();
242 assert_eq!(curriculum.difficulty(), 0.0);
243
244 curriculum.step(0, 0.5);
245 assert!(curriculum.difficulty() > 0.0);
246 }
247
248 #[test]
249 fn test_curriculum_scheduler_tier() {
250 let mut curriculum = AdaptiveCurriculum::new();
251
252 assert_eq!(curriculum.tier(), 1);
254
255 curriculum.overall_difficulty = 0.3;
257 assert_eq!(curriculum.tier(), 2);
258
259 curriculum.overall_difficulty = 0.6;
261 assert_eq!(curriculum.tier(), 3);
262
263 curriculum.overall_difficulty = 0.8;
265 assert_eq!(curriculum.tier(), 4);
266 }
267
268 #[test]
269 fn test_curriculum_scheduler_step() {
270 let mut curriculum = AdaptiveCurriculum::new();
271
272 curriculum.step(0, 1.0);
273 assert!((curriculum.difficulty() - 0.1).abs() < 0.001);
274
275 curriculum.step(1, 1.0);
276 assert!((curriculum.difficulty() - 0.19).abs() < 0.001);
278 }
279
280 #[test]
281 fn test_curriculum_scheduler_reset() {
282 let mut curriculum = AdaptiveCurriculum::new();
283 curriculum.update_class("E0308", true);
284 curriculum.step(0, 0.5);
285
286 assert!(!curriculum.class_accuracy.is_empty());
287 assert!(curriculum.difficulty() > 0.0);
288
289 curriculum.reset();
290
291 assert!(curriculum.class_accuracy.is_empty());
292 assert!(curriculum.class_attempts.is_empty());
293 assert_eq!(curriculum.difficulty(), 0.0);
294 }
295
296 #[test]
297 fn test_curriculum_scheduler_name() {
298 let curriculum = AdaptiveCurriculum::new();
299 assert_eq!(curriculum.name(), "AdaptiveCurriculum");
300 }
301
302 #[test]
303 fn test_adaptive_curriculum_clone() {
304 let mut curriculum = AdaptiveCurriculum::new();
305 curriculum.update_class("E0308", true);
306
307 let cloned = curriculum.clone();
308 assert_eq!(curriculum.class_attempts, cloned.class_attempts);
309 assert_eq!(curriculum.class_accuracy, cloned.class_accuracy);
310 }
311
312 #[test]
313 fn test_adaptive_curriculum_debug() {
314 let curriculum = AdaptiveCurriculum::new();
315 let debug = format!("{curriculum:?}");
316 assert!(debug.contains("AdaptiveCurriculum"));
317 }
318}