ghostflow_nn/
curriculum_learning.rs

1//! Curriculum Learning
2//!
3//! Implements training strategies that gradually increase task difficulty:
4//! - Easy-to-hard curriculum
5//! - Self-paced learning
6//! - Teacher-student curriculum
7//! - Competence-based curriculum
8//! - Dynamic difficulty adjustment
9
10use std::collections::HashMap;
11
12/// Curriculum learning strategy
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum CurriculumStrategy {
15    /// Fixed curriculum (predefined difficulty order)
16    Fixed,
17    /// Self-paced learning (model chooses samples)
18    SelfPaced,
19    /// Teacher-student (teacher guides difficulty)
20    TeacherStudent,
21    /// Competence-based (adjust based on performance)
22    CompetenceBased,
23    /// Anti-curriculum (hard-to-easy)
24    AntiCurriculum,
25}
26
27/// Difficulty scoring method
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum DifficultyMetric {
30    /// Loss-based difficulty
31    Loss,
32    /// Prediction confidence
33    Confidence,
34    /// Sample complexity (length, features, etc.)
35    Complexity,
36    /// Custom scoring function
37    Custom,
38}
39
40/// Curriculum learning configuration
41#[derive(Debug, Clone)]
42pub struct CurriculumConfig {
43    /// Curriculum strategy
44    pub strategy: CurriculumStrategy,
45    /// Difficulty metric
46    pub difficulty_metric: DifficultyMetric,
47    /// Initial difficulty threshold (0.0 = easiest, 1.0 = hardest)
48    pub initial_threshold: f32,
49    /// Final difficulty threshold
50    pub final_threshold: f32,
51    /// Number of epochs to reach final threshold
52    pub warmup_epochs: usize,
53    /// Pacing function (linear, exponential, etc.)
54    pub pacing_function: PacingFunction,
55    /// Minimum samples per batch
56    pub min_samples_per_batch: usize,
57}
58
59/// Pacing function for curriculum progression
60#[derive(Debug, Clone, Copy, PartialEq)]
61pub enum PacingFunction {
62    /// Linear progression
63    Linear,
64    /// Exponential progression
65    Exponential,
66    /// Step-wise progression
67    Step,
68    /// Root progression (slower at start)
69    Root,
70}
71
72impl Default for CurriculumConfig {
73    fn default() -> Self {
74        CurriculumConfig {
75            strategy: CurriculumStrategy::Fixed,
76            difficulty_metric: DifficultyMetric::Loss,
77            initial_threshold: 0.3,
78            final_threshold: 1.0,
79            warmup_epochs: 10,
80            pacing_function: PacingFunction::Linear,
81            min_samples_per_batch: 8,
82        }
83    }
84}
85
86impl CurriculumConfig {
87    /// Self-paced learning configuration
88    pub fn self_paced(warmup_epochs: usize) -> Self {
89        CurriculumConfig {
90            strategy: CurriculumStrategy::SelfPaced,
91            warmup_epochs,
92            ..Default::default()
93        }
94    }
95    
96    /// Competence-based configuration
97    pub fn competence_based(warmup_epochs: usize) -> Self {
98        CurriculumConfig {
99            strategy: CurriculumStrategy::CompetenceBased,
100            warmup_epochs,
101            ..Default::default()
102        }
103    }
104    
105    /// Anti-curriculum (hard-to-easy)
106    pub fn anti_curriculum() -> Self {
107        CurriculumConfig {
108            strategy: CurriculumStrategy::AntiCurriculum,
109            initial_threshold: 1.0,
110            final_threshold: 0.0,
111            ..Default::default()
112        }
113    }
114}
115
116/// Sample with difficulty score
117#[derive(Debug, Clone)]
118pub struct ScoredSample {
119    /// Sample index
120    pub index: usize,
121    /// Difficulty score (0.0 = easy, 1.0 = hard)
122    pub difficulty: f32,
123    /// Sample loss (if available)
124    pub loss: Option<f32>,
125    /// Sample metadata
126    pub metadata: HashMap<String, f32>,
127}
128
129/// Curriculum learning trainer
130pub struct CurriculumLearning {
131    config: CurriculumConfig,
132    /// Current epoch
133    current_epoch: usize,
134    /// Sample difficulty scores
135    sample_scores: Vec<ScoredSample>,
136    /// Current difficulty threshold
137    current_threshold: f32,
138    /// Performance history
139    performance_history: Vec<f32>,
140}
141
142impl CurriculumLearning {
143    /// Create new curriculum learning trainer
144    pub fn new(config: CurriculumConfig) -> Self {
145        CurriculumLearning {
146            current_threshold: config.initial_threshold,
147            config,
148            current_epoch: 0,
149            sample_scores: Vec::new(),
150            performance_history: Vec::new(),
151        }
152    }
153    
154    /// Initialize sample difficulties
155    pub fn initialize_samples(&mut self, num_samples: usize, difficulties: Vec<f32>) {
156        self.sample_scores = difficulties.into_iter()
157            .enumerate()
158            .map(|(i, difficulty)| ScoredSample {
159                index: i,
160                difficulty,
161                loss: None,
162                metadata: HashMap::new(),
163            })
164            .collect();
165    }
166    
167    /// Update difficulty threshold for current epoch
168    pub fn update_threshold(&mut self) {
169        self.current_threshold = self.compute_threshold(self.current_epoch);
170    }
171    
172    /// Compute threshold based on pacing function
173    fn compute_threshold(&self, epoch: usize) -> f32 {
174        if epoch >= self.config.warmup_epochs {
175            return self.config.final_threshold;
176        }
177        
178        let progress = epoch as f32 / self.config.warmup_epochs as f32;
179        let start = self.config.initial_threshold;
180        let end = self.config.final_threshold;
181        
182        match self.config.pacing_function {
183            PacingFunction::Linear => {
184                start + (end - start) * progress
185            }
186            PacingFunction::Exponential => {
187                start + (end - start) * progress.powi(2)
188            }
189            PacingFunction::Step => {
190                let num_steps = 5;
191                let step = (progress * num_steps as f32).floor() / num_steps as f32;
192                start + (end - start) * step
193            }
194            PacingFunction::Root => {
195                start + (end - start) * progress.sqrt()
196            }
197        }
198    }
199    
200    /// Select samples for current curriculum stage
201    pub fn select_samples(&self) -> Vec<usize> {
202        match self.config.strategy {
203            CurriculumStrategy::Fixed => self.select_fixed_curriculum(),
204            CurriculumStrategy::SelfPaced => self.select_self_paced(),
205            CurriculumStrategy::CompetenceBased => self.select_competence_based(),
206            CurriculumStrategy::TeacherStudent => self.select_teacher_student(),
207            CurriculumStrategy::AntiCurriculum => self.select_anti_curriculum(),
208        }
209    }
210    
211    /// Fixed curriculum: select samples below threshold
212    fn select_fixed_curriculum(&self) -> Vec<usize> {
213        self.sample_scores.iter()
214            .filter(|s| s.difficulty <= self.current_threshold)
215            .map(|s| s.index)
216            .collect()
217    }
218    
219    /// Self-paced learning: select based on loss
220    fn select_self_paced(&self) -> Vec<usize> {
221        let mut scored: Vec<_> = self.sample_scores.iter()
222            .filter(|s| s.loss.is_some())
223            .collect();
224        
225        scored.sort_by(|a, b| {
226            a.loss.unwrap().partial_cmp(&b.loss.unwrap()).unwrap()
227        });
228        
229        let num_select = (scored.len() as f32 * self.current_threshold) as usize;
230        let num_select = num_select.max(self.config.min_samples_per_batch);
231        
232        scored.iter()
233            .take(num_select)
234            .map(|s| s.index)
235            .collect()
236    }
237    
238    /// Competence-based: adjust based on recent performance
239    fn select_competence_based(&self) -> Vec<usize> {
240        let recent_performance = self.get_recent_performance();
241        
242        // Adjust threshold based on performance
243        let adjusted_threshold = if recent_performance > 0.8 {
244            // Doing well, increase difficulty
245            (self.current_threshold + 0.1).min(1.0)
246        } else if recent_performance < 0.5 {
247            // Struggling, decrease difficulty
248            (self.current_threshold - 0.1).max(0.0)
249        } else {
250            self.current_threshold
251        };
252        
253        self.sample_scores.iter()
254            .filter(|s| s.difficulty <= adjusted_threshold)
255            .map(|s| s.index)
256            .collect()
257    }
258    
259    /// Teacher-student curriculum
260    fn select_teacher_student(&self) -> Vec<usize> {
261        // Similar to fixed but with teacher guidance
262        // In practice, teacher would provide difficulty scores
263        self.select_fixed_curriculum()
264    }
265    
266    /// Anti-curriculum: hard-to-easy
267    fn select_anti_curriculum(&self) -> Vec<usize> {
268        self.sample_scores.iter()
269            .filter(|s| s.difficulty >= self.current_threshold)
270            .map(|s| s.index)
271            .collect()
272    }
273    
274    /// Update sample losses after training step
275    pub fn update_sample_losses(&mut self, indices: &[usize], losses: &[f32]) {
276        for (idx, &loss) in indices.iter().zip(losses.iter()) {
277            if let Some(sample) = self.sample_scores.iter_mut().find(|s| s.index == *idx) {
278                sample.loss = Some(loss);
279            }
280        }
281    }
282    
283    /// Update performance history
284    pub fn update_performance(&mut self, performance: f32) {
285        self.performance_history.push(performance);
286    }
287    
288    /// Get recent performance (average of last N epochs)
289    fn get_recent_performance(&self) -> f32 {
290        let window = 3;
291        let recent = self.performance_history.iter()
292            .rev()
293            .take(window)
294            .copied()
295            .collect::<Vec<_>>();
296        
297        if recent.is_empty() {
298            0.5 // Default
299        } else {
300            recent.iter().sum::<f32>() / recent.len() as f32
301        }
302    }
303    
304    /// Advance to next epoch
305    pub fn next_epoch(&mut self) {
306        self.current_epoch += 1;
307        self.update_threshold();
308    }
309    
310    /// Get current statistics
311    pub fn get_stats(&self) -> CurriculumStats {
312        let selected = self.select_samples();
313        let avg_difficulty = if !selected.is_empty() {
314            selected.iter()
315                .filter_map(|&idx| self.sample_scores.get(idx))
316                .map(|s| s.difficulty)
317                .sum::<f32>() / selected.len() as f32
318        } else {
319            0.0
320        };
321        
322        CurriculumStats {
323            current_epoch: self.current_epoch,
324            current_threshold: self.current_threshold,
325            num_selected_samples: selected.len(),
326            total_samples: self.sample_scores.len(),
327            avg_difficulty: avg_difficulty,
328            recent_performance: self.get_recent_performance(),
329        }
330    }
331}
332
333/// Curriculum learning statistics
334#[derive(Debug, Clone)]
335pub struct CurriculumStats {
336    pub current_epoch: usize,
337    pub current_threshold: f32,
338    pub num_selected_samples: usize,
339    pub total_samples: usize,
340    pub avg_difficulty: f32,
341    pub recent_performance: f32,
342}
343
344/// Difficulty scorer for samples
345pub struct DifficultyScorer {
346    metric: DifficultyMetric,
347}
348
349impl DifficultyScorer {
350    /// Create new difficulty scorer
351    pub fn new(metric: DifficultyMetric) -> Self {
352        DifficultyScorer { metric }
353    }
354    
355    /// Score sample difficulty
356    pub fn score(&self, loss: f32, confidence: f32, complexity: f32) -> f32 {
357        match self.metric {
358            DifficultyMetric::Loss => {
359                // Normalize loss to [0, 1]
360                loss.min(10.0) / 10.0
361            }
362            DifficultyMetric::Confidence => {
363                // Low confidence = high difficulty
364                1.0 - confidence
365            }
366            DifficultyMetric::Complexity => {
367                complexity
368            }
369            DifficultyMetric::Custom => {
370                // Combine multiple metrics
371                (loss * 0.4 + (1.0 - confidence) * 0.3 + complexity * 0.3).min(1.0)
372            }
373        }
374    }
375    
376    /// Batch score multiple samples
377    pub fn score_batch(&self, losses: &[f32], confidences: &[f32], complexities: &[f32]) -> Vec<f32> {
378        losses.iter()
379            .zip(confidences.iter())
380            .zip(complexities.iter())
381            .map(|((&loss, &conf), &comp)| self.score(loss, conf, comp))
382            .collect()
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    
390    #[test]
391    fn test_curriculum_config() {
392        let config = CurriculumConfig::default();
393        assert_eq!(config.strategy, CurriculumStrategy::Fixed);
394        assert_eq!(config.initial_threshold, 0.3);
395        
396        let self_paced = CurriculumConfig::self_paced(20);
397        assert_eq!(self_paced.strategy, CurriculumStrategy::SelfPaced);
398        assert_eq!(self_paced.warmup_epochs, 20);
399    }
400    
401    #[test]
402    fn test_curriculum_initialization() {
403        let config = CurriculumConfig::default();
404        let mut curriculum = CurriculumLearning::new(config);
405        
406        let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
407        curriculum.initialize_samples(5, difficulties);
408        
409        assert_eq!(curriculum.sample_scores.len(), 5);
410    }
411    
412    #[test]
413    fn test_threshold_computation() {
414        let config = CurriculumConfig {
415            initial_threshold: 0.2,
416            final_threshold: 1.0,
417            warmup_epochs: 10,
418            pacing_function: PacingFunction::Linear,
419            ..Default::default()
420        };
421        let curriculum = CurriculumLearning::new(config);
422        
423        let threshold_0 = curriculum.compute_threshold(0);
424        let threshold_5 = curriculum.compute_threshold(5);
425        let threshold_10 = curriculum.compute_threshold(10);
426        
427        assert_eq!(threshold_0, 0.2);
428        assert!((threshold_5 - 0.6).abs() < 0.01);
429        assert_eq!(threshold_10, 1.0);
430    }
431    
432    #[test]
433    fn test_fixed_curriculum_selection() {
434        let config = CurriculumConfig {
435            initial_threshold: 0.5,
436            ..Default::default()
437        };
438        let mut curriculum = CurriculumLearning::new(config);
439        
440        let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
441        curriculum.initialize_samples(5, difficulties);
442        
443        let selected = curriculum.select_samples();
444        assert_eq!(selected.len(), 3); // 0.1, 0.3, 0.5
445    }
446    
447    #[test]
448    fn test_self_paced_selection() {
449        let config = CurriculumConfig {
450            strategy: CurriculumStrategy::SelfPaced,
451            initial_threshold: 0.6,
452            ..Default::default()
453        };
454        let mut curriculum = CurriculumLearning::new(config);
455        
456        let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
457        curriculum.initialize_samples(5, difficulties);
458        
459        // Update losses
460        curriculum.update_sample_losses(&[0, 1, 2, 3, 4], &[0.5, 0.3, 0.8, 0.2, 0.9]);
461        
462        let selected = curriculum.select_samples();
463        assert!(selected.len() >= 2); // Should select easier samples
464    }
465    
466    #[test]
467    fn test_pacing_functions() {
468        let linear_config = CurriculumConfig {
469            pacing_function: PacingFunction::Linear,
470            warmup_epochs: 10,
471            ..Default::default()
472        };
473        let linear_curriculum = CurriculumLearning::new(linear_config);
474        
475        let exp_config = CurriculumConfig {
476            pacing_function: PacingFunction::Exponential,
477            warmup_epochs: 10,
478            ..Default::default()
479        };
480        let exp_curriculum = CurriculumLearning::new(exp_config);
481        
482        let linear_mid = linear_curriculum.compute_threshold(5);
483        let exp_mid = exp_curriculum.compute_threshold(5);
484        
485        // Exponential should be slower at midpoint
486        assert!(exp_mid < linear_mid);
487    }
488    
489    #[test]
490    fn test_competence_based_adjustment() {
491        let config = CurriculumConfig {
492            strategy: CurriculumStrategy::CompetenceBased,
493            ..Default::default()
494        };
495        let mut curriculum = CurriculumLearning::new(config);
496        
497        let difficulties = vec![0.2, 0.4, 0.6, 0.8];
498        curriculum.initialize_samples(4, difficulties);
499        
500        // Simulate good performance
501        curriculum.update_performance(0.9);
502        curriculum.update_performance(0.85);
503        curriculum.update_performance(0.88);
504        
505        let selected = curriculum.select_samples();
506        // Should select more samples due to good performance
507        assert!(selected.len() > 0);
508    }
509    
510    #[test]
511    fn test_anti_curriculum() {
512        let config = CurriculumConfig::anti_curriculum();
513        let mut curriculum = CurriculumLearning::new(config);
514        
515        assert_eq!(curriculum.current_threshold, 1.0);
516        
517        let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
518        curriculum.initialize_samples(5, difficulties);
519        
520        // Update threshold to select some samples
521        curriculum.current_threshold = 0.8;
522        
523        let selected = curriculum.select_samples();
524        // Should select hardest samples (>= 0.8)
525        assert!(selected.contains(&4)); // 0.9 difficulty
526        assert!(!selected.contains(&0)); // 0.1 difficulty should not be selected
527    }
528    
529    #[test]
530    fn test_difficulty_scorer() {
531        let scorer = DifficultyScorer::new(DifficultyMetric::Loss);
532        
533        let score = scorer.score(2.0, 0.8, 0.5);
534        assert!(score >= 0.0 && score <= 1.0);
535        
536        let batch_scores = scorer.score_batch(
537            &[1.0, 2.0, 3.0],
538            &[0.9, 0.7, 0.5],
539            &[0.3, 0.5, 0.7],
540        );
541        assert_eq!(batch_scores.len(), 3);
542    }
543    
544    #[test]
545    fn test_epoch_progression() {
546        let config = CurriculumConfig {
547            initial_threshold: 0.2,
548            final_threshold: 1.0,
549            warmup_epochs: 5,
550            ..Default::default()
551        };
552        let mut curriculum = CurriculumLearning::new(config);
553        
554        assert_eq!(curriculum.current_epoch, 0);
555        assert_eq!(curriculum.current_threshold, 0.2);
556        
557        curriculum.next_epoch();
558        assert_eq!(curriculum.current_epoch, 1);
559        assert!(curriculum.current_threshold > 0.2);
560        
561        for _ in 0..10 {
562            curriculum.next_epoch();
563        }
564        assert_eq!(curriculum.current_threshold, 1.0);
565    }
566    
567    #[test]
568    fn test_curriculum_stats() {
569        let config = CurriculumConfig::default();
570        let mut curriculum = CurriculumLearning::new(config);
571        
572        let difficulties = vec![0.1, 0.2, 0.3, 0.4, 0.5];
573        curriculum.initialize_samples(5, difficulties);
574        
575        let stats = curriculum.get_stats();
576        assert_eq!(stats.total_samples, 5);
577        assert!(stats.num_selected_samples > 0);
578    }
579}