use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use std::collections::{HashMap, VecDeque};
use std::fmt::Debug;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub enum CurriculumStrategy {
Linear {
start_difficulty: f64,
end_difficulty: f64,
num_steps: usize,
},
Exponential {
start_difficulty: f64,
end_difficulty: f64,
growth_rate: f64,
},
PerformanceBased {
advance_threshold: f64,
reduce_threshold: f64,
adjustment_step: f64,
window_size: usize,
},
Custom {
schedule: HashMap<usize, f64>,
default_difficulty: f64,
},
}
#[derive(Debug, Clone)]
pub enum ImportanceWeightingStrategy {
Uniform,
LossBased {
temperature: f64,
min_weight: f64,
},
GradientNormBased {
temperature: f64,
min_weight: f64,
},
UncertaintyBased {
temperature: f64,
min_weight: f64,
},
AgeBased {
decayfactor: f64,
},
}
#[derive(Debug, Clone)]
pub struct AdversarialConfig<A: Float> {
pub epsilon: A,
pub num_steps: usize,
pub step_size: A,
pub attack_type: AdversarialAttack,
pub adversarial_weight: A,
}
#[derive(Debug, Clone, Copy)]
pub enum AdversarialAttack {
FGSM,
PGD,
BIM,
MIM,
}
#[derive(Debug)]
pub struct CurriculumManager<A: Float, D: Dimension> {
strategy: CurriculumStrategy,
current_difficulty: f64,
step_count: usize,
performance_history: VecDeque<A>,
sample_difficulties: HashMap<usize, f64>,
importance_strategy: ImportanceWeightingStrategy,
sample_weights: HashMap<usize, A>,
adversarial_config: Option<AdversarialConfig<A>>,
_phantom: PhantomData<D>,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> CurriculumManager<A, D> {
pub fn new(
strategy: CurriculumStrategy,
importance_strategy: ImportanceWeightingStrategy,
) -> Self {
let initial_difficulty = match &strategy {
CurriculumStrategy::Linear {
start_difficulty, ..
} => *start_difficulty,
CurriculumStrategy::Exponential {
start_difficulty, ..
} => *start_difficulty,
CurriculumStrategy::PerformanceBased { .. } => 0.1, CurriculumStrategy::Custom {
default_difficulty, ..
} => *default_difficulty,
};
Self {
strategy,
current_difficulty: initial_difficulty,
step_count: 0,
performance_history: VecDeque::new(),
sample_difficulties: HashMap::new(),
importance_strategy,
sample_weights: HashMap::new(),
adversarial_config: None,
_phantom: PhantomData,
}
}
pub fn enable_adversarial_training(&mut self, config: AdversarialConfig<A>) {
self.adversarial_config = Some(config);
}
pub fn disable_adversarial_training(&mut self) {
self.adversarial_config = None;
}
pub fn update_curriculum(&mut self, performance: A) -> Result<()> {
self.performance_history.push_back(performance);
self.step_count += 1;
match &self.strategy {
CurriculumStrategy::Linear {
start_difficulty,
end_difficulty,
num_steps,
} => {
let progress = (self.step_count as f64) / (*num_steps as f64);
let progress = progress.min(1.0);
self.current_difficulty =
start_difficulty + progress * (end_difficulty - start_difficulty);
}
CurriculumStrategy::Exponential {
start_difficulty,
end_difficulty,
growth_rate,
} => {
let progress = 1.0 - (-growth_rate * self.step_count as f64).exp();
self.current_difficulty =
start_difficulty + progress * (end_difficulty - start_difficulty);
}
CurriculumStrategy::PerformanceBased {
advance_threshold,
reduce_threshold,
adjustment_step,
window_size,
} => {
if self.performance_history.len() >= *window_size {
while self.performance_history.len() > *window_size {
self.performance_history.pop_front();
}
let avg_performance = self
.performance_history
.iter()
.fold(A::zero(), |acc, &perf| acc + perf)
/ A::from(self.performance_history.len()).expect("unwrap failed");
let avg_perf_f64 = avg_performance.to_f64().unwrap_or(0.0);
if avg_perf_f64 > *advance_threshold {
self.current_difficulty =
(self.current_difficulty + adjustment_step).min(1.0);
} else if avg_perf_f64 < *reduce_threshold {
self.current_difficulty =
(self.current_difficulty - adjustment_step).max(0.0);
}
}
}
CurriculumStrategy::Custom {
schedule,
default_difficulty,
} => {
self.current_difficulty = schedule
.get(&self.step_count)
.copied()
.unwrap_or(*default_difficulty);
}
}
Ok(())
}
pub fn set_sample_difficulty(&mut self, sampleid: usize, difficulty: f64) {
self.sample_difficulties.insert(sampleid, difficulty);
}
pub fn should_include_sample(&self, sampleid: usize) -> bool {
if let Some(&sample_difficulty) = self.sample_difficulties.get(&sampleid) {
sample_difficulty <= self.current_difficulty
} else {
true }
}
pub fn get_current_difficulty(&self) -> f64 {
self.current_difficulty
}
pub fn compute_sample_weights(
&mut self,
sampleids: &[usize],
losses: &[A],
gradient_norms: Option<&[A]>,
uncertainties: Option<&[A]>,
) -> Result<()> {
if sampleids.len() != losses.len() {
return Err(OptimError::DimensionMismatch(
"Sample IDs and losses must have same length".to_string(),
));
}
match &self.importance_strategy {
ImportanceWeightingStrategy::Uniform => {
let uniform_weight = A::one();
for &sampleid in sampleids {
self.sample_weights.insert(sampleid, uniform_weight);
}
}
ImportanceWeightingStrategy::LossBased {
temperature,
min_weight,
} => {
self.compute_loss_based_weights(sampleids, losses, *temperature, *min_weight)?;
}
ImportanceWeightingStrategy::GradientNormBased {
temperature,
min_weight,
} => {
if let Some(grad_norms) = gradient_norms {
self.compute_gradient_norm_weights(
sampleids,
grad_norms,
*temperature,
*min_weight,
)?;
} else {
for &sampleid in sampleids {
self.sample_weights.insert(sampleid, A::one());
}
}
}
ImportanceWeightingStrategy::UncertaintyBased {
temperature,
min_weight,
} => {
if let Some(uncertainties_array) = uncertainties {
self.compute_uncertainty_weights(
sampleids,
uncertainties_array,
*temperature,
*min_weight,
)?;
} else {
for &sampleid in sampleids {
self.sample_weights.insert(sampleid, A::one());
}
}
}
ImportanceWeightingStrategy::AgeBased { decayfactor } => {
self.compute_age_based_weights(sampleids, *decayfactor)?;
}
}
Ok(())
}
fn compute_loss_based_weights(
&mut self,
sampleids: &[usize],
losses: &[A],
temperature: f64,
min_weight: f64,
) -> Result<()> {
let temp = A::from(temperature).expect("unwrap failed");
let min_w = A::from(min_weight).expect("unwrap failed");
let max_loss = losses.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
let mut unnormalized_weights = Vec::new();
for &loss in losses {
let normalized_loss = (loss - max_loss) / temp;
unnormalized_weights.push(A::exp(normalized_loss));
}
let sum_weights: A = unnormalized_weights
.iter()
.fold(A::zero(), |acc, &w| acc + w);
for (i, &sampleid) in sampleids.iter().enumerate() {
let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
self.sample_weights.insert(sampleid, weight);
}
Ok(())
}
fn compute_gradient_norm_weights(
&mut self,
sampleids: &[usize],
gradient_norms: &[A],
temperature: f64,
min_weight: f64,
) -> Result<()> {
let temp = A::from(temperature).expect("unwrap failed");
let min_w = A::from(min_weight).expect("unwrap failed");
let max_norm = gradient_norms
.iter()
.fold(A::neg_infinity(), |a, &b| A::max(a, b));
let mut unnormalized_weights = Vec::new();
for &norm in gradient_norms {
let normalized_norm = (norm - max_norm) / temp;
unnormalized_weights.push(A::exp(normalized_norm));
}
let sum_weights: A = unnormalized_weights
.iter()
.fold(A::zero(), |acc, &w| acc + w);
for (i, &sampleid) in sampleids.iter().enumerate() {
let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
self.sample_weights.insert(sampleid, weight);
}
Ok(())
}
fn compute_uncertainty_weights(
&mut self,
sampleids: &[usize],
uncertainties: &[A],
temperature: f64,
min_weight: f64,
) -> Result<()> {
let temp = A::from(temperature).expect("unwrap failed");
let min_w = A::from(min_weight).expect("unwrap failed");
let max_uncertainty = uncertainties
.iter()
.fold(A::neg_infinity(), |a, &b| A::max(a, b));
let mut unnormalized_weights = Vec::new();
for &uncertainty in uncertainties {
let normalized_uncertainty = (uncertainty - max_uncertainty) / temp;
unnormalized_weights.push(A::exp(normalized_uncertainty));
}
let sum_weights: A = unnormalized_weights
.iter()
.fold(A::zero(), |acc, &w| acc + w);
for (i, &sampleid) in sampleids.iter().enumerate() {
let weight = A::max(min_w, unnormalized_weights[i] / sum_weights);
self.sample_weights.insert(sampleid, weight);
}
Ok(())
}
fn compute_age_based_weights(&mut self, sampleids: &[usize], decayfactor: f64) -> Result<()> {
let decay = A::from(decayfactor).expect("unwrap failed");
for &sampleid in sampleids {
let age = A::from(self.step_count.saturating_sub(sampleid)).expect("unwrap failed");
let weight = A::exp(decay * age);
self.sample_weights.insert(sampleid, weight);
}
Ok(())
}
pub fn get_sample_weight(&self, sampleid: usize) -> A {
self.sample_weights
.get(&sampleid)
.copied()
.unwrap_or_else(|| A::one())
}
pub fn generate_adversarial_examples(
&self,
inputs: &Array<A, D>,
gradients: &Array<A, D>,
) -> Result<Array<A, D>> {
if let Some(config) = &self.adversarial_config {
match config.attack_type {
AdversarialAttack::FGSM => self.fgsm_attack(inputs, gradients, config),
AdversarialAttack::PGD => self.pgd_attack(inputs, gradients, config),
AdversarialAttack::BIM => self.bim_attack(inputs, gradients, config),
AdversarialAttack::MIM => self.mim_attack(inputs, gradients, config),
}
} else {
Ok(inputs.clone()) }
}
fn fgsm_attack(
&self,
inputs: &Array<A, D>,
gradients: &Array<A, D>,
config: &AdversarialConfig<A>,
) -> Result<Array<A, D>> {
let mut adversarial = inputs.clone();
let sign_gradients = gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
Zip::from(&mut adversarial)
.and(&sign_gradients)
.for_each(|x, &sign| {
*x = *x + config.epsilon * sign;
});
Ok(adversarial)
}
fn pgd_attack(
&self,
inputs: &Array<A, D>,
gradients: &Array<A, D>,
config: &AdversarialConfig<A>,
) -> Result<Array<A, D>> {
let mut adversarial = inputs.clone();
for _ in 0..config.num_steps {
let sign_gradients =
gradients.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
Zip::from(&mut adversarial)
.and(&sign_gradients)
.for_each(|x, &sign| {
*x = *x + config.step_size * sign;
});
Zip::from(&mut adversarial)
.and(inputs)
.for_each(|adv, &orig| {
let diff = *adv - orig;
let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
*adv = orig + clamped_diff;
});
}
Ok(adversarial)
}
fn bim_attack(
&self,
inputs: &Array<A, D>,
gradients: &Array<A, D>,
config: &AdversarialConfig<A>,
) -> Result<Array<A, D>> {
let mut modified_config = config.clone();
modified_config.step_size =
config.epsilon / A::from(config.num_steps).expect("unwrap failed");
self.pgd_attack(inputs, gradients, &modified_config)
}
fn mim_attack(
&self,
inputs: &Array<A, D>,
gradients: &Array<A, D>,
config: &AdversarialConfig<A>,
) -> Result<Array<A, D>> {
let mut adversarial = inputs.clone();
let mut momentum = Array::zeros(inputs.raw_dim());
let decayfactor = A::from(1.0).expect("unwrap failed");
for _ in 0..config.num_steps {
let grad_norm = gradients.mapv(|x| x * x).sum().sqrt();
let normalized_gradients = if grad_norm > A::zero() {
gradients.mapv(|x| x / grad_norm)
} else {
gradients.clone()
};
Zip::from(&mut momentum)
.and(&normalized_gradients)
.for_each(|m, &g| {
*m = decayfactor * *m + g;
});
let momentum_signs =
momentum.mapv(|x| if x >= A::zero() { A::one() } else { -A::one() });
Zip::from(&mut adversarial)
.and(&momentum_signs)
.for_each(|x, &sign| {
*x = *x + config.step_size * sign;
});
Zip::from(&mut adversarial)
.and(inputs)
.for_each(|adv, &orig| {
let diff = *adv - orig;
let clamped_diff = A::max(-config.epsilon, A::min(config.epsilon, diff));
*adv = orig + clamped_diff;
});
}
Ok(adversarial)
}
pub fn filter_samples(&self, sampleids: &[usize]) -> Vec<usize> {
sampleids
.iter()
.copied()
.filter(|&id| self.should_include_sample(id))
.collect()
}
pub fn get_performance_history(&self) -> &VecDeque<A> {
&self.performance_history
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn reset(&mut self) {
self.step_count = 0;
self.performance_history.clear();
self.sample_weights.clear();
self.current_difficulty = match &self.strategy {
CurriculumStrategy::Linear {
start_difficulty, ..
} => *start_difficulty,
CurriculumStrategy::Exponential {
start_difficulty, ..
} => *start_difficulty,
CurriculumStrategy::PerformanceBased { .. } => 0.1,
CurriculumStrategy::Custom {
default_difficulty, ..
} => *default_difficulty,
};
}
pub fn export_state(&self) -> CurriculumState<A> {
CurriculumState {
current_difficulty: self.current_difficulty,
step_count: self.step_count,
performance_history: self.performance_history.clone(),
sample_weights: self.sample_weights.clone(),
has_adversarial: self.adversarial_config.is_some(),
}
}
}
#[derive(Debug, Clone)]
pub struct CurriculumState<A: Float> {
pub current_difficulty: f64,
pub step_count: usize,
pub performance_history: VecDeque<A>,
pub sample_weights: HashMap<usize, A>,
pub has_adversarial: bool,
}
#[derive(Debug)]
pub struct AdaptiveCurriculum<A: Float, D: Dimension> {
curricula: Vec<CurriculumManager<A, D>>,
active_curriculum: usize,
curriculum_performance: Vec<VecDeque<A>>,
switchthreshold: A,
min_steps_before_switch: usize,
steps_since_switch: usize,
_phantom: PhantomData<D>,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> AdaptiveCurriculum<A, D> {
pub fn new(curricula: Vec<CurriculumManager<A, D>>, switchthreshold: A) -> Self {
let num_curricula = curricula.len();
Self {
curricula,
active_curriculum: 0,
curriculum_performance: vec![VecDeque::new(); num_curricula],
switchthreshold,
min_steps_before_switch: 100,
steps_since_switch: 0,
_phantom: PhantomData,
}
}
pub fn update(&mut self, performance: A) -> Result<()> {
self.curricula[self.active_curriculum].update_curriculum(performance)?;
self.curriculum_performance[self.active_curriculum].push_back(performance);
self.steps_since_switch += 1;
if self.steps_since_switch >= self.min_steps_before_switch {
self.consider_curriculum_switch()?;
}
Ok(())
}
fn consider_curriculum_switch(&mut self) -> Result<()> {
let current_performance = self.get_average_performance(self.active_curriculum);
let mut best_curriculum = self.active_curriculum;
let mut best_performance = current_performance;
for (i, _) in self.curricula.iter().enumerate() {
if i != self.active_curriculum {
let perf = self.get_average_performance(i);
if perf > best_performance + self.switchthreshold {
best_performance = perf;
best_curriculum = i;
}
}
}
if best_curriculum != self.active_curriculum {
self.active_curriculum = best_curriculum;
self.steps_since_switch = 0;
}
Ok(())
}
fn get_average_performance(&self, curriculumidx: usize) -> A {
let perf_history = &self.curriculum_performance[curriculumidx];
if perf_history.is_empty() {
A::zero()
} else {
let sum = perf_history.iter().fold(A::zero(), |acc, &perf| acc + perf);
sum / A::from(perf_history.len()).expect("unwrap failed")
}
}
pub fn active_curriculum(&self) -> &CurriculumManager<A, D> {
&self.curricula[self.active_curriculum]
}
pub fn active_curriculum_mut(&mut self) -> &mut CurriculumManager<A, D> {
&mut self.curricula[self.active_curriculum]
}
pub fn active_curriculum_index(&self) -> usize {
self.active_curriculum
}
pub fn get_curriculum_comparison(&self) -> Vec<(usize, A)> {
(0..self.curricula.len())
.map(|i| (i, self.get_average_performance(i)))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_linear_curriculum() {
let strategy = CurriculumStrategy::Linear {
start_difficulty: 0.1,
end_difficulty: 1.0,
num_steps: 10,
};
let importance_strategy = ImportanceWeightingStrategy::Uniform;
let mut curriculum =
CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
assert_relative_eq!(curriculum.get_current_difficulty(), 0.1, epsilon = 1e-6);
for _ in 0..5 {
curriculum.update_curriculum(0.8).expect("unwrap failed");
}
assert!(curriculum.get_current_difficulty() > 0.1);
assert!(curriculum.get_current_difficulty() <= 1.0);
}
#[test]
fn test_performance_based_curriculum() {
let strategy = CurriculumStrategy::PerformanceBased {
advance_threshold: 0.8,
reduce_threshold: 0.4,
adjustment_step: 0.1,
window_size: 3,
};
let importance_strategy = ImportanceWeightingStrategy::Uniform;
let mut curriculum =
CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
let initial_difficulty = curriculum.get_current_difficulty();
for _ in 0..5 {
curriculum.update_curriculum(0.9).expect("unwrap failed");
}
assert!(curriculum.get_current_difficulty() > initial_difficulty);
}
#[test]
fn test_sample_filtering() {
let strategy = CurriculumStrategy::Linear {
start_difficulty: 0.5,
end_difficulty: 0.5,
num_steps: 10,
};
let importance_strategy = ImportanceWeightingStrategy::Uniform;
let mut curriculum =
CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
curriculum.set_sample_difficulty(1, 0.3); curriculum.set_sample_difficulty(2, 0.7); curriculum.set_sample_difficulty(3, 0.5);
let sampleids = vec![1, 2, 3, 4]; let filtered = curriculum.filter_samples(&sampleids);
assert_eq!(filtered.len(), 3);
assert!(filtered.contains(&1));
assert!(filtered.contains(&3));
assert!(filtered.contains(&4));
assert!(!filtered.contains(&2));
}
#[test]
fn test_loss_based_importance_weighting() {
let strategy = CurriculumStrategy::Linear {
start_difficulty: 0.5,
end_difficulty: 0.5,
num_steps: 10,
};
let importance_strategy = ImportanceWeightingStrategy::LossBased {
temperature: 1.0,
min_weight: 0.1,
};
let mut curriculum =
CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
let sampleids = vec![1, 2, 3];
let losses = vec![0.1, 1.0, 0.5];
curriculum
.compute_sample_weights(&sampleids, &losses, None, None)
.expect("unwrap failed");
let weight1 = curriculum.get_sample_weight(1);
let weight2 = curriculum.get_sample_weight(2);
let weight3 = curriculum.get_sample_weight(3);
assert!(weight2 > weight3); assert!(weight3 > weight1); }
#[test]
fn test_adversarial_config() {
let strategy = CurriculumStrategy::Linear {
start_difficulty: 0.5,
end_difficulty: 0.5,
num_steps: 10,
};
let importance_strategy = ImportanceWeightingStrategy::Uniform;
let mut curriculum =
CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
let adversarial_config = AdversarialConfig {
epsilon: 0.1,
num_steps: 5,
step_size: 0.02,
attack_type: AdversarialAttack::FGSM,
adversarial_weight: 0.5,
};
curriculum.enable_adversarial_training(adversarial_config);
let inputs = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3]);
let adversarial = curriculum
.generate_adversarial_examples(&inputs, &gradients)
.expect("unwrap failed");
assert_ne!(
adversarial.as_slice().expect("unwrap failed"),
inputs.as_slice().expect("unwrap failed")
);
for (orig, adv) in inputs.iter().zip(adversarial.iter()) {
assert!((adv - orig).abs() <= 0.1 + 1e-6); }
}
#[test]
fn test_adaptive_curriculum() {
let strategy1 = CurriculumStrategy::Linear {
start_difficulty: 0.1,
end_difficulty: 0.5,
num_steps: 100,
};
let strategy2 = CurriculumStrategy::Linear {
start_difficulty: 0.2,
end_difficulty: 0.8,
num_steps: 100,
};
let importance_strategy = ImportanceWeightingStrategy::Uniform;
let curriculum1 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
strategy1,
importance_strategy.clone(),
);
let curriculum2 = CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(
strategy2,
importance_strategy,
);
let mut adaptive = AdaptiveCurriculum::new(vec![curriculum1, curriculum2], 0.1);
assert_eq!(adaptive.active_curriculum_index(), 0);
for _ in 0..150 {
adaptive.update(0.7).expect("unwrap failed");
}
let comparison = adaptive.get_curriculum_comparison();
assert_eq!(comparison.len(), 2);
}
#[test]
fn test_curriculum_state_export() {
let strategy = CurriculumStrategy::Linear {
start_difficulty: 0.1,
end_difficulty: 1.0,
num_steps: 10,
};
let importance_strategy = ImportanceWeightingStrategy::Uniform;
let mut curriculum =
CurriculumManager::<f64, scirs2_core::ndarray::Ix1>::new(strategy, importance_strategy);
curriculum.update_curriculum(0.8).expect("unwrap failed");
let state = curriculum.export_state();
assert_eq!(state.step_count, 1);
assert_eq!(state.performance_history.len(), 1);
assert!(!state.has_adversarial);
}
}