optirs_core/curriculum_optimization/
mod.rs

1// Curriculum optimization for adaptive training
2//
3// This module provides curriculum learning capabilities including task difficulty progression,
4// sample importance weighting, and adversarial training support.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13/// Curriculum learning strategy
14#[derive(Debug, Clone)]
15pub enum CurriculumStrategy {
16    /// Linear difficulty progression
17    Linear {
18        /// Starting difficulty (0.0 to 1.0)
19        start_difficulty: f64,
20        /// Ending difficulty (0.0 to 1.0)
21        end_difficulty: f64,
22        /// Number of steps to reach end difficulty
23        num_steps: usize,
24    },
25    /// Exponential difficulty progression
26    Exponential {
27        /// Starting difficulty (0.0 to 1.0)
28        start_difficulty: f64,
29        /// Ending difficulty (0.0 to 1.0)
30        end_difficulty: f64,
31        /// Growth rate
32        growth_rate: f64,
33    },
34    /// Performance-based curriculum
35    PerformanceBased {
36        /// Threshold for advancing difficulty
37        advance_threshold: f64,
38        /// Threshold for reducing difficulty
39        reduce_threshold: f64,
40        /// Difficulty adjustment step size
41        adjustment_step: f64,
42        /// Window size for performance averaging
43        window_size: usize,
44    },
45    /// Custom curriculum with predefined schedule
46    Custom {
47        /// Difficulty schedule (step -> difficulty)
48        schedule: HashMap<usize, f64>,
49        /// Default difficulty for unspecified steps
50        default_difficulty: f64,
51    },
52}
53
54/// Sample importance weighting strategy
55#[derive(Debug, Clone)]
56pub enum ImportanceWeightingStrategy {
57    /// Uniform weighting (all samples equal)
58    Uniform,
59    /// Loss-based weighting (higher loss = higher weight)
60    LossBased {
61        /// Temperature parameter for softmax weighting
62        temperature: f64,
63        /// Minimum weight to avoid zero weights
64        min_weight: f64,
65    },
66    /// Gradient norm based weighting
67    GradientNormBased {
68        /// Temperature parameter
69        temperature: f64,
70        /// Minimum weight
71        min_weight: f64,
72    },
73    /// Uncertainty-based weighting
74    UncertaintyBased {
75        /// Temperature parameter
76        temperature: f64,
77        /// Minimum weight
78        min_weight: f64,
79    },
80    /// Age-based weighting (older samples get higher weight)
81    AgeBased {
82        /// Decay factor for age
83        decayfactor: f64,
84    },
85}
86
87/// Adversarial training configuration
88#[derive(Debug, Clone)]
89pub struct AdversarialConfig<A: Float> {
90    /// Adversarial perturbation magnitude
91    pub epsilon: A,
92    /// Number of adversarial steps
93    pub num_steps: usize,
94    /// Step size for adversarial perturbation
95    pub step_size: A,
96    /// Type of adversarial attack
97    pub attack_type: AdversarialAttack,
98    /// Regularization weight for adversarial loss
99    pub adversarial_weight: A,
100}
101
102/// Types of adversarial attacks
103#[derive(Debug, Clone, Copy)]
104pub enum AdversarialAttack {
105    /// Fast Gradient Sign Method (FGSM)
106    FGSM,
107    /// Projected Gradient Descent (PGD)
108    PGD,
109    /// Basic Iterative Method (BIM)
110    BIM,
111    /// Momentum Iterative Method (MIM)
112    MIM,
113}
114
115/// Curriculum learning manager
116#[derive(Debug)]
117pub struct CurriculumManager<A: Float, D: Dimension> {
118    /// Curriculum strategy
119    strategy: CurriculumStrategy,
120    /// Current difficulty level
121    current_difficulty: f64,
122    /// Current step count
123    step_count: usize,
124    /// Performance history
125    performance_history: VecDeque<A>,
126    /// Sample difficulty scores
127    sample_difficulties: HashMap<usize, f64>,
128    /// Importance weighting strategy
129    importance_strategy: ImportanceWeightingStrategy,
130    /// Sample weights
131    sample_weights: HashMap<usize, A>,
132    /// Adversarial training configuration
133    adversarial_config: Option<AdversarialConfig<A>>,
134    /// Phantom data for dimension
135    _phantom: PhantomData<D>,
136}
137
138impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> CurriculumManager<A, D> {
139    /// Create a new curriculum manager
140    pub fn new(
141        strategy: CurriculumStrategy,
142        importance_strategy: ImportanceWeightingStrategy,
143    ) -> Self {
144        let initial_difficulty = match &strategy {
145            CurriculumStrategy::Linear {
146                start_difficulty, ..
147            } => *start_difficulty,
148            CurriculumStrategy::Exponential {
149                start_difficulty, ..
150            } => *start_difficulty,
151            CurriculumStrategy::PerformanceBased { .. } => 0.1, // Start easy
152            CurriculumStrategy::Custom {
153                default_difficulty, ..
154            } => *default_difficulty,
155        };
156
157        Self {
158            strategy,
159            current_difficulty: initial_difficulty,
160            step_count: 0,
161            performance_history: VecDeque::new(),
162            sample_difficulties: HashMap::new(),
163            importance_strategy,
164            sample_weights: HashMap::new(),
165            adversarial_config: None,
166            _phantom: PhantomData,
167        }
168    }
169
170    /// Enable adversarial training
171    pub fn enable_adversarial_training(&mut self, config: AdversarialConfig<A>) {
172        self.adversarial_config = Some(config);
173    }
174
175    /// Disable adversarial training
176    pub fn disable_adversarial_training(&mut self) {
177        self.adversarial_config = None;
178    }
179
180    /// Update curriculum based on performance
181    pub fn update_curriculum(&mut self, performance: A) -> Result<()> {
182        self.performance_history.push_back(performance);
183        self.step_count += 1;
184
185        // Update difficulty based on strategy
186        match &self.strategy {
187            CurriculumStrategy::Linear {
188                start_difficulty,
189                end_difficulty,
190                num_steps,
191            } => {
192                let progress = (self.step_count as f64) / (*num_steps as f64);
193                let progress = progress.min(1.0);
194                self.current_difficulty =
195                    start_difficulty + progress * (end_difficulty - start_difficulty);
196            }
197            CurriculumStrategy::Exponential {
198                start_difficulty,
199                end_difficulty,
200                growth_rate,
201            } => {
202                let progress = 1.0 - (-growth_rate * self.step_count as f64).exp();
203                self.current_difficulty =
204                    start_difficulty + progress * (end_difficulty - start_difficulty);
205            }
206            CurriculumStrategy::PerformanceBased {
207                advance_threshold,
208                reduce_threshold,
209                adjustment_step,
210                window_size,
211            } => {
212                if self.performance_history.len() >= *window_size {
213                    // Keep only recent performance
214                    while self.performance_history.len() > *window_size {
215                        self.performance_history.pop_front();
216                    }
217
218                    // Calculate average performance
219                    let avg_performance = self
220                        .performance_history
221                        .iter()
222                        .fold(A::zero(), |acc, &perf| acc + perf)
223                        / A::from(self.performance_history.len()).unwrap();
224
225                    let avg_perf_f64 = avg_performance.to_f64().unwrap_or(0.0);
226
227                    // Adjust difficulty based on performance
228                    if avg_perf_f64 > *advance_threshold {
229                        self.current_difficulty =
230                            (self.current_difficulty + adjustment_step).min(1.0);
231                    } else if avg_perf_f64 < *reduce_threshold {
232                        self.current_difficulty =
233                            (self.current_difficulty - adjustment_step).max(0.0);
234                    }
235                }
236            }
237            CurriculumStrategy::Custom {
238                schedule,
239                default_difficulty,
240            } => {
241                self.current_difficulty = schedule
242                    .get(&self.step_count)
243                    .copied()
244                    .unwrap_or(*default_difficulty);
245            }
246        }
247
248        Ok(())
249    }
250
251    /// Set difficulty score for a sample
252    pub fn set_sample_difficulty(&mut self, sampleid: usize, difficulty: f64) {
253        self.sample_difficulties.insert(sampleid, difficulty);
254    }
255
256    /// Check if sample should be included based on current difficulty
257    pub fn should_include_sample(&self, sampleid: usize) -> bool {
258        if let Some(&sample_difficulty) = self.sample_difficulties.get(&sampleid) {
259            sample_difficulty <= self.current_difficulty
260        } else {
261            true // Include unknown samples
262        }
263    }
264
265    /// Get current difficulty level
266    pub fn get_current_difficulty(&self) -> f64 {
267        self.current_difficulty
268    }
269
270    /// Compute importance weights for samples
271    pub fn compute_sample_weights(
272        &mut self,
273        sampleids: &[usize],
274        losses: &[A],
275        gradient_norms: Option<&[A]>,
276        uncertainties: Option<&[A]>,
277    ) -> Result<()> {
278        if sampleids.len() != losses.len() {
279            return Err(OptimError::DimensionMismatch(
280                "Sample IDs and losses must have same length".to_string(),
281            ));
282        }
283
284        match &self.importance_strategy {
285            ImportanceWeightingStrategy::Uniform => {
286                let uniform_weight = A::one();
287                for &sampleid in sampleids {
288                    self.sample_weights.insert(sampleid, uniform_weight);
289                }
290            }
291            ImportanceWeightingStrategy::LossBased {
292                temperature,
293                min_weight,
294            } => {
295                self.compute_loss_based_weights(sampleids, losses, *temperature, *min_weight)?;
296            }
297            ImportanceWeightingStrategy::GradientNormBased {
298                temperature,
299                min_weight,
300            } => {
301                if let Some(grad_norms) = gradient_norms {
302                    self.compute_gradient_norm_weights(
303                        sampleids,
304                        grad_norms,
305                        *temperature,
306                        *min_weight,
307                    )?;
308                } else {
309                    // Fall back to uniform weights
310                    for &sampleid in sampleids {
311                        self.sample_weights.insert(sampleid, A::one());
312                    }
313                }
314            }
315            ImportanceWeightingStrategy::UncertaintyBased {
316                temperature,
317                min_weight,
318            } => {
319                if let Some(uncertainties_array) = uncertainties {
320                    self.compute_uncertainty_weights(
321                        sampleids,
322                        uncertainties_array,
323                        *temperature,
324                        *min_weight,
325                    )?;
326                } else {
327                    // Fall back to uniform weights
328                    for &sampleid in sampleids {
329                        self.sample_weights.insert(sampleid, A::one());
330                    }
331                }
332            }
333            ImportanceWeightingStrategy::AgeBased { decayfactor } => {
334                self.compute_age_based_weights(sampleids, *decayfactor)?;
335            }
336        }
337
338        Ok(())
339    }
340
341    /// Compute loss-based importance weights
342    fn compute_loss_based_weights(
343        &mut self,
344        sampleids: &[usize],
345        losses: &[A],
346        temperature: f64,
347        min_weight: f64,
348    ) -> Result<()> {
349        // Compute softmax weights based on losses
350        let temp = A::from(temperature).unwrap();
351        let min_w = A::from(min_weight).unwrap();
352
353        // Find max loss for numerical stability
354        let max_loss = losses.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
355
356        // Compute unnormalized weights
357        let mut unnormalized_weights = Vec::new();
358        for &loss in losses {
359            let normalized_loss = (loss - max_loss) / temp;
360            unnormalized_weights.push(A::exp(normalized_loss));
361        }
362
363        // Normalize weights
364        let sum_weights: A = unnormalized_weights
365            .iter()
366            .fold(A::zero(), |acc, &w| acc + w);
367
368        for (i, &sampleid) in sampleids.iter().enumerate() {
369            let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
370            self.sample_weights.insert(sampleid, weight);
371        }
372
373        Ok(())
374    }
375
376    /// Compute gradient norm based weights
377    fn compute_gradient_norm_weights(
378        &mut self,
379        sampleids: &[usize],
380        gradient_norms: &[A],
381        temperature: f64,
382        min_weight: f64,
383    ) -> Result<()> {
384        let temp = A::from(temperature).unwrap();
385        let min_w = A::from(min_weight).unwrap();
386
387        // Find max gradient norm for numerical stability
388        let max_norm = gradient_norms
389            .iter()
390            .fold(A::neg_infinity(), |a, &b| A::max(a, b));
391
392        // Compute softmax weights
393        let mut unnormalized_weights = Vec::new();
394        for &norm in gradient_norms {
395            let normalized_norm = (norm - max_norm) / temp;
396            unnormalized_weights.push(A::exp(normalized_norm));
397        }
398
399        let sum_weights: A = unnormalized_weights
400            .iter()
401            .fold(A::zero(), |acc, &w| acc + w);
402
403        for (i, &sampleid) in sampleids.iter().enumerate() {
404            let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
405            self.sample_weights.insert(sampleid, weight);
406        }
407
408        Ok(())
409    }
410
411    /// Compute uncertainty-based weights
412    fn compute_uncertainty_weights(
413        &mut self,
414        sampleids: &[usize],
415        uncertainties: &[A],
416        temperature: f64,
417        min_weight: f64,
418    ) -> Result<()> {
419        let temp = A::from(temperature).unwrap();
420        let min_w = A::from(min_weight).unwrap();
421
422        // Find max uncertainty for numerical stability
423        let max_uncertainty = uncertainties
424            .iter()
425            .fold(A::neg_infinity(), |a, &b| A::max(a, b));
426
427        // Compute softmax weights (higher uncertainty = higher weight)
428        let mut unnormalized_weights = Vec::new();
429        for &uncertainty in uncertainties {
430            let normalized_uncertainty = (uncertainty - max_uncertainty) / temp;
431            unnormalized_weights.push(A::exp(normalized_uncertainty));
432        }
433
434        let sum_weights: A = unnormalized_weights
435            .iter()
436            .fold(A::zero(), |acc, &w| acc + w);
437
438        for (i, &sampleid) in sampleids.iter().enumerate() {
439            let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
440            self.sample_weights.insert(sampleid, weight);
441        }
442
443        Ok(())
444    }
445
446    /// Compute age-based weights
447    fn compute_age_based_weights(&mut self, sampleids: &[usize], decayfactor: f64) -> Result<()> {
448        let decay = A::from(decayfactor).unwrap();
449
450        for &sampleid in sampleids {
451            // Simple age-based weighting (older samples get exponentially higher weight)
452            let age = A::from(self.step_count.saturating_sub(sampleid)).unwrap();
453            let weight = A::exp(decay * age);
454            self.sample_weights.insert(sampleid, weight);
455        }
456
457        Ok(())
458    }
459
460    /// Get importance weight for a sample
461    pub fn get_sample_weight(&self, sampleid: usize) -> A {
462        self.sample_weights
463            .get(&sampleid)
464            .copied()
465            .unwrap_or_else(|| A::one())
466    }
467
468    /// Generate adversarial examples
469    pub fn generate_adversarial_examples(
470        &self,
471        inputs: &Array<A, D>,
472        gradients: &Array<A, D>,
473    ) -> Result<Array<A, D>> {
474        if let Some(config) = &self.adversarial_config {
475            match config.attack_type {
476                AdversarialAttack::FGSM => self.fgsm_attack(inputs, gradients, config),
477                AdversarialAttack::PGD => self.pgd_attack(inputs, gradients, config),
478                AdversarialAttack::BIM => self.bim_attack(inputs, gradients, config),
479                AdversarialAttack::MIM => self.mim_attack(inputs, gradients, config),
480            }
481        } else {
482            Ok(inputs.clone()) // No adversarial training
483        }
484    }
485
486    /// Fast Gradient Sign Method (FGSM)
487    fn fgsm_attack(
488        &self,
489        inputs: &Array<A, D>,
490        gradients: &Array<A, D>,
491        config: &AdversarialConfig<A>,
492    ) -> Result<Array<A, D>> {
493        let mut adversarial = inputs.clone();
494
495        // Sign of gradients
496        let sign_gradients = gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
497
498        // Add perturbation
499        Zip::from(&mut adversarial)
500            .and(&sign_gradients)
501            .for_each(|x, &sign| {
502                *x = *x + config.epsilon * sign;
503            });
504
505        Ok(adversarial)
506    }
507
508    /// Projected Gradient Descent (PGD)
509    fn pgd_attack(
510        &self,
511        inputs: &Array<A, D>,
512        gradients: &Array<A, D>,
513        config: &AdversarialConfig<A>,
514    ) -> Result<Array<A, D>> {
515        let mut adversarial = inputs.clone();
516
517        // Multiple PGD steps
518        for _ in 0..config.num_steps {
519            // Gradient step
520            let sign_gradients =
521                gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
522
523            Zip::from(&mut adversarial)
524                .and(&sign_gradients)
525                .for_each(|x, &sign| {
526                    *x = *x + config.step_size * sign;
527                });
528
529            // Project back to epsilon ball
530            Zip::from(&mut adversarial)
531                .and(inputs)
532                .for_each(|adv, &orig| {
533                    let diff = *adv - orig;
534                    let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
535                    *adv = orig + clamped_diff;
536                });
537        }
538
539        Ok(adversarial)
540    }
541
542    /// Basic Iterative Method (BIM)
543    fn bim_attack(
544        &self,
545        inputs: &Array<A, D>,
546        gradients: &Array<A, D>,
547        config: &AdversarialConfig<A>,
548    ) -> Result<Array<A, D>> {
549        // BIM is similar to PGD but with smaller steps
550        let mut modified_config = config.clone();
551        modified_config.step_size = config.epsilon / A::from(config.num_steps).unwrap();
552
553        self.pgd_attack(inputs, gradients, &modified_config)
554    }
555
556    /// Momentum Iterative Method (MIM)
557    fn mim_attack(
558        &self,
559        inputs: &Array<A, D>,
560        gradients: &Array<A, D>,
561        config: &AdversarialConfig<A>,
562    ) -> Result<Array<A, D>> {
563        let mut adversarial = inputs.clone();
564        let mut momentum = Array::zeros(inputs.raw_dim());
565        let decayfactor = A::from(1.0).unwrap(); // Momentum decay factor
566
567        for _ in 0..config.num_steps {
568            // Update momentum
569            let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
570            let normalized_gradients = if grad_norm > A::zero() {
571                gradients.mapv(|x| x / grad_norm)
572            } else {
573                gradients.clone()
574            };
575
576            Zip::from(&mut momentum)
577                .and(&normalized_gradients)
578                .for_each(|m, &g| {
579                    *m = decayfactor * *m + g;
580                });
581
582            // Apply momentum-based update
583            let momentum_signs =
584                momentum.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
585
586            Zip::from(&mut adversarial)
587                .and(&momentum_signs)
588                .for_each(|x, &sign| {
589                    *x = *x + config.step_size * sign;
590                });
591
592            // Project back to epsilon ball
593            Zip::from(&mut adversarial)
594                .and(inputs)
595                .for_each(|adv, &orig| {
596                    let diff = *adv - orig;
597                    let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
598                    *adv = orig + clamped_diff;
599                });
600        }
601
602        Ok(adversarial)
603    }
604
605    /// Get filtered samples based on current curriculum
606    pub fn filter_samples(&self, sampleids: &[usize]) -> Vec<usize> {
607        sampleids
608            .iter()
609            .copied()
610            .filter(|&id| self.should_include_sample(id))
611            .collect()
612    }
613
614    /// Get performance history
615    pub fn get_performance_history(&self) -> &VecDeque<A> {
616        &self.performance_history
617    }
618
619    /// Get step count
620    pub fn step_count(&self) -> usize {
621        self.step_count
622    }
623
624    /// Reset curriculum state
625    pub fn reset(&mut self) {
626        self.step_count = 0;
627        self.performance_history.clear();
628        self.sample_weights.clear();
629        self.current_difficulty = match &self.strategy {
630            CurriculumStrategy::Linear {
631                start_difficulty, ..
632            } => *start_difficulty,
633            CurriculumStrategy::Exponential {
634                start_difficulty, ..
635            } => *start_difficulty,
636            CurriculumStrategy::PerformanceBased { .. } => 0.1,
637            CurriculumStrategy::Custom {
638                default_difficulty, ..
639            } => *default_difficulty,
640        };
641    }
642
643    /// Export curriculum state for analysis
644    pub fn export_state(&self) -> CurriculumState<A> {
645        CurriculumState {
646            current_difficulty: self.current_difficulty,
647            step_count: self.step_count,
648            performance_history: self.performance_history.clone(),
649            sample_weights: self.sample_weights.clone(),
650            has_adversarial: self.adversarial_config.is_some(),
651        }
652    }
653}
654
655/// Curriculum state for analysis and visualization
656#[derive(Debug, Clone)]
657pub struct CurriculumState<A: Float> {
658    /// Current difficulty level
659    pub current_difficulty: f64,
660    /// Current step count
661    pub step_count: usize,
662    /// Performance history
663    pub performance_history: VecDeque<A>,
664    /// Sample weights
665    pub sample_weights: HashMap<usize, A>,
666    /// Whether adversarial training is enabled
667    pub has_adversarial: bool,
668}
669
670/// Adaptive curriculum that automatically adjusts strategy
671#[derive(Debug)]
672pub struct AdaptiveCurriculum<A: Float, D: Dimension> {
673    /// Collection of curriculum managers with different strategies
674    curricula: Vec<CurriculumManager<A, D>>,
675    /// Current active curriculum index
676    active_curriculum: usize,
677    /// Performance tracking for each curriculum
678    curriculum_performance: Vec<VecDeque<A>>,
679    /// Switch threshold for changing curriculum
680    switchthreshold: A,
681    /// Minimum steps before switching
682    min_steps_before_switch: usize,
683    /// Steps since last switch
684    steps_since_switch: usize,
685    /// Phantom data for dimension
686    _phantom: PhantomData<D>,
687}
688
689impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> AdaptiveCurriculum<A, D> {
690    /// Create a new adaptive curriculum
691    pub fn new(curricula: Vec<CurriculumManager<A, D>>, switchthreshold: A) -> Self {
692        let num_curricula = curricula.len();
693        Self {
694            curricula,
695            active_curriculum: 0,
696            curriculum_performance: vec![VecDeque::new(); num_curricula],
697            switchthreshold,
698            min_steps_before_switch: 100,
699            steps_since_switch: 0,
700            _phantom: PhantomData,
701        }
702    }
703
704    /// Update with performance and potentially switch curriculum
705    pub fn update(&mut self, performance: A) -> Result<()> {
706        // Update current curriculum
707        self.curricula[self.active_curriculum].update_curriculum(performance)?;
708        self.curriculum_performance[self.active_curriculum].push_back(performance);
709        self.steps_since_switch += 1;
710
711        // Consider switching if enough steps have passed
712        if self.steps_since_switch >= self.min_steps_before_switch {
713            self.consider_curriculum_switch()?;
714        }
715
716        Ok(())
717    }
718
719    /// Consider switching to a better performing curriculum
720    fn consider_curriculum_switch(&mut self) -> Result<()> {
721        let current_performance = self.get_average_performance(self.active_curriculum);
722        let mut best_curriculum = self.active_curriculum;
723        let mut best_performance = current_performance;
724
725        // Find best performing curriculum
726        for (i, _) in self.curricula.iter().enumerate() {
727            if i != self.active_curriculum {
728                let perf = self.get_average_performance(i);
729                if perf > best_performance + self.switchthreshold {
730                    best_performance = perf;
731                    best_curriculum = i;
732                }
733            }
734        }
735
736        // Switch if a better curriculum is found
737        if best_curriculum != self.active_curriculum {
738            self.active_curriculum = best_curriculum;
739            self.steps_since_switch = 0;
740        }
741
742        Ok(())
743    }
744
745    /// Get average performance for a curriculum
746    fn get_average_performance(&self, curriculumidx: usize) -> A {
747        let perf_history = &self.curriculum_performance[curriculumidx];
748        if perf_history.is_empty() {
749            A::zero()
750        } else {
751            let sum = perf_history.iter().fold(A::zero(), |acc, &perf| acc + perf);
752            sum / A::from(perf_history.len()).unwrap()
753        }
754    }
755
756    /// Get active curriculum manager
757    pub fn active_curriculum(&self) -> &CurriculumManager<A, D> {
758        &self.curricula[self.active_curriculum]
759    }
760
761    /// Get mutable active curriculum manager
762    pub fn active_curriculum_mut(&mut self) -> &mut CurriculumManager<A, D> {
763        &mut self.curricula[self.active_curriculum]
764    }
765
766    /// Get active curriculum index
767    pub fn active_curriculum_index(&self) -> usize {
768        self.active_curriculum
769    }
770
771    /// Get performance comparison across curricula
772    pub fn get_curriculum_comparison(&self) -> Vec<(usize, A)> {
773        (0..self.curricula.len())
774            .map(|i| (i, self.get_average_performance(i)))
775            .collect()
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782    use approx::assert_relative_eq;
783    use scirs2_core::ndarray::Array1;
784
785    #[test]
786    fn test_linear_curriculum() {
787        let strategy = CurriculumStrategy::Linear {
788            start_difficulty: 0.1,
789            end_difficulty: 1.0,
790            num_steps: 10,
791        };
792
793        let importance_strategy = ImportanceWeightingStrategy::Uniform;
794        let mut curriculum =
795            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
796
797        // Test initial difficulty
798        assert_relative_eq!(curriculum.get_current_difficulty(), 0.1, epsilon = 1e-6);
799
800        // Update curriculum multiple times
801        for _ in 0..5 {
802            curriculum.update_curriculum(0.8).unwrap();
803        }
804
805        // Difficulty should have increased
806        assert!(curriculum.get_current_difficulty() > 0.1);
807        assert!(curriculum.get_current_difficulty() <= 1.0);
808    }
809
810    #[test]
811    fn test_performance_based_curriculum() {
812        let strategy = CurriculumStrategy::PerformanceBased {
813            advance_threshold: 0.8,
814            reduce_threshold: 0.4,
815            adjustment_step: 0.1,
816            window_size: 3,
817        };
818
819        let importance_strategy = ImportanceWeightingStrategy::Uniform;
820        let mut curriculum =
821            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
822
823        let initial_difficulty = curriculum.get_current_difficulty();
824
825        // Simulate good performance (should increase difficulty)
826        for _ in 0..5 {
827            curriculum.update_curriculum(0.9).unwrap();
828        }
829
830        assert!(curriculum.get_current_difficulty() > initial_difficulty);
831    }
832
833    #[test]
834    fn test_sample_filtering() {
835        let strategy = CurriculumStrategy::Linear {
836            start_difficulty: 0.5,
837            end_difficulty: 0.5,
838            num_steps: 10,
839        };
840
841        let importance_strategy = ImportanceWeightingStrategy::Uniform;
842        let mut curriculum =
843            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
844
845        // Set sample difficulties
846        curriculum.set_sample_difficulty(1, 0.3); // Easy
847        curriculum.set_sample_difficulty(2, 0.7); // Hard
848        curriculum.set_sample_difficulty(3, 0.5); // Medium
849
850        let sampleids = vec![1, 2, 3, 4]; // 4 has no difficulty set
851        let filtered = curriculum.filter_samples(&sampleids);
852
853        // Should include samples 1, 3, 4 (difficulty <= 0.5 or unknown)
854        assert_eq!(filtered.len(), 3);
855        assert!(filtered.contains(&1));
856        assert!(filtered.contains(&3));
857        assert!(filtered.contains(&4));
858        assert!(!filtered.contains(&2));
859    }
860
861    #[test]
862    fn test_loss_based_importance_weighting() {
863        let strategy = CurriculumStrategy::Linear {
864            start_difficulty: 0.5,
865            end_difficulty: 0.5,
866            num_steps: 10,
867        };
868
869        let importance_strategy = ImportanceWeightingStrategy::LossBased {
870            temperature: 1.0,
871            min_weight: 0.1,
872        };
873
874        let mut curriculum =
875            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
876
877        let sampleids = vec![1, 2, 3];
878        let losses = vec![0.1, 1.0, 0.5]; // Low, high, medium loss
879
880        curriculum
881            .compute_sample_weights(&sampleids, &losses, None, None)
882            .unwrap();
883
884        // Sample with highest loss should have highest weight
885        let weight1 = curriculum.get_sample_weight(1);
886        let weight2 = curriculum.get_sample_weight(2);
887        let weight3 = curriculum.get_sample_weight(3);
888
889        assert!(weight2 > weight3); // High loss > medium loss
890        assert!(weight3 > weight1); // Medium loss > low loss
891    }
892
893    #[test]
894    fn test_adversarial_config() {
895        let strategy = CurriculumStrategy::Linear {
896            start_difficulty: 0.5,
897            end_difficulty: 0.5,
898            num_steps: 10,
899        };
900
901        let importance_strategy = ImportanceWeightingStrategy::Uniform;
902        let mut curriculum =
903            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
904
905        let adversarial_config = AdversarialConfig {
906            epsilon: 0.1,
907            num_steps: 5,
908            step_size: 0.02,
909            attack_type: AdversarialAttack::FGSM,
910            adversarial_weight: 0.5,
911        };
912
913        curriculum.enable_adversarial_training(adversarial_config);
914
915        let inputs = Array1::from_vec(vec![1.0, 2.0, 3.0]);
916        let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3]);
917
918        let adversarial = curriculum
919            .generate_adversarial_examples(&inputs, &gradients)
920            .unwrap();
921
922        // Adversarial examples should be different from original
923        assert_ne!(adversarial.as_slice().unwrap(), inputs.as_slice().unwrap());
924
925        // Check that perturbation is bounded
926        for (orig, adv) in inputs.iter().zip(adversarial.iter()) {
927            assert!((adv - orig).abs() <= 0.1 + 1e-6); // epsilon + small tolerance
928        }
929    }
930
931    #[test]
932    fn test_adaptive_curriculum() {
933        let strategy1 = CurriculumStrategy::Linear {
934            start_difficulty: 0.1,
935            end_difficulty: 0.5,
936            num_steps: 100,
937        };
938
939        let strategy2 = CurriculumStrategy::Linear {
940            start_difficulty: 0.2,
941            end_difficulty: 0.8,
942            num_steps: 100,
943        };
944
945        let importance_strategy = ImportanceWeightingStrategy::Uniform;
946        let curriculum1 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
947            strategy1,
948            importance_strategy.clone(),
949        );
950        let curriculum2 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
951            strategy2,
952            importance_strategy,
953        );
954
955        let mut adaptive = AdaptiveCurriculum::new(vec![curriculum1, curriculum2], 0.1);
956
957        assert_eq!(adaptive.active_curriculum_index(), 0);
958
959        // Update with some performance values
960        for _ in 0..150 {
961            adaptive.update(0.7).unwrap();
962        }
963
964        // Should potentially have switched curriculum
965        let comparison = adaptive.get_curriculum_comparison();
966        assert_eq!(comparison.len(), 2);
967    }
968
969    #[test]
970    fn test_curriculum_state_export() {
971        let strategy = CurriculumStrategy::Linear {
972            start_difficulty: 0.1,
973            end_difficulty: 1.0,
974            num_steps: 10,
975        };
976
977        let importance_strategy = ImportanceWeightingStrategy::Uniform;
978        let mut curriculum =
979            CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
980
981        curriculum.update_curriculum(0.8).unwrap();
982        let state = curriculum.export_state();
983
984        assert_eq!(state.step_count, 1);
985        assert_eq!(state.performance_history.len(), 1);
986        assert!(!state.has_adversarial);
987    }
988}