use super::*;
use crate::training::checkpoint::CheckpointMetadata;
use crate::training::metrics::ConfusionMatrix;
use crate::training::trainer::CallbackSignal;
#[cfg(test)]
mod basic_tests {
use super::*;
#[test]
fn test_trainer_config_creation() {
let config = TrainerConfig::default();
assert_eq!(config.epochs, 10);
assert_eq!(config.log_frequency, 100);
assert_eq!(config.validation_frequency, 1);
assert_eq!(config.device, "cpu");
assert!(!config.use_mixed_precision);
assert_eq!(config.accumulation_steps, 1);
}
#[test]
fn test_training_state() {
let mut state = TrainingState::<f32>::new(5);
assert_eq!(state.total_epochs, 5);
assert_eq!(state.current_epoch, 0);
assert!(!state.is_completed());
assert_eq!(state.progress(), 0.0);
let epoch_state = EpochState::new(0);
state.add_epoch(epoch_state);
assert_eq!(state.current_epoch, 1);
assert_eq!(state.progress(), 20.0);
}
#[test]
fn test_epoch_state() {
let mut epoch_state: EpochState<f32> = EpochState::new(5);
assert_eq!(epoch_state.epoch, 5);
assert!(epoch_state.train_metrics.is_none());
assert!(epoch_state.val_metrics.is_none());
assert!(epoch_state.batch_history.is_empty());
epoch_state.set_metadata("learning_rate".to_string(), "0.001".to_string());
assert_eq!(
epoch_state.metadata.get("learning_rate"),
Some(&"0.001".to_string())
);
}
#[test]
fn test_batch_state() {
let mut batch_state = BatchState::new(42);
assert_eq!(batch_state.batch, 42);
assert!(batch_state.loss.is_none());
assert!(batch_state.metrics.is_empty());
assert!(batch_state.batch_size.is_none());
batch_state.set_metric("accuracy".to_string(), 0.85);
batch_state.set_metadata("device".to_string(), "cuda".to_string());
assert_eq!(batch_state.get_metric("accuracy"), Some(0.85));
assert_eq!(
batch_state.metadata.get("device"),
Some(&"cuda".to_string())
);
}
#[test]
fn test_metrics_collector() {
let _collector = MetricsCollector::<f32>::new();
}
#[test]
fn test_early_stopping_callback() {
let _early_stopping: EarlyStopping<f32> = EarlyStopping::monitor_val_loss(3, 0.01);
}
#[test]
fn test_lr_scheduler() {
let _scheduler = LearningRateScheduler::exponential_decay(0.001, 0.95);
}
#[test]
fn test_progress_bar() {
let _progress = ProgressBar::simple();
}
#[test]
fn test_confusion_matrix() {
let mut matrix = ConfusionMatrix::new();
matrix.true_positives = 80;
matrix.false_positives = 10;
matrix.true_negatives = 90;
matrix.false_negatives = 20;
assert_eq!(matrix.total(), 200);
assert_eq!(matrix.accuracy(), 0.85);
let precision = matrix.precision();
let recall = matrix.recall();
assert!(precision > 0.0 && precision <= 1.0);
assert!(recall > 0.0 && recall <= 1.0);
let f1 = matrix.f1_score();
assert!(f1 > 0.0 && f1 <= 1.0);
}
#[test]
fn test_checkpoint_config() {
let config = SaveConfig::default();
assert_eq!(config.prefix, "model");
assert_eq!(config.max_checkpoints, 5);
assert!(!config.save_best_only);
assert_eq!(config.monitor, "val_loss");
assert!(!config.mode_max);
}
#[test]
fn test_checkpoint_metadata() {
let mut metadata = CheckpointMetadata::new(5, 0.25);
assert_eq!(metadata.epoch, 5);
assert_eq!(metadata.train_loss, 0.25);
assert!(metadata.val_loss.is_none());
assert!(metadata.timestamp > 0);
metadata.set_metric("accuracy".to_string(), 0.85);
metadata.set_extra("model_type".to_string(), "transformer".to_string());
assert_eq!(metadata.metrics.get("accuracy"), Some(&0.85));
assert_eq!(
metadata.extra.get("model_type"),
Some(&"transformer".to_string())
);
}
}
#[cfg(test)]
mod performance_tests {
use super::*;
use std::time::{Duration, Instant};
#[test]
fn test_state_creation_performance() {
let start_time = Instant::now();
for _ in 0..1000 {
let _state = TrainingState::<f32>::new(100);
}
let duration = start_time.elapsed();
assert!(duration < Duration::from_millis(100));
println!("State creation performance: {:?}", duration);
}
#[test]
fn test_metrics_collector_performance() {
let start_time = Instant::now();
for _ in 0..1000 {
let _collector = MetricsCollector::<f32>::new();
}
let duration = start_time.elapsed();
assert!(duration < Duration::from_millis(50));
println!("MetricsCollector creation performance: {:?}", duration);
}
#[test]
fn test_callback_creation_performance() {
let start_time = Instant::now();
for _ in 0..1000 {
let _early_stopping: EarlyStopping<f32> = EarlyStopping::monitor_val_loss(5, 0.01);
let _scheduler = LearningRateScheduler::exponential_decay(0.001, 0.95);
let _progress = ProgressBar::simple();
}
let duration = start_time.elapsed();
assert!(duration < Duration::from_millis(100));
println!("Callback creation performance: {:?}", duration);
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
#[test]
fn test_training_state_progression() {
let mut state = TrainingState::<f32>::new(3);
for epoch in 0..3 {
let mut epoch_state = EpochState::new(epoch);
epoch_state.learning_rate = Some(0.001 * 0.9_f64.powi(epoch as i32));
let mut batch_state = BatchState::new(0);
batch_state.loss = Some(0.5 - epoch as f64 * 0.1);
batch_state.set_metric("accuracy".to_string(), 0.5 + epoch as f64 * 0.15);
epoch_state.add_batch(batch_state);
state.add_epoch(epoch_state);
}
assert!(state.is_completed());
assert_eq!(state.progress(), 100.0);
assert_eq!(state.epoch_history.len(), 3);
for (i, epoch) in state.epoch_history.iter().enumerate() {
let expected_lr = 0.001 * 0.9_f64.powi(i as i32);
assert!((epoch.learning_rate.unwrap() - expected_lr).abs() < 1e-6);
}
}
#[test]
fn test_early_stopping_logic() {
let mut early_stopping: EarlyStopping<f32> = EarlyStopping::monitor_val_loss(3, 0.01);
let mut state = TrainingState::<f32>::new(10);
early_stopping.on_train_begin(&mut state).unwrap();
let mut epoch_state = EpochState::new(0);
let val_metrics = trainer::EpochMetrics {
total_loss: 0.5,
avg_loss: 0.5,
batch_count: 10,
};
epoch_state.val_metrics = Some(val_metrics);
let result = early_stopping
.on_epoch_end(&mut state, &epoch_state)
.unwrap();
assert!(matches!(result, Some(CallbackSignal::Continue)));
let mut epoch_state = EpochState::new(1);
let val_metrics = trainer::EpochMetrics {
total_loss: 0.3,
avg_loss: 0.3,
batch_count: 10,
};
epoch_state.val_metrics = Some(val_metrics);
let result = early_stopping
.on_epoch_end(&mut state, &epoch_state)
.unwrap();
assert!(matches!(result, Some(CallbackSignal::Continue)));
}
#[test]
fn test_learning_rate_scheduler_logic() {
let mut scheduler = LearningRateScheduler::exponential_decay(0.1, 0.9);
let mut state = TrainingState::<f32>::new(5);
scheduler.on_train_begin(&mut state).unwrap();
for epoch in 0..5 {
let mut epoch_state = EpochState::new(epoch);
scheduler
.on_epoch_begin(&mut state, &mut epoch_state)
.unwrap();
assert!(epoch_state.learning_rate.is_some());
}
}
#[test]
fn test_checkpoint_metadata_progression() {
let mut metadata = CheckpointMetadata::new(0, 1.0);
metadata.val_loss = Some(0.8);
metadata.learning_rate = Some(0.001);
metadata.set_metric("accuracy".to_string(), 0.75);
metadata.set_metric("precision".to_string(), 0.78);
metadata.set_metric("recall".to_string(), 0.72);
metadata.set_extra("optimizer".to_string(), "adam".to_string());
metadata.set_extra("batch_size".to_string(), "32".to_string());
assert_eq!(metadata.epoch, 0);
assert_eq!(metadata.train_loss, 1.0);
assert_eq!(metadata.val_loss, Some(0.8));
assert_eq!(metadata.learning_rate, Some(0.001));
assert_eq!(metadata.metrics.len(), 3);
assert_eq!(metadata.extra.len(), 2);
assert_eq!(metadata.metrics.get("accuracy"), Some(&0.75));
assert_eq!(metadata.extra.get("optimizer"), Some(&"adam".to_string()));
}
}
pub mod simple_benchmarks {
use super::*;
use std::time::{Duration, Instant};
pub fn benchmark_object_creation(iterations: usize) -> Duration {
let start_time = Instant::now();
for _ in 0..iterations {
let _state = TrainingState::<f32>::new(100);
let _collector = MetricsCollector::<f32>::new();
let _early_stopping: EarlyStopping<f32> = EarlyStopping::monitor_val_loss(5, 0.01);
}
start_time.elapsed()
}
pub fn benchmark_state_management(num_epochs: usize) -> Duration {
let start_time = Instant::now();
let mut state = TrainingState::<f32>::new(num_epochs);
for epoch in 0..num_epochs {
let mut epoch_state = EpochState::new(epoch);
for batch in 0..10 {
let mut batch_state = BatchState::new(batch);
batch_state.loss = Some(0.5 - batch as f64 * 0.01);
batch_state.set_metric("accuracy".to_string(), 0.8 + batch as f64 * 0.01);
epoch_state.add_batch(batch_state);
}
state.add_epoch(epoch_state);
}
start_time.elapsed()
}
}
#[cfg(test)]
mod benchmark_tests {
use super::simple_benchmarks::*;
use std::time::Duration;
#[test]
fn test_object_creation_benchmark() {
let duration = benchmark_object_creation(100);
println!("Object creation benchmark (100 iterations): {:?}", duration);
assert!(duration < Duration::from_millis(100));
}
#[test]
fn test_state_management_benchmark() {
let duration = benchmark_state_management(10);
println!("State management benchmark (10 epochs): {:?}", duration);
assert!(duration < Duration::from_millis(50));
}
}