1use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum CurriculumStrategy {
15 Fixed,
17 SelfPaced,
19 TeacherStudent,
21 CompetenceBased,
23 AntiCurriculum,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum DifficultyMetric {
30 Loss,
32 Confidence,
34 Complexity,
36 Custom,
38}
39
40#[derive(Debug, Clone)]
42pub struct CurriculumConfig {
43 pub strategy: CurriculumStrategy,
45 pub difficulty_metric: DifficultyMetric,
47 pub initial_threshold: f32,
49 pub final_threshold: f32,
51 pub warmup_epochs: usize,
53 pub pacing_function: PacingFunction,
55 pub min_samples_per_batch: usize,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq)]
61pub enum PacingFunction {
62 Linear,
64 Exponential,
66 Step,
68 Root,
70}
71
72impl Default for CurriculumConfig {
73 fn default() -> Self {
74 CurriculumConfig {
75 strategy: CurriculumStrategy::Fixed,
76 difficulty_metric: DifficultyMetric::Loss,
77 initial_threshold: 0.3,
78 final_threshold: 1.0,
79 warmup_epochs: 10,
80 pacing_function: PacingFunction::Linear,
81 min_samples_per_batch: 8,
82 }
83 }
84}
85
86impl CurriculumConfig {
87 pub fn self_paced(warmup_epochs: usize) -> Self {
89 CurriculumConfig {
90 strategy: CurriculumStrategy::SelfPaced,
91 warmup_epochs,
92 ..Default::default()
93 }
94 }
95
96 pub fn competence_based(warmup_epochs: usize) -> Self {
98 CurriculumConfig {
99 strategy: CurriculumStrategy::CompetenceBased,
100 warmup_epochs,
101 ..Default::default()
102 }
103 }
104
105 pub fn anti_curriculum() -> Self {
107 CurriculumConfig {
108 strategy: CurriculumStrategy::AntiCurriculum,
109 initial_threshold: 1.0,
110 final_threshold: 0.0,
111 ..Default::default()
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct ScoredSample {
119 pub index: usize,
121 pub difficulty: f32,
123 pub loss: Option<f32>,
125 pub metadata: HashMap<String, f32>,
127}
128
129pub struct CurriculumLearning {
131 config: CurriculumConfig,
132 current_epoch: usize,
134 sample_scores: Vec<ScoredSample>,
136 current_threshold: f32,
138 performance_history: Vec<f32>,
140}
141
142impl CurriculumLearning {
143 pub fn new(config: CurriculumConfig) -> Self {
145 CurriculumLearning {
146 current_threshold: config.initial_threshold,
147 config,
148 current_epoch: 0,
149 sample_scores: Vec::new(),
150 performance_history: Vec::new(),
151 }
152 }
153
154 pub fn initialize_samples(&mut self, num_samples: usize, difficulties: Vec<f32>) {
156 self.sample_scores = difficulties.into_iter()
157 .enumerate()
158 .map(|(i, difficulty)| ScoredSample {
159 index: i,
160 difficulty,
161 loss: None,
162 metadata: HashMap::new(),
163 })
164 .collect();
165 }
166
167 pub fn update_threshold(&mut self) {
169 self.current_threshold = self.compute_threshold(self.current_epoch);
170 }
171
172 fn compute_threshold(&self, epoch: usize) -> f32 {
174 if epoch >= self.config.warmup_epochs {
175 return self.config.final_threshold;
176 }
177
178 let progress = epoch as f32 / self.config.warmup_epochs as f32;
179 let start = self.config.initial_threshold;
180 let end = self.config.final_threshold;
181
182 match self.config.pacing_function {
183 PacingFunction::Linear => {
184 start + (end - start) * progress
185 }
186 PacingFunction::Exponential => {
187 start + (end - start) * progress.powi(2)
188 }
189 PacingFunction::Step => {
190 let num_steps = 5;
191 let step = (progress * num_steps as f32).floor() / num_steps as f32;
192 start + (end - start) * step
193 }
194 PacingFunction::Root => {
195 start + (end - start) * progress.sqrt()
196 }
197 }
198 }
199
200 pub fn select_samples(&self) -> Vec<usize> {
202 match self.config.strategy {
203 CurriculumStrategy::Fixed => self.select_fixed_curriculum(),
204 CurriculumStrategy::SelfPaced => self.select_self_paced(),
205 CurriculumStrategy::CompetenceBased => self.select_competence_based(),
206 CurriculumStrategy::TeacherStudent => self.select_teacher_student(),
207 CurriculumStrategy::AntiCurriculum => self.select_anti_curriculum(),
208 }
209 }
210
211 fn select_fixed_curriculum(&self) -> Vec<usize> {
213 self.sample_scores.iter()
214 .filter(|s| s.difficulty <= self.current_threshold)
215 .map(|s| s.index)
216 .collect()
217 }
218
219 fn select_self_paced(&self) -> Vec<usize> {
221 let mut scored: Vec<_> = self.sample_scores.iter()
222 .filter(|s| s.loss.is_some())
223 .collect();
224
225 scored.sort_by(|a, b| {
226 a.loss.unwrap().partial_cmp(&b.loss.unwrap()).unwrap()
227 });
228
229 let num_select = (scored.len() as f32 * self.current_threshold) as usize;
230 let num_select = num_select.max(self.config.min_samples_per_batch);
231
232 scored.iter()
233 .take(num_select)
234 .map(|s| s.index)
235 .collect()
236 }
237
238 fn select_competence_based(&self) -> Vec<usize> {
240 let recent_performance = self.get_recent_performance();
241
242 let adjusted_threshold = if recent_performance > 0.8 {
244 (self.current_threshold + 0.1).min(1.0)
246 } else if recent_performance < 0.5 {
247 (self.current_threshold - 0.1).max(0.0)
249 } else {
250 self.current_threshold
251 };
252
253 self.sample_scores.iter()
254 .filter(|s| s.difficulty <= adjusted_threshold)
255 .map(|s| s.index)
256 .collect()
257 }
258
259 fn select_teacher_student(&self) -> Vec<usize> {
261 self.select_fixed_curriculum()
264 }
265
266 fn select_anti_curriculum(&self) -> Vec<usize> {
268 self.sample_scores.iter()
269 .filter(|s| s.difficulty >= self.current_threshold)
270 .map(|s| s.index)
271 .collect()
272 }
273
274 pub fn update_sample_losses(&mut self, indices: &[usize], losses: &[f32]) {
276 for (idx, &loss) in indices.iter().zip(losses.iter()) {
277 if let Some(sample) = self.sample_scores.iter_mut().find(|s| s.index == *idx) {
278 sample.loss = Some(loss);
279 }
280 }
281 }
282
283 pub fn update_performance(&mut self, performance: f32) {
285 self.performance_history.push(performance);
286 }
287
288 fn get_recent_performance(&self) -> f32 {
290 let window = 3;
291 let recent = self.performance_history.iter()
292 .rev()
293 .take(window)
294 .copied()
295 .collect::<Vec<_>>();
296
297 if recent.is_empty() {
298 0.5 } else {
300 recent.iter().sum::<f32>() / recent.len() as f32
301 }
302 }
303
304 pub fn next_epoch(&mut self) {
306 self.current_epoch += 1;
307 self.update_threshold();
308 }
309
310 pub fn get_stats(&self) -> CurriculumStats {
312 let selected = self.select_samples();
313 let avg_difficulty = if !selected.is_empty() {
314 selected.iter()
315 .filter_map(|&idx| self.sample_scores.get(idx))
316 .map(|s| s.difficulty)
317 .sum::<f32>() / selected.len() as f32
318 } else {
319 0.0
320 };
321
322 CurriculumStats {
323 current_epoch: self.current_epoch,
324 current_threshold: self.current_threshold,
325 num_selected_samples: selected.len(),
326 total_samples: self.sample_scores.len(),
327 avg_difficulty: avg_difficulty,
328 recent_performance: self.get_recent_performance(),
329 }
330 }
331}
332
333#[derive(Debug, Clone)]
335pub struct CurriculumStats {
336 pub current_epoch: usize,
337 pub current_threshold: f32,
338 pub num_selected_samples: usize,
339 pub total_samples: usize,
340 pub avg_difficulty: f32,
341 pub recent_performance: f32,
342}
343
344pub struct DifficultyScorer {
346 metric: DifficultyMetric,
347}
348
349impl DifficultyScorer {
350 pub fn new(metric: DifficultyMetric) -> Self {
352 DifficultyScorer { metric }
353 }
354
355 pub fn score(&self, loss: f32, confidence: f32, complexity: f32) -> f32 {
357 match self.metric {
358 DifficultyMetric::Loss => {
359 loss.min(10.0) / 10.0
361 }
362 DifficultyMetric::Confidence => {
363 1.0 - confidence
365 }
366 DifficultyMetric::Complexity => {
367 complexity
368 }
369 DifficultyMetric::Custom => {
370 (loss * 0.4 + (1.0 - confidence) * 0.3 + complexity * 0.3).min(1.0)
372 }
373 }
374 }
375
376 pub fn score_batch(&self, losses: &[f32], confidences: &[f32], complexities: &[f32]) -> Vec<f32> {
378 losses.iter()
379 .zip(confidences.iter())
380 .zip(complexities.iter())
381 .map(|((&loss, &conf), &comp)| self.score(loss, conf, comp))
382 .collect()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_curriculum_config() {
392 let config = CurriculumConfig::default();
393 assert_eq!(config.strategy, CurriculumStrategy::Fixed);
394 assert_eq!(config.initial_threshold, 0.3);
395
396 let self_paced = CurriculumConfig::self_paced(20);
397 assert_eq!(self_paced.strategy, CurriculumStrategy::SelfPaced);
398 assert_eq!(self_paced.warmup_epochs, 20);
399 }
400
401 #[test]
402 fn test_curriculum_initialization() {
403 let config = CurriculumConfig::default();
404 let mut curriculum = CurriculumLearning::new(config);
405
406 let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
407 curriculum.initialize_samples(5, difficulties);
408
409 assert_eq!(curriculum.sample_scores.len(), 5);
410 }
411
412 #[test]
413 fn test_threshold_computation() {
414 let config = CurriculumConfig {
415 initial_threshold: 0.2,
416 final_threshold: 1.0,
417 warmup_epochs: 10,
418 pacing_function: PacingFunction::Linear,
419 ..Default::default()
420 };
421 let curriculum = CurriculumLearning::new(config);
422
423 let threshold_0 = curriculum.compute_threshold(0);
424 let threshold_5 = curriculum.compute_threshold(5);
425 let threshold_10 = curriculum.compute_threshold(10);
426
427 assert_eq!(threshold_0, 0.2);
428 assert!((threshold_5 - 0.6).abs() < 0.01);
429 assert_eq!(threshold_10, 1.0);
430 }
431
432 #[test]
433 fn test_fixed_curriculum_selection() {
434 let config = CurriculumConfig {
435 initial_threshold: 0.5,
436 ..Default::default()
437 };
438 let mut curriculum = CurriculumLearning::new(config);
439
440 let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
441 curriculum.initialize_samples(5, difficulties);
442
443 let selected = curriculum.select_samples();
444 assert_eq!(selected.len(), 3); }
446
447 #[test]
448 fn test_self_paced_selection() {
449 let config = CurriculumConfig {
450 strategy: CurriculumStrategy::SelfPaced,
451 initial_threshold: 0.6,
452 ..Default::default()
453 };
454 let mut curriculum = CurriculumLearning::new(config);
455
456 let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
457 curriculum.initialize_samples(5, difficulties);
458
459 curriculum.update_sample_losses(&[0, 1, 2, 3, 4], &[0.5, 0.3, 0.8, 0.2, 0.9]);
461
462 let selected = curriculum.select_samples();
463 assert!(selected.len() >= 2); }
465
466 #[test]
467 fn test_pacing_functions() {
468 let linear_config = CurriculumConfig {
469 pacing_function: PacingFunction::Linear,
470 warmup_epochs: 10,
471 ..Default::default()
472 };
473 let linear_curriculum = CurriculumLearning::new(linear_config);
474
475 let exp_config = CurriculumConfig {
476 pacing_function: PacingFunction::Exponential,
477 warmup_epochs: 10,
478 ..Default::default()
479 };
480 let exp_curriculum = CurriculumLearning::new(exp_config);
481
482 let linear_mid = linear_curriculum.compute_threshold(5);
483 let exp_mid = exp_curriculum.compute_threshold(5);
484
485 assert!(exp_mid < linear_mid);
487 }
488
489 #[test]
490 fn test_competence_based_adjustment() {
491 let config = CurriculumConfig {
492 strategy: CurriculumStrategy::CompetenceBased,
493 ..Default::default()
494 };
495 let mut curriculum = CurriculumLearning::new(config);
496
497 let difficulties = vec![0.2, 0.4, 0.6, 0.8];
498 curriculum.initialize_samples(4, difficulties);
499
500 curriculum.update_performance(0.9);
502 curriculum.update_performance(0.85);
503 curriculum.update_performance(0.88);
504
505 let selected = curriculum.select_samples();
506 assert!(selected.len() > 0);
508 }
509
510 #[test]
511 fn test_anti_curriculum() {
512 let config = CurriculumConfig::anti_curriculum();
513 let mut curriculum = CurriculumLearning::new(config);
514
515 assert_eq!(curriculum.current_threshold, 1.0);
516
517 let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
518 curriculum.initialize_samples(5, difficulties);
519
520 curriculum.current_threshold = 0.8;
522
523 let selected = curriculum.select_samples();
524 assert!(selected.contains(&4)); assert!(!selected.contains(&0)); }
528
529 #[test]
530 fn test_difficulty_scorer() {
531 let scorer = DifficultyScorer::new(DifficultyMetric::Loss);
532
533 let score = scorer.score(2.0, 0.8, 0.5);
534 assert!(score >= 0.0 && score <= 1.0);
535
536 let batch_scores = scorer.score_batch(
537 &[1.0, 2.0, 3.0],
538 &[0.9, 0.7, 0.5],
539 &[0.3, 0.5, 0.7],
540 );
541 assert_eq!(batch_scores.len(), 3);
542 }
543
544 #[test]
545 fn test_epoch_progression() {
546 let config = CurriculumConfig {
547 initial_threshold: 0.2,
548 final_threshold: 1.0,
549 warmup_epochs: 5,
550 ..Default::default()
551 };
552 let mut curriculum = CurriculumLearning::new(config);
553
554 assert_eq!(curriculum.current_epoch, 0);
555 assert_eq!(curriculum.current_threshold, 0.2);
556
557 curriculum.next_epoch();
558 assert_eq!(curriculum.current_epoch, 1);
559 assert!(curriculum.current_threshold > 0.2);
560
561 for _ in 0..10 {
562 curriculum.next_epoch();
563 }
564 assert_eq!(curriculum.current_threshold, 1.0);
565 }
566
567 #[test]
568 fn test_curriculum_stats() {
569 let config = CurriculumConfig::default();
570 let mut curriculum = CurriculumLearning::new(config);
571
572 let difficulties = vec![0.1, 0.2, 0.3, 0.4, 0.5];
573 curriculum.initialize_samples(5, difficulties);
574
575 let stats = curriculum.get_stats();
576 assert_eq!(stats.total_samples, 5);
577 assert!(stats.num_selected_samples > 0);
578 }
579}