1#[allow(dead_code)]
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13
14#[derive(Debug, Clone, Copy)]
16pub enum CurriculumStrategy {
17 None,
19 DifficultyProgression,
21 DiversityBased,
23 SelfPaced,
25 TeacherStudent,
27 Adversarial,
29 MultiTask,
31 Adaptive,
33}
34
35#[derive(Debug, Clone)]
37pub struct CurriculumLearner<T: Float + Debug + Send + Sync + 'static> {
38 strategy: CurriculumStrategy,
40
41 curriculum_params: CurriculumParams<T>,
43
44 difficulty_estimator: TaskDifficultyEstimator<T>,
46
47 progress_tracker: LearningProgressTracker<T>,
49
50 curriculum_state: CurriculumState<T>,
52
53 task_scheduler: TaskScheduler<T>,
55
56 performance_history: VecDeque<PerformanceRecord<T>>,
58}
59
60#[derive(Debug, Clone)]
62pub struct CurriculumParams<T: Float + Debug + Send + Sync + 'static> {
63 initial_difficulty: T,
65
66 max_difficulty: T,
68
69 difficulty_increment: T,
71
72 progression_threshold: T,
74
75 patience: usize,
77
78 self_pacing_factor: T,
80
81 diversity_weight: T,
83
84 teacher_confidence: T,
86}
87
88#[derive(Debug, Clone)]
90pub struct TaskDifficultyEstimator<T: Float + Debug + Send + Sync + 'static> {
91 difficulty_predictor: DifficultyPredictor<T>,
93
94 task_features: HashMap<String, Array1<T>>,
96
97 difficulty_history: HashMap<String, Vec<T>>,
99
100 estimation_method: DifficultyEstimationMethod,
102}
103
104#[derive(Debug, Clone)]
106pub struct LearningProgressTracker<T: Float + Debug + Send + Sync + 'static> {
107 performance_timeline: VecDeque<T>,
109
110 learning_rates: VecDeque<T>,
112
113 competency_levels: HashMap<String, T>,
115
116 milestones: Vec<ProgressMilestone<T>>,
118}
119
120#[derive(Debug, Clone)]
122pub struct CurriculumState<T: Float + Debug + Send + Sync + 'static> {
123 current_difficulty: T,
125
126 active_tasks: Vec<String>,
128
129 recent_performance: T,
131
132 epochs_since_increase: usize,
134
135 learning_phase: LearningPhase,
137
138 adaptive_params: HashMap<String, T>,
140}
141
142#[derive(Debug, Clone)]
144pub struct TaskScheduler<T: Float + Debug + Send + Sync + 'static> {
145 task_queue: VecDeque<ScheduledTask<T>>,
147
148 scheduling_policy: SchedulingPolicy,
150
151 task_weights: HashMap<String, T>,
153
154 load_balancing: HashMap<String, T>,
156}
157
158#[derive(Debug, Clone)]
160pub struct PerformanceRecord<T: Float + Debug + Send + Sync + 'static> {
161 task_id: String,
163
164 performance: T,
166
167 difficulty_level: T,
169
170 training_steps: usize,
172
173 timestamp: usize,
175
176 metrics: HashMap<String, T>,
178}
179
180#[derive(Debug, Clone)]
182pub struct DifficultyPredictor<T: Float + Debug + Send + Sync + 'static> {
183 input_dim: usize,
185
186 hidden_layers: Vec<Array2<T>>,
188
189 output_layer: Array1<T>,
191
192 training_history: Vec<(Array1<T>, T)>,
194}
195
196#[derive(Debug, Clone)]
198pub struct ProgressMilestone<T: Float + Debug + Send + Sync + 'static> {
199 name: String,
201
202 threshold: T,
204
205 achieved: bool,
207
208 achieved_at: Option<usize>,
210}
211
212#[derive(Debug, Clone)]
214pub struct ScheduledTask<T: Float + Debug + Send + Sync + 'static> {
215 task_id: String,
217
218 priority: T,
220
221 difficulty: T,
223
224 required_competency: T,
226
227 parameters: HashMap<String, T>,
229}
230
231#[derive(Debug, Clone, Copy)]
233pub enum LearningPhase {
234 Exploration,
236 SkillBuilding,
238 Mastery,
240 Transfer,
242 Generalization,
244}
245
246#[derive(Debug, Clone, Copy)]
248pub enum DifficultyEstimationMethod {
249 PerformanceBased,
251 FeatureBased,
253 GradientBased,
255 UncertaintyBased,
257 MultiModal,
259}
260
261#[derive(Debug, Clone, Copy)]
263pub enum SchedulingPolicy {
264 FIFO,
266 Priority,
268 WeightedRandom,
270 Balanced,
272 Adaptive,
274}
275
276impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> CurriculumLearner<T> {
277 pub fn new(strategy: CurriculumStrategy) -> Result<Self> {
279 Ok(Self {
280 strategy,
281 curriculum_params: CurriculumParams::default(),
282 difficulty_estimator: TaskDifficultyEstimator::new()?,
283 progress_tracker: LearningProgressTracker::new(),
284 curriculum_state: CurriculumState::new()?,
285 task_scheduler: TaskScheduler::new()?,
286 performance_history: VecDeque::new(),
287 })
288 }
289
290 pub fn update_curriculum(
292 &mut self,
293 task_id: &str,
294 performance: T,
295 training_steps: usize,
296 ) -> Result<()> {
297 let record = PerformanceRecord {
299 task_id: task_id.to_string(),
300 performance,
301 difficulty_level: self.curriculum_state.current_difficulty,
302 training_steps,
303 timestamp: self.performance_history.len(),
304 metrics: HashMap::new(),
305 };
306
307 self.performance_history.push_back(record);
308 if self.performance_history.len() > 1000 {
309 self.performance_history.pop_front();
310 }
311
312 self.progress_tracker.update_performance(performance);
314
315 match self.strategy {
317 CurriculumStrategy::None => Ok(()),
318 CurriculumStrategy::DifficultyProgression => {
319 self.update_difficulty_progression(performance)
320 }
321 CurriculumStrategy::SelfPaced => self.update_self_paced_curriculum(performance),
322 CurriculumStrategy::Adaptive => self.update_adaptive_curriculum(task_id, performance),
323 _ => self.update_generic_curriculum(performance),
324 }
325 }
326
327 pub fn get_next_task(&mut self) -> Result<Option<String>> {
329 match self.strategy {
330 CurriculumStrategy::None => Ok(None),
331 _ => Ok(self.task_scheduler.schedule_next_task()),
332 }
333 }
334
335 fn update_difficulty_progression(&mut self, performance: T) -> Result<()> {
337 self.curriculum_state.recent_performance = performance;
338
339 if performance > self.curriculum_params.progression_threshold {
340 self.curriculum_state.epochs_since_increase += 1;
341
342 if self.curriculum_state.epochs_since_increase >= self.curriculum_params.patience {
343 let new_difficulty = (self.curriculum_state.current_difficulty
345 + self.curriculum_params.difficulty_increment)
346 .min(self.curriculum_params.max_difficulty);
347
348 self.curriculum_state.current_difficulty = new_difficulty;
349 self.curriculum_state.epochs_since_increase = 0;
350
351 self.update_learning_phase();
353 }
354 } else {
355 self.curriculum_state.epochs_since_increase = 0;
356 }
357
358 Ok(())
359 }
360
361 fn update_self_paced_curriculum(&mut self, performance: T) -> Result<()> {
363 let pacing_factor = self.curriculum_params.self_pacing_factor;
364
365 let performance_ratio = performance / self.get_expected_performance();
367 let difficulty_adjustment = (performance_ratio - T::one()) * pacing_factor;
368
369 let new_difficulty = (self.curriculum_state.current_difficulty + difficulty_adjustment)
370 .max(self.curriculum_params.initial_difficulty)
371 .min(self.curriculum_params.max_difficulty);
372
373 self.curriculum_state.current_difficulty = new_difficulty;
374
375 Ok(())
376 }
377
378 fn update_adaptive_curriculum(&mut self, task_id: &str, performance: T) -> Result<()> {
380 let competency = self
382 .progress_tracker
383 .competency_levels
384 .get(task_id)
385 .copied()
386 .unwrap_or(T::zero());
387
388 let alpha = scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
389 let new_competency = competency * (T::one() - alpha) + performance * alpha;
390
391 self.progress_tracker
392 .competency_levels
393 .insert(task_id.to_string(), new_competency);
394
395 self.adapt_curriculum_parameters(task_id, performance)?;
397
398 Ok(())
399 }
400
401 fn update_generic_curriculum(&mut self, performance: T) -> Result<()> {
403 if performance > scirs2_core::numeric::NumCast::from(0.8).unwrap_or_else(|| T::zero()) {
405 let increment = self.curriculum_params.difficulty_increment
406 * scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero());
407 self.curriculum_state.current_difficulty = (self.curriculum_state.current_difficulty
408 + increment)
409 .min(self.curriculum_params.max_difficulty);
410 }
411
412 Ok(())
413 }
414
415 fn update_learning_phase(&mut self) {
417 let difficulty_ratio =
418 self.curriculum_state.current_difficulty / self.curriculum_params.max_difficulty;
419
420 self.curriculum_state.learning_phase = match difficulty_ratio {
421 x if x < scirs2_core::numeric::NumCast::from(0.2).unwrap_or_else(|| T::zero()) => {
422 LearningPhase::Exploration
423 }
424 x if x < scirs2_core::numeric::NumCast::from(0.4).unwrap_or_else(|| T::zero()) => {
425 LearningPhase::SkillBuilding
426 }
427 x if x < scirs2_core::numeric::NumCast::from(0.7).unwrap_or_else(|| T::zero()) => {
428 LearningPhase::Mastery
429 }
430 x if x < scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()) => {
431 LearningPhase::Transfer
432 }
433 _ => LearningPhase::Generalization,
434 };
435 }
436
437 fn adapt_curriculum_parameters(&mut self, task_id: &str, performance: T) -> Result<()> {
439 let performance_variance = self.calculate_performance_variance(task_id);
441 if performance_variance
442 > scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero())
443 {
444 self.curriculum_params.patience = self.curriculum_params.patience.max(5);
445 } else {
446 self.curriculum_params.patience =
447 (self.curriculum_params.patience.saturating_sub(1)).max(1);
448 }
449
450 let trend = self.calculate_performance_trend();
452 if trend > T::zero() {
453 self.curriculum_params.progression_threshold =
455 (self.curriculum_params.progression_threshold
456 * scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero()))
457 .max(scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero()));
458 } else {
459 self.curriculum_params.progression_threshold =
461 (self.curriculum_params.progression_threshold
462 * scirs2_core::numeric::NumCast::from(1.05).unwrap_or_else(|| T::zero()))
463 .min(scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero()));
464 }
465
466 Ok(())
467 }
468
469 fn calculate_performance_variance(&self, task_id: &str) -> T {
471 let task_performances: Vec<T> = self
472 .performance_history
473 .iter()
474 .filter(|record| record.task_id == task_id)
475 .map(|record| record.performance)
476 .collect();
477
478 if task_performances.len() < 2 {
479 return T::zero();
480 }
481
482 let mean = task_performances
483 .iter()
484 .cloned()
485 .fold(T::zero(), |a, b| a + b)
486 / T::from(task_performances.len() as f64).expect("unwrap failed");
487
488 let variance = task_performances
489 .iter()
490 .map(|&x| (x - mean) * (x - mean))
491 .fold(T::zero(), |a, b| a + b)
492 / T::from((task_performances.len() - 1) as f64).expect("unwrap failed");
493
494 variance
495 }
496
497 fn calculate_performance_trend(&self) -> T {
499 if self.performance_history.len() < 10 {
500 return T::zero();
501 }
502
503 let recent: Vec<T> = self
504 .performance_history
505 .iter()
506 .rev()
507 .take(10)
508 .map(|record| record.performance)
509 .collect();
510
511 let first_half_avg = recent[5..].iter().cloned().fold(T::zero(), |a, b| a + b)
512 / scirs2_core::numeric::NumCast::from(5.0).unwrap_or_else(|| T::zero());
513 let second_half_avg = recent[..5].iter().cloned().fold(T::zero(), |a, b| a + b)
514 / scirs2_core::numeric::NumCast::from(5.0).unwrap_or_else(|| T::zero());
515
516 second_half_avg - first_half_avg
517 }
518
519 fn get_expected_performance(&self) -> T {
521 let difficulty_factor =
523 self.curriculum_state.current_difficulty / self.curriculum_params.max_difficulty;
524 T::one()
525 - difficulty_factor
526 * scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero())
527 }
528
529 pub fn add_task(
531 &mut self,
532 task_id: String,
533 estimated_difficulty: T,
534 required_competency: T,
535 ) -> Result<()> {
536 let scheduled_task = ScheduledTask {
537 task_id: task_id.clone(),
538 priority: T::one() / estimated_difficulty, difficulty: estimated_difficulty,
540 required_competency,
541 parameters: HashMap::new(),
542 };
543
544 self.task_scheduler.add_task(scheduled_task);
545
546 self.progress_tracker
548 .competency_levels
549 .insert(task_id, T::zero());
550
551 Ok(())
552 }
553
554 pub fn get_curriculum_statistics(&self) -> HashMap<String, T> {
556 let mut stats = HashMap::new();
557
558 stats.insert(
559 "current_difficulty".to_string(),
560 self.curriculum_state.current_difficulty,
561 );
562 stats.insert(
563 "recent_performance".to_string(),
564 self.curriculum_state.recent_performance,
565 );
566 stats.insert(
567 "epochs_since_increase".to_string(),
568 scirs2_core::numeric::NumCast::from(self.curriculum_state.epochs_since_increase as f64)
569 .unwrap_or_else(|| T::zero()),
570 );
571 stats.insert(
572 "active_tasks_count".to_string(),
573 T::from(self.curriculum_state.active_tasks.len() as f64).expect("unwrap failed"),
574 );
575
576 if !self.progress_tracker.competency_levels.is_empty() {
578 let avg_competency = self
579 .progress_tracker
580 .competency_levels
581 .values()
582 .cloned()
583 .fold(T::zero(), |a, b| a + b)
584 / T::from(self.progress_tracker.competency_levels.len() as f64)
585 .expect("unwrap failed");
586 stats.insert("average_competency".to_string(), avg_competency);
587 }
588
589 stats
590 }
591
592 pub fn reset(&mut self) {
594 self.curriculum_state = CurriculumState::new().expect("unwrap failed");
595 self.progress_tracker.reset();
596 self.performance_history.clear();
597 self.task_scheduler.reset();
598 }
599}
600
601impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> TaskDifficultyEstimator<T> {
603 fn new() -> Result<Self> {
604 Ok(Self {
605 difficulty_predictor: DifficultyPredictor::new(10)?, task_features: HashMap::new(),
607 difficulty_history: HashMap::new(),
608 estimation_method: DifficultyEstimationMethod::PerformanceBased,
609 })
610 }
611}
612
613impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> LearningProgressTracker<T> {
614 fn new() -> Self {
615 Self {
616 performance_timeline: VecDeque::new(),
617 learning_rates: VecDeque::new(),
618 competency_levels: HashMap::new(),
619 milestones: Vec::new(),
620 }
621 }
622
623 fn update_performance(&mut self, performance: T) {
624 self.performance_timeline.push_back(performance);
625 if self.performance_timeline.len() > 1000 {
626 self.performance_timeline.pop_front();
627 }
628 }
629
630 fn reset(&mut self) {
631 self.performance_timeline.clear();
632 self.learning_rates.clear();
633 self.competency_levels.clear();
634 self.milestones.clear();
635 }
636}
637
638impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> CurriculumState<T> {
639 fn new() -> Result<Self> {
640 Ok(Self {
641 current_difficulty: scirs2_core::numeric::NumCast::from(0.1)
642 .unwrap_or_else(|| T::zero()),
643 active_tasks: Vec::new(),
644 recent_performance: T::zero(),
645 epochs_since_increase: 0,
646 learning_phase: LearningPhase::Exploration,
647 adaptive_params: HashMap::new(),
648 })
649 }
650}
651
652impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> TaskScheduler<T> {
653 fn new() -> Result<Self> {
654 Ok(Self {
655 task_queue: VecDeque::new(),
656 scheduling_policy: SchedulingPolicy::Priority,
657 task_weights: HashMap::new(),
658 load_balancing: HashMap::new(),
659 })
660 }
661
662 fn add_task(&mut self, task: ScheduledTask<T>) {
663 self.task_queue.push_back(task);
664 }
665
666 fn schedule_next_task(&mut self) -> Option<String> {
667 if let Some(task) = self.task_queue.pop_front() {
668 Some(task.task_id)
669 } else {
670 None
671 }
672 }
673
674 fn reset(&mut self) {
675 self.task_queue.clear();
676 self.task_weights.clear();
677 self.load_balancing.clear();
678 }
679}
680
681impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> DifficultyPredictor<T> {
682 fn new(input_dim: usize) -> Result<Self> {
683 Ok(Self {
684 input_dim,
685 hidden_layers: vec![Array2::eye(input_dim)],
686 output_layer: Array1::ones(input_dim),
687 training_history: Vec::new(),
688 })
689 }
690}
691
692impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> Default for CurriculumParams<T> {
693 fn default() -> Self {
694 Self {
695 initial_difficulty: scirs2_core::numeric::NumCast::from(0.1)
696 .unwrap_or_else(|| T::zero()),
697 max_difficulty: scirs2_core::numeric::NumCast::from(1.0).unwrap_or_else(|| T::zero()),
698 difficulty_increment: scirs2_core::numeric::NumCast::from(0.05)
699 .unwrap_or_else(|| T::zero()),
700 progression_threshold: scirs2_core::numeric::NumCast::from(0.8)
701 .unwrap_or_else(|| T::zero()),
702 patience: 5,
703 self_pacing_factor: scirs2_core::numeric::NumCast::from(0.1)
704 .unwrap_or_else(|| T::zero()),
705 diversity_weight: scirs2_core::numeric::NumCast::from(0.2).unwrap_or_else(|| T::zero()),
706 teacher_confidence: scirs2_core::numeric::NumCast::from(0.9)
707 .unwrap_or_else(|| T::zero()),
708 }
709 }
710}