1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13#[derive(Debug, Clone)]
15pub enum CurriculumStrategy {
16 Linear {
18 start_difficulty: f64,
20 end_difficulty: f64,
22 num_steps: usize,
24 },
25 Exponential {
27 start_difficulty: f64,
29 end_difficulty: f64,
31 growth_rate: f64,
33 },
34 PerformanceBased {
36 advance_threshold: f64,
38 reduce_threshold: f64,
40 adjustment_step: f64,
42 window_size: usize,
44 },
45 Custom {
47 schedule: HashMap<usize, f64>,
49 default_difficulty: f64,
51 },
52}
53
54#[derive(Debug, Clone)]
56pub enum ImportanceWeightingStrategy {
57 Uniform,
59 LossBased {
61 temperature: f64,
63 min_weight: f64,
65 },
66 GradientNormBased {
68 temperature: f64,
70 min_weight: f64,
72 },
73 UncertaintyBased {
75 temperature: f64,
77 min_weight: f64,
79 },
80 AgeBased {
82 decayfactor: f64,
84 },
85}
86
87#[derive(Debug, Clone)]
89pub struct AdversarialConfig<A: Float> {
90 pub epsilon: A,
92 pub num_steps: usize,
94 pub step_size: A,
96 pub attack_type: AdversarialAttack,
98 pub adversarial_weight: A,
100}
101
102#[derive(Debug, Clone, Copy)]
104pub enum AdversarialAttack {
105 FGSM,
107 PGD,
109 BIM,
111 MIM,
113}
114
115#[derive(Debug)]
117pub struct CurriculumManager<A: Float, D: Dimension> {
118 strategy: CurriculumStrategy,
120 current_difficulty: f64,
122 step_count: usize,
124 performance_history: VecDeque<A>,
126 sample_difficulties: HashMap<usize, f64>,
128 importance_strategy: ImportanceWeightingStrategy,
130 sample_weights: HashMap<usize, A>,
132 adversarial_config: Option<AdversarialConfig<A>>,
134 _phantom: PhantomData<D>,
136}
137
138impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> CurriculumManager<A, D> {
139 pub fn new(
141 strategy: CurriculumStrategy,
142 importance_strategy: ImportanceWeightingStrategy,
143 ) -> Self {
144 let initial_difficulty = match &strategy {
145 CurriculumStrategy::Linear {
146 start_difficulty, ..
147 } => *start_difficulty,
148 CurriculumStrategy::Exponential {
149 start_difficulty, ..
150 } => *start_difficulty,
151 CurriculumStrategy::PerformanceBased { .. } => 0.1, CurriculumStrategy::Custom {
153 default_difficulty, ..
154 } => *default_difficulty,
155 };
156
157 Self {
158 strategy,
159 current_difficulty: initial_difficulty,
160 step_count: 0,
161 performance_history: VecDeque::new(),
162 sample_difficulties: HashMap::new(),
163 importance_strategy,
164 sample_weights: HashMap::new(),
165 adversarial_config: None,
166 _phantom: PhantomData,
167 }
168 }
169
170 pub fn enable_adversarial_training(&mut self, config: AdversarialConfig<A>) {
172 self.adversarial_config = Some(config);
173 }
174
175 pub fn disable_adversarial_training(&mut self) {
177 self.adversarial_config = None;
178 }
179
180 pub fn update_curriculum(&mut self, performance: A) -> Result<()> {
182 self.performance_history.push_back(performance);
183 self.step_count += 1;
184
185 match &self.strategy {
187 CurriculumStrategy::Linear {
188 start_difficulty,
189 end_difficulty,
190 num_steps,
191 } => {
192 let progress = (self.step_count as f64) / (*num_steps as f64);
193 let progress = progress.min(1.0);
194 self.current_difficulty =
195 start_difficulty + progress * (end_difficulty - start_difficulty);
196 }
197 CurriculumStrategy::Exponential {
198 start_difficulty,
199 end_difficulty,
200 growth_rate,
201 } => {
202 let progress = 1.0 - (-growth_rate * self.step_count as f64).exp();
203 self.current_difficulty =
204 start_difficulty + progress * (end_difficulty - start_difficulty);
205 }
206 CurriculumStrategy::PerformanceBased {
207 advance_threshold,
208 reduce_threshold,
209 adjustment_step,
210 window_size,
211 } => {
212 if self.performance_history.len() >= *window_size {
213 while self.performance_history.len() > *window_size {
215 self.performance_history.pop_front();
216 }
217
218 let avg_performance = self
220 .performance_history
221 .iter()
222 .fold(A::zero(), |acc, &perf| acc + perf)
223 / A::from(self.performance_history.len()).unwrap();
224
225 let avg_perf_f64 = avg_performance.to_f64().unwrap_or(0.0);
226
227 if avg_perf_f64 > *advance_threshold {
229 self.current_difficulty =
230 (self.current_difficulty + adjustment_step).min(1.0);
231 } else if avg_perf_f64 < *reduce_threshold {
232 self.current_difficulty =
233 (self.current_difficulty - adjustment_step).max(0.0);
234 }
235 }
236 }
237 CurriculumStrategy::Custom {
238 schedule,
239 default_difficulty,
240 } => {
241 self.current_difficulty = schedule
242 .get(&self.step_count)
243 .copied()
244 .unwrap_or(*default_difficulty);
245 }
246 }
247
248 Ok(())
249 }
250
251 pub fn set_sample_difficulty(&mut self, sampleid: usize, difficulty: f64) {
253 self.sample_difficulties.insert(sampleid, difficulty);
254 }
255
256 pub fn should_include_sample(&self, sampleid: usize) -> bool {
258 if let Some(&sample_difficulty) = self.sample_difficulties.get(&sampleid) {
259 sample_difficulty <= self.current_difficulty
260 } else {
261 true }
263 }
264
265 pub fn get_current_difficulty(&self) -> f64 {
267 self.current_difficulty
268 }
269
270 pub fn compute_sample_weights(
272 &mut self,
273 sampleids: &[usize],
274 losses: &[A],
275 gradient_norms: Option<&[A]>,
276 uncertainties: Option<&[A]>,
277 ) -> Result<()> {
278 if sampleids.len() != losses.len() {
279 return Err(OptimError::DimensionMismatch(
280 "Sample IDs and losses must have same length".to_string(),
281 ));
282 }
283
284 match &self.importance_strategy {
285 ImportanceWeightingStrategy::Uniform => {
286 let uniform_weight = A::one();
287 for &sampleid in sampleids {
288 self.sample_weights.insert(sampleid, uniform_weight);
289 }
290 }
291 ImportanceWeightingStrategy::LossBased {
292 temperature,
293 min_weight,
294 } => {
295 self.compute_loss_based_weights(sampleids, losses, *temperature, *min_weight)?;
296 }
297 ImportanceWeightingStrategy::GradientNormBased {
298 temperature,
299 min_weight,
300 } => {
301 if let Some(grad_norms) = gradient_norms {
302 self.compute_gradient_norm_weights(
303 sampleids,
304 grad_norms,
305 *temperature,
306 *min_weight,
307 )?;
308 } else {
309 for &sampleid in sampleids {
311 self.sample_weights.insert(sampleid, A::one());
312 }
313 }
314 }
315 ImportanceWeightingStrategy::UncertaintyBased {
316 temperature,
317 min_weight,
318 } => {
319 if let Some(uncertainties_array) = uncertainties {
320 self.compute_uncertainty_weights(
321 sampleids,
322 uncertainties_array,
323 *temperature,
324 *min_weight,
325 )?;
326 } else {
327 for &sampleid in sampleids {
329 self.sample_weights.insert(sampleid, A::one());
330 }
331 }
332 }
333 ImportanceWeightingStrategy::AgeBased { decayfactor } => {
334 self.compute_age_based_weights(sampleids, *decayfactor)?;
335 }
336 }
337
338 Ok(())
339 }
340
341 fn compute_loss_based_weights(
343 &mut self,
344 sampleids: &[usize],
345 losses: &[A],
346 temperature: f64,
347 min_weight: f64,
348 ) -> Result<()> {
349 let temp = A::from(temperature).unwrap();
351 let min_w = A::from(min_weight).unwrap();
352
353 let max_loss = losses.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
355
356 let mut unnormalized_weights = Vec::new();
358 for &loss in losses {
359 let normalized_loss = (loss - max_loss) / temp;
360 unnormalized_weights.push(A::exp(normalized_loss));
361 }
362
363 let sum_weights: A = unnormalized_weights
365 .iter()
366 .fold(A::zero(), |acc, &w| acc + w);
367
368 for (i, &sampleid) in sampleids.iter().enumerate() {
369 let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
370 self.sample_weights.insert(sampleid, weight);
371 }
372
373 Ok(())
374 }
375
376 fn compute_gradient_norm_weights(
378 &mut self,
379 sampleids: &[usize],
380 gradient_norms: &[A],
381 temperature: f64,
382 min_weight: f64,
383 ) -> Result<()> {
384 let temp = A::from(temperature).unwrap();
385 let min_w = A::from(min_weight).unwrap();
386
387 let max_norm = gradient_norms
389 .iter()
390 .fold(A::neg_infinity(), |a, &b| A::max(a, b));
391
392 let mut unnormalized_weights = Vec::new();
394 for &norm in gradient_norms {
395 let normalized_norm = (norm - max_norm) / temp;
396 unnormalized_weights.push(A::exp(normalized_norm));
397 }
398
399 let sum_weights: A = unnormalized_weights
400 .iter()
401 .fold(A::zero(), |acc, &w| acc + w);
402
403 for (i, &sampleid) in sampleids.iter().enumerate() {
404 let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
405 self.sample_weights.insert(sampleid, weight);
406 }
407
408 Ok(())
409 }
410
411 fn compute_uncertainty_weights(
413 &mut self,
414 sampleids: &[usize],
415 uncertainties: &[A],
416 temperature: f64,
417 min_weight: f64,
418 ) -> Result<()> {
419 let temp = A::from(temperature).unwrap();
420 let min_w = A::from(min_weight).unwrap();
421
422 let max_uncertainty = uncertainties
424 .iter()
425 .fold(A::neg_infinity(), |a, &b| A::max(a, b));
426
427 let mut unnormalized_weights = Vec::new();
429 for &uncertainty in uncertainties {
430 let normalized_uncertainty = (uncertainty - max_uncertainty) / temp;
431 unnormalized_weights.push(A::exp(normalized_uncertainty));
432 }
433
434 let sum_weights: A = unnormalized_weights
435 .iter()
436 .fold(A::zero(), |acc, &w| acc + w);
437
438 for (i, &sampleid) in sampleids.iter().enumerate() {
439 let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
440 self.sample_weights.insert(sampleid, weight);
441 }
442
443 Ok(())
444 }
445
446 fn compute_age_based_weights(&mut self, sampleids: &[usize], decayfactor: f64) -> Result<()> {
448 let decay = A::from(decayfactor).unwrap();
449
450 for &sampleid in sampleids {
451 let age = A::from(self.step_count.saturating_sub(sampleid)).unwrap();
453 let weight = A::exp(decay * age);
454 self.sample_weights.insert(sampleid, weight);
455 }
456
457 Ok(())
458 }
459
460 pub fn get_sample_weight(&self, sampleid: usize) -> A {
462 self.sample_weights
463 .get(&sampleid)
464 .copied()
465 .unwrap_or_else(|| A::one())
466 }
467
468 pub fn generate_adversarial_examples(
470 &self,
471 inputs: &Array<A, D>,
472 gradients: &Array<A, D>,
473 ) -> Result<Array<A, D>> {
474 if let Some(config) = &self.adversarial_config {
475 match config.attack_type {
476 AdversarialAttack::FGSM => self.fgsm_attack(inputs, gradients, config),
477 AdversarialAttack::PGD => self.pgd_attack(inputs, gradients, config),
478 AdversarialAttack::BIM => self.bim_attack(inputs, gradients, config),
479 AdversarialAttack::MIM => self.mim_attack(inputs, gradients, config),
480 }
481 } else {
482 Ok(inputs.clone()) }
484 }
485
486 fn fgsm_attack(
488 &self,
489 inputs: &Array<A, D>,
490 gradients: &Array<A, D>,
491 config: &AdversarialConfig<A>,
492 ) -> Result<Array<A, D>> {
493 let mut adversarial = inputs.clone();
494
495 let sign_gradients = gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
497
498 Zip::from(&mut adversarial)
500 .and(&sign_gradients)
501 .for_each(|x, &sign| {
502 *x = *x + config.epsilon * sign;
503 });
504
505 Ok(adversarial)
506 }
507
508 fn pgd_attack(
510 &self,
511 inputs: &Array<A, D>,
512 gradients: &Array<A, D>,
513 config: &AdversarialConfig<A>,
514 ) -> Result<Array<A, D>> {
515 let mut adversarial = inputs.clone();
516
517 for _ in 0..config.num_steps {
519 let sign_gradients =
521 gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
522
523 Zip::from(&mut adversarial)
524 .and(&sign_gradients)
525 .for_each(|x, &sign| {
526 *x = *x + config.step_size * sign;
527 });
528
529 Zip::from(&mut adversarial)
531 .and(inputs)
532 .for_each(|adv, &orig| {
533 let diff = *adv - orig;
534 let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
535 *adv = orig + clamped_diff;
536 });
537 }
538
539 Ok(adversarial)
540 }
541
542 fn bim_attack(
544 &self,
545 inputs: &Array<A, D>,
546 gradients: &Array<A, D>,
547 config: &AdversarialConfig<A>,
548 ) -> Result<Array<A, D>> {
549 let mut modified_config = config.clone();
551 modified_config.step_size = config.epsilon / A::from(config.num_steps).unwrap();
552
553 self.pgd_attack(inputs, gradients, &modified_config)
554 }
555
556 fn mim_attack(
558 &self,
559 inputs: &Array<A, D>,
560 gradients: &Array<A, D>,
561 config: &AdversarialConfig<A>,
562 ) -> Result<Array<A, D>> {
563 let mut adversarial = inputs.clone();
564 let mut momentum = Array::zeros(inputs.raw_dim());
565 let decayfactor = A::from(1.0).unwrap(); for _ in 0..config.num_steps {
568 let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
570 let normalized_gradients = if grad_norm > A::zero() {
571 gradients.mapv(|x| x / grad_norm)
572 } else {
573 gradients.clone()
574 };
575
576 Zip::from(&mut momentum)
577 .and(&normalized_gradients)
578 .for_each(|m, &g| {
579 *m = decayfactor * *m + g;
580 });
581
582 let momentum_signs =
584 momentum.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
585
586 Zip::from(&mut adversarial)
587 .and(&momentum_signs)
588 .for_each(|x, &sign| {
589 *x = *x + config.step_size * sign;
590 });
591
592 Zip::from(&mut adversarial)
594 .and(inputs)
595 .for_each(|adv, &orig| {
596 let diff = *adv - orig;
597 let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
598 *adv = orig + clamped_diff;
599 });
600 }
601
602 Ok(adversarial)
603 }
604
605 pub fn filter_samples(&self, sampleids: &[usize]) -> Vec<usize> {
607 sampleids
608 .iter()
609 .copied()
610 .filter(|&id| self.should_include_sample(id))
611 .collect()
612 }
613
614 pub fn get_performance_history(&self) -> &VecDeque<A> {
616 &self.performance_history
617 }
618
619 pub fn step_count(&self) -> usize {
621 self.step_count
622 }
623
624 pub fn reset(&mut self) {
626 self.step_count = 0;
627 self.performance_history.clear();
628 self.sample_weights.clear();
629 self.current_difficulty = match &self.strategy {
630 CurriculumStrategy::Linear {
631 start_difficulty, ..
632 } => *start_difficulty,
633 CurriculumStrategy::Exponential {
634 start_difficulty, ..
635 } => *start_difficulty,
636 CurriculumStrategy::PerformanceBased { .. } => 0.1,
637 CurriculumStrategy::Custom {
638 default_difficulty, ..
639 } => *default_difficulty,
640 };
641 }
642
643 pub fn export_state(&self) -> CurriculumState<A> {
645 CurriculumState {
646 current_difficulty: self.current_difficulty,
647 step_count: self.step_count,
648 performance_history: self.performance_history.clone(),
649 sample_weights: self.sample_weights.clone(),
650 has_adversarial: self.adversarial_config.is_some(),
651 }
652 }
653}
654
655#[derive(Debug, Clone)]
657pub struct CurriculumState<A: Float> {
658 pub current_difficulty: f64,
660 pub step_count: usize,
662 pub performance_history: VecDeque<A>,
664 pub sample_weights: HashMap<usize, A>,
666 pub has_adversarial: bool,
668}
669
670#[derive(Debug)]
672pub struct AdaptiveCurriculum<A: Float, D: Dimension> {
673 curricula: Vec<CurriculumManager<A, D>>,
675 active_curriculum: usize,
677 curriculum_performance: Vec<VecDeque<A>>,
679 switchthreshold: A,
681 min_steps_before_switch: usize,
683 steps_since_switch: usize,
685 _phantom: PhantomData<D>,
687}
688
689impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> AdaptiveCurriculum<A, D> {
690 pub fn new(curricula: Vec<CurriculumManager<A, D>>, switchthreshold: A) -> Self {
692 let num_curricula = curricula.len();
693 Self {
694 curricula,
695 active_curriculum: 0,
696 curriculum_performance: vec![VecDeque::new(); num_curricula],
697 switchthreshold,
698 min_steps_before_switch: 100,
699 steps_since_switch: 0,
700 _phantom: PhantomData,
701 }
702 }
703
704 pub fn update(&mut self, performance: A) -> Result<()> {
706 self.curricula[self.active_curriculum].update_curriculum(performance)?;
708 self.curriculum_performance[self.active_curriculum].push_back(performance);
709 self.steps_since_switch += 1;
710
711 if self.steps_since_switch >= self.min_steps_before_switch {
713 self.consider_curriculum_switch()?;
714 }
715
716 Ok(())
717 }
718
719 fn consider_curriculum_switch(&mut self) -> Result<()> {
721 let current_performance = self.get_average_performance(self.active_curriculum);
722 let mut best_curriculum = self.active_curriculum;
723 let mut best_performance = current_performance;
724
725 for (i, _) in self.curricula.iter().enumerate() {
727 if i != self.active_curriculum {
728 let perf = self.get_average_performance(i);
729 if perf > best_performance + self.switchthreshold {
730 best_performance = perf;
731 best_curriculum = i;
732 }
733 }
734 }
735
736 if best_curriculum != self.active_curriculum {
738 self.active_curriculum = best_curriculum;
739 self.steps_since_switch = 0;
740 }
741
742 Ok(())
743 }
744
745 fn get_average_performance(&self, curriculumidx: usize) -> A {
747 let perf_history = &self.curriculum_performance[curriculumidx];
748 if perf_history.is_empty() {
749 A::zero()
750 } else {
751 let sum = perf_history.iter().fold(A::zero(), |acc, &perf| acc + perf);
752 sum / A::from(perf_history.len()).unwrap()
753 }
754 }
755
756 pub fn active_curriculum(&self) -> &CurriculumManager<A, D> {
758 &self.curricula[self.active_curriculum]
759 }
760
761 pub fn active_curriculum_mut(&mut self) -> &mut CurriculumManager<A, D> {
763 &mut self.curricula[self.active_curriculum]
764 }
765
766 pub fn active_curriculum_index(&self) -> usize {
768 self.active_curriculum
769 }
770
771 pub fn get_curriculum_comparison(&self) -> Vec<(usize, A)> {
773 (0..self.curricula.len())
774 .map(|i| (i, self.get_average_performance(i)))
775 .collect()
776 }
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782 use approx::assert_relative_eq;
783 use scirs2_core::ndarray::Array1;
784
785 #[test]
786 fn test_linear_curriculum() {
787 let strategy = CurriculumStrategy::Linear {
788 start_difficulty: 0.1,
789 end_difficulty: 1.0,
790 num_steps: 10,
791 };
792
793 let importance_strategy = ImportanceWeightingStrategy::Uniform;
794 let mut curriculum =
795 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
796
797 assert_relative_eq!(curriculum.get_current_difficulty(), 0.1, epsilon = 1e-6);
799
800 for _ in 0..5 {
802 curriculum.update_curriculum(0.8).unwrap();
803 }
804
805 assert!(curriculum.get_current_difficulty() > 0.1);
807 assert!(curriculum.get_current_difficulty() <= 1.0);
808 }
809
810 #[test]
811 fn test_performance_based_curriculum() {
812 let strategy = CurriculumStrategy::PerformanceBased {
813 advance_threshold: 0.8,
814 reduce_threshold: 0.4,
815 adjustment_step: 0.1,
816 window_size: 3,
817 };
818
819 let importance_strategy = ImportanceWeightingStrategy::Uniform;
820 let mut curriculum =
821 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
822
823 let initial_difficulty = curriculum.get_current_difficulty();
824
825 for _ in 0..5 {
827 curriculum.update_curriculum(0.9).unwrap();
828 }
829
830 assert!(curriculum.get_current_difficulty() > initial_difficulty);
831 }
832
833 #[test]
834 fn test_sample_filtering() {
835 let strategy = CurriculumStrategy::Linear {
836 start_difficulty: 0.5,
837 end_difficulty: 0.5,
838 num_steps: 10,
839 };
840
841 let importance_strategy = ImportanceWeightingStrategy::Uniform;
842 let mut curriculum =
843 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
844
845 curriculum.set_sample_difficulty(1, 0.3); curriculum.set_sample_difficulty(2, 0.7); curriculum.set_sample_difficulty(3, 0.5); let sampleids = vec![1, 2, 3, 4]; let filtered = curriculum.filter_samples(&sampleids);
852
853 assert_eq!(filtered.len(), 3);
855 assert!(filtered.contains(&1));
856 assert!(filtered.contains(&3));
857 assert!(filtered.contains(&4));
858 assert!(!filtered.contains(&2));
859 }
860
861 #[test]
862 fn test_loss_based_importance_weighting() {
863 let strategy = CurriculumStrategy::Linear {
864 start_difficulty: 0.5,
865 end_difficulty: 0.5,
866 num_steps: 10,
867 };
868
869 let importance_strategy = ImportanceWeightingStrategy::LossBased {
870 temperature: 1.0,
871 min_weight: 0.1,
872 };
873
874 let mut curriculum =
875 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
876
877 let sampleids = vec![1, 2, 3];
878 let losses = vec![0.1, 1.0, 0.5]; curriculum
881 .compute_sample_weights(&sampleids, &losses, None, None)
882 .unwrap();
883
884 let weight1 = curriculum.get_sample_weight(1);
886 let weight2 = curriculum.get_sample_weight(2);
887 let weight3 = curriculum.get_sample_weight(3);
888
889 assert!(weight2 > weight3); assert!(weight3 > weight1); }
892
893 #[test]
894 fn test_adversarial_config() {
895 let strategy = CurriculumStrategy::Linear {
896 start_difficulty: 0.5,
897 end_difficulty: 0.5,
898 num_steps: 10,
899 };
900
901 let importance_strategy = ImportanceWeightingStrategy::Uniform;
902 let mut curriculum =
903 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
904
905 let adversarial_config = AdversarialConfig {
906 epsilon: 0.1,
907 num_steps: 5,
908 step_size: 0.02,
909 attack_type: AdversarialAttack::FGSM,
910 adversarial_weight: 0.5,
911 };
912
913 curriculum.enable_adversarial_training(adversarial_config);
914
915 let inputs = Array1::from_vec(vec![1.0, 2.0, 3.0]);
916 let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3]);
917
918 let adversarial = curriculum
919 .generate_adversarial_examples(&inputs, &gradients)
920 .unwrap();
921
922 assert_ne!(adversarial.as_slice().unwrap(), inputs.as_slice().unwrap());
924
925 for (orig, adv) in inputs.iter().zip(adversarial.iter()) {
927 assert!((adv - orig).abs() <= 0.1 + 1e-6); }
929 }
930
931 #[test]
932 fn test_adaptive_curriculum() {
933 let strategy1 = CurriculumStrategy::Linear {
934 start_difficulty: 0.1,
935 end_difficulty: 0.5,
936 num_steps: 100,
937 };
938
939 let strategy2 = CurriculumStrategy::Linear {
940 start_difficulty: 0.2,
941 end_difficulty: 0.8,
942 num_steps: 100,
943 };
944
945 let importance_strategy = ImportanceWeightingStrategy::Uniform;
946 let curriculum1 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
947 strategy1,
948 importance_strategy.clone(),
949 );
950 let curriculum2 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
951 strategy2,
952 importance_strategy,
953 );
954
955 let mut adaptive = AdaptiveCurriculum::new(vec![curriculum1, curriculum2], 0.1);
956
957 assert_eq!(adaptive.active_curriculum_index(), 0);
958
959 for _ in 0..150 {
961 adaptive.update(0.7).unwrap();
962 }
963
964 let comparison = adaptive.get_curriculum_comparison();
966 assert_eq!(comparison.len(), 2);
967 }
968
969 #[test]
970 fn test_curriculum_state_export() {
971 let strategy = CurriculumStrategy::Linear {
972 start_difficulty: 0.1,
973 end_difficulty: 1.0,
974 num_steps: 10,
975 };
976
977 let importance_strategy = ImportanceWeightingStrategy::Uniform;
978 let mut curriculum =
979 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
980
981 curriculum.update_curriculum(0.8).unwrap();
982 let state = curriculum.export_state();
983
984 assert_eq!(state.step_count, 1);
985 assert_eq!(state.performance_history.len(), 1);
986 assert!(!state.has_adversarial);
987 }
988}