#[cfg(test)]
mod tests {
use super::super::OnlineLinearRegression;
use super::*;
#[test]
fn test_orchestrator_creation() {
let model = OnlineLinearRegression::new(3);
let orchestrator = RetrainOrchestrator::new(model, 3);
assert_eq!(orchestrator.buffer_size(), 0);
assert_eq!(orchestrator.stats().samples_observed, 0);
}
#[test]
fn test_orchestrator_observe_stable() {
let model = OnlineLinearRegression::new(2);
let mut orchestrator = RetrainOrchestrator::new(model, 2);
for _ in 0..50 {
let result = orchestrator.observe(&[1.0, 2.0], &[5.0], &[5.0]).unwrap();
assert!(result == ObserveResult::Stable || result == ObserveResult::Skipped);
}
assert_eq!(orchestrator.stats().retrain_count, 0);
}
#[test]
fn test_orchestrator_observe_drift() {
let config = RetrainConfig {
min_samples: 10,
incremental_updates: false,
curriculum_learning: false,
retrain_epochs: 1,
..Default::default()
};
let model = OnlineLinearRegression::new(2);
let detector = ADWIN::with_delta(0.1); let mut orchestrator = RetrainOrchestrator::with_config(model, detector, 2, config);
for i in 0..50 {
orchestrator
.observe(&[i as f64, 0.0], &[0.0], &[0.1]) .unwrap();
}
for i in 0..150 {
orchestrator
.observe(&[(50 + i) as f64, 1.0], &[1.0], &[0.0]) .unwrap();
}
let stats = orchestrator.stats();
assert!(
stats.samples_observed >= 100,
"Should have observed at least 100 samples, got {}",
stats.samples_observed
);
assert!(
stats.drift_status == DriftStatus::Warning
|| stats.drift_status == DriftStatus::Drift
|| stats.retrain_count > 0
|| stats.buffer_size > 0,
"Expected some activity, got stats={:?}",
stats
);
}
#[test]
fn test_orchestrator_force_retrain() {
let config = RetrainConfig {
min_samples: 5,
curriculum_learning: false,
retrain_epochs: 1,
..Default::default()
};
let model = OnlineLinearRegression::new(2);
let detector = ADWIN::new();
let mut orchestrator = RetrainOrchestrator::with_config(model, detector, 2, config);
for i in 0..10 {
orchestrator
.observe(&[i as f64, (i * 2) as f64], &[(i * 3) as f64], &[0.0])
.unwrap();
}
assert!(orchestrator.buffer_size() > 0);
orchestrator.force_retrain().unwrap();
assert_eq!(orchestrator.stats().retrain_count, 1);
}
#[test]
fn test_orchestrator_should_retrain() {
let config = RetrainConfig {
min_samples: 5,
..Default::default()
};
let model = OnlineLinearRegression::new(2);
let detector = ADWIN::with_delta(1.0); let orchestrator = RetrainOrchestrator::with_config(model, detector, 2, config);
assert!(!orchestrator.should_retrain());
}
#[test]
fn test_orchestrator_builder() {
let model = OnlineLinearRegression::new(3);
let orchestrator = OrchestratorBuilder::new(model, 3)
.min_samples(50)
.max_buffer_size(1000)
.incremental_updates(false)
.curriculum_learning(true)
.curriculum_stages(3)
.learning_rate(0.05)
.retrain_epochs(5)
.adwin_delta(0.01)
.build();
assert_eq!(orchestrator.config().min_samples, 50);
assert_eq!(orchestrator.config().max_buffer_size, 1000);
assert!(!orchestrator.config().incremental_updates);
assert!(orchestrator.config().curriculum_learning);
}
#[test]
fn test_observe_result_equality() {
assert_eq!(ObserveResult::Stable, ObserveResult::Stable);
assert_ne!(ObserveResult::Stable, ObserveResult::Warning);
assert_ne!(ObserveResult::Warning, ObserveResult::Retrained);
}
#[test]
fn test_retrain_config_default() {
let config = RetrainConfig::default();
assert_eq!(config.min_samples, 100);
assert_eq!(config.max_buffer_size, 10_000);
assert!(config.incremental_updates);
assert!(config.curriculum_learning);
}
#[test]
fn test_orchestrator_stats_default() {
let stats = OrchestratorStats::default();
assert_eq!(stats.samples_observed, 0);
assert_eq!(stats.retrain_count, 0);
assert_eq!(stats.buffer_size, 0);
}
#[test]
fn test_compute_error_regression() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert!(!orchestrator.compute_error(&[10.0], &[10.5]));
assert!(orchestrator.compute_error(&[10.0], &[15.0]));
}
#[test]
fn test_compute_error_classification() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert!(!orchestrator.compute_error(&[0.0], &[0.3]));
assert!(orchestrator.compute_error(&[0.0], &[0.8]));
}
#[test]
fn test_compute_error_multiclass() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert!(!orchestrator.compute_error(&[0.1, 0.9, 0.0], &[0.2, 0.7, 0.1]));
assert!(orchestrator.compute_error(&[0.1, 0.9, 0.0], &[0.8, 0.1, 0.1]));
}
#[test]
fn test_orchestrator_model_access() {
let model = OnlineLinearRegression::new(2);
let mut orchestrator = RetrainOrchestrator::new(model, 2);
assert_eq!(orchestrator.model().n_samples_seen(), 0);
orchestrator
.model_mut()
.partial_fit(&[1.0, 2.0], &[3.0], None)
.unwrap();
assert!(orchestrator.model().n_samples_seen() > 0);
}
#[test]
fn test_orchestrator_curriculum_retraining() {
let config = RetrainConfig {
min_samples: 10,
curriculum_learning: true,
curriculum_stages: 3,
retrain_epochs: 3,
incremental_updates: false,
..Default::default()
};
let model = OnlineLinearRegression::new(2);
let detector = ADWIN::new();
let mut orchestrator = RetrainOrchestrator::with_config(model, detector, 2, config);
for i in 0..20 {
orchestrator
.observe(&[i as f64, (i * 2) as f64], &[(i * 3) as f64], &[0.0])
.unwrap();
}
orchestrator.force_retrain().unwrap();
assert_eq!(orchestrator.stats().retrain_count, 1);
}
#[test]
fn test_compute_error_empty_target() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert!(orchestrator.compute_error(&[], &[1.0]));
}
#[test]
fn test_compute_error_empty_prediction() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert!(orchestrator.compute_error(&[1.0], &[]));
}
#[test]
fn test_compute_error_both_empty() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert!(orchestrator.compute_error(&[], &[]));
}
#[test]
fn test_compute_error_regression_small_relative_error() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert!(!orchestrator.compute_error(&[100.0], &[105.0]));
}
#[test]
fn test_compute_error_regression_large_relative_error() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert!(orchestrator.compute_error(&[100.0], &[150.0]));
}
#[test]
fn test_drift_status_method() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
assert_eq!(orchestrator.drift_status(), DriftStatus::Stable);
}
#[test]
fn test_observe_result_debug_clone() {
let result = ObserveResult::Stable;
let debug_str = format!("{:?}", result);
assert!(debug_str.contains("Stable"));
let cloned = result;
assert_eq!(cloned, ObserveResult::Stable);
let warning = ObserveResult::Warning;
assert_eq!(format!("{:?}", warning), "Warning");
let retrained = ObserveResult::Retrained;
assert_eq!(format!("{:?}", retrained), "Retrained");
let skipped = ObserveResult::Skipped;
assert_eq!(format!("{:?}", skipped), "Skipped");
}
#[test]
fn test_retrain_config_debug_clone() {
let config = RetrainConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("RetrainConfig"));
let cloned = config.clone();
assert_eq!(cloned.min_samples, 100);
assert_eq!(cloned.retrain_epochs, 10);
}
#[test]
fn test_orchestrator_stats_debug_clone() {
let stats = OrchestratorStats::default();
let debug_str = format!("{:?}", stats);
assert!(debug_str.contains("OrchestratorStats"));
let cloned = stats.clone();
assert_eq!(cloned.samples_observed, 0);
assert_eq!(cloned.retrain_count, 0);
}
#[test]
fn test_force_retrain_empty_buffer() {
let config = RetrainConfig {
min_samples: 5,
curriculum_learning: false,
retrain_epochs: 1,
..Default::default()
};
let model = OnlineLinearRegression::new(2);
let detector = ADWIN::new();
let mut orchestrator = RetrainOrchestrator::with_config(model, detector, 2, config);
orchestrator.force_retrain().unwrap();
}
#[test]
fn test_orchestrator_with_no_incremental_updates() {
let config = RetrainConfig {
min_samples: 5,
incremental_updates: false,
curriculum_learning: false,
retrain_epochs: 1,
..Default::default()
};
let model = OnlineLinearRegression::new(2);
let detector = ADWIN::with_delta(1.0); let mut orchestrator = RetrainOrchestrator::with_config(model, detector, 2, config);
let result = orchestrator.observe(&[1.0, 2.0], &[3.0], &[3.0]).unwrap();
assert!(result == ObserveResult::Stable || result == ObserveResult::Skipped);
}
#[test]
fn test_orchestrator_detector_access() {
let model = OnlineLinearRegression::new(2);
let orchestrator = RetrainOrchestrator::new(model, 2);
let detector = orchestrator.detector();
let debug_str = format!("{:?}", detector);
assert!(debug_str.contains("ADWIN"));
}
#[test]
fn test_orchestrator_retrain_clears_and_keeps_recent() {
let config = RetrainConfig {
min_samples: 5,
curriculum_learning: false,
retrain_epochs: 1,
..Default::default()
};
let model = OnlineLinearRegression::new(2);
let detector = ADWIN::new();
let mut orchestrator = RetrainOrchestrator::with_config(model, detector, 2, config);
for i in 0..20 {
orchestrator
.observe(&[i as f64, (i * 2) as f64], &[(i * 3) as f64], &[0.0])
.unwrap();
}
let pre_retrain_size = orchestrator.buffer_size();
assert!(pre_retrain_size > 0);
orchestrator.force_retrain().unwrap();
assert!(orchestrator.buffer_size() <= pre_retrain_size);
assert_eq!(orchestrator.stats().retrain_count, 1);
assert_eq!(orchestrator.stats().samples_since_retrain, 0);
}
#[test]
fn test_orchestrator_standard_retrain_without_curriculum() {
let config = RetrainConfig {
min_samples: 5,
curriculum_learning: false,
retrain_epochs: 2,
incremental_updates: false,
..Default::default()
};
let model = OnlineLinearRegression::new(2);
let detector = ADWIN::new();
let mut orchestrator = RetrainOrchestrator::with_config(model, detector, 2, config);
for i in 0..15 {
orchestrator
.observe(&[i as f64, (i * 2) as f64], &[(i * 3) as f64], &[0.0])
.unwrap();
}
orchestrator.force_retrain().unwrap();
assert_eq!(orchestrator.stats().retrain_count, 1);
assert!(orchestrator.stats().last_retrain_samples > 0);
}
}