use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use std::collections::HashMap;
pub trait CurriculumStrategy {
fn select_samples(
&self,
epoch: usize,
total_epochs: usize,
difficulties: &ArrayView1<f64>,
) -> TrainResult<Vec<usize>>;
fn compute_difficulty(
&self,
data: &Array2<f64>,
labels: &Array2<f64>,
predictions: Option<&Array2<f64>>,
) -> TrainResult<Array1<f64>>;
}
#[derive(Debug, Clone)]
pub struct LinearCurriculum {
pub start_percentage: f64,
pub sort_by_difficulty: bool,
}
impl LinearCurriculum {
pub fn new(start_percentage: f64) -> TrainResult<Self> {
if !(0.0..=1.0).contains(&start_percentage) {
return Err(TrainError::InvalidParameter(
"start_percentage must be in [0, 1]".to_string(),
));
}
Ok(Self {
start_percentage,
sort_by_difficulty: true,
})
}
pub fn without_sorting(mut self) -> Self {
self.sort_by_difficulty = false;
self
}
}
impl Default for LinearCurriculum {
fn default() -> Self {
Self {
start_percentage: 0.2,
sort_by_difficulty: true,
}
}
}
impl CurriculumStrategy for LinearCurriculum {
fn select_samples(
&self,
epoch: usize,
total_epochs: usize,
difficulties: &ArrayView1<f64>,
) -> TrainResult<Vec<usize>> {
let n = difficulties.len();
if n == 0 {
return Ok(Vec::new());
}
let progress = if total_epochs > 1 {
epoch as f64 / (total_epochs - 1) as f64
} else {
1.0
};
let current_percentage = self.start_percentage + (1.0 - self.start_percentage) * progress;
let num_samples = ((n as f64 * current_percentage).ceil() as usize).min(n);
if !self.sort_by_difficulty {
return Ok((0..num_samples).collect());
}
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| {
difficulties[a]
.partial_cmp(&difficulties[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(indices.into_iter().take(num_samples).collect())
}
fn compute_difficulty(
&self,
_data: &Array2<f64>,
_labels: &Array2<f64>,
predictions: Option<&Array2<f64>>,
) -> TrainResult<Array1<f64>> {
if let Some(preds) = predictions {
let n = preds.nrows();
let mut difficulties = Array1::zeros(n);
for i in 0..n {
let pred = preds.row(i);
let mut entropy = 0.0;
for &p in pred.iter() {
if p > 1e-10 {
entropy -= p * p.ln();
}
}
difficulties[i] = entropy;
}
Ok(difficulties)
} else {
Ok(Array1::zeros(_labels.nrows()))
}
}
}
#[derive(Debug, Clone)]
pub struct ExponentialCurriculum {
pub start_percentage: f64,
pub growth_rate: f64,
}
impl ExponentialCurriculum {
pub fn new(start_percentage: f64, growth_rate: f64) -> TrainResult<Self> {
if !(0.0..=1.0).contains(&start_percentage) {
return Err(TrainError::InvalidParameter(
"start_percentage must be in [0, 1]".to_string(),
));
}
if growth_rate <= 0.0 {
return Err(TrainError::InvalidParameter(
"growth_rate must be positive".to_string(),
));
}
Ok(Self {
start_percentage,
growth_rate,
})
}
}
impl Default for ExponentialCurriculum {
fn default() -> Self {
Self {
start_percentage: 0.1,
growth_rate: 2.0,
}
}
}
impl CurriculumStrategy for ExponentialCurriculum {
fn select_samples(
&self,
epoch: usize,
total_epochs: usize,
difficulties: &ArrayView1<f64>,
) -> TrainResult<Vec<usize>> {
let n = difficulties.len();
if n == 0 {
return Ok(Vec::new());
}
let progress = if total_epochs > 1 {
epoch as f64 / (total_epochs - 1) as f64
} else {
1.0
};
let current_percentage =
(self.start_percentage * (self.growth_rate * progress).exp()).min(1.0);
let num_samples = ((n as f64 * current_percentage).ceil() as usize).min(n);
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| {
difficulties[a]
.partial_cmp(&difficulties[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(indices.into_iter().take(num_samples).collect())
}
fn compute_difficulty(
&self,
_data: &Array2<f64>,
_labels: &Array2<f64>,
predictions: Option<&Array2<f64>>,
) -> TrainResult<Array1<f64>> {
if let Some(preds) = predictions {
let n = preds.nrows();
let mut difficulties = Array1::zeros(n);
for i in 0..n {
let pred = preds.row(i);
let mut entropy = 0.0;
for &p in pred.iter() {
if p > 1e-10 {
entropy -= p * p.ln();
}
}
difficulties[i] = entropy;
}
Ok(difficulties)
} else {
Ok(Array1::zeros(_labels.nrows()))
}
}
}
#[derive(Debug, Clone)]
pub struct SelfPacedCurriculum {
pub lambda: f64,
pub threshold: f64,
}
impl SelfPacedCurriculum {
pub fn new(lambda: f64, threshold: f64) -> TrainResult<Self> {
if lambda <= 0.0 {
return Err(TrainError::InvalidParameter(
"lambda must be positive".to_string(),
));
}
Ok(Self { lambda, threshold })
}
}
impl Default for SelfPacedCurriculum {
fn default() -> Self {
Self {
lambda: 1.0,
threshold: 0.5,
}
}
}
impl CurriculumStrategy for SelfPacedCurriculum {
fn select_samples(
&self,
_epoch: usize,
_total_epochs: usize,
difficulties: &ArrayView1<f64>,
) -> TrainResult<Vec<usize>> {
let indices: Vec<usize> = difficulties
.iter()
.enumerate()
.filter(|(_, &d)| d < self.threshold)
.map(|(i, _)| i)
.collect();
Ok(indices)
}
fn compute_difficulty(
&self,
_data: &Array2<f64>,
labels: &Array2<f64>,
predictions: Option<&Array2<f64>>,
) -> TrainResult<Array1<f64>> {
if let Some(preds) = predictions {
let n = preds.nrows();
let mut difficulties = Array1::zeros(n);
for i in 0..n {
let pred = preds.row(i);
let label = labels.row(i);
let mut loss = 0.0;
for j in 0..pred.len() {
let p = pred[j].clamp(1e-10, 1.0 - 1e-10);
loss -= label[j] * p.ln();
}
difficulties[i] = loss * self.lambda;
}
Ok(difficulties)
} else {
Err(TrainError::InvalidParameter(
"SelfPacedCurriculum requires predictions for difficulty computation".to_string(),
))
}
}
}
#[derive(Debug, Clone)]
pub struct CompetenceCurriculum {
pub initial_competence: f64,
pub growth_rate: f64,
pub max_competence: f64,
}
impl CompetenceCurriculum {
pub fn new(initial_competence: f64, growth_rate: f64) -> TrainResult<Self> {
if !(0.0..=1.0).contains(&initial_competence) {
return Err(TrainError::InvalidParameter(
"initial_competence must be in [0, 1]".to_string(),
));
}
Ok(Self {
initial_competence,
growth_rate,
max_competence: 1.0,
})
}
}
impl Default for CompetenceCurriculum {
fn default() -> Self {
Self {
initial_competence: 0.3,
growth_rate: 0.05,
max_competence: 1.0,
}
}
}
impl CurriculumStrategy for CompetenceCurriculum {
fn select_samples(
&self,
epoch: usize,
_total_epochs: usize,
difficulties: &ArrayView1<f64>,
) -> TrainResult<Vec<usize>> {
let competence =
(self.initial_competence + self.growth_rate * epoch as f64).min(self.max_competence);
let indices: Vec<usize> = difficulties
.iter()
.enumerate()
.filter(|(_, &d)| d <= competence)
.map(|(i, _)| i)
.collect();
Ok(indices)
}
fn compute_difficulty(
&self,
_data: &Array2<f64>,
_labels: &Array2<f64>,
predictions: Option<&Array2<f64>>,
) -> TrainResult<Array1<f64>> {
if let Some(preds) = predictions {
let n = preds.nrows();
let mut difficulties = Array1::zeros(n);
for i in 0..n {
let pred = preds.row(i);
let mut entropy = 0.0;
for &p in pred.iter() {
if p > 1e-10 {
entropy -= p * p.ln();
}
}
difficulties[i] = entropy;
}
let max_difficulty = difficulties.iter().cloned().fold(0.0f64, f64::max);
if max_difficulty > 0.0 {
difficulties.mapv_inplace(|d| d / max_difficulty);
}
Ok(difficulties)
} else {
Ok(Array1::zeros(_labels.nrows()))
}
}
}
#[derive(Debug, Clone)]
pub struct TaskCurriculum {
task_schedule: Vec<(usize, usize)>,
}
impl TaskCurriculum {
pub fn new(schedule: Vec<(usize, usize)>) -> Self {
let mut sorted_schedule = schedule;
sorted_schedule.sort_by_key(|(epoch, _)| *epoch);
Self {
task_schedule: sorted_schedule,
}
}
pub fn get_active_tasks(&self, epoch: usize) -> Vec<usize> {
self.task_schedule
.iter()
.filter(|(start_epoch, _)| *start_epoch <= epoch)
.map(|(_, task_id)| *task_id)
.collect()
}
}
impl Default for TaskCurriculum {
fn default() -> Self {
Self {
task_schedule: vec![(0, 0)],
}
}
}
pub struct CurriculumManager {
strategy: Box<dyn CurriculumStrategyClone>,
difficulty_cache: HashMap<String, Array1<f64>>,
current_epoch: usize,
}
impl std::fmt::Debug for CurriculumManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CurriculumManager")
.field("current_epoch", &self.current_epoch)
.field("num_cached_difficulties", &self.difficulty_cache.len())
.finish()
}
}
trait CurriculumStrategyClone: CurriculumStrategy {
fn clone_box(&self) -> Box<dyn CurriculumStrategyClone>;
}
impl<T: CurriculumStrategy + Clone + 'static> CurriculumStrategyClone for T {
fn clone_box(&self) -> Box<dyn CurriculumStrategyClone> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn CurriculumStrategyClone> {
fn clone(&self) -> Self {
self.clone_box()
}
}
impl CurriculumStrategy for Box<dyn CurriculumStrategyClone> {
fn select_samples(
&self,
epoch: usize,
total_epochs: usize,
difficulties: &ArrayView1<f64>,
) -> TrainResult<Vec<usize>> {
(**self).select_samples(epoch, total_epochs, difficulties)
}
fn compute_difficulty(
&self,
data: &Array2<f64>,
labels: &Array2<f64>,
predictions: Option<&Array2<f64>>,
) -> TrainResult<Array1<f64>> {
(**self).compute_difficulty(data, labels, predictions)
}
}
impl CurriculumManager {
pub fn new<S: CurriculumStrategy + Clone + 'static>(strategy: S) -> Self {
Self {
strategy: Box::new(strategy),
difficulty_cache: HashMap::new(),
current_epoch: 0,
}
}
pub fn set_epoch(&mut self, epoch: usize) {
self.current_epoch = epoch;
}
pub fn compute_difficulty(
&mut self,
key: &str,
data: &Array2<f64>,
labels: &Array2<f64>,
predictions: Option<&Array2<f64>>,
) -> TrainResult<()> {
let difficulties = self
.strategy
.compute_difficulty(data, labels, predictions)?;
self.difficulty_cache.insert(key.to_string(), difficulties);
Ok(())
}
pub fn get_selected_samples(&self, key: &str, total_epochs: usize) -> TrainResult<Vec<usize>> {
let difficulties = self.difficulty_cache.get(key).ok_or_else(|| {
TrainError::InvalidParameter(format!("No difficulty scores cached for key: {}", key))
})?;
self.strategy
.select_samples(self.current_epoch, total_epochs, &difficulties.view())
}
pub fn clear_cache(&mut self) {
self.difficulty_cache.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_linear_curriculum() {
let curriculum = LinearCurriculum::new(0.2).expect("unwrap");
let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
let selected = curriculum
.select_samples(0, 10, &difficulties.view())
.expect("unwrap");
assert_eq!(selected.len(), 1);
let selected = curriculum
.select_samples(9, 10, &difficulties.view())
.expect("unwrap");
assert_eq!(selected.len(), 5);
}
#[test]
fn test_linear_curriculum_invalid() {
assert!(LinearCurriculum::new(-0.1).is_err());
assert!(LinearCurriculum::new(1.5).is_err());
}
#[test]
fn test_exponential_curriculum() {
let curriculum = ExponentialCurriculum::new(0.1, 2.0).expect("unwrap");
let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
let selected = curriculum
.select_samples(0, 10, &difficulties.view())
.expect("unwrap");
assert!(!selected.is_empty());
let selected = curriculum
.select_samples(9, 10, &difficulties.view())
.expect("unwrap");
assert!(selected.len() >= 4);
}
#[test]
fn test_self_paced_curriculum() {
let curriculum = SelfPacedCurriculum::new(1.0, 0.5).expect("unwrap");
let difficulties = array![0.1, 0.6, 0.3, 0.9, 0.2];
let selected = curriculum
.select_samples(0, 10, &difficulties.view())
.expect("unwrap");
assert_eq!(selected.len(), 3); }
#[test]
fn test_competence_curriculum() {
let curriculum = CompetenceCurriculum::new(0.3, 0.1).expect("unwrap");
let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
let selected = curriculum
.select_samples(0, 10, &difficulties.view())
.expect("unwrap");
assert_eq!(selected.len(), 3);
let selected = curriculum
.select_samples(5, 10, &difficulties.view())
.expect("unwrap");
assert!(selected.len() >= 3);
}
#[test]
fn test_task_curriculum() {
let curriculum = TaskCurriculum::new(vec![(0, 0), (5, 1), (10, 2)]);
let tasks = curriculum.get_active_tasks(0);
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0], 0);
let tasks = curriculum.get_active_tasks(7);
assert_eq!(tasks.len(), 2);
assert!(tasks.contains(&0));
assert!(tasks.contains(&1));
let tasks = curriculum.get_active_tasks(15);
assert_eq!(tasks.len(), 3);
}
#[test]
fn test_difficulty_computation() {
let curriculum = LinearCurriculum::default();
let data = array![[1.0, 2.0], [3.0, 4.0]];
let labels = array![[1.0, 0.0], [0.0, 1.0]];
let predictions = array![[0.8, 0.2], [0.3, 0.7]];
let difficulties = curriculum
.compute_difficulty(&data, &labels, Some(&predictions))
.expect("unwrap");
assert_eq!(difficulties.len(), 2);
assert!(difficulties.iter().all(|&d| d >= 0.0));
let difficulties = curriculum
.compute_difficulty(&data, &labels, None)
.expect("unwrap");
assert_eq!(difficulties.len(), 2);
assert!(difficulties.iter().all(|&d| d == 0.0));
}
#[test]
fn test_curriculum_manager() {
let mut manager = CurriculumManager::new(LinearCurriculum::default());
let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
let predictions = array![[0.8, 0.2], [0.3, 0.7], [0.6, 0.4]];
manager
.compute_difficulty("train", &data, &labels, Some(&predictions))
.expect("unwrap");
manager.set_epoch(0);
let selected = manager.get_selected_samples("train", 10).expect("unwrap");
assert!(!selected.is_empty());
manager.clear_cache();
}
#[test]
fn test_curriculum_manager_missing_key() {
let manager = CurriculumManager::new(LinearCurriculum::default());
let result = manager.get_selected_samples("nonexistent", 10);
assert!(result.is_err());
}
#[test]
fn test_linear_curriculum_without_sorting() {
let curriculum = LinearCurriculum::new(0.5)
.expect("unwrap")
.without_sorting();
let difficulties = array![0.9, 0.1, 0.5, 0.3, 0.7];
let selected = curriculum
.select_samples(0, 10, &difficulties.view())
.expect("unwrap");
assert_eq!(selected.len(), 3); }
#[test]
fn test_empty_difficulties() {
let curriculum = LinearCurriculum::default();
let difficulties = Array1::<f64>::zeros(0);
let selected = curriculum
.select_samples(0, 10, &difficulties.view())
.expect("unwrap");
assert_eq!(selected.len(), 0);
}
}