use crate::error::{ModelError, ModelResult};
use crate::training_loop::{ArrayDataProvider, DataProvider};
use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub enum CurriculumStrategy {
Competence {
initial: f32,
increment: f32,
},
SelfPaced {
lambda: f32,
},
Annealing {
start_difficulty: f32,
end: f32,
},
}
#[derive(Debug, Clone)]
pub struct CurriculumScheduler {
strategy: CurriculumStrategy,
current_competence: f32,
epoch: usize,
spl_lambda: f32,
spl_growth: f32,
}
impl CurriculumScheduler {
pub fn new(strategy: CurriculumStrategy) -> Self {
let (initial_competence, spl_lambda) = match &strategy {
CurriculumStrategy::Competence { initial, .. } => (*initial, 0.0),
CurriculumStrategy::SelfPaced { lambda } => {
((*lambda * 0.5).clamp(0.0, 1.0), *lambda)
}
CurriculumStrategy::Annealing {
start_difficulty, ..
} => (*start_difficulty, 0.0),
};
Self {
strategy,
current_competence: initial_competence.clamp(0.0, 1.0),
epoch: 0,
spl_lambda,
spl_growth: 1.2,
}
}
pub fn with_spl_growth(mut self, growth: f32) -> Self {
self.spl_growth = growth.max(1.0);
self
}
pub fn step(&mut self) -> f32 {
self.epoch += 1;
match &self.strategy {
CurriculumStrategy::Competence {
initial, increment, ..
} => {
self.current_competence =
(*initial + self.epoch as f32 * *increment).clamp(0.0, 1.0);
}
CurriculumStrategy::SelfPaced { .. } => {
self.spl_lambda *= self.spl_growth;
self.current_competence = (self.spl_lambda * 0.5).clamp(0.0, 1.0);
}
CurriculumStrategy::Annealing {
start_difficulty,
end,
} => {
let ramp_epochs = 100.0_f32;
let t = (self.epoch as f32 / ramp_epochs).clamp(0.0, 1.0);
self.current_competence =
(*start_difficulty + (*end - *start_difficulty) * t).clamp(0.0, 1.0);
}
}
self.current_competence
}
pub fn current_competence(&self) -> f32 {
self.current_competence
}
pub fn epoch(&self) -> usize {
self.epoch
}
pub fn should_include(&self, difficulty: f32) -> bool {
match &self.strategy {
CurriculumStrategy::Competence { .. } | CurriculumStrategy::Annealing { .. } => {
difficulty <= self.current_competence
}
CurriculumStrategy::SelfPaced { .. } => {
if self.spl_lambda <= 0.0 {
return false;
}
let w = (1.0 - difficulty / self.spl_lambda).max(0.0);
w > 0.5
}
}
}
pub fn filter_indices(&self, difficulties: &[f32]) -> Vec<usize> {
difficulties
.iter()
.enumerate()
.filter(|(_, &d)| self.should_include(d))
.map(|(i, _)| i)
.collect()
}
pub fn spl_weight(&self, difficulty: f32) -> f32 {
match &self.strategy {
CurriculumStrategy::SelfPaced { .. } => {
if self.spl_lambda <= 0.0 {
return 0.0;
}
(1.0 - difficulty / self.spl_lambda).clamp(0.0, 1.0)
}
_ => {
if self.should_include(difficulty) {
1.0
} else {
0.0
}
}
}
}
}
pub struct CurriculumDataProvider {
inner: ArrayDataProvider,
difficulties: Array1<f32>,
scheduler: CurriculumScheduler,
cached_active: Vec<usize>,
}
impl CurriculumDataProvider {
pub fn new(
data: ArrayDataProvider,
difficulties: Array1<f32>,
strategy: CurriculumStrategy,
) -> ModelResult<Self> {
if difficulties.len() != data.num_samples() {
return Err(ModelError::dimension_mismatch(
"CurriculumDataProvider::new",
data.num_samples(),
difficulties.len(),
));
}
let scheduler = CurriculumScheduler::new(strategy);
let cached_active = scheduler.filter_indices(difficulties.as_slice().unwrap_or(&[]));
Ok(Self {
inner: data,
difficulties,
scheduler,
cached_active,
})
}
pub fn advance_epoch(&mut self) {
self.scheduler.step();
self.recompute_active();
}
fn recompute_active(&mut self) {
let diff_slice = self.difficulties.as_slice().unwrap_or(&[]);
self.cached_active = self.scheduler.filter_indices(diff_slice);
}
pub fn active_indices(&self) -> Vec<usize> {
self.cached_active.clone()
}
pub fn active_count(&self) -> usize {
self.cached_active.len()
}
pub fn total_samples(&self) -> usize {
self.inner.num_samples()
}
pub fn active_fraction(&self) -> f32 {
if self.inner.num_samples() == 0 {
return 0.0;
}
self.cached_active.len() as f32 / self.inner.num_samples() as f32
}
pub fn scheduler(&self) -> &CurriculumScheduler {
&self.scheduler
}
pub fn difficulties(&self) -> &Array1<f32> {
&self.difficulties
}
pub fn active_weights(&self) -> Vec<(usize, f32)> {
self.cached_active
.iter()
.map(|&idx| {
let d = self.difficulties[idx];
(idx, self.scheduler.spl_weight(d))
})
.collect()
}
}
impl DataProvider for CurriculumDataProvider {
fn num_samples(&self) -> usize {
self.cached_active.len()
}
fn num_features(&self) -> usize {
self.inner.num_features()
}
fn get_batch(&self, indices: &[usize]) -> (Array2<f32>, Array1<f32>) {
let mapped: Vec<usize> = indices
.iter()
.map(|&i| {
if i < self.cached_active.len() {
self.cached_active[i]
} else {
self.cached_active.last().copied().unwrap_or(0)
}
})
.collect();
self.inner.get_batch(&mapped)
}
fn shuffle_indices(&self, rng_seed: u64) -> Vec<usize> {
let n = self.cached_active.len();
let mut indices: Vec<usize> = (0..n).collect();
let mut state = rng_seed.wrapping_add(1);
for i in (1..n).rev() {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let j = (state >> 33) as usize % (i + 1);
indices.swap(i, j);
}
indices
}
}
pub fn estimate_difficulty_from_loss(losses: &[f32]) -> ModelResult<Array1<f32>> {
if losses.is_empty() {
return Err(ModelError::invalid_config(
"Cannot estimate difficulty from empty loss vector",
));
}
let min_loss = losses.iter().copied().fold(f32::INFINITY, f32::min);
let max_loss = losses.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let range = max_loss - min_loss;
let difficulties = if range.abs() < f32::EPSILON {
Array1::from_elem(losses.len(), 0.5)
} else {
Array1::from_vec(
losses
.iter()
.map(|&l| ((l - min_loss) / range).clamp(0.0, 1.0))
.collect(),
)
};
Ok(difficulties)
}
pub fn estimate_difficulty_from_variance(features: &Array2<f32>) -> ModelResult<Array1<f32>> {
let n = features.nrows();
if n == 0 {
return Err(ModelError::invalid_config(
"Cannot estimate difficulty from empty feature matrix",
));
}
let ncols = features.ncols();
let mut variances = Vec::with_capacity(n);
for row in features.rows() {
let mean: f32 = row.iter().sum::<f32>() / ncols.max(1) as f32;
let var: f32 =
row.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / ncols.max(1) as f32;
variances.push(var);
}
estimate_difficulty_from_loss(&variances)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
fn make_provider(n: usize, nf: usize) -> ArrayDataProvider {
let features = Array2::<f32>::zeros((n, nf));
let targets = Array1::<f32>::zeros(n);
ArrayDataProvider::new(features, targets)
}
#[test]
fn test_curriculum_competence_pacing() {
let strategy = CurriculumStrategy::Competence {
initial: 0.2,
increment: 0.1,
};
let mut sched = CurriculumScheduler::new(strategy);
assert!((sched.current_competence() - 0.2).abs() < 1e-6);
let c1 = sched.step();
assert!((c1 - 0.3).abs() < 1e-6);
let c2 = sched.step();
assert!((c2 - 0.4).abs() < 1e-6);
let difficulties = [0.1, 0.35, 0.5, 0.9];
let indices = sched.filter_indices(&difficulties);
assert_eq!(indices, vec![0, 1]);
for _ in 0..20 {
sched.step();
}
assert!((sched.current_competence() - 1.0).abs() < 1e-6);
let all = sched.filter_indices(&difficulties);
assert_eq!(all.len(), 4);
}
#[test]
fn test_curriculum_self_paced() {
let strategy = CurriculumStrategy::SelfPaced { lambda: 0.4 };
let mut sched = CurriculumScheduler::new(strategy);
let difficulties = [0.1, 0.19, 0.3, 0.5, 0.8];
let easy_first = sched.filter_indices(&difficulties);
assert_eq!(easy_first, vec![0, 1]);
for _ in 0..5 {
sched.step();
}
let more = sched.filter_indices(&difficulties);
assert!(
more.len() >= 3,
"expected >= 3 included, got {}",
more.len()
);
for _ in 0..20 {
sched.step();
}
let all = sched.filter_indices(&difficulties);
assert_eq!(all.len(), difficulties.len());
}
#[test]
fn test_curriculum_annealing() {
let strategy = CurriculumStrategy::Annealing {
start_difficulty: 0.1,
end: 1.0,
};
let mut sched = CurriculumScheduler::new(strategy);
assert!((sched.current_competence() - 0.1).abs() < 1e-6);
for _ in 0..50 {
sched.step();
}
assert!(
(sched.current_competence() - 0.55).abs() < 0.02,
"expected ~0.55, got {}",
sched.current_competence()
);
for _ in 0..50 {
sched.step();
}
assert!(
(sched.current_competence() - 1.0).abs() < 1e-6,
"expected 1.0, got {}",
sched.current_competence()
);
let diffs = [0.0, 0.3, 0.7, 1.0];
let included = sched.filter_indices(&diffs);
assert_eq!(included.len(), 4);
}
#[test]
fn test_curriculum_provider_active_indices() {
let provider = make_provider(10, 2);
let difficulties = Array1::from_vec((0..10).map(|i| i as f32 * 0.1).collect());
let strategy = CurriculumStrategy::Competence {
initial: 0.0,
increment: 0.2,
};
let mut cp = CurriculumDataProvider::new(provider, difficulties, strategy)
.expect("construction should succeed");
let a0 = cp.active_indices();
assert_eq!(a0, vec![0]);
cp.advance_epoch();
let a1 = cp.active_indices();
assert_eq!(a1, vec![0, 1, 2]);
cp.advance_epoch();
let a2 = cp.active_indices();
assert_eq!(a2, vec![0, 1, 2, 3, 4]);
assert!(a2.len() > a1.len());
assert!(a1.len() > a0.len());
}
#[test]
fn test_curriculum_provider_implements_data_provider() {
let features = Array2::from_shape_vec(
(5, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, ],
)
.expect("shape ok");
let targets = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
let data = ArrayDataProvider::new(features, targets);
let difficulties = Array1::from_vec(vec![0.1, 0.15, 0.5, 0.8, 0.95]);
let strategy = CurriculumStrategy::Competence {
initial: 0.2,
increment: 0.1,
};
let cp = CurriculumDataProvider::new(data, difficulties, strategy)
.expect("construction should succeed");
assert_eq!(cp.num_samples(), 2);
assert_eq!(cp.num_features(), 2);
let (feat, tgt) = cp.get_batch(&[0, 1]);
assert_eq!(feat.shape(), &[2, 2]);
assert_eq!(tgt.len(), 2);
assert!((feat[[0, 0]] - 1.0).abs() < 1e-6);
assert!((feat[[1, 0]] - 3.0).abs() < 1e-6);
assert!((tgt[0] - 10.0).abs() < 1e-6);
assert!((tgt[1] - 20.0).abs() < 1e-6);
}
}