use std::collections::HashMap;
use super::CurriculumScheduler;
#[derive(Debug, Clone)]
pub struct AdaptiveCurriculum {
pub(crate) class_accuracy: HashMap<String, f32>,
pub(crate) class_attempts: HashMap<String, usize>,
default_tier: usize,
overall_difficulty: f32,
}
impl AdaptiveCurriculum {
pub fn new() -> Self {
Self {
class_accuracy: HashMap::new(),
class_attempts: HashMap::new(),
default_tier: 1,
overall_difficulty: 0.0,
}
}
pub fn tier_for_error(&self, error_code: &str, attempt: usize) -> usize {
if error_code.starts_with("ICE") {
return 4; }
if matches!(error_code, "E0308" | "E0277" | "E0382") && attempt >= 1 {
return 3;
}
if matches!(error_code, "E0425" | "E0433") && attempt >= 2 {
return 3;
}
match attempt {
0 => self.default_tier,
1 => 2,
2.. => 3,
}
}
pub fn update_class(&mut self, error_code: &str, correct: bool) {
let attempts = self.class_attempts.entry(error_code.to_string()).or_insert(0);
*attempts += 1;
let acc = self.class_accuracy.entry(error_code.to_string()).or_insert(0.0);
let alpha = 0.1;
*acc = *acc * (1.0 - alpha) + if correct { alpha } else { 0.0 };
if !self.class_accuracy.is_empty() {
self.overall_difficulty =
self.class_accuracy.values().sum::<f32>() / self.class_accuracy.len() as f32;
}
}
pub fn weight_for_class(&self, error_code: &str) -> f32 {
let attempts = *self.class_attempts.get(error_code).unwrap_or(&0);
let accuracy = *self.class_accuracy.get(error_code).unwrap_or(&0.0);
let rarity_weight = 1.0 / (attempts as f32 + 1.0).sqrt();
let difficulty_weight = 1.0 - accuracy;
(1.0 + rarity_weight + difficulty_weight).min(3.0)
}
}
impl Default for AdaptiveCurriculum {
fn default() -> Self {
Self::new()
}
}
impl CurriculumScheduler for AdaptiveCurriculum {
fn difficulty(&self) -> f32 {
self.overall_difficulty
}
fn tier(&self) -> usize {
if self.overall_difficulty < 0.25 {
1
} else if self.overall_difficulty < 0.5 {
2
} else if self.overall_difficulty < 0.75 {
3
} else {
4
}
}
fn step(&mut self, _epoch: usize, accuracy: f32) {
let alpha = 0.1;
self.overall_difficulty = self.overall_difficulty * (1.0 - alpha) + accuracy * alpha;
}
fn reset(&mut self) {
self.class_accuracy.clear();
self.class_attempts.clear();
self.overall_difficulty = 0.0;
}
fn name(&self) -> &'static str {
"AdaptiveCurriculum"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adaptive_curriculum_new() {
let curriculum = AdaptiveCurriculum::new();
assert!(curriculum.class_accuracy.is_empty());
assert!(curriculum.class_attempts.is_empty());
assert_eq!(curriculum.overall_difficulty, 0.0);
}
#[test]
fn test_adaptive_curriculum_default() {
let curriculum = AdaptiveCurriculum::default();
assert_eq!(curriculum.difficulty(), 0.0);
}
#[test]
fn test_tier_for_error_ice() {
let curriculum = AdaptiveCurriculum::new();
assert_eq!(curriculum.tier_for_error("ICE001", 0), 4);
assert_eq!(curriculum.tier_for_error("ICE-crash", 5), 4);
}
#[test]
fn test_tier_for_error_type_errors() {
let curriculum = AdaptiveCurriculum::new();
assert_eq!(curriculum.tier_for_error("E0308", 0), 1);
assert_eq!(curriculum.tier_for_error("E0308", 1), 3);
assert_eq!(curriculum.tier_for_error("E0277", 2), 3);
assert_eq!(curriculum.tier_for_error("E0382", 1), 3);
}
#[test]
fn test_tier_for_error_name_resolution() {
let curriculum = AdaptiveCurriculum::new();
assert_eq!(curriculum.tier_for_error("E0425", 0), 1);
assert_eq!(curriculum.tier_for_error("E0425", 1), 2);
assert_eq!(curriculum.tier_for_error("E0425", 2), 3);
assert_eq!(curriculum.tier_for_error("E0433", 3), 3);
}
#[test]
fn test_tier_for_error_default_escalation() {
let curriculum = AdaptiveCurriculum::new();
assert_eq!(curriculum.tier_for_error("E0001", 0), 1);
assert_eq!(curriculum.tier_for_error("E0001", 1), 2);
assert_eq!(curriculum.tier_for_error("E0001", 2), 3);
assert_eq!(curriculum.tier_for_error("E0001", 5), 3);
}
#[test]
fn test_update_class() {
let mut curriculum = AdaptiveCurriculum::new();
curriculum.update_class("E0308", true);
assert_eq!(curriculum.class_attempts.get("E0308"), Some(&1));
assert!(
(curriculum.class_accuracy.get("E0308").expect("key should exist") - 0.1).abs() < 0.001
);
curriculum.update_class("E0308", false);
assert_eq!(curriculum.class_attempts.get("E0308"), Some(&2));
assert!(
(curriculum.class_accuracy.get("E0308").expect("key should exist") - 0.09).abs()
< 0.001
);
}
#[test]
fn test_weight_for_class_unknown() {
let curriculum = AdaptiveCurriculum::new();
let weight = curriculum.weight_for_class("unknown");
assert!((weight - 3.0).abs() < 0.001);
}
#[test]
fn test_weight_for_class_known() {
let mut curriculum = AdaptiveCurriculum::new();
for _ in 0..10 {
curriculum.update_class("E0308", true);
}
let weight = curriculum.weight_for_class("E0308");
assert!(weight < 3.0);
assert!(weight >= 1.0);
}
#[test]
fn test_curriculum_scheduler_difficulty() {
let mut curriculum = AdaptiveCurriculum::new();
assert_eq!(curriculum.difficulty(), 0.0);
curriculum.step(0, 0.5);
assert!(curriculum.difficulty() > 0.0);
}
#[test]
fn test_curriculum_scheduler_tier() {
let mut curriculum = AdaptiveCurriculum::new();
assert_eq!(curriculum.tier(), 1);
curriculum.overall_difficulty = 0.3;
assert_eq!(curriculum.tier(), 2);
curriculum.overall_difficulty = 0.6;
assert_eq!(curriculum.tier(), 3);
curriculum.overall_difficulty = 0.8;
assert_eq!(curriculum.tier(), 4);
}
#[test]
fn test_curriculum_scheduler_step() {
let mut curriculum = AdaptiveCurriculum::new();
curriculum.step(0, 1.0);
assert!((curriculum.difficulty() - 0.1).abs() < 0.001);
curriculum.step(1, 1.0);
assert!((curriculum.difficulty() - 0.19).abs() < 0.001);
}
#[test]
fn test_curriculum_scheduler_reset() {
let mut curriculum = AdaptiveCurriculum::new();
curriculum.update_class("E0308", true);
curriculum.step(0, 0.5);
assert!(!curriculum.class_accuracy.is_empty());
assert!(curriculum.difficulty() > 0.0);
curriculum.reset();
assert!(curriculum.class_accuracy.is_empty());
assert!(curriculum.class_attempts.is_empty());
assert_eq!(curriculum.difficulty(), 0.0);
}
#[test]
fn test_curriculum_scheduler_name() {
let curriculum = AdaptiveCurriculum::new();
assert_eq!(curriculum.name(), "AdaptiveCurriculum");
}
#[test]
fn test_adaptive_curriculum_clone() {
let mut curriculum = AdaptiveCurriculum::new();
curriculum.update_class("E0308", true);
let cloned = curriculum.clone();
assert_eq!(curriculum.class_attempts, cloned.class_attempts);
assert_eq!(curriculum.class_accuracy, cloned.class_accuracy);
}
#[test]
fn test_adaptive_curriculum_debug() {
let curriculum = AdaptiveCurriculum::new();
let debug = format!("{curriculum:?}");
assert!(debug.contains("AdaptiveCurriculum"));
}
}