use crate::error::{AprenderError, Result};
pub trait DifficultyScorer: Send + Sync {
fn score(&self, features: &[f64], target: f64) -> f64;
}
pub trait CurriculumScheduler: Send + Sync {
fn stage(&self) -> f64;
fn advance(&mut self);
fn is_complete(&self) -> bool;
fn current_threshold(&self) -> f64;
fn reset(&mut self);
}
#[derive(Debug, Clone)]
pub struct ScoredSample {
pub features: Vec<f64>,
pub target: f64,
pub difficulty: f64,
}
impl ScoredSample {
#[must_use]
pub fn new(features: Vec<f64>, target: f64, difficulty: f64) -> Self {
Self {
features,
target,
difficulty,
}
}
}
#[derive(Debug, Clone)]
pub struct LinearCurriculum {
stage: f64,
step_size: f64,
difficulty_range: (f64, f64),
}
impl Default for LinearCurriculum {
fn default() -> Self {
Self::new(10)
}
}
impl LinearCurriculum {
#[must_use]
pub fn new(n_stages: usize) -> Self {
let step_size = if n_stages > 0 {
1.0 / n_stages as f64
} else {
1.0
};
Self {
stage: 0.0,
step_size,
difficulty_range: (0.0, 1.0),
}
}
#[must_use]
pub fn with_difficulty_range(mut self, min: f64, max: f64) -> Self {
self.difficulty_range = (min, max);
self
}
}
impl CurriculumScheduler for LinearCurriculum {
fn stage(&self) -> f64 {
self.stage
}
fn advance(&mut self) {
self.stage = (self.stage + self.step_size).min(1.0);
}
fn is_complete(&self) -> bool {
self.stage >= 1.0
}
fn current_threshold(&self) -> f64 {
let (min, max) = self.difficulty_range;
min + (max - min) * self.stage
}
fn reset(&mut self) {
self.stage = 0.0;
}
}
#[derive(Debug, Clone)]
pub struct ExponentialCurriculum {
stage: f64,
growth_rate: f64,
n_advances: u64,
}
impl Default for ExponentialCurriculum {
fn default() -> Self {
Self::new(0.3)
}
}
impl ExponentialCurriculum {
#[must_use]
pub fn new(growth_rate: f64) -> Self {
Self {
stage: 0.0,
growth_rate,
n_advances: 0,
}
}
}
impl CurriculumScheduler for ExponentialCurriculum {
fn stage(&self) -> f64 {
self.stage
}
fn advance(&mut self) {
self.n_advances += 1;
self.stage = 1.0 - (-self.growth_rate * self.n_advances as f64).exp();
}
fn is_complete(&self) -> bool {
self.stage >= 0.99
}
fn current_threshold(&self) -> f64 {
self.stage
}
fn reset(&mut self) {
self.stage = 0.0;
self.n_advances = 0;
}
}
#[derive(Debug, Clone)]
pub struct SelfPacedCurriculum {
samples: Vec<ScoredSample>,
threshold: f64,
initial_threshold: f64,
growth_rate: f64,
max_threshold: f64,
batch_idx: usize,
}
impl SelfPacedCurriculum {
#[must_use]
pub fn new(initial_threshold: f64, growth_rate: f64) -> Self {
Self {
samples: Vec::new(),
threshold: initial_threshold,
initial_threshold,
growth_rate,
max_threshold: f64::MAX,
batch_idx: 0,
}
}
#[must_use]
pub fn with_max_threshold(mut self, max: f64) -> Self {
self.max_threshold = max;
self
}
pub fn add_samples(&mut self, samples: Vec<ScoredSample>) {
self.samples.extend(samples);
self.samples.sort_by(|a, b| {
a.difficulty
.partial_cmp(&b.difficulty)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
pub fn update_difficulties<F>(&mut self, loss_fn: F)
where
F: Fn(&[f64], f64) -> f64,
{
for sample in &mut self.samples {
sample.difficulty = loss_fn(&sample.features, sample.target);
}
self.samples.sort_by(|a, b| {
a.difficulty
.partial_cmp(&b.difficulty)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
pub fn next_batch(&mut self, batch_size: usize) -> Vec<&ScoredSample> {
contract_pre_streaming_data_loader!();
let eligible: Vec<_> = self
.samples
.iter()
.filter(|s| s.difficulty <= self.threshold)
.collect();
let start = self.batch_idx;
let end = (start + batch_size).min(eligible.len());
if start >= eligible.len() {
self.batch_idx = 0;
let result: Vec<&ScoredSample> = vec![];
contract_post_streaming_data_loader!(&result);
return result;
}
self.batch_idx = end;
let result = eligible[start..end].to_vec();
contract_post_streaming_data_loader!(&result);
result
}
#[must_use]
pub fn eligible_samples(&self) -> Vec<&ScoredSample> {
self.samples
.iter()
.filter(|s| s.difficulty <= self.threshold)
.collect()
}
#[must_use]
pub fn n_eligible(&self) -> usize {
self.samples
.iter()
.filter(|s| s.difficulty <= self.threshold)
.count()
}
#[must_use]
pub fn n_total(&self) -> usize {
self.samples.len()
}
#[must_use]
pub fn threshold(&self) -> f64 {
self.threshold
}
}
impl CurriculumScheduler for SelfPacedCurriculum {
fn stage(&self) -> f64 {
if self.samples.is_empty() {
return 1.0;
}
let eligible = self.n_eligible();
eligible as f64 / self.samples.len() as f64
}
fn advance(&mut self) {
self.threshold = (self.threshold * self.growth_rate).min(self.max_threshold);
self.batch_idx = 0;
}
fn is_complete(&self) -> bool {
self.n_eligible() >= self.samples.len()
}
fn current_threshold(&self) -> f64 {
self.threshold
}
fn reset(&mut self) {
self.threshold = self.initial_threshold;
self.batch_idx = 0;
}
}
#[derive(Debug, Clone)]
pub struct LossDifficultyScorer {
target_mean: f64,
}
impl LossDifficultyScorer {
#[must_use]
pub fn new() -> Self {
Self { target_mean: 0.0 }
}
#[must_use]
pub fn with_mean(target_mean: f64) -> Self {
Self { target_mean }
}
pub fn fit(&mut self, targets: &[f64]) {
if !targets.is_empty() {
self.target_mean = targets.iter().sum::<f64>() / targets.len() as f64;
}
}
}
impl Default for LossDifficultyScorer {
fn default() -> Self {
Self::new()
}
}
impl DifficultyScorer for LossDifficultyScorer {
fn score(&self, _features: &[f64], target: f64) -> f64 {
(target - self.target_mean).abs()
}
}
#[derive(Debug, Clone, Default)]
pub struct FeatureNormScorer;
impl FeatureNormScorer {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl DifficultyScorer for FeatureNormScorer {
fn score(&self, features: &[f64], _target: f64) -> f64 {
features.iter().map(|x| x * x).sum::<f64>().sqrt()
}
}
#[derive(Debug)]
pub struct CurriculumTrainer<S: CurriculumScheduler + std::fmt::Debug> {
scheduler: S,
samples: Vec<ScoredSample>,
}
include!("curriculum_trainer.rs");
include!("curriculum_tests.rs");