use super::{Corpus, DifficultyLevel, Sample};
#[derive(Debug, Clone)]
pub struct CurriculumConfig {
pub advance_threshold: f64,
pub samples_per_level: usize,
pub min_samples_before_advance: usize,
pub auto_advance: bool,
}
impl Default for CurriculumConfig {
fn default() -> Self {
Self {
advance_threshold: 0.85,
samples_per_level: 100,
min_samples_before_advance: 10,
auto_advance: true,
}
}
}
#[derive(Debug)]
pub struct CurriculumScheduler {
current_level: DifficultyLevel,
config: CurriculumConfig,
samples_at_level: usize,
correct_at_level: usize,
advancement_history: Vec<(DifficultyLevel, f64)>,
}
impl CurriculumScheduler {
#[must_use]
pub fn new() -> Self {
Self::with_config(CurriculumConfig::default())
}
#[must_use]
pub fn with_config(config: CurriculumConfig) -> Self {
Self {
current_level: DifficultyLevel::Easy,
config,
samples_at_level: 0,
correct_at_level: 0,
advancement_history: Vec::new(),
}
}
#[must_use]
pub fn current_level(&self) -> DifficultyLevel {
self.current_level
}
#[must_use]
pub fn config(&self) -> &CurriculumConfig {
&self.config
}
pub fn report_accuracy(&mut self, accuracy: f64) {
self.samples_at_level += 1;
if accuracy >= self.config.advance_threshold {
self.correct_at_level += 1;
}
if self.config.auto_advance && self.should_advance() {
self.advance();
}
}
pub fn record_prediction(&mut self, correct: bool) {
self.samples_at_level += 1;
if correct {
self.correct_at_level += 1;
}
if self.config.auto_advance && self.should_advance() {
self.advance();
}
}
#[must_use]
pub fn should_advance(&self) -> bool {
if self.current_level == DifficultyLevel::Expert {
return false; }
if self.samples_at_level < self.config.min_samples_before_advance {
return false; }
let accuracy = self.level_accuracy();
accuracy >= self.config.advance_threshold
}
#[must_use]
pub fn level_accuracy(&self) -> f64 {
if self.samples_at_level == 0 {
0.0
} else {
self.correct_at_level as f64 / self.samples_at_level as f64
}
}
pub fn advance(&mut self) {
if self.current_level == DifficultyLevel::Expert {
return;
}
let accuracy = self.level_accuracy();
self.advancement_history
.push((self.current_level, accuracy));
self.current_level = self.current_level.next();
self.samples_at_level = 0;
self.correct_at_level = 0;
}
#[must_use]
pub fn next_batch<'a>(&self, corpus: &'a Corpus, limit: usize) -> Vec<&'a Sample> {
let max_difficulty = self.current_level.score();
corpus
.samples()
.iter()
.filter(|s| s.difficulty <= max_difficulty)
.take(limit)
.collect()
}
#[must_use]
pub fn samples_at_difficulty<'a>(
&self,
corpus: &'a Corpus,
level: DifficultyLevel,
) -> Vec<&'a Sample> {
let target_score = level.score();
let tolerance = 0.125;
corpus
.samples()
.iter()
.filter(|s| (s.difficulty - target_score).abs() <= tolerance)
.collect()
}
pub fn reset(&mut self) {
self.current_level = DifficultyLevel::Easy;
self.samples_at_level = 0;
self.correct_at_level = 0;
self.advancement_history.clear();
}
#[must_use]
pub fn advancement_history(&self) -> &[(DifficultyLevel, f64)] {
&self.advancement_history
}
#[must_use]
pub fn samples_processed(&self) -> usize {
self.samples_at_level
}
}
impl Default for CurriculumScheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::oracle::{ErrorCategory, SampleSource};
#[test]
fn test_curriculum_scheduler_new() {
let scheduler = CurriculumScheduler::new();
assert_eq!(scheduler.current_level(), DifficultyLevel::Easy);
}
#[test]
fn test_curriculum_scheduler_with_config() {
let config = CurriculumConfig {
advance_threshold: 0.90,
samples_per_level: 50,
..Default::default()
};
let scheduler = CurriculumScheduler::with_config(config);
assert_eq!(scheduler.config().advance_threshold, 0.90);
}
#[test]
fn test_curriculum_advance_on_high_accuracy() {
let config = CurriculumConfig {
advance_threshold: 0.85,
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
assert_eq!(scheduler.current_level(), DifficultyLevel::Easy);
scheduler.report_accuracy(0.90);
assert_eq!(scheduler.current_level(), DifficultyLevel::Medium);
}
#[test]
fn test_curriculum_no_advance_on_low_accuracy() {
let config = CurriculumConfig {
advance_threshold: 0.85,
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.report_accuracy(0.70);
assert_eq!(scheduler.current_level(), DifficultyLevel::Easy);
}
#[test]
fn test_curriculum_advance_through_all_levels() {
let config = CurriculumConfig {
advance_threshold: 0.85,
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.report_accuracy(0.90);
assert_eq!(scheduler.current_level(), DifficultyLevel::Medium);
scheduler.report_accuracy(0.90);
assert_eq!(scheduler.current_level(), DifficultyLevel::Hard);
scheduler.report_accuracy(0.90);
assert_eq!(scheduler.current_level(), DifficultyLevel::Expert);
scheduler.report_accuracy(0.90);
assert_eq!(scheduler.current_level(), DifficultyLevel::Expert);
}
#[test]
fn test_curriculum_next_batch_filters() {
let mut corpus = Corpus::new();
corpus.add(
Sample::new("easy", Some("E0308".into()), ErrorCategory::TypeMismatch)
.with_difficulty(0.25),
);
corpus.add(
Sample::new("medium", Some("E0382".into()), ErrorCategory::BorrowChecker)
.with_difficulty(0.50),
);
corpus.add(
Sample::new("hard", Some("E0597".into()), ErrorCategory::LifetimeError)
.with_difficulty(0.75),
);
let scheduler = CurriculumScheduler::new();
let batch = scheduler.next_batch(&corpus, 10);
assert!(batch.iter().all(|s| s.difficulty <= 0.25));
}
#[test]
fn test_curriculum_next_batch_limit() {
let mut corpus = Corpus::new();
for i in 0..100 {
corpus.add(
Sample::new(format!("error {i}"), None, ErrorCategory::TypeMismatch)
.with_difficulty(0.25)
.with_source(SampleSource::Synthetic),
);
}
let scheduler = CurriculumScheduler::new();
let batch = scheduler.next_batch(&corpus, 10);
assert!(batch.len() <= 10);
}
#[test]
fn test_curriculum_level_accuracy() {
let config = CurriculumConfig {
min_samples_before_advance: 100,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.record_prediction(true);
scheduler.record_prediction(true);
scheduler.record_prediction(false);
scheduler.record_prediction(true);
assert!((scheduler.level_accuracy() - 0.75).abs() < 0.01);
}
#[test]
fn test_curriculum_reset() {
let config = CurriculumConfig {
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.report_accuracy(0.90);
scheduler.report_accuracy(0.90);
assert_eq!(scheduler.current_level(), DifficultyLevel::Hard);
scheduler.reset();
assert_eq!(scheduler.current_level(), DifficultyLevel::Easy);
assert_eq!(scheduler.samples_processed(), 0);
}
#[test]
fn test_curriculum_config_default() {
let config = CurriculumConfig::default();
assert!((config.advance_threshold - 0.85).abs() < f64::EPSILON);
assert_eq!(config.samples_per_level, 100);
}
#[test]
fn test_curriculum_scheduler_default() {
let scheduler = CurriculumScheduler::default();
assert_eq!(scheduler.current_level(), DifficultyLevel::Easy);
assert_eq!(scheduler.samples_processed(), 0);
}
#[test]
fn test_curriculum_config_fields() {
let scheduler = CurriculumScheduler::new();
let config = scheduler.config();
assert!(config.auto_advance);
assert_eq!(config.min_samples_before_advance, 10);
}
#[test]
fn test_curriculum_samples_at_difficulty() {
let mut corpus = Corpus::new();
corpus.add(
Sample::new(
"mismatched types: expected i32, found String",
Some("E0308".into()),
ErrorCategory::TypeMismatch,
)
.with_difficulty(0.25),
);
corpus.add(
Sample::new(
"mismatched types: expected bool, found char",
Some("E0308".into()),
ErrorCategory::TypeMismatch,
)
.with_difficulty(0.25),
);
corpus.add(
Sample::new(
"borrow of moved value",
Some("E0382".into()),
ErrorCategory::BorrowChecker,
)
.with_difficulty(0.50),
);
let scheduler = CurriculumScheduler::new();
let easy_samples = scheduler.samples_at_difficulty(&corpus, DifficultyLevel::Easy);
assert_eq!(easy_samples.len(), 2);
}
#[test]
fn test_curriculum_advancement_history() {
let config = CurriculumConfig {
advance_threshold: 0.85,
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.report_accuracy(0.90);
scheduler.report_accuracy(0.90);
let history = scheduler.advancement_history();
assert_eq!(history.len(), 2);
assert_eq!(history[0].0, DifficultyLevel::Easy);
assert_eq!(history[1].0, DifficultyLevel::Medium);
}
#[test]
fn test_curriculum_should_advance_min_samples() {
let config = CurriculumConfig {
advance_threshold: 0.85,
min_samples_before_advance: 10,
auto_advance: false, ..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.record_prediction(true);
assert!(!scheduler.should_advance());
}
#[test]
fn test_curriculum_level_accuracy_zero_samples() {
let scheduler = CurriculumScheduler::new();
assert_eq!(scheduler.level_accuracy(), 0.0);
}
#[test]
fn test_curriculum_advance_at_expert_noop() {
let config = CurriculumConfig {
advance_threshold: 0.85,
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.report_accuracy(0.90); scheduler.report_accuracy(0.90); scheduler.report_accuracy(0.90); assert_eq!(scheduler.current_level(), DifficultyLevel::Expert);
scheduler.advance();
assert_eq!(scheduler.current_level(), DifficultyLevel::Expert);
}
#[test]
fn test_curriculum_should_advance_at_expert() {
let config = CurriculumConfig {
advance_threshold: 0.85,
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.report_accuracy(0.90);
scheduler.report_accuracy(0.90);
scheduler.report_accuracy(0.90);
assert!(!scheduler.should_advance());
}
#[test]
fn test_curriculum_record_prediction_false() {
let config = CurriculumConfig {
min_samples_before_advance: 100,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.record_prediction(false);
scheduler.record_prediction(false);
scheduler.record_prediction(false);
assert_eq!(scheduler.level_accuracy(), 0.0);
assert_eq!(scheduler.samples_processed(), 3);
}
#[test]
fn test_curriculum_manual_advance() {
let config = CurriculumConfig {
auto_advance: false,
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.record_prediction(true);
assert_eq!(scheduler.current_level(), DifficultyLevel::Easy);
scheduler.advance();
assert_eq!(scheduler.current_level(), DifficultyLevel::Medium);
}
#[test]
fn test_curriculum_config_all_fields() {
let config = CurriculumConfig {
advance_threshold: 0.99,
samples_per_level: 500,
min_samples_before_advance: 50,
auto_advance: false,
};
assert_eq!(config.advance_threshold, 0.99);
assert_eq!(config.samples_per_level, 500);
assert_eq!(config.min_samples_before_advance, 50);
assert!(!config.auto_advance);
}
#[test]
fn test_curriculum_reset_clears_history() {
let config = CurriculumConfig {
min_samples_before_advance: 1,
..Default::default()
};
let mut scheduler = CurriculumScheduler::with_config(config);
scheduler.report_accuracy(0.90);
assert!(!scheduler.advancement_history().is_empty());
scheduler.reset();
assert!(scheduler.advancement_history().is_empty());
}
}