1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13#[derive(Debug, Clone)]
15pub enum CurriculumStrategy {
16 Linear {
18 start_difficulty: f64,
20 end_difficulty: f64,
22 num_steps: usize,
24 },
25 Exponential {
27 start_difficulty: f64,
29 end_difficulty: f64,
31 growth_rate: f64,
33 },
34 PerformanceBased {
36 advance_threshold: f64,
38 reduce_threshold: f64,
40 adjustment_step: f64,
42 window_size: usize,
44 },
45 Custom {
47 schedule: HashMap<usize, f64>,
49 default_difficulty: f64,
51 },
52}
53
54#[derive(Debug, Clone)]
56pub enum ImportanceWeightingStrategy {
57 Uniform,
59 LossBased {
61 temperature: f64,
63 min_weight: f64,
65 },
66 GradientNormBased {
68 temperature: f64,
70 min_weight: f64,
72 },
73 UncertaintyBased {
75 temperature: f64,
77 min_weight: f64,
79 },
80 AgeBased {
82 decayfactor: f64,
84 },
85}
86
87#[derive(Debug, Clone)]
89pub struct AdversarialConfig<A: Float> {
90 pub epsilon: A,
92 pub num_steps: usize,
94 pub step_size: A,
96 pub attack_type: AdversarialAttack,
98 pub adversarial_weight: A,
100}
101
102#[derive(Debug, Clone, Copy)]
104pub enum AdversarialAttack {
105 FGSM,
107 PGD,
109 BIM,
111 MIM,
113}
114
115#[derive(Debug)]
117pub struct CurriculumManager<A: Float, D: Dimension> {
118 strategy: CurriculumStrategy,
120 current_difficulty: f64,
122 step_count: usize,
124 performance_history: VecDeque<A>,
126 sample_difficulties: HashMap<usize, f64>,
128 importance_strategy: ImportanceWeightingStrategy,
130 sample_weights: HashMap<usize, A>,
132 adversarial_config: Option<AdversarialConfig<A>>,
134 _phantom: PhantomData<D>,
136}
137
138impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> CurriculumManager<A, D> {
139 pub fn new(
141 strategy: CurriculumStrategy,
142 importance_strategy: ImportanceWeightingStrategy,
143 ) -> Self {
144 let initial_difficulty = match &strategy {
145 CurriculumStrategy::Linear {
146 start_difficulty, ..
147 } => *start_difficulty,
148 CurriculumStrategy::Exponential {
149 start_difficulty, ..
150 } => *start_difficulty,
151 CurriculumStrategy::PerformanceBased { .. } => 0.1, CurriculumStrategy::Custom {
153 default_difficulty, ..
154 } => *default_difficulty,
155 };
156
157 Self {
158 strategy,
159 current_difficulty: initial_difficulty,
160 step_count: 0,
161 performance_history: VecDeque::new(),
162 sample_difficulties: HashMap::new(),
163 importance_strategy,
164 sample_weights: HashMap::new(),
165 adversarial_config: None,
166 _phantom: PhantomData,
167 }
168 }
169
170 pub fn enable_adversarial_training(&mut self, config: AdversarialConfig<A>) {
172 self.adversarial_config = Some(config);
173 }
174
175 pub fn disable_adversarial_training(&mut self) {
177 self.adversarial_config = None;
178 }
179
180 pub fn update_curriculum(&mut self, performance: A) -> Result<()> {
182 self.performance_history.push_back(performance);
183 self.step_count += 1;
184
185 match &self.strategy {
187 CurriculumStrategy::Linear {
188 start_difficulty,
189 end_difficulty,
190 num_steps,
191 } => {
192 let progress = (self.step_count as f64) / (*num_steps as f64);
193 let progress = progress.min(1.0);
194 self.current_difficulty =
195 start_difficulty + progress * (end_difficulty - start_difficulty);
196 }
197 CurriculumStrategy::Exponential {
198 start_difficulty,
199 end_difficulty,
200 growth_rate,
201 } => {
202 let progress = 1.0 - (-growth_rate * self.step_count as f64).exp();
203 self.current_difficulty =
204 start_difficulty + progress * (end_difficulty - start_difficulty);
205 }
206 CurriculumStrategy::PerformanceBased {
207 advance_threshold,
208 reduce_threshold,
209 adjustment_step,
210 window_size,
211 } => {
212 if self.performance_history.len() >= *window_size {
213 while self.performance_history.len() > *window_size {
215 self.performance_history.pop_front();
216 }
217
218 let avg_performance = self
220 .performance_history
221 .iter()
222 .fold(A::zero(), |acc, &perf| acc + perf)
223 / A::from(self.performance_history.len()).expect("unwrap failed");
224
225 let avg_perf_f64 = avg_performance.to_f64().unwrap_or(0.0);
226
227 if avg_perf_f64 > *advance_threshold {
229 self.current_difficulty =
230 (self.current_difficulty + adjustment_step).min(1.0);
231 } else if avg_perf_f64 < *reduce_threshold {
232 self.current_difficulty =
233 (self.current_difficulty - adjustment_step).max(0.0);
234 }
235 }
236 }
237 CurriculumStrategy::Custom {
238 schedule,
239 default_difficulty,
240 } => {
241 self.current_difficulty = schedule
242 .get(&self.step_count)
243 .copied()
244 .unwrap_or(*default_difficulty);
245 }
246 }
247
248 Ok(())
249 }
250
251 pub fn set_sample_difficulty(&mut self, sampleid: usize, difficulty: f64) {
253 self.sample_difficulties.insert(sampleid, difficulty);
254 }
255
256 pub fn should_include_sample(&self, sampleid: usize) -> bool {
258 if let Some(&sample_difficulty) = self.sample_difficulties.get(&sampleid) {
259 sample_difficulty <= self.current_difficulty
260 } else {
261 true }
263 }
264
265 pub fn get_current_difficulty(&self) -> f64 {
267 self.current_difficulty
268 }
269
270 pub fn compute_sample_weights(
272 &mut self,
273 sampleids: &[usize],
274 losses: &[A],
275 gradient_norms: Option<&[A]>,
276 uncertainties: Option<&[A]>,
277 ) -> Result<()> {
278 if sampleids.len() != losses.len() {
279 return Err(OptimError::DimensionMismatch(
280 "Sample IDs and losses must have same length".to_string(),
281 ));
282 }
283
284 match &self.importance_strategy {
285 ImportanceWeightingStrategy::Uniform => {
286 let uniform_weight = A::one();
287 for &sampleid in sampleids {
288 self.sample_weights.insert(sampleid, uniform_weight);
289 }
290 }
291 ImportanceWeightingStrategy::LossBased {
292 temperature,
293 min_weight,
294 } => {
295 self.compute_loss_based_weights(sampleids, losses, *temperature, *min_weight)?;
296 }
297 ImportanceWeightingStrategy::GradientNormBased {
298 temperature,
299 min_weight,
300 } => {
301 if let Some(grad_norms) = gradient_norms {
302 self.compute_gradient_norm_weights(
303 sampleids,
304 grad_norms,
305 *temperature,
306 *min_weight,
307 )?;
308 } else {
309 for &sampleid in sampleids {
311 self.sample_weights.insert(sampleid, A::one());
312 }
313 }
314 }
315 ImportanceWeightingStrategy::UncertaintyBased {
316 temperature,
317 min_weight,
318 } => {
319 if let Some(uncertainties_array) = uncertainties {
320 self.compute_uncertainty_weights(
321 sampleids,
322 uncertainties_array,
323 *temperature,
324 *min_weight,
325 )?;
326 } else {
327 for &sampleid in sampleids {
329 self.sample_weights.insert(sampleid, A::one());
330 }
331 }
332 }
333 ImportanceWeightingStrategy::AgeBased { decayfactor } => {
334 self.compute_age_based_weights(sampleids, *decayfactor)?;
335 }
336 }
337
338 Ok(())
339 }
340
341 fn compute_loss_based_weights(
343 &mut self,
344 sampleids: &[usize],
345 losses: &[A],
346 temperature: f64,
347 min_weight: f64,
348 ) -> Result<()> {
349 let temp = A::from(temperature).expect("unwrap failed");
351 let min_w = A::from(min_weight).expect("unwrap failed");
352
353 let max_loss = losses.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
355
356 let mut unnormalized_weights = Vec::new();
358 for &loss in losses {
359 let normalized_loss = (loss - max_loss) / temp;
360 unnormalized_weights.push(A::exp(normalized_loss));
361 }
362
363 let sum_weights: A = unnormalized_weights
365 .iter()
366 .fold(A::zero(), |acc, &w| acc + w);
367
368 for (i, &sampleid) in sampleids.iter().enumerate() {
369 let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
370 self.sample_weights.insert(sampleid, weight);
371 }
372
373 Ok(())
374 }
375
376 fn compute_gradient_norm_weights(
378 &mut self,
379 sampleids: &[usize],
380 gradient_norms: &[A],
381 temperature: f64,
382 min_weight: f64,
383 ) -> Result<()> {
384 let temp = A::from(temperature).expect("unwrap failed");
385 let min_w = A::from(min_weight).expect("unwrap failed");
386
387 let max_norm = gradient_norms
389 .iter()
390 .fold(A::neg_infinity(), |a, &b| A::max(a, b));
391
392 let mut unnormalized_weights = Vec::new();
394 for &norm in gradient_norms {
395 let normalized_norm = (norm - max_norm) / temp;
396 unnormalized_weights.push(A::exp(normalized_norm));
397 }
398
399 let sum_weights: A = unnormalized_weights
400 .iter()
401 .fold(A::zero(), |acc, &w| acc + w);
402
403 for (i, &sampleid) in sampleids.iter().enumerate() {
404 let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
405 self.sample_weights.insert(sampleid, weight);
406 }
407
408 Ok(())
409 }
410
411 fn compute_uncertainty_weights(
413 &mut self,
414 sampleids: &[usize],
415 uncertainties: &[A],
416 temperature: f64,
417 min_weight: f64,
418 ) -> Result<()> {
419 let temp = A::from(temperature).expect("unwrap failed");
420 let min_w = A::from(min_weight).expect("unwrap failed");
421
422 let max_uncertainty = uncertainties
424 .iter()
425 .fold(A::neg_infinity(), |a, &b| A::max(a, b));
426
427 let mut unnormalized_weights = Vec::new();
429 for &uncertainty in uncertainties {
430 let normalized_uncertainty = (uncertainty - max_uncertainty) / temp;
431 unnormalized_weights.push(A::exp(normalized_uncertainty));
432 }
433
434 let sum_weights: A = unnormalized_weights
435 .iter()
436 .fold(A::zero(), |acc, &w| acc + w);
437
438 for (i, &sampleid) in sampleids.iter().enumerate() {
439 let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
440 self.sample_weights.insert(sampleid, weight);
441 }
442
443 Ok(())
444 }
445
446 fn compute_age_based_weights(&mut self, sampleids: &[usize], decayfactor: f64) -> Result<()> {
448 let decay = A::from(decayfactor).expect("unwrap failed");
449
450 for &sampleid in sampleids {
451 let age = A::from(self.step_count.saturating_sub(sampleid)).expect("unwrap failed");
453 let weight = A::exp(decay * age);
454 self.sample_weights.insert(sampleid, weight);
455 }
456
457 Ok(())
458 }
459
460 pub fn get_sample_weight(&self, sampleid: usize) -> A {
462 self.sample_weights
463 .get(&sampleid)
464 .copied()
465 .unwrap_or_else(|| A::one())
466 }
467
468 pub fn generate_adversarial_examples(
470 &self,
471 inputs: &Array<A, D>,
472 gradients: &Array<A, D>,
473 ) -> Result<Array<A, D>> {
474 if let Some(config) = &self.adversarial_config {
475 match config.attack_type {
476 AdversarialAttack::FGSM => self.fgsm_attack(inputs, gradients, config),
477 AdversarialAttack::PGD => self.pgd_attack(inputs, gradients, config),
478 AdversarialAttack::BIM => self.bim_attack(inputs, gradients, config),
479 AdversarialAttack::MIM => self.mim_attack(inputs, gradients, config),
480 }
481 } else {
482 Ok(inputs.clone()) }
484 }
485
486 fn fgsm_attack(
488 &self,
489 inputs: &Array<A, D>,
490 gradients: &Array<A, D>,
491 config: &AdversarialConfig<A>,
492 ) -> Result<Array<A, D>> {
493 let mut adversarial = inputs.clone();
494
495 let sign_gradients = gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
497
498 Zip::from(&mut adversarial)
500 .and(&sign_gradients)
501 .for_each(|x, &sign| {
502 *x = *x + config.epsilon * sign;
503 });
504
505 Ok(adversarial)
506 }
507
508 fn pgd_attack(
510 &self,
511 inputs: &Array<A, D>,
512 gradients: &Array<A, D>,
513 config: &AdversarialConfig<A>,
514 ) -> Result<Array<A, D>> {
515 let mut adversarial = inputs.clone();
516
517 for _ in 0..config.num_steps {
519 let sign_gradients =
521 gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
522
523 Zip::from(&mut adversarial)
524 .and(&sign_gradients)
525 .for_each(|x, &sign| {
526 *x = *x + config.step_size * sign;
527 });
528
529 Zip::from(&mut adversarial)
531 .and(inputs)
532 .for_each(|adv, &orig| {
533 let diff = *adv - orig;
534 let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
535 *adv = orig + clamped_diff;
536 });
537 }
538
539 Ok(adversarial)
540 }
541
542 fn bim_attack(
544 &self,
545 inputs: &Array<A, D>,
546 gradients: &Array<A, D>,
547 config: &AdversarialConfig<A>,
548 ) -> Result<Array<A, D>> {
549 let mut modified_config = config.clone();
551 modified_config.step_size =
552 config.epsilon / A::from(config.num_steps).expect("unwrap failed");
553
554 self.pgd_attack(inputs, gradients, &modified_config)
555 }
556
557 fn mim_attack(
559 &self,
560 inputs: &Array<A, D>,
561 gradients: &Array<A, D>,
562 config: &AdversarialConfig<A>,
563 ) -> Result<Array<A, D>> {
564 let mut adversarial = inputs.clone();
565 let mut momentum = Array::zeros(inputs.raw_dim());
566 let decayfactor = A::from(1.0).expect("unwrap failed"); for _ in 0..config.num_steps {
569 let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
571 let normalized_gradients = if grad_norm > A::zero() {
572 gradients.mapv(|x| x / grad_norm)
573 } else {
574 gradients.clone()
575 };
576
577 Zip::from(&mut momentum)
578 .and(&normalized_gradients)
579 .for_each(|m, &g| {
580 *m = decayfactor * *m + g;
581 });
582
583 let momentum_signs =
585 momentum.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
586
587 Zip::from(&mut adversarial)
588 .and(&momentum_signs)
589 .for_each(|x, &sign| {
590 *x = *x + config.step_size * sign;
591 });
592
593 Zip::from(&mut adversarial)
595 .and(inputs)
596 .for_each(|adv, &orig| {
597 let diff = *adv - orig;
598 let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
599 *adv = orig + clamped_diff;
600 });
601 }
602
603 Ok(adversarial)
604 }
605
606 pub fn filter_samples(&self, sampleids: &[usize]) -> Vec<usize> {
608 sampleids
609 .iter()
610 .copied()
611 .filter(|&id| self.should_include_sample(id))
612 .collect()
613 }
614
615 pub fn get_performance_history(&self) -> &VecDeque<A> {
617 &self.performance_history
618 }
619
620 pub fn step_count(&self) -> usize {
622 self.step_count
623 }
624
625 pub fn reset(&mut self) {
627 self.step_count = 0;
628 self.performance_history.clear();
629 self.sample_weights.clear();
630 self.current_difficulty = match &self.strategy {
631 CurriculumStrategy::Linear {
632 start_difficulty, ..
633 } => *start_difficulty,
634 CurriculumStrategy::Exponential {
635 start_difficulty, ..
636 } => *start_difficulty,
637 CurriculumStrategy::PerformanceBased { .. } => 0.1,
638 CurriculumStrategy::Custom {
639 default_difficulty, ..
640 } => *default_difficulty,
641 };
642 }
643
644 pub fn export_state(&self) -> CurriculumState<A> {
646 CurriculumState {
647 current_difficulty: self.current_difficulty,
648 step_count: self.step_count,
649 performance_history: self.performance_history.clone(),
650 sample_weights: self.sample_weights.clone(),
651 has_adversarial: self.adversarial_config.is_some(),
652 }
653 }
654}
655
656#[derive(Debug, Clone)]
658pub struct CurriculumState<A: Float> {
659 pub current_difficulty: f64,
661 pub step_count: usize,
663 pub performance_history: VecDeque<A>,
665 pub sample_weights: HashMap<usize, A>,
667 pub has_adversarial: bool,
669}
670
671#[derive(Debug)]
673pub struct AdaptiveCurriculum<A: Float, D: Dimension> {
674 curricula: Vec<CurriculumManager<A, D>>,
676 active_curriculum: usize,
678 curriculum_performance: Vec<VecDeque<A>>,
680 switchthreshold: A,
682 min_steps_before_switch: usize,
684 steps_since_switch: usize,
686 _phantom: PhantomData<D>,
688}
689
690impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> AdaptiveCurriculum<A, D> {
691 pub fn new(curricula: Vec<CurriculumManager<A, D>>, switchthreshold: A) -> Self {
693 let num_curricula = curricula.len();
694 Self {
695 curricula,
696 active_curriculum: 0,
697 curriculum_performance: vec![VecDeque::new(); num_curricula],
698 switchthreshold,
699 min_steps_before_switch: 100,
700 steps_since_switch: 0,
701 _phantom: PhantomData,
702 }
703 }
704
705 pub fn update(&mut self, performance: A) -> Result<()> {
707 self.curricula[self.active_curriculum].update_curriculum(performance)?;
709 self.curriculum_performance[self.active_curriculum].push_back(performance);
710 self.steps_since_switch += 1;
711
712 if self.steps_since_switch >= self.min_steps_before_switch {
714 self.consider_curriculum_switch()?;
715 }
716
717 Ok(())
718 }
719
720 fn consider_curriculum_switch(&mut self) -> Result<()> {
722 let current_performance = self.get_average_performance(self.active_curriculum);
723 let mut best_curriculum = self.active_curriculum;
724 let mut best_performance = current_performance;
725
726 for (i, _) in self.curricula.iter().enumerate() {
728 if i != self.active_curriculum {
729 let perf = self.get_average_performance(i);
730 if perf > best_performance + self.switchthreshold {
731 best_performance = perf;
732 best_curriculum = i;
733 }
734 }
735 }
736
737 if best_curriculum != self.active_curriculum {
739 self.active_curriculum = best_curriculum;
740 self.steps_since_switch = 0;
741 }
742
743 Ok(())
744 }
745
746 fn get_average_performance(&self, curriculumidx: usize) -> A {
748 let perf_history = &self.curriculum_performance[curriculumidx];
749 if perf_history.is_empty() {
750 A::zero()
751 } else {
752 let sum = perf_history.iter().fold(A::zero(), |acc, &perf| acc + perf);
753 sum / A::from(perf_history.len()).expect("unwrap failed")
754 }
755 }
756
757 pub fn active_curriculum(&self) -> &CurriculumManager<A, D> {
759 &self.curricula[self.active_curriculum]
760 }
761
762 pub fn active_curriculum_mut(&mut self) -> &mut CurriculumManager<A, D> {
764 &mut self.curricula[self.active_curriculum]
765 }
766
767 pub fn active_curriculum_index(&self) -> usize {
769 self.active_curriculum
770 }
771
772 pub fn get_curriculum_comparison(&self) -> Vec<(usize, A)> {
774 (0..self.curricula.len())
775 .map(|i| (i, self.get_average_performance(i)))
776 .collect()
777 }
778}
779
780#[cfg(test)]
781mod tests {
782 use super::*;
783 use approx::assert_relative_eq;
784 use scirs2_core::ndarray::Array1;
785
786 #[test]
787 fn test_linear_curriculum() {
788 let strategy = CurriculumStrategy::Linear {
789 start_difficulty: 0.1,
790 end_difficulty: 1.0,
791 num_steps: 10,
792 };
793
794 let importance_strategy = ImportanceWeightingStrategy::Uniform;
795 let mut curriculum =
796 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
797
798 assert_relative_eq!(curriculum.get_current_difficulty(), 0.1, epsilon = 1e-6);
800
801 for _ in 0..5 {
803 curriculum.update_curriculum(0.8).expect("unwrap failed");
804 }
805
806 assert!(curriculum.get_current_difficulty() > 0.1);
808 assert!(curriculum.get_current_difficulty() <= 1.0);
809 }
810
811 #[test]
812 fn test_performance_based_curriculum() {
813 let strategy = CurriculumStrategy::PerformanceBased {
814 advance_threshold: 0.8,
815 reduce_threshold: 0.4,
816 adjustment_step: 0.1,
817 window_size: 3,
818 };
819
820 let importance_strategy = ImportanceWeightingStrategy::Uniform;
821 let mut curriculum =
822 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
823
824 let initial_difficulty = curriculum.get_current_difficulty();
825
826 for _ in 0..5 {
828 curriculum.update_curriculum(0.9).expect("unwrap failed");
829 }
830
831 assert!(curriculum.get_current_difficulty() > initial_difficulty);
832 }
833
834 #[test]
835 fn test_sample_filtering() {
836 let strategy = CurriculumStrategy::Linear {
837 start_difficulty: 0.5,
838 end_difficulty: 0.5,
839 num_steps: 10,
840 };
841
842 let importance_strategy = ImportanceWeightingStrategy::Uniform;
843 let mut curriculum =
844 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
845
846 curriculum.set_sample_difficulty(1, 0.3); curriculum.set_sample_difficulty(2, 0.7); curriculum.set_sample_difficulty(3, 0.5); let sampleids = vec![1, 2, 3, 4]; let filtered = curriculum.filter_samples(&sampleids);
853
854 assert_eq!(filtered.len(), 3);
856 assert!(filtered.contains(&1));
857 assert!(filtered.contains(&3));
858 assert!(filtered.contains(&4));
859 assert!(!filtered.contains(&2));
860 }
861
862 #[test]
863 fn test_loss_based_importance_weighting() {
864 let strategy = CurriculumStrategy::Linear {
865 start_difficulty: 0.5,
866 end_difficulty: 0.5,
867 num_steps: 10,
868 };
869
870 let importance_strategy = ImportanceWeightingStrategy::LossBased {
871 temperature: 1.0,
872 min_weight: 0.1,
873 };
874
875 let mut curriculum =
876 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
877
878 let sampleids = vec![1, 2, 3];
879 let losses = vec![0.1, 1.0, 0.5]; curriculum
882 .compute_sample_weights(&sampleids, &losses, None, None)
883 .expect("unwrap failed");
884
885 let weight1 = curriculum.get_sample_weight(1);
887 let weight2 = curriculum.get_sample_weight(2);
888 let weight3 = curriculum.get_sample_weight(3);
889
890 assert!(weight2 > weight3); assert!(weight3 > weight1); }
893
894 #[test]
895 fn test_adversarial_config() {
896 let strategy = CurriculumStrategy::Linear {
897 start_difficulty: 0.5,
898 end_difficulty: 0.5,
899 num_steps: 10,
900 };
901
902 let importance_strategy = ImportanceWeightingStrategy::Uniform;
903 let mut curriculum =
904 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
905
906 let adversarial_config = AdversarialConfig {
907 epsilon: 0.1,
908 num_steps: 5,
909 step_size: 0.02,
910 attack_type: AdversarialAttack::FGSM,
911 adversarial_weight: 0.5,
912 };
913
914 curriculum.enable_adversarial_training(adversarial_config);
915
916 let inputs = Array1::from_vec(vec![1.0, 2.0, 3.0]);
917 let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3]);
918
919 let adversarial = curriculum
920 .generate_adversarial_examples(&inputs, &gradients)
921 .expect("unwrap failed");
922
923 assert_ne!(
925 adversarial.as_slice().expect("unwrap failed"),
926 inputs.as_slice().expect("unwrap failed")
927 );
928
929 for (orig, adv) in inputs.iter().zip(adversarial.iter()) {
931 assert!((adv - orig).abs() <= 0.1 + 1e-6); }
933 }
934
935 #[test]
936 fn test_adaptive_curriculum() {
937 let strategy1 = CurriculumStrategy::Linear {
938 start_difficulty: 0.1,
939 end_difficulty: 0.5,
940 num_steps: 100,
941 };
942
943 let strategy2 = CurriculumStrategy::Linear {
944 start_difficulty: 0.2,
945 end_difficulty: 0.8,
946 num_steps: 100,
947 };
948
949 let importance_strategy = ImportanceWeightingStrategy::Uniform;
950 let curriculum1 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
951 strategy1,
952 importance_strategy.clone(),
953 );
954 let curriculum2 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
955 strategy2,
956 importance_strategy,
957 );
958
959 let mut adaptive = AdaptiveCurriculum::new(vec![curriculum1, curriculum2], 0.1);
960
961 assert_eq!(adaptive.active_curriculum_index(), 0);
962
963 for _ in 0..150 {
965 adaptive.update(0.7).expect("unwrap failed");
966 }
967
968 let comparison = adaptive.get_curriculum_comparison();
970 assert_eq!(comparison.len(), 2);
971 }
972
973 #[test]
974 fn test_curriculum_state_export() {
975 let strategy = CurriculumStrategy::Linear {
976 start_difficulty: 0.1,
977 end_difficulty: 1.0,
978 num_steps: 10,
979 };
980
981 let importance_strategy = ImportanceWeightingStrategy::Uniform;
982 let mut curriculum =
983 CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
984
985 curriculum.update_curriculum(0.8).expect("unwrap failed");
986 let state = curriculum.export_state();
987
988 assert_eq!(state.step_count, 1);
989 assert_eq!(state.performance_history.len(), 1);
990 assert!(!state.has_adversarial);
991 }
992}