#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use super::core::{rng_utils, Sampler, SamplerIterator};
#[derive(Clone, Debug, PartialEq)]
pub enum CurriculumStrategy {
Linear,
Exponential { base: f64 },
Step { thresholds: Vec<usize> },
AntiCurriculum,
SelfPaced { lambda: f64 },
}
impl Default for CurriculumStrategy {
fn default() -> Self {
CurriculumStrategy::Linear
}
}
#[derive(Clone)]
pub struct CurriculumSampler<F> {
difficulties: Vec<f64>,
difficulty_fn: F,
current_epoch: usize,
total_epochs: usize,
curriculum_strategy: CurriculumStrategy,
generator: Option<u64>,
}
impl<F> CurriculumSampler<F>
where
F: Fn(usize) -> f64 + Send + Clone,
{
pub fn new(
dataset_size: usize,
difficulty_fn: F,
total_epochs: usize,
strategy: CurriculumStrategy,
) -> Self {
let difficulties: Vec<f64> = (0..dataset_size).map(&difficulty_fn).collect();
Self {
difficulties,
difficulty_fn,
current_epoch: 0,
total_epochs,
curriculum_strategy: strategy,
generator: None,
}
}
pub fn from_difficulties(
difficulties: Vec<f64>,
total_epochs: usize,
strategy: CurriculumStrategy,
) -> Self
where
F: Default,
{
Self {
difficulty_fn: F::default(),
difficulties,
current_epoch: 0,
total_epochs,
curriculum_strategy: strategy,
generator: None,
}
}
pub fn set_epoch(&mut self, epoch: usize) {
self.current_epoch = epoch;
}
pub fn with_generator(mut self, seed: u64) -> Self {
self.generator = Some(seed);
self
}
pub fn current_epoch(&self) -> usize {
self.current_epoch
}
pub fn total_epochs(&self) -> usize {
self.total_epochs
}
pub fn strategy(&self) -> &CurriculumStrategy {
&self.curriculum_strategy
}
pub fn difficulties(&self) -> &[f64] {
&self.difficulties
}
pub fn generator(&self) -> Option<u64> {
self.generator
}
pub fn progress(&self) -> f64 {
if self.total_epochs <= 1 {
1.0
} else {
(self.current_epoch as f64 / (self.total_epochs - 1) as f64).min(1.0)
}
}
pub fn get_difficulty_threshold(&self) -> f64 {
let progress = self.progress();
match &self.curriculum_strategy {
CurriculumStrategy::Linear => progress,
CurriculumStrategy::Exponential { base } => {
if *base <= 1.0 {
progress } else {
(base.powf(progress) - 1.0) / (base - 1.0)
}
}
CurriculumStrategy::Step { thresholds } => {
if thresholds.is_empty() {
1.0
} else {
let mut threshold = 0.0;
for &epoch_threshold in thresholds {
if self.current_epoch >= epoch_threshold {
threshold += 1.0 / thresholds.len() as f64;
}
}
threshold.min(1.0)
}
}
CurriculumStrategy::AntiCurriculum => {
1.0 }
CurriculumStrategy::SelfPaced { lambda } => {
(progress * lambda).min(1.0)
}
}
}
pub fn get_curriculum_indices(&self) -> Vec<usize> {
if self.difficulties.is_empty() {
return Vec::new();
}
let threshold = self.get_difficulty_threshold();
let max_difficulty = self
.difficulties
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let min_difficulty = self
.difficulties
.iter()
.fold(f64::INFINITY, |a, &b| a.min(b));
let range = max_difficulty - min_difficulty;
let normalized_threshold = if range > 0.0 {
threshold
} else {
1.0 };
self.difficulties
.iter()
.enumerate()
.filter_map(|(idx, &difficulty)| {
let normalized_difficulty = if range > 0.0 {
(difficulty - min_difficulty) / range
} else {
0.0
};
let include_sample = match &self.curriculum_strategy {
CurriculumStrategy::AntiCurriculum => {
normalized_difficulty <= normalized_threshold
}
_ => {
normalized_difficulty <= normalized_threshold
}
};
if include_sample {
Some(idx)
} else {
None
}
})
.collect()
}
pub fn update_difficulties(&mut self) {
let dataset_size = self.difficulties.len();
self.difficulties = (0..dataset_size).map(&self.difficulty_fn).collect();
}
pub fn set_strategy(&mut self, strategy: CurriculumStrategy) {
self.curriculum_strategy = strategy;
}
pub fn reset(&mut self) {
self.current_epoch = 0;
}
pub fn is_complete(&self) -> bool {
self.current_epoch >= self.total_epochs || self.get_difficulty_threshold() >= 1.0
}
pub fn curriculum_stats(&self) -> CurriculumStats {
let threshold = self.get_difficulty_threshold();
let indices = self.get_curriculum_indices();
let total_samples = self.difficulties.len();
CurriculumStats {
current_epoch: self.current_epoch,
total_epochs: self.total_epochs,
progress: self.progress(),
difficulty_threshold: threshold,
included_samples: indices.len(),
total_samples,
inclusion_ratio: if total_samples > 0 {
indices.len() as f64 / total_samples as f64
} else {
0.0
},
}
}
}
impl<F: Send + Clone> Sampler for CurriculumSampler<F>
where
F: Fn(usize) -> f64 + Send + Clone,
{
type Iter = SamplerIterator;
fn iter(&self) -> Self::Iter {
let mut indices = self.get_curriculum_indices();
rng_utils::shuffle_indices(&mut indices, self.generator);
SamplerIterator::new(indices)
}
fn len(&self) -> usize {
self.get_curriculum_indices().len()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CurriculumStats {
pub current_epoch: usize,
pub total_epochs: usize,
pub progress: f64,
pub difficulty_threshold: f64,
pub included_samples: usize,
pub total_samples: usize,
pub inclusion_ratio: f64,
}
pub fn linear_curriculum<F>(
dataset_size: usize,
difficulty_fn: F,
total_epochs: usize,
seed: Option<u64>,
) -> CurriculumSampler<F>
where
F: Fn(usize) -> f64 + Send + Clone,
{
let mut sampler = CurriculumSampler::new(
dataset_size,
difficulty_fn,
total_epochs,
CurriculumStrategy::Linear,
);
if let Some(s) = seed {
sampler = sampler.with_generator(s);
}
sampler
}
pub fn exponential_curriculum<F>(
dataset_size: usize,
difficulty_fn: F,
total_epochs: usize,
base: f64,
seed: Option<u64>,
) -> CurriculumSampler<F>
where
F: Fn(usize) -> f64 + Send + Clone,
{
let mut sampler = CurriculumSampler::new(
dataset_size,
difficulty_fn,
total_epochs,
CurriculumStrategy::Exponential { base },
);
if let Some(s) = seed {
sampler = sampler.with_generator(s);
}
sampler
}
pub fn step_curriculum<F>(
dataset_size: usize,
difficulty_fn: F,
total_epochs: usize,
thresholds: Vec<usize>,
seed: Option<u64>,
) -> CurriculumSampler<F>
where
F: Fn(usize) -> f64 + Send + Clone,
{
let mut sampler = CurriculumSampler::new(
dataset_size,
difficulty_fn,
total_epochs,
CurriculumStrategy::Step { thresholds },
);
if let Some(s) = seed {
sampler = sampler.with_generator(s);
}
sampler
}
pub fn anti_curriculum<F>(
dataset_size: usize,
difficulty_fn: F,
total_epochs: usize,
seed: Option<u64>,
) -> CurriculumSampler<F>
where
F: Fn(usize) -> f64 + Send + Clone,
{
let mut sampler = CurriculumSampler::new(
dataset_size,
difficulty_fn,
total_epochs,
CurriculumStrategy::AntiCurriculum,
);
if let Some(s) = seed {
sampler = sampler.with_generator(s);
}
sampler
}
#[cfg(test)]
mod tests {
use super::*;
fn linear_difficulty(idx: usize) -> f64 {
idx as f64 / 100.0
}
fn center_distance_difficulty(idx: usize) -> f64 {
(idx as f64 - 50.0).abs() / 50.0
}
#[test]
fn test_curriculum_sampler_basic() {
let mut sampler =
CurriculumSampler::new(100, linear_difficulty, 10, CurriculumStrategy::Linear)
.with_generator(42);
assert_eq!(sampler.total_epochs(), 10);
assert_eq!(sampler.current_epoch(), 0);
assert_eq!(sampler.generator(), Some(42));
assert_eq!(sampler.difficulties().len(), 100);
assert!(!sampler.is_complete());
sampler.set_epoch(0);
let early_indices = sampler.get_curriculum_indices();
assert!(!early_indices.is_empty());
assert!(early_indices.len() < 100);
sampler.set_epoch(9);
let late_indices = sampler.get_curriculum_indices();
assert_eq!(late_indices.len(), 100);
assert!(sampler.is_complete());
}
#[test]
fn test_curriculum_strategies() {
let dataset_size = 100;
let total_epochs = 10;
let mut linear_sampler = CurriculumSampler::new(
dataset_size,
linear_difficulty,
total_epochs,
CurriculumStrategy::Linear,
);
linear_sampler.set_epoch(0);
let linear_early = linear_sampler.get_curriculum_indices().len();
linear_sampler.set_epoch(5);
let linear_mid = linear_sampler.get_curriculum_indices().len();
linear_sampler.set_epoch(9);
let linear_late = linear_sampler.get_curriculum_indices().len();
assert!(linear_early < linear_mid);
assert!(linear_mid < linear_late);
assert_eq!(linear_late, dataset_size);
let mut exp_sampler = CurriculumSampler::new(
dataset_size,
linear_difficulty,
total_epochs,
CurriculumStrategy::Exponential { base: 2.0 },
);
exp_sampler.set_epoch(0);
let exp_early = exp_sampler.get_curriculum_indices().len();
exp_sampler.set_epoch(5);
let _exp_mid = exp_sampler.get_curriculum_indices().len();
assert!(exp_early <= linear_early);
let mut step_sampler = CurriculumSampler::new(
dataset_size,
linear_difficulty,
total_epochs,
CurriculumStrategy::Step {
thresholds: vec![3, 6, 9],
},
);
step_sampler.set_epoch(2);
let step_before = step_sampler.get_curriculum_indices().len();
step_sampler.set_epoch(3);
let step_after = step_sampler.get_curriculum_indices().len();
assert!(step_after > step_before);
let mut anti_sampler = CurriculumSampler::new(
dataset_size,
linear_difficulty,
total_epochs,
CurriculumStrategy::AntiCurriculum,
);
anti_sampler.set_epoch(0);
let anti_early = anti_sampler.get_curriculum_indices().len();
anti_sampler.set_epoch(9);
let anti_late = anti_sampler.get_curriculum_indices().len();
assert!(anti_early > linear_early);
assert_eq!(anti_late, dataset_size);
}
#[test]
fn test_difficulty_threshold_calculation() {
let sampler =
CurriculumSampler::new(100, linear_difficulty, 10, CurriculumStrategy::Linear);
assert_eq!(sampler.progress(), 0.0);
let mut sampler = sampler;
sampler.set_epoch(5);
assert!((sampler.progress() - 5.0 / 9.0).abs() < f64::EPSILON);
sampler.set_epoch(9);
assert_eq!(sampler.progress(), 1.0);
assert_eq!(sampler.get_difficulty_threshold(), 1.0);
sampler.set_strategy(CurriculumStrategy::Exponential { base: 2.0 });
sampler.set_epoch(0);
assert_eq!(sampler.get_difficulty_threshold(), 0.0);
sampler.set_strategy(CurriculumStrategy::AntiCurriculum);
sampler.set_epoch(0);
assert_eq!(sampler.get_difficulty_threshold(), 1.0);
}
#[test]
fn test_curriculum_from_difficulties() {
let difficulties = vec![0.1, 0.3, 0.5, 0.7, 0.9];
let difficulty_fn = |idx: usize| difficulties.get(idx).copied().unwrap_or(0.0);
let sampler = CurriculumSampler::new(
difficulties.len(),
difficulty_fn,
5,
CurriculumStrategy::Linear,
);
assert_eq!(sampler.difficulties(), &difficulties);
assert_eq!(sampler.total_epochs(), 5);
}
#[test]
fn test_curriculum_indices_selection() {
let difficulties = vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
let difficulty_fn = |idx: usize| difficulties.get(idx).copied().unwrap_or(0.0);
let mut sampler = CurriculumSampler::new(
difficulties.len(),
difficulty_fn,
6,
CurriculumStrategy::Linear,
);
sampler.set_epoch(0);
let indices = sampler.get_curriculum_indices();
assert!(indices.contains(&0)); assert!(!indices.contains(&5));
sampler.set_epoch(5);
let indices = sampler.get_curriculum_indices();
assert_eq!(indices.len(), 6);
for i in 0..6 {
assert!(indices.contains(&i));
}
}
#[test]
fn test_curriculum_stats() {
let mut sampler =
CurriculumSampler::new(100, linear_difficulty, 10, CurriculumStrategy::Linear);
sampler.set_epoch(5);
let stats = sampler.curriculum_stats();
assert_eq!(stats.current_epoch, 5);
assert_eq!(stats.total_epochs, 10);
assert_eq!(stats.total_samples, 100);
assert!(stats.progress > 0.0 && stats.progress < 1.0);
assert!(stats.difficulty_threshold > 0.0 && stats.difficulty_threshold < 1.0);
assert!(stats.included_samples > 0 && stats.included_samples < 100);
assert!(stats.inclusion_ratio > 0.0 && stats.inclusion_ratio < 1.0);
}
#[test]
fn test_curriculum_sampler_iter() {
let mut sampler =
CurriculumSampler::new(20, linear_difficulty, 5, CurriculumStrategy::Linear)
.with_generator(42);
sampler.set_epoch(0);
let indices1: Vec<usize> = sampler.iter().collect();
let indices2: Vec<usize> = sampler.iter().collect();
assert_eq!(indices1.len(), sampler.len());
assert_eq!(indices2.len(), sampler.len());
assert_eq!(indices1, indices2);
sampler.set_epoch(4);
let late_indices: Vec<usize> = sampler.iter().collect();
assert!(late_indices.len() >= indices1.len());
}
#[test]
fn test_convenience_functions() {
let linear = linear_curriculum(50, linear_difficulty, 10, Some(42));
assert_eq!(linear.total_epochs(), 10);
assert_eq!(linear.generator(), Some(42));
assert!(matches!(linear.strategy(), CurriculumStrategy::Linear));
let exponential = exponential_curriculum(50, linear_difficulty, 10, 2.0, Some(42));
assert!(
matches!(exponential.strategy(), CurriculumStrategy::Exponential { base } if *base == 2.0)
);
let step = step_curriculum(50, linear_difficulty, 10, vec![2, 5, 8], Some(42));
assert!(
matches!(step.strategy(), CurriculumStrategy::Step { thresholds } if thresholds == &vec![2, 5, 8])
);
let anti = anti_curriculum(50, linear_difficulty, 10, Some(42));
assert!(matches!(
anti.strategy(),
CurriculumStrategy::AntiCurriculum
));
}
#[test]
fn test_curriculum_methods() {
let mut sampler =
CurriculumSampler::new(100, linear_difficulty, 10, CurriculumStrategy::Linear);
sampler.set_epoch(5);
assert_eq!(sampler.current_epoch(), 5);
sampler.reset();
assert_eq!(sampler.current_epoch(), 0);
sampler.set_strategy(CurriculumStrategy::Exponential { base: 3.0 });
assert!(
matches!(sampler.strategy(), CurriculumStrategy::Exponential { base } if *base == 3.0)
);
let original_difficulties = sampler.difficulties().to_vec();
sampler.update_difficulties();
assert_eq!(sampler.difficulties(), &original_difficulties);
}
#[test]
fn test_edge_cases() {
let empty_sampler =
CurriculumSampler::new(0, linear_difficulty, 5, CurriculumStrategy::Linear);
assert_eq!(empty_sampler.len(), 0);
assert!(empty_sampler.get_curriculum_indices().is_empty());
let mut single_epoch =
CurriculumSampler::new(10, linear_difficulty, 1, CurriculumStrategy::Linear);
single_epoch.set_epoch(0);
assert_eq!(single_epoch.progress(), 1.0);
assert!(single_epoch.is_complete());
let same_difficulties = vec![0.5; 10];
let same_difficulty_fn = |idx: usize| same_difficulties.get(idx).copied().unwrap_or(0.0);
let mut same_sampler = CurriculumSampler::new(
same_difficulties.len(),
same_difficulty_fn,
5,
CurriculumStrategy::Linear,
);
same_sampler.set_epoch(0);
assert_eq!(same_sampler.get_curriculum_indices().len(), 10);
let mut invalid_exp = CurriculumSampler::new(
10,
linear_difficulty,
5,
CurriculumStrategy::Exponential { base: 0.5 }, );
invalid_exp.set_epoch(2);
assert!(invalid_exp.get_difficulty_threshold() >= 0.0);
}
#[test]
fn test_curriculum_strategy_equality() {
assert_eq!(CurriculumStrategy::Linear, CurriculumStrategy::Linear);
assert_eq!(
CurriculumStrategy::Exponential { base: 2.0 },
CurriculumStrategy::Exponential { base: 2.0 }
);
assert_ne!(
CurriculumStrategy::Linear,
CurriculumStrategy::AntiCurriculum
);
}
#[test]
fn test_curriculum_strategy_default() {
assert_eq!(CurriculumStrategy::default(), CurriculumStrategy::Linear);
}
#[test]
fn test_center_distance_difficulty() {
let mut sampler = CurriculumSampler::new(
101, center_distance_difficulty,
10,
CurriculumStrategy::Linear,
)
.with_generator(42);
sampler.set_epoch(0);
let early_indices = sampler.get_curriculum_indices();
assert!(early_indices.contains(&50));
sampler.set_epoch(9);
let late_indices = sampler.get_curriculum_indices();
assert!(late_indices.len() > early_indices.len());
assert!(late_indices.contains(&0) || late_indices.contains(&100)); }
}