Skip to main content

optirs_learned/transformer/training/
curriculum.rs

1// Curriculum learning strategies for transformer optimization
2//
3// This module implements various curriculum learning approaches that progressively
4// introduce optimization challenges of increasing difficulty to improve learning.
5
6#[allow(dead_code)]
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13
14/// Curriculum learning strategies
15#[derive(Debug, Clone, Copy)]
16pub enum CurriculumStrategy {
17    /// No curriculum learning
18    None,
19    /// Difficulty-based progression
20    DifficultyProgression,
21    /// Diversity-based curriculum
22    DiversityBased,
23    /// Self-paced learning
24    SelfPaced,
25    /// Teacher-student curriculum
26    TeacherStudent,
27    /// Adversarial curriculum
28    Adversarial,
29    /// Multi-task curriculum
30    MultiTask,
31    /// Adaptive curriculum
32    Adaptive,
33}
34
35/// Curriculum learning manager
36#[derive(Debug, Clone)]
37pub struct CurriculumLearner<T: Float + Debug + Send + Sync + 'static> {
38    /// Curriculum strategy
39    strategy: CurriculumStrategy,
40
41    /// Curriculum parameters
42    curriculum_params: CurriculumParams<T>,
43
44    /// Task difficulty estimator
45    difficulty_estimator: TaskDifficultyEstimator<T>,
46
47    /// Learning progress tracker
48    progress_tracker: LearningProgressTracker<T>,
49
50    /// Current curriculum state
51    curriculum_state: CurriculumState<T>,
52
53    /// Task scheduling policy
54    task_scheduler: TaskScheduler<T>,
55
56    /// Performance history
57    performance_history: VecDeque<PerformanceRecord<T>>,
58}
59
60/// Curriculum parameters
61#[derive(Debug, Clone)]
62pub struct CurriculumParams<T: Float + Debug + Send + Sync + 'static> {
63    /// Initial difficulty threshold
64    initial_difficulty: T,
65
66    /// Maximum difficulty threshold
67    max_difficulty: T,
68
69    /// Difficulty increment per epoch
70    difficulty_increment: T,
71
72    /// Performance threshold for progression
73    progression_threshold: T,
74
75    /// Patience for difficulty increases
76    patience: usize,
77
78    /// Self-pacing factor
79    self_pacing_factor: T,
80
81    /// Diversity weight in curriculum
82    diversity_weight: T,
83
84    /// Teacher model confidence threshold
85    teacher_confidence: T,
86}
87
88/// Task difficulty estimator
89#[derive(Debug, Clone)]
90pub struct TaskDifficultyEstimator<T: Float + Debug + Send + Sync + 'static> {
91    /// Learned difficulty predictor
92    difficulty_predictor: DifficultyPredictor<T>,
93
94    /// Feature extractors for tasks
95    task_features: HashMap<String, Array1<T>>,
96
97    /// Historical difficulty measurements
98    difficulty_history: HashMap<String, Vec<T>>,
99
100    /// Difficulty estimation method
101    estimation_method: DifficultyEstimationMethod,
102}
103
104/// Learning progress tracker
105#[derive(Debug, Clone)]
106pub struct LearningProgressTracker<T: Float + Debug + Send + Sync + 'static> {
107    /// Performance metrics over time
108    performance_timeline: VecDeque<T>,
109
110    /// Learning rate estimates
111    learning_rates: VecDeque<T>,
112
113    /// Competency levels for different task types
114    competency_levels: HashMap<String, T>,
115
116    /// Progress milestones
117    milestones: Vec<ProgressMilestone<T>>,
118}
119
120/// Current curriculum state
121#[derive(Debug, Clone)]
122pub struct CurriculumState<T: Float + Debug + Send + Sync + 'static> {
123    /// Current difficulty level
124    current_difficulty: T,
125
126    /// Active task types
127    active_tasks: Vec<String>,
128
129    /// Recent performance
130    recent_performance: T,
131
132    /// Epochs since last difficulty increase
133    epochs_since_increase: usize,
134
135    /// Current learning phase
136    learning_phase: LearningPhase,
137
138    /// Adaptive parameters
139    adaptive_params: HashMap<String, T>,
140}
141
142/// Task scheduler for curriculum
143#[derive(Debug, Clone)]
144pub struct TaskScheduler<T: Float + Debug + Send + Sync + 'static> {
145    /// Task queue with priorities
146    task_queue: VecDeque<ScheduledTask<T>>,
147
148    /// Scheduling policy
149    scheduling_policy: SchedulingPolicy,
150
151    /// Task weights for sampling
152    task_weights: HashMap<String, T>,
153
154    /// Load balancing factors
155    load_balancing: HashMap<String, T>,
156}
157
158/// Performance record for curriculum tracking
159#[derive(Debug, Clone)]
160pub struct PerformanceRecord<T: Float + Debug + Send + Sync + 'static> {
161    /// Task identifier
162    task_id: String,
163
164    /// Performance score
165    performance: T,
166
167    /// Difficulty level when task was attempted
168    difficulty_level: T,
169
170    /// Number of training steps
171    training_steps: usize,
172
173    /// Timestamp
174    timestamp: usize,
175
176    /// Additional metrics
177    metrics: HashMap<String, T>,
178}
179
180/// Difficulty predictor network
181#[derive(Debug, Clone)]
182pub struct DifficultyPredictor<T: Float + Debug + Send + Sync + 'static> {
183    /// Input features dimension
184    input_dim: usize,
185
186    /// Hidden layers
187    hidden_layers: Vec<Array2<T>>,
188
189    /// Output layer
190    output_layer: Array1<T>,
191
192    /// Training history
193    training_history: Vec<(Array1<T>, T)>,
194}
195
196/// Progress milestone
197#[derive(Debug, Clone)]
198pub struct ProgressMilestone<T: Float + Debug + Send + Sync + 'static> {
199    /// Milestone name
200    name: String,
201
202    /// Performance threshold
203    threshold: T,
204
205    /// Whether milestone is achieved
206    achieved: bool,
207
208    /// Achievement timestamp
209    achieved_at: Option<usize>,
210}
211
212/// Scheduled task with priority
213#[derive(Debug, Clone)]
214pub struct ScheduledTask<T: Float + Debug + Send + Sync + 'static> {
215    /// Task identifier
216    task_id: String,
217
218    /// Task priority
219    priority: T,
220
221    /// Estimated difficulty
222    difficulty: T,
223
224    /// Required competency level
225    required_competency: T,
226
227    /// Task parameters
228    parameters: HashMap<String, T>,
229}
230
231/// Learning phases in curriculum
232#[derive(Debug, Clone, Copy)]
233pub enum LearningPhase {
234    /// Initial exploration phase
235    Exploration,
236    /// Skill building phase
237    SkillBuilding,
238    /// Mastery phase
239    Mastery,
240    /// Transfer phase
241    Transfer,
242    /// Generalization phase
243    Generalization,
244}
245
246/// Difficulty estimation methods
247#[derive(Debug, Clone, Copy)]
248pub enum DifficultyEstimationMethod {
249    /// Performance-based estimation
250    PerformanceBased,
251    /// Feature-based prediction
252    FeatureBased,
253    /// Gradient-based estimation
254    GradientBased,
255    /// Uncertainty-based estimation
256    UncertaintyBased,
257    /// Multi-modal estimation
258    MultiModal,
259}
260
261/// Scheduling policies
262#[derive(Debug, Clone, Copy)]
263pub enum SchedulingPolicy {
264    /// First-in-first-out
265    FIFO,
266    /// Priority-based scheduling
267    Priority,
268    /// Weighted random sampling
269    WeightedRandom,
270    /// Balanced sampling
271    Balanced,
272    /// Adaptive scheduling
273    Adaptive,
274}
275
276impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> CurriculumLearner<T> {
277    /// Create new curriculum learner
278    pub fn new(strategy: CurriculumStrategy) -> Result<Self> {
279        Ok(Self {
280            strategy,
281            curriculum_params: CurriculumParams::default(),
282            difficulty_estimator: TaskDifficultyEstimator::new()?,
283            progress_tracker: LearningProgressTracker::new(),
284            curriculum_state: CurriculumState::new()?,
285            task_scheduler: TaskScheduler::new()?,
286            performance_history: VecDeque::new(),
287        })
288    }
289
290    /// Update curriculum based on performance
291    pub fn update_curriculum(
292        &mut self,
293        task_id: &str,
294        performance: T,
295        training_steps: usize,
296    ) -> Result<()> {
297        // Record performance
298        let record = PerformanceRecord {
299            task_id: task_id.to_string(),
300            performance,
301            difficulty_level: self.curriculum_state.current_difficulty,
302            training_steps,
303            timestamp: self.performance_history.len(),
304            metrics: HashMap::new(),
305        };
306
307        self.performance_history.push_back(record);
308        if self.performance_history.len() > 1000 {
309            self.performance_history.pop_front();
310        }
311
312        // Update progress tracker
313        self.progress_tracker.update_performance(performance);
314
315        // Update curriculum state based on strategy
316        match self.strategy {
317            CurriculumStrategy::None => Ok(()),
318            CurriculumStrategy::DifficultyProgression => {
319                self.update_difficulty_progression(performance)
320            }
321            CurriculumStrategy::SelfPaced => self.update_self_paced_curriculum(performance),
322            CurriculumStrategy::Adaptive => self.update_adaptive_curriculum(task_id, performance),
323            _ => self.update_generic_curriculum(performance),
324        }
325    }
326
327    /// Get next task according to curriculum
328    pub fn get_next_task(&mut self) -> Result<Option<String>> {
329        match self.strategy {
330            CurriculumStrategy::None => Ok(None),
331            _ => Ok(self.task_scheduler.schedule_next_task()),
332        }
333    }
334
335    /// Update difficulty progression curriculum
336    fn update_difficulty_progression(&mut self, performance: T) -> Result<()> {
337        self.curriculum_state.recent_performance = performance;
338
339        if performance > self.curriculum_params.progression_threshold {
340            self.curriculum_state.epochs_since_increase += 1;
341
342            if self.curriculum_state.epochs_since_increase >= self.curriculum_params.patience {
343                // Increase difficulty
344                let new_difficulty = (self.curriculum_state.current_difficulty
345                    + self.curriculum_params.difficulty_increment)
346                    .min(self.curriculum_params.max_difficulty);
347
348                self.curriculum_state.current_difficulty = new_difficulty;
349                self.curriculum_state.epochs_since_increase = 0;
350
351                // Update learning phase
352                self.update_learning_phase();
353            }
354        } else {
355            self.curriculum_state.epochs_since_increase = 0;
356        }
357
358        Ok(())
359    }
360
361    /// Update self-paced curriculum
362    fn update_self_paced_curriculum(&mut self, performance: T) -> Result<()> {
363        let pacing_factor = self.curriculum_params.self_pacing_factor;
364
365        // Adjust difficulty based on performance
366        let performance_ratio = performance / self.get_expected_performance();
367        let difficulty_adjustment = (performance_ratio - T::one()) * pacing_factor;
368
369        let new_difficulty = (self.curriculum_state.current_difficulty + difficulty_adjustment)
370            .max(self.curriculum_params.initial_difficulty)
371            .min(self.curriculum_params.max_difficulty);
372
373        self.curriculum_state.current_difficulty = new_difficulty;
374
375        Ok(())
376    }
377
378    /// Update adaptive curriculum
379    fn update_adaptive_curriculum(&mut self, task_id: &str, performance: T) -> Result<()> {
380        // Update task-specific competency
381        let competency = self
382            .progress_tracker
383            .competency_levels
384            .get(task_id)
385            .copied()
386            .unwrap_or(T::zero());
387
388        let alpha = scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
389        let new_competency = competency * (T::one() - alpha) + performance * alpha;
390
391        self.progress_tracker
392            .competency_levels
393            .insert(task_id.to_string(), new_competency);
394
395        // Adapt curriculum parameters
396        self.adapt_curriculum_parameters(task_id, performance)?;
397
398        Ok(())
399    }
400
401    /// Generic curriculum update
402    fn update_generic_curriculum(&mut self, performance: T) -> Result<()> {
403        // Simple linear progression based on performance
404        if performance > scirs2_core::numeric::NumCast::from(0.8).unwrap_or_else(|| T::zero()) {
405            let increment = self.curriculum_params.difficulty_increment
406                * scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero());
407            self.curriculum_state.current_difficulty = (self.curriculum_state.current_difficulty
408                + increment)
409                .min(self.curriculum_params.max_difficulty);
410        }
411
412        Ok(())
413    }
414
415    /// Update learning phase
416    fn update_learning_phase(&mut self) {
417        let difficulty_ratio =
418            self.curriculum_state.current_difficulty / self.curriculum_params.max_difficulty;
419
420        self.curriculum_state.learning_phase = match difficulty_ratio {
421            x if x < scirs2_core::numeric::NumCast::from(0.2).unwrap_or_else(|| T::zero()) => {
422                LearningPhase::Exploration
423            }
424            x if x < scirs2_core::numeric::NumCast::from(0.4).unwrap_or_else(|| T::zero()) => {
425                LearningPhase::SkillBuilding
426            }
427            x if x < scirs2_core::numeric::NumCast::from(0.7).unwrap_or_else(|| T::zero()) => {
428                LearningPhase::Mastery
429            }
430            x if x < scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()) => {
431                LearningPhase::Transfer
432            }
433            _ => LearningPhase::Generalization,
434        };
435    }
436
437    /// Adapt curriculum parameters based on performance
438    fn adapt_curriculum_parameters(&mut self, task_id: &str, performance: T) -> Result<()> {
439        // Adapt patience based on task performance variance
440        let performance_variance = self.calculate_performance_variance(task_id);
441        if performance_variance
442            > scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero())
443        {
444            self.curriculum_params.patience = self.curriculum_params.patience.max(5);
445        } else {
446            self.curriculum_params.patience =
447                (self.curriculum_params.patience.saturating_sub(1)).max(1);
448        }
449
450        // Adapt progression threshold based on recent performance trend
451        let trend = self.calculate_performance_trend();
452        if trend > T::zero() {
453            // Performance is improving, can be more aggressive
454            self.curriculum_params.progression_threshold =
455                (self.curriculum_params.progression_threshold
456                    * scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero()))
457                .max(scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero()));
458        } else {
459            // Performance declining, be more conservative
460            self.curriculum_params.progression_threshold =
461                (self.curriculum_params.progression_threshold
462                    * scirs2_core::numeric::NumCast::from(1.05).unwrap_or_else(|| T::zero()))
463                .min(scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero()));
464        }
465
466        Ok(())
467    }
468
469    /// Calculate performance variance for a task
470    fn calculate_performance_variance(&self, task_id: &str) -> T {
471        let task_performances: Vec<T> = self
472            .performance_history
473            .iter()
474            .filter(|record| record.task_id == task_id)
475            .map(|record| record.performance)
476            .collect();
477
478        if task_performances.len() < 2 {
479            return T::zero();
480        }
481
482        let mean = task_performances
483            .iter()
484            .cloned()
485            .fold(T::zero(), |a, b| a + b)
486            / T::from(task_performances.len() as f64).expect("unwrap failed");
487
488        let variance = task_performances
489            .iter()
490            .map(|&x| (x - mean) * (x - mean))
491            .fold(T::zero(), |a, b| a + b)
492            / T::from((task_performances.len() - 1) as f64).expect("unwrap failed");
493
494        variance
495    }
496
497    /// Calculate recent performance trend
498    fn calculate_performance_trend(&self) -> T {
499        if self.performance_history.len() < 10 {
500            return T::zero();
501        }
502
503        let recent: Vec<T> = self
504            .performance_history
505            .iter()
506            .rev()
507            .take(10)
508            .map(|record| record.performance)
509            .collect();
510
511        let first_half_avg = recent[5..].iter().cloned().fold(T::zero(), |a, b| a + b)
512            / scirs2_core::numeric::NumCast::from(5.0).unwrap_or_else(|| T::zero());
513        let second_half_avg = recent[..5].iter().cloned().fold(T::zero(), |a, b| a + b)
514            / scirs2_core::numeric::NumCast::from(5.0).unwrap_or_else(|| T::zero());
515
516        second_half_avg - first_half_avg
517    }
518
519    /// Get expected performance for current difficulty
520    fn get_expected_performance(&self) -> T {
521        // Simple model: expected performance decreases with difficulty
522        let difficulty_factor =
523            self.curriculum_state.current_difficulty / self.curriculum_params.max_difficulty;
524        T::one()
525            - difficulty_factor
526                * scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero())
527    }
528
529    /// Add task to curriculum
530    pub fn add_task(
531        &mut self,
532        task_id: String,
533        estimated_difficulty: T,
534        required_competency: T,
535    ) -> Result<()> {
536        let scheduled_task = ScheduledTask {
537            task_id: task_id.clone(),
538            priority: T::one() / estimated_difficulty, // Higher priority for easier tasks initially
539            difficulty: estimated_difficulty,
540            required_competency,
541            parameters: HashMap::new(),
542        };
543
544        self.task_scheduler.add_task(scheduled_task);
545
546        // Initialize competency tracking
547        self.progress_tracker
548            .competency_levels
549            .insert(task_id, T::zero());
550
551        Ok(())
552    }
553
554    /// Get curriculum statistics
555    pub fn get_curriculum_statistics(&self) -> HashMap<String, T> {
556        let mut stats = HashMap::new();
557
558        stats.insert(
559            "current_difficulty".to_string(),
560            self.curriculum_state.current_difficulty,
561        );
562        stats.insert(
563            "recent_performance".to_string(),
564            self.curriculum_state.recent_performance,
565        );
566        stats.insert(
567            "epochs_since_increase".to_string(),
568            scirs2_core::numeric::NumCast::from(self.curriculum_state.epochs_since_increase as f64)
569                .unwrap_or_else(|| T::zero()),
570        );
571        stats.insert(
572            "active_tasks_count".to_string(),
573            T::from(self.curriculum_state.active_tasks.len() as f64).expect("unwrap failed"),
574        );
575
576        // Average competency across all tasks
577        if !self.progress_tracker.competency_levels.is_empty() {
578            let avg_competency = self
579                .progress_tracker
580                .competency_levels
581                .values()
582                .cloned()
583                .fold(T::zero(), |a, b| a + b)
584                / T::from(self.progress_tracker.competency_levels.len() as f64)
585                    .expect("unwrap failed");
586            stats.insert("average_competency".to_string(), avg_competency);
587        }
588
589        stats
590    }
591
592    /// Reset curriculum state
593    pub fn reset(&mut self) {
594        self.curriculum_state = CurriculumState::new().expect("unwrap failed");
595        self.progress_tracker.reset();
596        self.performance_history.clear();
597        self.task_scheduler.reset();
598    }
599}
600
601// Supporting type implementations
602impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> TaskDifficultyEstimator<T> {
603    fn new() -> Result<Self> {
604        Ok(Self {
605            difficulty_predictor: DifficultyPredictor::new(10)?, // Default 10-dim input
606            task_features: HashMap::new(),
607            difficulty_history: HashMap::new(),
608            estimation_method: DifficultyEstimationMethod::PerformanceBased,
609        })
610    }
611}
612
613impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> LearningProgressTracker<T> {
614    fn new() -> Self {
615        Self {
616            performance_timeline: VecDeque::new(),
617            learning_rates: VecDeque::new(),
618            competency_levels: HashMap::new(),
619            milestones: Vec::new(),
620        }
621    }
622
623    fn update_performance(&mut self, performance: T) {
624        self.performance_timeline.push_back(performance);
625        if self.performance_timeline.len() > 1000 {
626            self.performance_timeline.pop_front();
627        }
628    }
629
630    fn reset(&mut self) {
631        self.performance_timeline.clear();
632        self.learning_rates.clear();
633        self.competency_levels.clear();
634        self.milestones.clear();
635    }
636}
637
638impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> CurriculumState<T> {
639    fn new() -> Result<Self> {
640        Ok(Self {
641            current_difficulty: scirs2_core::numeric::NumCast::from(0.1)
642                .unwrap_or_else(|| T::zero()),
643            active_tasks: Vec::new(),
644            recent_performance: T::zero(),
645            epochs_since_increase: 0,
646            learning_phase: LearningPhase::Exploration,
647            adaptive_params: HashMap::new(),
648        })
649    }
650}
651
652impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> TaskScheduler<T> {
653    fn new() -> Result<Self> {
654        Ok(Self {
655            task_queue: VecDeque::new(),
656            scheduling_policy: SchedulingPolicy::Priority,
657            task_weights: HashMap::new(),
658            load_balancing: HashMap::new(),
659        })
660    }
661
662    fn add_task(&mut self, task: ScheduledTask<T>) {
663        self.task_queue.push_back(task);
664    }
665
666    fn schedule_next_task(&mut self) -> Option<String> {
667        if let Some(task) = self.task_queue.pop_front() {
668            Some(task.task_id)
669        } else {
670            None
671        }
672    }
673
674    fn reset(&mut self) {
675        self.task_queue.clear();
676        self.task_weights.clear();
677        self.load_balancing.clear();
678    }
679}
680
681impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> DifficultyPredictor<T> {
682    fn new(input_dim: usize) -> Result<Self> {
683        Ok(Self {
684            input_dim,
685            hidden_layers: vec![Array2::eye(input_dim)],
686            output_layer: Array1::ones(input_dim),
687            training_history: Vec::new(),
688        })
689    }
690}
691
692impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> Default for CurriculumParams<T> {
693    fn default() -> Self {
694        Self {
695            initial_difficulty: scirs2_core::numeric::NumCast::from(0.1)
696                .unwrap_or_else(|| T::zero()),
697            max_difficulty: scirs2_core::numeric::NumCast::from(1.0).unwrap_or_else(|| T::zero()),
698            difficulty_increment: scirs2_core::numeric::NumCast::from(0.05)
699                .unwrap_or_else(|| T::zero()),
700            progression_threshold: scirs2_core::numeric::NumCast::from(0.8)
701                .unwrap_or_else(|| T::zero()),
702            patience: 5,
703            self_pacing_factor: scirs2_core::numeric::NumCast::from(0.1)
704                .unwrap_or_else(|| T::zero()),
705            diversity_weight: scirs2_core::numeric::NumCast::from(0.2).unwrap_or_else(|| T::zero()),
706            teacher_confidence: scirs2_core::numeric::NumCast::from(0.9)
707                .unwrap_or_else(|| T::zero()),
708        }
709    }
710}