use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::{errors::invalid_input, tensor::Tensor, traits::Model, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CurriculumConfig {
pub strategy: CurriculumStrategy,
pub difficulty_measure: DifficultyMeasure,
pub pacing_function: PacingFunction,
pub initial_data_percentage: f32,
pub use_throughout_training: bool,
pub curriculum_epochs: usize,
pub shuffle_easy_examples: bool,
pub adaptive_threshold: bool,
pub min_difficulty_threshold: f32,
pub max_difficulty_threshold: f32,
pub evaluation_frequency: usize,
}
impl Default for CurriculumConfig {
fn default() -> Self {
Self {
strategy: CurriculumStrategy::SelfPaced {
lambda: 0.5,
gamma: 1.1,
},
difficulty_measure: DifficultyMeasure::LossBasedDifficulty,
pacing_function: PacingFunction::Linear,
initial_data_percentage: 0.1,
use_throughout_training: true,
curriculum_epochs: 10,
shuffle_easy_examples: true,
adaptive_threshold: true,
min_difficulty_threshold: 0.1,
max_difficulty_threshold: 0.9,
evaluation_frequency: 1000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CurriculumStrategy {
SelfPaced { lambda: f32, gamma: f32 },
CompetenceBased {
competence_threshold: f32,
increase_rate: f32,
},
Predefined {
difficulty_levels: Vec<f32>,
level_durations: Vec<usize>,
},
BabySteps { step_size: f32, patience: usize },
AntiCurriculum { reverse_pacing: bool },
Cyclical {
cycle_length: usize,
num_cycles: usize,
},
Minimax {
teacher_lambda: f32,
student_lambda: f32,
},
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DifficultyMeasure {
LossBasedDifficulty,
GradientNormDifficulty,
ConfidenceDifficulty,
LengthDifficulty,
ComplexityDifficulty,
MultiCriteria {
measures: Vec<DifficultyMeasure>,
weights: Vec<f32>,
},
LearnedDifficulty {
difficulty_network: Option<String>, },
ManualDifficulty,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PacingFunction {
Linear,
Exponential { rate: f32 },
Logarithmic { base: f32 },
Sigmoid { steepness: f32, midpoint: f32 },
StepWise { steps: Vec<(usize, f32)> },
Polynomial { degree: f32 },
Custom { function_name: String },
}
#[derive(Debug, Clone)]
pub struct CurriculumExample {
pub input: Tensor,
pub target: Tensor,
pub difficulty: f32,
pub metadata: HashMap<String, String>,
pub weight: f32,
}
impl CurriculumExample {
pub fn new(input: Tensor, target: Tensor, difficulty: f32) -> Self {
Self {
input,
target,
difficulty,
metadata: HashMap::new(),
weight: 1.0,
}
}
pub fn with_metadata(
input: Tensor,
target: Tensor,
difficulty: f32,
metadata: HashMap<String, String>,
) -> Self {
Self {
input,
target,
difficulty,
metadata,
weight: 1.0,
}
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
}
pub struct CurriculumLearningTrainer<M: Model> {
pub model: M,
pub config: CurriculumConfig,
pub examples: Vec<CurriculumExample>,
pub current_threshold: f32,
pub current_epoch: usize,
pub step_counter: usize,
pub performance_history: Vec<f32>,
pub difficulty_scorer: Option<DifficultyScorer>,
}
impl<M: Model<Input = Tensor, Output = Tensor>> CurriculumLearningTrainer<M> {
pub fn new(model: M, config: CurriculumConfig) -> Result<Self> {
let difficulty_scorer = match &config.difficulty_measure {
DifficultyMeasure::LearnedDifficulty { .. } => {
Some(DifficultyScorer::new(&config.difficulty_measure)?)
},
_ => None,
};
let initial_data_percentage = config.initial_data_percentage;
Ok(Self {
model,
config,
examples: Vec::new(),
current_threshold: initial_data_percentage,
current_epoch: 0,
step_counter: 0,
performance_history: Vec::new(),
difficulty_scorer,
})
}
pub fn add_examples(&mut self, examples: Vec<CurriculumExample>) {
self.examples.extend(examples);
self.sort_examples_by_difficulty();
}
pub fn add_example(&mut self, example: CurriculumExample) {
self.examples.push(example);
self.sort_examples_by_difficulty();
}
pub fn estimate_difficulties(&mut self) -> Result<()> {
let mut indices_to_update = Vec::new();
for (i, example) in self.examples.iter().enumerate() {
if example.difficulty == 0.0 {
indices_to_update.push(i);
}
}
for i in indices_to_update {
let input = self.examples[i].input.clone();
let target = self.examples[i].target.clone();
let difficulty = self.compute_difficulty(&input, &target)?;
self.examples[i].difficulty = difficulty;
}
self.sort_examples_by_difficulty();
Ok(())
}
fn compute_difficulty(&self, input: &Tensor, target: &Tensor) -> Result<f32> {
match &self.config.difficulty_measure {
DifficultyMeasure::LossBasedDifficulty => {
let outputs = self.model.forward(input.clone())?;
let loss = self.compute_loss(&outputs, target)?;
loss.to_scalar().map_err(|e| {
invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
})
},
DifficultyMeasure::GradientNormDifficulty => {
Ok(0.5) },
DifficultyMeasure::ConfidenceDifficulty => {
let outputs = self.model.forward(input.clone())?;
let probs = outputs.softmax(-1)?;
let max_prob = self.compute_max_probability(&probs)?;
Ok(1.0 - max_prob) },
DifficultyMeasure::LengthDifficulty => {
let seq_len = input.shape()[1] as f32; Ok(seq_len / 1000.0) },
DifficultyMeasure::ComplexityDifficulty => {
Ok(0.5) },
DifficultyMeasure::MultiCriteria { measures, weights } => {
let mut total_difficulty = 0.0;
let mut total_weight = 0.0;
for (measure, &weight) in measures.iter().zip(weights.iter()) {
let difficulty = self.compute_individual_difficulty(measure, input, target)?;
total_difficulty += difficulty * weight;
total_weight += weight;
}
Ok(if total_weight > 0.0 { total_difficulty / total_weight } else { 0.5 })
},
DifficultyMeasure::LearnedDifficulty { .. } => {
if let Some(scorer) = &self.difficulty_scorer {
scorer.score_difficulty(input, target)
} else {
Ok(0.5)
}
},
DifficultyMeasure::ManualDifficulty => {
Ok(0.5) },
}
}
fn compute_individual_difficulty(
&self,
measure: &DifficultyMeasure,
input: &Tensor,
target: &Tensor,
) -> Result<f32> {
match measure {
DifficultyMeasure::LossBasedDifficulty => {
let outputs = self.model.forward(input.clone())?;
let loss = self.compute_loss(&outputs, target)?;
loss.to_scalar().map_err(|e| {
invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
})
},
DifficultyMeasure::LengthDifficulty => {
let seq_len = input.shape()[1] as f32; Ok(seq_len / 1000.0) },
DifficultyMeasure::GradientNormDifficulty => {
Ok(0.5) },
DifficultyMeasure::ConfidenceDifficulty => {
let _outputs = self.model.forward(input.clone())?;
Ok(0.5) },
DifficultyMeasure::ComplexityDifficulty => {
Ok(0.5) },
DifficultyMeasure::LearnedDifficulty { .. } => {
if let Some(scorer) = &self.difficulty_scorer {
scorer.score_difficulty(input, target)
} else {
Ok(0.5)
}
},
DifficultyMeasure::ManualDifficulty => {
Ok(0.5) },
DifficultyMeasure::MultiCriteria { .. } => {
Ok(0.5)
},
}
}
fn sort_examples_by_difficulty(&mut self) {
self.examples.sort_by(|a, b| {
a.difficulty.partial_cmp(&b.difficulty).unwrap_or(std::cmp::Ordering::Equal)
});
}
pub fn get_current_curriculum(&self) -> Vec<CurriculumExample> {
let num_examples = self.examples.len();
let threshold_count = (num_examples as f32 * self.current_threshold) as usize;
match &self.config.strategy {
CurriculumStrategy::AntiCurriculum { reverse_pacing } => {
if *reverse_pacing {
self.examples.iter().rev().take(threshold_count).cloned().collect()
} else {
self.examples.iter().take(threshold_count).cloned().collect()
}
},
_ => {
self.examples.iter().take(threshold_count).cloned().collect()
},
}
}
pub fn update_curriculum_threshold(&mut self) -> Result<()> {
match &self.config.strategy {
CurriculumStrategy::SelfPaced { lambda: _, gamma } => {
let recent_performance = self.get_recent_performance();
if recent_performance > 0.8 {
self.current_threshold = (self.current_threshold * gamma).min(1.0);
}
},
CurriculumStrategy::CompetenceBased {
competence_threshold,
increase_rate,
} => {
let competence = self.compute_competence()?;
if competence > *competence_threshold {
self.current_threshold = (self.current_threshold + increase_rate).min(1.0);
}
},
CurriculumStrategy::Predefined {
difficulty_levels,
level_durations,
} => {
let total_steps: usize = level_durations.iter().sum();
let current_step = self.step_counter % total_steps;
let mut cumulative_steps = 0;
for (i, &duration) in level_durations.iter().enumerate() {
cumulative_steps += duration;
if current_step < cumulative_steps {
if i < difficulty_levels.len() {
self.current_threshold = difficulty_levels[i];
}
break;
}
}
},
CurriculumStrategy::BabySteps {
step_size,
patience,
} => {
if self.performance_history.len() >= *patience {
let recent_avg =
self.performance_history.iter().rev().take(*patience).sum::<f32>()
/ *patience as f32;
if recent_avg > 0.85 {
self.current_threshold = (self.current_threshold + step_size).min(1.0);
}
}
},
CurriculumStrategy::Cyclical { cycle_length, .. } => {
let cycle_position =
(self.step_counter % cycle_length) as f32 / *cycle_length as f32;
self.current_threshold = self.apply_pacing_function(cycle_position);
},
_ => {
let progress = self.current_epoch as f32 / self.config.curriculum_epochs as f32;
self.current_threshold = self.apply_pacing_function(progress);
},
}
self.current_threshold = self
.current_threshold
.max(self.config.min_difficulty_threshold)
.min(self.config.max_difficulty_threshold);
Ok(())
}
fn apply_pacing_function(&self, progress: f32) -> f32 {
let clamped_progress = progress.clamp(0.0, 1.0);
match &self.config.pacing_function {
PacingFunction::Linear => {
self.config.initial_data_percentage
+ (1.0 - self.config.initial_data_percentage) * clamped_progress
},
PacingFunction::Exponential { rate } => {
self.config.initial_data_percentage
+ (1.0 - self.config.initial_data_percentage)
* (1.0 - (-rate * clamped_progress).exp())
},
PacingFunction::Logarithmic { base } => {
self.config.initial_data_percentage
+ (1.0 - self.config.initial_data_percentage) * (clamped_progress * base).ln()
/ base.ln()
},
PacingFunction::Sigmoid {
steepness,
midpoint,
} => {
let sigmoid = 1.0 / (1.0 + (-steepness * (clamped_progress - midpoint)).exp());
self.config.initial_data_percentage
+ (1.0 - self.config.initial_data_percentage) * sigmoid
},
PacingFunction::StepWise { steps } => {
let total_steps = self.step_counter;
for &(step_threshold, threshold_value) in steps {
if total_steps <= step_threshold {
return threshold_value;
}
}
1.0 },
PacingFunction::Polynomial { degree } => {
self.config.initial_data_percentage
+ (1.0 - self.config.initial_data_percentage) * clamped_progress.powf(*degree)
},
PacingFunction::Custom { .. } => {
self.apply_pacing_function_linear(clamped_progress)
},
}
}
fn apply_pacing_function_linear(&self, progress: f32) -> f32 {
self.config.initial_data_percentage + (1.0 - self.config.initial_data_percentage) * progress
}
fn compute_competence(&self) -> Result<f32> {
if self.performance_history.is_empty() {
return Ok(0.0);
}
let recent_performance = self.get_recent_performance();
Ok(recent_performance)
}
fn get_recent_performance(&self) -> f32 {
if self.performance_history.is_empty() {
return 0.0;
}
let window_size = 10.min(self.performance_history.len());
self.performance_history.iter().rev().take(window_size).sum::<f32>() / window_size as f32
}
pub fn train_step(&mut self) -> Result<CurriculumLearningOutput> {
self.update_curriculum_threshold()?;
let curriculum_examples = self.get_current_curriculum();
if curriculum_examples.is_empty() {
return Err(invalid_input(
"No examples available for training".to_string(),
));
}
let example = &curriculum_examples[self.step_counter % curriculum_examples.len()];
let outputs = self.model.forward(example.input.clone())?;
let loss = self.compute_loss(&outputs, &example.target)?;
let weighted_loss = loss.scalar_mul(example.weight)?;
let accuracy = self.compute_accuracy(&outputs, &example.target)?;
self.performance_history.push(accuracy);
if self.performance_history.len() > 1000 {
self.performance_history = self.performance_history.split_off(500);
}
self.step_counter += 1;
Ok(CurriculumLearningOutput {
loss: weighted_loss,
accuracy,
difficulty_threshold: self.current_threshold,
examples_used: curriculum_examples.len(),
current_difficulty: example.difficulty,
})
}
pub fn train_epoch(&mut self) -> Result<CurriculumEpochOutput> {
let mut total_loss = 0.0;
let mut total_accuracy = 0.0;
let mut num_steps = 0;
let curriculum_examples = self.get_current_curriculum();
for example in &curriculum_examples {
let outputs = self.model.forward(example.input.clone())?;
let loss = self.compute_loss(&outputs, &example.target)?;
let accuracy = self.compute_accuracy(&outputs, &example.target)?;
let loss_scalar = loss.to_scalar().map_err(|e| {
invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
})?;
total_loss += loss_scalar * example.weight;
total_accuracy += accuracy;
num_steps += 1;
}
self.current_epoch += 1;
Ok(CurriculumEpochOutput {
epoch: self.current_epoch,
average_loss: total_loss / num_steps as f32,
average_accuracy: total_accuracy / num_steps as f32,
difficulty_threshold: self.current_threshold,
examples_used: curriculum_examples.len(),
total_examples: self.examples.len(),
})
}
fn compute_loss(&self, outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
self.compute_cross_entropy_loss(outputs, targets)
}
fn compute_accuracy(&self, outputs: &Tensor, targets: &Tensor) -> Result<f32> {
let predicted = self.compute_argmax(outputs)?;
let target_indices = self.compute_argmax(targets)?;
let total_samples = predicted.len() as f32;
if total_samples == 0.0 {
return Ok(0.0);
}
let mut correct = 0.0;
for (pred, target) in predicted.iter().zip(target_indices.iter()) {
if (pred - target).abs() < f32::EPSILON {
correct += 1.0;
}
}
Ok(correct / total_samples)
}
pub fn get_curriculum_stats(&self) -> CurriculumStats {
let curriculum_examples = self.get_current_curriculum();
let difficulties: Vec<f32> = curriculum_examples.iter().map(|e| e.difficulty).collect();
let min_difficulty = difficulties.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_difficulty = difficulties.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let avg_difficulty = if !difficulties.is_empty() {
difficulties.iter().sum::<f32>() / difficulties.len() as f32
} else {
0.0
};
CurriculumStats {
current_threshold: self.current_threshold,
examples_in_curriculum: curriculum_examples.len(),
total_examples: self.examples.len(),
min_difficulty,
max_difficulty,
avg_difficulty,
epoch: self.current_epoch,
step: self.step_counter,
}
}
fn compute_max_probability(&self, probs: &Tensor) -> Result<f32> {
match probs {
Tensor::F32(arr) => {
let max_val = arr.iter().fold(0.0f32, |acc, &x| acc.max(x));
Ok(max_val)
},
_ => {
Ok(0.5) },
}
}
fn compute_cross_entropy_loss(&self, outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
let probs = outputs.softmax(-1)?;
let log_probs = probs.log()?;
match (log_probs, targets) {
(Tensor::F32(log_prob_arr), Tensor::F32(target_arr)) => {
let batch_size = log_prob_arr.shape()[0];
let num_classes = log_prob_arr.shape().get(1).copied().ok_or_else(|| {
invalid_input(format!(
"Invalid tensor shape: expected at least 2 dimensions, got {}",
log_prob_arr.shape().len()
))
})?;
let mut total_loss = 0.0f32;
for batch_idx in 0..batch_size {
if target_arr.shape().len() == 1 {
let target_class = target_arr[[batch_idx]] as usize;
if target_class < num_classes {
total_loss -= log_prob_arr[[batch_idx, target_class]];
}
} else if target_arr.shape().len() >= 2 && target_arr.shape()[1] == num_classes
{
for class_idx in 0..num_classes {
let target_prob = target_arr[[batch_idx, class_idx]];
if target_prob > 0.0 {
total_loss -= target_prob * log_prob_arr[[batch_idx, class_idx]];
}
}
}
}
let mean_loss = total_loss / batch_size as f32;
Ok(Tensor::scalar(mean_loss)?)
},
_ => {
Ok(Tensor::scalar(1.0f32)?)
},
}
}
fn compute_argmax(&self, tensor: &Tensor) -> Result<Vec<f32>> {
match tensor {
Tensor::F32(arr) => {
let mut argmax_values = Vec::new();
if arr.ndim() == 1 {
let mut max_idx = 0;
let mut max_val = arr[0];
for (idx, &val) in arr.iter().enumerate() {
if val > max_val {
max_val = val;
max_idx = idx;
}
}
argmax_values.push(max_idx as f32);
} else if arr.ndim() == 2 {
let batch_size = arr.shape()[0];
let num_classes = arr.shape()[1];
for batch_idx in 0..batch_size {
let mut max_idx = 0;
let mut max_val = arr[[batch_idx, 0]];
for class_idx in 1..num_classes {
let val = arr[[batch_idx, class_idx]];
if val > max_val {
max_val = val;
max_idx = class_idx;
}
}
argmax_values.push(max_idx as f32);
}
} else {
let mut max_idx = 0;
let mut max_val = arr.iter().next().copied().ok_or_else(|| {
invalid_input("Cannot compute argmax on empty tensor".to_string())
})?;
for (idx, &val) in arr.iter().enumerate() {
if val > max_val {
max_val = val;
max_idx = idx;
}
}
argmax_values.push(max_idx as f32);
}
Ok(argmax_values)
},
_ => {
Ok(vec![0.0])
},
}
}
}
pub struct DifficultyScorer {
#[allow(dead_code)]
method: DifficultyMeasure,
}
impl DifficultyScorer {
pub fn new(method: &DifficultyMeasure) -> Result<Self> {
Ok(Self {
method: method.clone(),
})
}
pub fn score_difficulty(&self, _input: &Tensor, _target: &Tensor) -> Result<f32> {
Ok(0.5) }
}
#[derive(Debug, Clone)]
pub struct CurriculumLearningOutput {
pub loss: Tensor,
pub accuracy: f32,
pub difficulty_threshold: f32,
pub examples_used: usize,
pub current_difficulty: f32,
}
#[derive(Debug, Clone)]
pub struct CurriculumEpochOutput {
pub epoch: usize,
pub average_loss: f32,
pub average_accuracy: f32,
pub difficulty_threshold: f32,
pub examples_used: usize,
pub total_examples: usize,
}
#[derive(Debug, Clone)]
pub struct CurriculumStats {
pub current_threshold: f32,
pub examples_in_curriculum: usize,
pub total_examples: usize,
pub min_difficulty: f32,
pub max_difficulty: f32,
pub avg_difficulty: f32,
pub epoch: usize,
pub step: usize,
}
pub mod utils {
use super::*;
pub fn self_paced_config(lambda: f32, gamma: f32) -> CurriculumConfig {
CurriculumConfig {
strategy: CurriculumStrategy::SelfPaced { lambda, gamma },
..Default::default()
}
}
pub fn competence_based_config(threshold: f32, increase_rate: f32) -> CurriculumConfig {
CurriculumConfig {
strategy: CurriculumStrategy::CompetenceBased {
competence_threshold: threshold,
increase_rate,
},
..Default::default()
}
}
pub fn baby_steps_config(step_size: f32, patience: usize) -> CurriculumConfig {
CurriculumConfig {
strategy: CurriculumStrategy::BabySteps {
step_size,
patience,
},
pacing_function: PacingFunction::Linear,
..Default::default()
}
}
pub fn predefined_config(
difficulty_levels: Vec<f32>,
level_durations: Vec<usize>,
) -> CurriculumConfig {
CurriculumConfig {
strategy: CurriculumStrategy::Predefined {
difficulty_levels,
level_durations,
},
..Default::default()
}
}
pub fn anti_curriculum_config() -> CurriculumConfig {
CurriculumConfig {
strategy: CurriculumStrategy::AntiCurriculum {
reverse_pacing: true,
},
..Default::default()
}
}
pub fn cyclical_config(cycle_length: usize, num_cycles: usize) -> CurriculumConfig {
CurriculumConfig {
strategy: CurriculumStrategy::Cyclical {
cycle_length,
num_cycles,
},
..Default::default()
}
}
pub fn create_length_based_examples(
inputs: Vec<Tensor>,
targets: Vec<Tensor>,
) -> Vec<CurriculumExample> {
inputs
.into_iter()
.zip(targets)
.map(|(input, target)| {
let length = input.shape()[1] as f32; let difficulty = (length / 512.0).min(1.0); CurriculumExample::new(input, target, difficulty)
})
.collect()
}
pub fn create_loss_based_examples<M: Model<Input = Tensor, Output = Tensor>>(
model: &M,
inputs: Vec<Tensor>,
targets: Vec<Tensor>,
) -> Result<Vec<CurriculumExample>> {
let mut examples = Vec::new();
for (input, target) in inputs.into_iter().zip(targets) {
let outputs = model.forward(input.clone())?;
let loss = simple_cross_entropy_loss(&outputs, &target)?;
let difficulty = loss.to_scalar().map_err(|e| {
invalid_input(format!(
"Failed to convert loss tensor to scalar for difficulty estimation: {}",
e
))
})?;
examples.push(CurriculumExample::new(input, target, difficulty));
}
Ok(examples)
}
fn simple_cross_entropy_loss(outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
let probs = outputs.softmax(-1)?;
match targets.data() {
Ok(target_data) => {
if let Ok(prob_data) = probs.data() {
let batch_size = targets.shape()[0];
let mut total_loss = 0.0f32;
for i in 0..batch_size {
let target_idx = target_data[i] as usize;
if target_idx < prob_data.len() {
let prob = prob_data[target_idx].max(1e-8); total_loss += -prob.ln();
}
}
let mean_loss = total_loss / batch_size as f32;
Ok(Tensor::scalar(mean_loss)?)
} else {
Ok(Tensor::scalar(1.0f32)?)
}
},
Err(_) => Ok(Tensor::scalar(1.0f32)?),
}
}
pub fn create_manual_examples(
inputs: Vec<Tensor>,
targets: Vec<Tensor>,
difficulties: Vec<f32>,
) -> Result<Vec<CurriculumExample>> {
if inputs.len() != targets.len() || inputs.len() != difficulties.len() {
return Err(invalid_input("Mismatched array lengths".to_string()));
}
Ok(inputs
.into_iter()
.zip(targets)
.zip(difficulties)
.map(|((input, target), difficulty)| CurriculumExample::new(input, target, difficulty))
.collect())
}
pub fn analyze_curriculum_effectiveness(
baseline_accuracies: &[f32],
curriculum_accuracies: &[f32],
) -> CurriculumAnalysis {
let baseline_final = baseline_accuracies.last().copied().unwrap_or_else(|| {
eprintln!("Warning: Empty baseline accuracies array, using 0.0");
0.0
});
let curriculum_final = curriculum_accuracies.last().copied().unwrap_or_else(|| {
eprintln!("Warning: Empty curriculum accuracies array, using 0.0");
0.0
});
let improvement = curriculum_final - baseline_final;
let baseline_auc = baseline_accuracies.iter().sum::<f32>();
let curriculum_auc = curriculum_accuracies.iter().sum::<f32>();
let convergence_speedup = curriculum_auc / baseline_auc.max(1e-8);
CurriculumAnalysis {
final_accuracy_improvement: improvement,
convergence_speedup,
baseline_final_accuracy: baseline_final,
curriculum_final_accuracy: curriculum_final,
}
}
}
#[derive(Debug, Clone)]
pub struct CurriculumAnalysis {
pub final_accuracy_improvement: f32,
pub convergence_speedup: f32,
pub baseline_final_accuracy: f32,
pub curriculum_final_accuracy: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_curriculum_config_default() {
let config = CurriculumConfig::default();
assert_eq!(config.initial_data_percentage, 0.1);
assert!(config.use_throughout_training);
assert!(config.shuffle_easy_examples);
if let CurriculumStrategy::SelfPaced { lambda, gamma } = config.strategy {
assert_eq!(lambda, 0.5);
assert_eq!(gamma, 1.1);
} else {
panic!("Expected SelfPaced strategy");
}
}
#[test]
fn test_curriculum_example() {
let input = Tensor::zeros(&[1, 10]).expect("operation failed");
let target = Tensor::zeros(&[1]).expect("operation failed");
let example = CurriculumExample::new(input, target, 0.5);
assert_eq!(example.difficulty, 0.5);
assert_eq!(example.weight, 1.0);
assert!(example.metadata.is_empty());
}
#[test]
fn test_curriculum_example_with_metadata() {
let input = Tensor::zeros(&[1, 10]).expect("operation failed");
let target = Tensor::zeros(&[1]).expect("operation failed");
let mut metadata = HashMap::new();
metadata.insert("source".to_string(), "test".to_string());
let example = CurriculumExample::with_metadata(input, target, 0.7, metadata);
assert_eq!(example.difficulty, 0.7);
assert_eq!(
example.metadata.get("source").expect("operation failed"),
"test"
);
}
#[test]
fn test_curriculum_example_with_weight() {
let input = Tensor::zeros(&[1, 10]).expect("operation failed");
let target = Tensor::zeros(&[1]).expect("operation failed");
let example = CurriculumExample::new(input, target, 0.3).with_weight(2.0);
assert_eq!(example.difficulty, 0.3);
assert_eq!(example.weight, 2.0);
}
#[test]
fn test_self_paced_config() {
let config = utils::self_paced_config(0.8, 1.2);
if let CurriculumStrategy::SelfPaced { lambda, gamma } = config.strategy {
assert_eq!(lambda, 0.8);
assert_eq!(gamma, 1.2);
} else {
panic!("Expected SelfPaced strategy");
}
}
#[test]
fn test_competence_based_config() {
let config = utils::competence_based_config(0.85, 0.1);
if let CurriculumStrategy::CompetenceBased {
competence_threshold,
increase_rate,
} = config.strategy
{
assert_eq!(competence_threshold, 0.85);
assert_eq!(increase_rate, 0.1);
} else {
panic!("Expected CompetenceBased strategy");
}
}
#[test]
fn test_baby_steps_config() {
let config = utils::baby_steps_config(0.05, 5);
if let CurriculumStrategy::BabySteps {
step_size,
patience,
} = config.strategy
{
assert_eq!(step_size, 0.05);
assert_eq!(patience, 5);
} else {
panic!("Expected BabySteps strategy");
}
}
#[test]
fn test_predefined_config() {
let levels = vec![0.2, 0.5, 0.8, 1.0];
let durations = vec![1000, 1500, 2000, 2500];
let config = utils::predefined_config(levels.clone(), durations.clone());
if let CurriculumStrategy::Predefined {
difficulty_levels,
level_durations,
} = config.strategy
{
assert_eq!(difficulty_levels, levels);
assert_eq!(level_durations, durations);
} else {
panic!("Expected Predefined strategy");
}
}
#[test]
fn test_anti_curriculum_config() {
let config = utils::anti_curriculum_config();
if let CurriculumStrategy::AntiCurriculum { reverse_pacing } = config.strategy {
assert!(reverse_pacing);
} else {
panic!("Expected AntiCurriculum strategy");
}
}
#[test]
fn test_cyclical_config() {
let config = utils::cyclical_config(1000, 3);
if let CurriculumStrategy::Cyclical {
cycle_length,
num_cycles,
} = config.strategy
{
assert_eq!(cycle_length, 1000);
assert_eq!(num_cycles, 3);
} else {
panic!("Expected Cyclical strategy");
}
}
#[test]
fn test_create_manual_examples() {
let inputs = vec![
Tensor::zeros(&[1, 10]).expect("operation failed"),
Tensor::ones(&[1, 10]).expect("operation failed"),
];
let targets = vec![
Tensor::zeros(&[1]).expect("operation failed"),
Tensor::ones(&[1]).expect("operation failed"),
];
let difficulties = vec![0.2, 0.8];
let examples =
utils::create_manual_examples(inputs, targets, difficulties).expect("operation failed");
assert_eq!(examples.len(), 2);
assert_eq!(examples[0].difficulty, 0.2);
assert_eq!(examples[1].difficulty, 0.8);
}
#[test]
fn test_create_manual_examples_mismatched_lengths() {
let inputs = vec![Tensor::zeros(&[1, 10]).expect("operation failed")];
let targets = vec![Tensor::zeros(&[1]).expect("operation failed")];
let difficulties = vec![0.2, 0.8];
let result = utils::create_manual_examples(inputs, targets, difficulties);
assert!(result.is_err());
}
#[test]
fn test_curriculum_analysis() {
let baseline = vec![0.6, 0.7, 0.75, 0.8];
let curriculum = vec![0.7, 0.8, 0.85, 0.9];
let analysis = utils::analyze_curriculum_effectiveness(&baseline, &curriculum);
assert!((analysis.final_accuracy_improvement - 0.1).abs() < 1e-6); assert!((analysis.baseline_final_accuracy - 0.8).abs() < 1e-6);
assert!((analysis.curriculum_final_accuracy - 0.9).abs() < 1e-6);
assert!(analysis.convergence_speedup > 1.0);
}
}