Skip to main content

optirs_core/curriculum_optimization/
mod.rs

1// Curriculum optimization for adaptive training
2//
3// This module provides curriculum learning capabilities including task difficulty progression,
4// sample importance weighting, and adversarial training support.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13/// Curriculum learning strategy
14#[derive(Debug, Clone)]
15pub enum CurriculumStrategy {
16    /// Linear difficulty progression
17    Linear {
18        /// Starting difficulty (0.0 to 1.0)
19        start_difficulty: f64,
20        /// Ending difficulty (0.0 to 1.0)
21        end_difficulty: f64,
22        /// Number of steps to reach end difficulty
23        num_steps: usize,
24    },
25    /// Exponential difficulty progression
26    Exponential {
27        /// Starting difficulty (0.0 to 1.0)
28        start_difficulty: f64,
29        /// Ending difficulty (0.0 to 1.0)
30        end_difficulty: f64,
31        /// Growth rate
32        growth_rate: f64,
33    },
34    /// Performance-based curriculum
35    PerformanceBased {
36        /// Threshold for advancing difficulty
37        advance_threshold: f64,
38        /// Threshold for reducing difficulty
39        reduce_threshold: f64,
40        /// Difficulty adjustment step size
41        adjustment_step: f64,
42        /// Window size for performance averaging
43        window_size: usize,
44    },
45    /// Custom curriculum with predefined schedule
46    Custom {
47        /// Difficulty schedule (step -> difficulty)
48        schedule: HashMap<usize, f64>,
49        /// Default difficulty for unspecified steps
50        default_difficulty: f64,
51    },
52}
53
54/// Sample importance weighting strategy
55#[derive(Debug, Clone)]
56pub enum ImportanceWeightingStrategy {
57    /// Uniform weighting (all samples equal)
58    Uniform,
59    /// Loss-based weighting (higher loss = higher weight)
60    LossBased {
61        /// Temperature parameter for softmax weighting
62        temperature: f64,
63        /// Minimum weight to avoid zero weights
64        min_weight: f64,
65    },
66    /// Gradient norm based weighting
67    GradientNormBased {
68        /// Temperature parameter
69        temperature: f64,
70        /// Minimum weight
71        min_weight: f64,
72    },
73    /// Uncertainty-based weighting
74    UncertaintyBased {
75        /// Temperature parameter
76        temperature: f64,
77        /// Minimum weight
78        min_weight: f64,
79    },
80    /// Age-based weighting (older samples get higher weight)
81    AgeBased {
82        /// Decay factor for age
83        decayfactor: f64,
84    },
85}
86
87/// Adversarial training configuration
88#[derive(Debug, Clone)]
89pub struct AdversarialConfig<A: Float> {
90    /// Adversarial perturbation magnitude
91    pub epsilon: A,
92    /// Number of adversarial steps
93    pub num_steps: usize,
94    /// Step size for adversarial perturbation
95    pub step_size: A,
96    /// Type of adversarial attack
97    pub attack_type: AdversarialAttack,
98    /// Regularization weight for adversarial loss
99    pub adversarial_weight: A,
100}
101
102/// Types of adversarial attacks
103#[derive(Debug, Clone, Copy)]
104pub enum AdversarialAttack {
105    /// Fast Gradient Sign Method (FGSM)
106    FGSM,
107    /// Projected Gradient Descent (PGD)
108    PGD,
109    /// Basic Iterative Method (BIM)
110    BIM,
111    /// Momentum Iterative Method (MIM)
112    MIM,
113}
114
115/// Curriculum learning manager
116#[derive(Debug)]
117pub struct CurriculumManager<A: Float, D: Dimension> {
118    /// Curriculum strategy
119    strategy: CurriculumStrategy,
120    /// Current difficulty level
121    current_difficulty: f64,
122    /// Current step count
123    step_count: usize,
124    /// Performance history
125    performance_history: VecDeque<A>,
126    /// Sample difficulty scores
127    sample_difficulties: HashMap<usize, f64>,
128    /// Importance weighting strategy
129    importance_strategy: ImportanceWeightingStrategy,
130    /// Sample weights
131    sample_weights: HashMap<usize, A>,
132    /// Adversarial training configuration
133    adversarial_config: Option<AdversarialConfig<A>>,
134    /// Phantom data for dimension
135    _phantom: PhantomData<D>,
136}
137
138impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> CurriculumManager<A, D> {
139    /// Create a new curriculum manager
140    pub fn new(
141        strategy: CurriculumStrategy,
142        importance_strategy: ImportanceWeightingStrategy,
143    ) -> Self {
144        let initial_difficulty = match &strategy {
145            CurriculumStrategy::Linear {
146                start_difficulty, ..
147            } => *start_difficulty,
148            CurriculumStrategy::Exponential {
149                start_difficulty, ..
150            } => *start_difficulty,
151            CurriculumStrategy::PerformanceBased { .. } => 0.1, // Start easy
152            CurriculumStrategy::Custom {
153                default_difficulty, ..
154            } => *default_difficulty,
155        };
156
157        Self {
158            strategy,
159            current_difficulty: initial_difficulty,
160            step_count: 0,
161            performance_history: VecDeque::new(),
162            sample_difficulties: HashMap::new(),
163            importance_strategy,
164            sample_weights: HashMap::new(),
165            adversarial_config: None,
166            _phantom: PhantomData,
167        }
168    }
169
170    /// Enable adversarial training
171    pub fn enable_adversarial_training(&mut self, config: AdversarialConfig<A>) {
172        self.adversarial_config = Some(config);
173    }
174
175    /// Disable adversarial training
176    pub fn disable_adversarial_training(&mut self) {
177        self.adversarial_config = None;
178    }
179
180    /// Update curriculum based on performance
181    pub fn update_curriculum(&mut self, performance: A) -> Result<()> {
182        self.performance_history.push_back(performance);
183        self.step_count += 1;
184
185        // Update difficulty based on strategy
186        match &self.strategy {
187            CurriculumStrategy::Linear {
188                start_difficulty,
189                end_difficulty,
190                num_steps,
191            } => {
192                let progress = (self.step_count as f64) / (*num_steps as f64);
193                let progress = progress.min(1.0);
194                self.current_difficulty =
195                    start_difficulty + progress * (end_difficulty - start_difficulty);
196            }
197            CurriculumStrategy::Exponential {
198                start_difficulty,
199                end_difficulty,
200                growth_rate,
201            } => {
202                let progress = 1.0 - (-growth_rate * self.step_count as f64).exp();
203                self.current_difficulty =
204                    start_difficulty + progress * (end_difficulty - start_difficulty);
205            }
206            CurriculumStrategy::PerformanceBased {
207                advance_threshold,
208                reduce_threshold,
209                adjustment_step,
210                window_size,
211            } => {
212                if self.performance_history.len() >= *window_size {
213                    // Keep only recent performance
214                    while self.performance_history.len() > *window_size {
215                        self.performance_history.pop_front();
216                    }
217
218                    // Calculate average performance
219                    let avg_performance = self
220                        .performance_history
221                        .iter()
222                        .fold(A::zero(), |acc, &perf| acc + perf)
223                        / A::from(self.performance_history.len()).expect("unwrap failed");
224
225                    let avg_perf_f64 = avg_performance.to_f64().unwrap_or(0.0);
226
227                    // Adjust difficulty based on performance
228                    if avg_perf_f64 > *advance_threshold {
229                        self.current_difficulty =
230                            (self.current_difficulty + adjustment_step).min(1.0);
231                    } else if avg_perf_f64 < *reduce_threshold {
232                        self.current_difficulty =
233                            (self.current_difficulty - adjustment_step).max(0.0);
234                    }
235                }
236            }
237            CurriculumStrategy::Custom {
238                schedule,
239                default_difficulty,
240            } => {
241                self.current_difficulty = schedule
242                    .get(&self.step_count)
243                    .copied()
244                    .unwrap_or(*default_difficulty);
245            }
246        }
247
248        Ok(())
249    }
250
251    /// Set difficulty score for a sample
252    pub fn set_sample_difficulty(&mut self, sampleid: usize, difficulty: f64) {
253        self.sample_difficulties.insert(sampleid, difficulty);
254    }
255
256    /// Check if sample should be included based on current difficulty
257    pub fn should_include_sample(&self, sampleid: usize) -> bool {
258        if let Some(&sample_difficulty) = self.sample_difficulties.get(&sampleid) {
259            sample_difficulty <= self.current_difficulty
260        } else {
261            true // Include unknown samples
262        }
263    }
264
265    /// Get current difficulty level
266    pub fn get_current_difficulty(&self) -> f64 {
267        self.current_difficulty
268    }
269
270    /// Compute importance weights for samples
271    pub fn compute_sample_weights(
272        &mut self,
273        sampleids: &[usize],
274        losses: &[A],
275        gradient_norms: Option<&[A]>,
276        uncertainties: Option<&[A]>,
277    ) -> Result<()> {
278        if sampleids.len() != losses.len() {
279            return Err(OptimError::DimensionMismatch(
280                "Sample IDs and losses must have same length".to_string(),
281            ));
282        }
283
284        match &self.importance_strategy {
285            ImportanceWeightingStrategy::Uniform => {
286                let uniform_weight = A::one();
287                for &sampleid in sampleids {
288                    self.sample_weights.insert(sampleid, uniform_weight);
289                }
290            }
291            ImportanceWeightingStrategy::LossBased {
292                temperature,
293                min_weight,
294            } => {
295                self.compute_loss_based_weights(sampleids, losses, *temperature, *min_weight)?;
296            }
297            ImportanceWeightingStrategy::GradientNormBased {
298                temperature,
299                min_weight,
300            } => {
301                if let Some(grad_norms) = gradient_norms {
302                    self.compute_gradient_norm_weights(
303                        sampleids,
304                        grad_norms,
305                        *temperature,
306                        *min_weight,
307                    )?;
308                } else {
309                    // Fall back to uniform weights
310                    for &sampleid in sampleids {
311                        self.sample_weights.insert(sampleid, A::one());
312                    }
313                }
314            }
315            ImportanceWeightingStrategy::UncertaintyBased {
316                temperature,
317                min_weight,
318            } => {
319                if let Some(uncertainties_array) = uncertainties {
320                    self.compute_uncertainty_weights(
321                        sampleids,
322                        uncertainties_array,
323                        *temperature,
324                        *min_weight,
325                    )?;
326                } else {
327                    // Fall back to uniform weights
328                    for &sampleid in sampleids {
329                        self.sample_weights.insert(sampleid, A::one());
330                    }
331                }
332            }
333            ImportanceWeightingStrategy::AgeBased { decayfactor } => {
334                self.compute_age_based_weights(sampleids, *decayfactor)?;
335            }
336        }
337
338        Ok(())
339    }
340
341    /// Compute loss-based importance weights
342    fn compute_loss_based_weights(
343        &mut self,
344        sampleids: &[usize],
345        losses: &[A],
346        temperature: f64,
347        min_weight: f64,
348    ) -> Result<()> {
349        // Compute softmax weights based on losses
350        let temp = A::from(temperature).expect("unwrap failed");
351        let min_w = A::from(min_weight).expect("unwrap failed");
352
353        // Find max loss for numerical stability
354        let max_loss = losses.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
355
356        // Compute unnormalized weights
357        let mut unnormalized_weights = Vec::new();
358        for &loss in losses {
359            let normalized_loss = (loss - max_loss) / temp;
360            unnormalized_weights.push(A::exp(normalized_loss));
361        }
362
363        // Normalize weights
364        let sum_weights: A = unnormalized_weights
365            .iter()
366            .fold(A::zero(), |acc, &w| acc + w);
367
368        for (i, &sampleid) in sampleids.iter().enumerate() {
369            let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
370            self.sample_weights.insert(sampleid, weight);
371        }
372
373        Ok(())
374    }
375
376    /// Compute gradient norm based weights
377    fn compute_gradient_norm_weights(
378        &mut self,
379        sampleids: &[usize],
380        gradient_norms: &[A],
381        temperature: f64,
382        min_weight: f64,
383    ) -> Result<()> {
384        let temp = A::from(temperature).expect("unwrap failed");
385        let min_w = A::from(min_weight).expect("unwrap failed");
386
387        // Find max gradient norm for numerical stability
388        let max_norm = gradient_norms
389            .iter()
390            .fold(A::neg_infinity(), |a, &b| A::max(a, b));
391
392        // Compute softmax weights
393        let mut unnormalized_weights = Vec::new();
394        for &norm in gradient_norms {
395            let normalized_norm = (norm - max_norm) / temp;
396            unnormalized_weights.push(A::exp(normalized_norm));
397        }
398
399        let sum_weights: A = unnormalized_weights
400            .iter()
401            .fold(A::zero(), |acc, &w| acc + w);
402
403        for (i, &sampleid) in sampleids.iter().enumerate() {
404            let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
405            self.sample_weights.insert(sampleid, weight);
406        }
407
408        Ok(())
409    }
410
411    /// Compute uncertainty-based weights
412    fn compute_uncertainty_weights(
413        &mut self,
414        sampleids: &[usize],
415        uncertainties: &[A],
416        temperature: f64,
417        min_weight: f64,
418    ) -> Result<()> {
419        let temp = A::from(temperature).expect("unwrap failed");
420        let min_w = A::from(min_weight).expect("unwrap failed");
421
422        // Find max uncertainty for numerical stability
423        let max_uncertainty = uncertainties
424            .iter()
425            .fold(A::neg_infinity(), |a, &b| A::max(a, b));
426
427        // Compute softmax weights (higher uncertainty = higher weight)
428        let mut unnormalized_weights = Vec::new();
429        for &uncertainty in uncertainties {
430            let normalized_uncertainty = (uncertainty - max_uncertainty) / temp;
431            unnormalized_weights.push(A::exp(normalized_uncertainty));
432        }
433
434        let sum_weights: A = unnormalized_weights
435            .iter()
436            .fold(A::zero(), |acc, &w| acc + w);
437
438        for (i, &sampleid) in sampleids.iter().enumerate() {
439            let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
440            self.sample_weights.insert(sampleid, weight);
441        }
442
443        Ok(())
444    }
445
446    /// Compute age-based weights
447    fn compute_age_based_weights(&mut self, sampleids: &[usize], decayfactor: f64) -> Result<()> {
448        let decay = A::from(decayfactor).expect("unwrap failed");
449
450        for &sampleid in sampleids {
451            // Simple age-based weighting (older samples get exponentially higher weight)
452            let age = A::from(self.step_count.saturating_sub(sampleid)).expect("unwrap failed");
453            let weight = A::exp(decay * age);
454            self.sample_weights.insert(sampleid, weight);
455        }
456
457        Ok(())
458    }
459
460    /// Get importance weight for a sample
461    pub fn get_sample_weight(&self, sampleid: usize) -> A {
462        self.sample_weights
463            .get(&sampleid)
464            .copied()
465            .unwrap_or_else(|| A::one())
466    }
467
468    /// Generate adversarial examples
469    pub fn generate_adversarial_examples(
470        &self,
471        inputs: &Array<A, D>,
472        gradients: &Array<A, D>,
473    ) -> Result<Array<A, D>> {
474        if let Some(config) = &self.adversarial_config {
475            match config.attack_type {
476                AdversarialAttack::FGSM => self.fgsm_attack(inputs, gradients, config),
477                AdversarialAttack::PGD => self.pgd_attack(inputs, gradients, config),
478                AdversarialAttack::BIM => self.bim_attack(inputs, gradients, config),
479                AdversarialAttack::MIM => self.mim_attack(inputs, gradients, config),
480            }
481        } else {
482            Ok(inputs.clone()) // No adversarial training
483        }
484    }
485
486    /// Fast Gradient Sign Method (FGSM)
487    fn fgsm_attack(
488        &self,
489        inputs: &Array<A, D>,
490        gradients: &Array<A, D>,
491        config: &AdversarialConfig<A>,
492    ) -> Result<Array<A, D>> {
493        let mut adversarial = inputs.clone();
494
495        // Sign of gradients
496        let sign_gradients = gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
497
498        // Add perturbation
499        Zip::from(&mut adversarial)
500            .and(&sign_gradients)
501            .for_each(|x, &sign| {
502                *x = *x + config.epsilon * sign;
503            });
504
505        Ok(adversarial)
506    }
507
508    /// Projected Gradient Descent (PGD)
509    fn pgd_attack(
510        &self,
511        inputs: &Array<A, D>,
512        gradients: &Array<A, D>,
513        config: &AdversarialConfig<A>,
514    ) -> Result<Array<A, D>> {
515        let mut adversarial = inputs.clone();
516
517        // Multiple PGD steps
518        for _ in 0..config.num_steps {
519            // Gradient step
520            let sign_gradients =
521                gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
522
523            Zip::from(&mut adversarial)
524                .and(&sign_gradients)
525                .for_each(|x, &sign| {
526                    *x = *x + config.step_size * sign;
527                });
528
529            // Project back to epsilon ball
530            Zip::from(&mut adversarial)
531                .and(inputs)
532                .for_each(|adv, &orig| {
533                    let diff = *adv - orig;
534                    let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
535                    *adv = orig + clamped_diff;
536                });
537        }
538
539        Ok(adversarial)
540    }
541
542    /// Basic Iterative Method (BIM)
543    fn bim_attack(
544        &self,
545        inputs: &Array<A, D>,
546        gradients: &Array<A, D>,
547        config: &AdversarialConfig<A>,
548    ) -> Result<Array<A, D>> {
549        // BIM is similar to PGD but with smaller steps
550        let mut modified_config = config.clone();
551        modified_config.step_size =
552            config.epsilon / A::from(config.num_steps).expect("unwrap failed");
553
554        self.pgd_attack(inputs, gradients, &modified_config)
555    }
556
557    /// Momentum Iterative Method (MIM)
558    fn mim_attack(
559        &self,
560        inputs: &Array<A, D>,
561        gradients: &Array<A, D>,
562        config: &AdversarialConfig<A>,
563    ) -> Result<Array<A, D>> {
564        let mut adversarial = inputs.clone();
565        let mut momentum = Array::zeros(inputs.raw_dim());
566        let decayfactor = A::from(1.0).expect("unwrap failed"); // Momentum decay factor
567
568        for _ in 0..config.num_steps {
569            // Update momentum
570            let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
571            let normalized_gradients = if grad_norm > A::zero() {
572                gradients.mapv(|x| x / grad_norm)
573            } else {
574                gradients.clone()
575            };
576
577            Zip::from(&mut momentum)
578                .and(&normalized_gradients)
579                .for_each(|m, &g| {
580                    *m = decayfactor * *m + g;
581                });
582
583            // Apply momentum-based update
584            let momentum_signs =
585                momentum.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
586
587            Zip::from(&mut adversarial)
588                .and(&momentum_signs)
589                .for_each(|x, &sign| {
590                    *x = *x + config.step_size * sign;
591                });
592
593            // Project back to epsilon ball
594            Zip::from(&mut adversarial)
595                .and(inputs)
596                .for_each(|adv, &orig| {
597                    let diff = *adv - orig;
598                    let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
599                    *adv = orig + clamped_diff;
600                });
601        }
602
603        Ok(adversarial)
604    }
605
606    /// Get filtered samples based on current curriculum
607    pub fn filter_samples(&self, sampleids: &[usize]) -> Vec<usize> {
608        sampleids
609            .iter()
610            .copied()
611            .filter(|&id| self.should_include_sample(id))
612            .collect()
613    }
614
615    /// Get performance history
616    pub fn get_performance_history(&self) -> &VecDeque<A> {
617        &self.performance_history
618    }
619
620    /// Get step count
621    pub fn step_count(&self) -> usize {
622        self.step_count
623    }
624
625    /// Reset curriculum state
626    pub fn reset(&mut self) {
627        self.step_count = 0;
628        self.performance_history.clear();
629        self.sample_weights.clear();
630        self.current_difficulty = match &self.strategy {
631            CurriculumStrategy::Linear {
632                start_difficulty, ..
633            } => *start_difficulty,
634            CurriculumStrategy::Exponential {
635                start_difficulty, ..
636            } => *start_difficulty,
637            CurriculumStrategy::PerformanceBased { .. } => 0.1,
638            CurriculumStrategy::Custom {
639                default_difficulty, ..
640            } => *default_difficulty,
641        };
642    }
643
644    /// Export curriculum state for analysis
645    pub fn export_state(&self) -> CurriculumState<A> {
646        CurriculumState {
647            current_difficulty: self.current_difficulty,
648            step_count: self.step_count,
649            performance_history: self.performance_history.clone(),
650            sample_weights: self.sample_weights.clone(),
651            has_adversarial: self.adversarial_config.is_some(),
652        }
653    }
654}
655
656/// Curriculum state for analysis and visualization
657#[derive(Debug, Clone)]
658pub struct CurriculumState<A: Float> {
659    /// Current difficulty level
660    pub current_difficulty: f64,
661    /// Current step count
662    pub step_count: usize,
663    /// Performance history
664    pub performance_history: VecDeque<A>,
665    /// Sample weights
666    pub sample_weights: HashMap<usize, A>,
667    /// Whether adversarial training is enabled
668    pub has_adversarial: bool,
669}
670
671/// Adaptive curriculum that automatically adjusts strategy
672#[derive(Debug)]
673pub struct AdaptiveCurriculum<A: Float, D: Dimension> {
674    /// Collection of curriculum managers with different strategies
675    curricula: Vec<CurriculumManager<A, D>>,
676    /// Current active curriculum index
677    active_curriculum: usize,
678    /// Performance tracking for each curriculum
679    curriculum_performance: Vec<VecDeque<A>>,
680    /// Switch threshold for changing curriculum
681    switchthreshold: A,
682    /// Minimum steps before switching
683    min_steps_before_switch: usize,
684    /// Steps since last switch
685    steps_since_switch: usize,
686    /// Phantom data for dimension
687    _phantom: PhantomData<D>,
688}
689
690impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> AdaptiveCurriculum<A, D> {
691    /// Create a new adaptive curriculum
692    pub fn new(curricula: Vec<CurriculumManager<A, D>>, switchthreshold: A) -> Self {
693        let num_curricula = curricula.len();
694        Self {
695            curricula,
696            active_curriculum: 0,
697            curriculum_performance: vec![VecDeque::new(); num_curricula],
698            switchthreshold,
699            min_steps_before_switch: 100,
700            steps_since_switch: 0,
701            _phantom: PhantomData,
702        }
703    }
704
705    /// Update with performance and potentially switch curriculum
706    pub fn update(&mut self, performance: A) -> Result<()> {
707        // Update current curriculum
708        self.curricula[self.active_curriculum].update_curriculum(performance)?;
709        self.curriculum_performance[self.active_curriculum].push_back(performance);
710        self.steps_since_switch += 1;
711
712        // Consider switching if enough steps have passed
713        if self.steps_since_switch >= self.min_steps_before_switch {
714            self.consider_curriculum_switch()?;
715        }
716
717        Ok(())
718    }
719
720    /// Consider switching to a better performing curriculum
721    fn consider_curriculum_switch(&mut self) -> Result<()> {
722        let current_performance = self.get_average_performance(self.active_curriculum);
723        let mut best_curriculum = self.active_curriculum;
724        let mut best_performance = current_performance;
725
726        // Find best performing curriculum
727        for (i, _) in self.curricula.iter().enumerate() {
728            if i != self.active_curriculum {
729                let perf = self.get_average_performance(i);
730                if perf > best_performance + self.switchthreshold {
731                    best_performance = perf;
732                    best_curriculum = i;
733                }
734            }
735        }
736
737        // Switch if a better curriculum is found
738        if best_curriculum != self.active_curriculum {
739            self.active_curriculum = best_curriculum;
740            self.steps_since_switch = 0;
741        }
742
743        Ok(())
744    }
745
746    /// Get average performance for a curriculum
747    fn get_average_performance(&self, curriculumidx: usize) -> A {
748        let perf_history = &self.curriculum_performance[curriculumidx];
749        if perf_history.is_empty() {
750            A::zero()
751        } else {
752            let sum = perf_history.iter().fold(A::zero(), |acc, &perf| acc + perf);
753            sum / A::from(perf_history.len()).expect("unwrap failed")
754        }
755    }
756
757    /// Get active curriculum manager
758    pub fn active_curriculum(&self) -> &CurriculumManager<A, D> {
759        &self.curricula[self.active_curriculum]
760    }
761
762    /// Get mutable active curriculum manager
763    pub fn active_curriculum_mut(&mut self) -> &mut CurriculumManager<A, D> {
764        &mut self.curricula[self.active_curriculum]
765    }
766
767    /// Get active curriculum index
768    pub fn active_curriculum_index(&self) -> usize {
769        self.active_curriculum
770    }
771
772    /// Get performance comparison across curricula
773    pub fn get_curriculum_comparison(&self) -> Vec<(usize, A)> {
774        (0..self.curricula.len())
775            .map(|i| (i, self.get_average_performance(i)))
776            .collect()
777    }
778}
779
780#[cfg(test)]
781mod tests {
782    use super::*;
783    use approx::assert_relative_eq;
784    use scirs2_core::ndarray::Array1;
785
786    #[test]
787    fn test_linear_curriculum() {
788        let strategy = CurriculumStrategy::Linear {
789            start_difficulty: 0.1,
790            end_difficulty: 1.0,
791            num_steps: 10,
792        };
793
794        let importance_strategy = ImportanceWeightingStrategy::Uniform;
795        let mut curriculum =
796            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
797
798        // Test initial difficulty
799        assert_relative_eq!(curriculum.get_current_difficulty(), 0.1, epsilon = 1e-6);
800
801        // Update curriculum multiple times
802        for _ in 0..5 {
803            curriculum.update_curriculum(0.8).expect("unwrap failed");
804        }
805
806        // Difficulty should have increased
807        assert!(curriculum.get_current_difficulty() > 0.1);
808        assert!(curriculum.get_current_difficulty() <= 1.0);
809    }
810
811    #[test]
812    fn test_performance_based_curriculum() {
813        let strategy = CurriculumStrategy::PerformanceBased {
814            advance_threshold: 0.8,
815            reduce_threshold: 0.4,
816            adjustment_step: 0.1,
817            window_size: 3,
818        };
819
820        let importance_strategy = ImportanceWeightingStrategy::Uniform;
821        let mut curriculum =
822            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
823
824        let initial_difficulty = curriculum.get_current_difficulty();
825
826        // Simulate good performance (should increase difficulty)
827        for _ in 0..5 {
828            curriculum.update_curriculum(0.9).expect("unwrap failed");
829        }
830
831        assert!(curriculum.get_current_difficulty() > initial_difficulty);
832    }
833
834    #[test]
835    fn test_sample_filtering() {
836        let strategy = CurriculumStrategy::Linear {
837            start_difficulty: 0.5,
838            end_difficulty: 0.5,
839            num_steps: 10,
840        };
841
842        let importance_strategy = ImportanceWeightingStrategy::Uniform;
843        let mut curriculum =
844            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
845
846        // Set sample difficulties
847        curriculum.set_sample_difficulty(1, 0.3); // Easy
848        curriculum.set_sample_difficulty(2, 0.7); // Hard
849        curriculum.set_sample_difficulty(3, 0.5); // Medium
850
851        let sampleids = vec![1, 2, 3, 4]; // 4 has no difficulty set
852        let filtered = curriculum.filter_samples(&sampleids);
853
854        // Should include samples 1, 3, 4 (difficulty <= 0.5 or unknown)
855        assert_eq!(filtered.len(), 3);
856        assert!(filtered.contains(&1));
857        assert!(filtered.contains(&3));
858        assert!(filtered.contains(&4));
859        assert!(!filtered.contains(&2));
860    }
861
862    #[test]
863    fn test_loss_based_importance_weighting() {
864        let strategy = CurriculumStrategy::Linear {
865            start_difficulty: 0.5,
866            end_difficulty: 0.5,
867            num_steps: 10,
868        };
869
870        let importance_strategy = ImportanceWeightingStrategy::LossBased {
871            temperature: 1.0,
872            min_weight: 0.1,
873        };
874
875        let mut curriculum =
876            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
877
878        let sampleids = vec![1, 2, 3];
879        let losses = vec![0.1, 1.0, 0.5]; // Low, high, medium loss
880
881        curriculum
882            .compute_sample_weights(&sampleids, &losses, None, None)
883            .expect("unwrap failed");
884
885        // Sample with highest loss should have highest weight
886        let weight1 = curriculum.get_sample_weight(1);
887        let weight2 = curriculum.get_sample_weight(2);
888        let weight3 = curriculum.get_sample_weight(3);
889
890        assert!(weight2 > weight3); // High loss > medium loss
891        assert!(weight3 > weight1); // Medium loss > low loss
892    }
893
894    #[test]
895    fn test_adversarial_config() {
896        let strategy = CurriculumStrategy::Linear {
897            start_difficulty: 0.5,
898            end_difficulty: 0.5,
899            num_steps: 10,
900        };
901
902        let importance_strategy = ImportanceWeightingStrategy::Uniform;
903        let mut curriculum =
904            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
905
906        let adversarial_config = AdversarialConfig {
907            epsilon: 0.1,
908            num_steps: 5,
909            step_size: 0.02,
910            attack_type: AdversarialAttack::FGSM,
911            adversarial_weight: 0.5,
912        };
913
914        curriculum.enable_adversarial_training(adversarial_config);
915
916        let inputs = Array1::from_vec(vec![1.0, 2.0, 3.0]);
917        let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3]);
918
919        let adversarial = curriculum
920            .generate_adversarial_examples(&inputs, &gradients)
921            .expect("unwrap failed");
922
923        // Adversarial examples should be different from original
924        assert_ne!(
925            adversarial.as_slice().expect("unwrap failed"),
926            inputs.as_slice().expect("unwrap failed")
927        );
928
929        // Check that perturbation is bounded
930        for (orig, adv) in inputs.iter().zip(adversarial.iter()) {
931            assert!((adv - orig).abs() <= 0.1 + 1e-6); // epsilon + small tolerance
932        }
933    }
934
935    #[test]
936    fn test_adaptive_curriculum() {
937        let strategy1 = CurriculumStrategy::Linear {
938            start_difficulty: 0.1,
939            end_difficulty: 0.5,
940            num_steps: 100,
941        };
942
943        let strategy2 = CurriculumStrategy::Linear {
944            start_difficulty: 0.2,
945            end_difficulty: 0.8,
946            num_steps: 100,
947        };
948
949        let importance_strategy = ImportanceWeightingStrategy::Uniform;
950        let curriculum1 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
951            strategy1,
952            importance_strategy.clone(),
953        );
954        let curriculum2 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
955            strategy2,
956            importance_strategy,
957        );
958
959        let mut adaptive = AdaptiveCurriculum::new(vec![curriculum1, curriculum2], 0.1);
960
961        assert_eq!(adaptive.active_curriculum_index(), 0);
962
963        // Update with some performance values
964        for _ in 0..150 {
965            adaptive.update(0.7).expect("unwrap failed");
966        }
967
968        // Should potentially have switched curriculum
969        let comparison = adaptive.get_curriculum_comparison();
970        assert_eq!(comparison.len(), 2);
971    }
972
973    #[test]
974    fn test_curriculum_state_export() {
975        let strategy = CurriculumStrategy::Linear {
976            start_difficulty: 0.1,
977            end_difficulty: 1.0,
978            num_steps: 10,
979        };
980
981        let importance_strategy = ImportanceWeightingStrategy::Uniform;
982        let mut curriculum =
983            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
984
985        curriculum.update_curriculum(0.8).expect("unwrap failed");
986        let state = curriculum.export_state();
987
988        assert_eq!(state.step_count, 1);
989        assert_eq!(state.performance_history.len(), 1);
990        assert!(!state.has_adversarial);
991    }
992}