use std::collections::HashMap;
use tensorlogic_train::{
CrossValidationResults, CrossValidationSplit, KFold, LeaveOneOut, StratifiedKFold,
TimeSeriesSplit,
};
fn main() {
println!("=== Cross-Validation Examples ===\n");
let n_samples = 20;
println!("1. K-Fold Cross-Validation");
println!(" Split data into K equally-sized folds\n");
let kfold = KFold::new(5).expect("unwrap");
println!(" Configuration: {} folds", kfold.num_splits());
println!(" Dataset size: {} samples\n", n_samples);
for fold in 0..3 {
let (train_idx, val_idx) = kfold.get_split(fold, n_samples).expect("unwrap");
println!(" Fold {}:", fold);
println!(" Training: {} samples", train_idx.len());
println!(" Validation: {} samples", val_idx.len());
println!(" Val indices: {:?}\n", val_idx);
}
let kfold_shuffled = KFold::new(5).expect("unwrap").with_shuffle(42);
println!(" With shuffling (seed=42):");
let (_train, val) = kfold_shuffled.get_split(0, n_samples).expect("unwrap");
println!(" Fold 0 validation indices: {:?}\n", val);
println!("2. Stratified K-Fold Cross-Validation");
println!(" Maintains class distribution in each fold\n");
let labels = vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, ];
println!(" Dataset:");
println!(" Class 0: 12 samples (60%)");
println!(" Class 1: 6 samples (30%)");
println!(" Class 2: 2 samples (10%)");
println!(" Total: 20 samples\n");
let stratified = StratifiedKFold::new(5).expect("unwrap");
println!(" Stratified splits (5 folds):");
for fold in 0..3 {
let (_train_idx, val_idx) = stratified
.get_stratified_split(fold, &labels)
.expect("unwrap");
let mut class_counts = HashMap::new();
for &idx in &val_idx {
*class_counts.entry(labels[idx]).or_insert(0) += 1;
}
println!(" Fold {}:", fold);
println!(" Validation: {} samples", val_idx.len());
println!(
" Class 0: {} samples",
class_counts.get(&0).unwrap_or(&0)
);
println!(
" Class 1: {} samples",
class_counts.get(&1).unwrap_or(&0)
);
println!(
" Class 2: {} samples\n",
class_counts.get(&2).unwrap_or(&0)
);
}
println!(" ✓ Class proportions maintained in each fold\n");
println!("3. Time Series Split");
println!(" Respects temporal order (no data leakage from future)\n");
let ts_split = TimeSeriesSplit::new(5).expect("unwrap");
println!(" Temporal dataset: 30 time steps");
println!(" Configuration: 5 splits\n");
for fold in 0..5 {
let (train_idx, val_idx) = ts_split.get_split(fold, 30).expect("unwrap");
if !train_idx.is_empty() && !val_idx.is_empty() {
let train_range = format!(
"{}-{}",
train_idx.first().expect("unwrap"),
train_idx.last().expect("unwrap")
);
let val_range = format!(
"{}-{}",
val_idx.first().expect("unwrap"),
val_idx.last().expect("unwrap")
);
println!(" Fold {}:", fold);
println!(
" Train: steps {} ({} samples)",
train_range,
train_idx.len()
);
println!(" Val: steps {} ({} samples)", val_range, val_idx.len());
println!(" ✓ Train always before validation\n");
}
}
let ts_split_window = TimeSeriesSplit::new(5)
.expect("unwrap")
.with_max_train_size(10);
println!(" With sliding window (max_train_size=10):");
for fold in 2..4 {
let (train_idx, _val_idx) = ts_split_window.get_split(fold, 30).expect("unwrap");
println!(" Fold {}: Train size = {} (≤10)", fold, train_idx.len());
}
println!("\n4. Leave-One-Out Cross-Validation (LOO)");
println!(" Use each sample once as validation\n");
let loo = LeaveOneOut::new();
let small_n = 8;
println!(" Dataset size: {} samples", small_n);
println!(" Number of folds: {} (= n_samples)\n", small_n);
for fold in 0..4 {
let (train_idx, val_idx) = loo.get_split(fold, small_n).expect("unwrap");
println!(" Fold {}:", fold);
println!(" Train: {} samples", train_idx.len());
println!(" Val: 1 sample (index {})", val_idx[0]);
}
println!("\n Use case: Very small datasets (n < 50)");
println!(" Warning: Computationally expensive for large n\n");
println!("5. Cross-Validation Results");
println!(" Aggregate and analyze results across folds\n");
let mut cv_results = CrossValidationResults::new();
let fold_scores = [0.85, 0.87, 0.83, 0.86, 0.84];
println!(" Simulating 5-fold CV:");
for (fold, &score) in fold_scores.iter().enumerate() {
let mut metrics = HashMap::new();
metrics.insert("accuracy".to_string(), score);
metrics.insert("loss".to_string(), 1.0 - score);
metrics.insert("f1_score".to_string(), score - 0.02);
cv_results.add_fold(score, metrics);
println!(" Fold {}: accuracy = {:.3}", fold, score);
}
println!("\n Aggregated Results:");
println!(" Mean accuracy: {:.4}", cv_results.mean_score());
println!(" Std accuracy: {:.4}", cv_results.std_score());
println!(
" Mean F1: {:.4}",
cv_results.mean_metric("f1_score").expect("unwrap")
);
println!(
" 95% CI: [{:.4}, {:.4}]",
cv_results.mean_score() - 1.96 * cv_results.std_score(),
cv_results.mean_score() + 1.96 * cv_results.std_score()
);
println!("\n=== Practical Workflow ===\n");
println!("```rust");
println!("// 1. Choose CV strategy");
println!("let cv_strategy = KFold::new(5)?.with_shuffle(42);");
println!("// or");
println!("let cv_strategy = StratifiedKFold::new(5)?;");
println!("// or");
println!("let cv_strategy = TimeSeriesSplit::new(5)?;");
println!();
println!("// 2. Initialize results tracker");
println!("let mut cv_results = CrossValidationResults::new();");
println!();
println!("// 3. Run cross-validation loop");
println!("for fold in 0..cv_strategy.num_splits() {{");
println!(" // Get train/val split");
println!(" let (train_idx, val_idx) = cv_strategy.get_split(fold, n_samples)?;");
println!(" ");
println!(" // Extract data");
println!(" let train_data = data.select(Axis(0), &train_idx);");
println!(" let val_data = data.select(Axis(0), &val_idx);");
println!(" ");
println!(" // Train model");
println!(" let model = train_model(&train_data)?;");
println!(" ");
println!(" // Evaluate");
println!(" let score = evaluate(&model, &val_data);");
println!(" let metrics = compute_metrics(&model, &val_data);");
println!(" ");
println!(" // Record results");
println!(" cv_results.add_fold(score, metrics);");
println!("}}");
println!();
println!("// 4. Analyze results");
println!("println!(\"Mean: {{:.4}} ± {{:.4}}\",");
println!(" cv_results.mean_score(),");
println!(" cv_results.std_score()");
println!(");");
println!("```");
println!("\n=== Strategy Selection Guide ===\n");
println!("K-Fold:");
println!(" • Use for: General-purpose CV");
println!(" • Pros: Simple, efficient, widely used");
println!(" • Cons: May not preserve class distribution");
println!(" • Recommended: 5 or 10 folds\n");
println!("Stratified K-Fold:");
println!(" • Use for: Imbalanced classification");
println!(" • Pros: Maintains class proportions");
println!(" • Cons: Slightly more complex");
println!(" • Recommended: Always use for classification\n");
println!("Time Series Split:");
println!(" • Use for: Temporal data");
println!(" • Pros: No data leakage from future");
println!(" • Cons: Unequal fold sizes");
println!(" • Recommended: Financial, forecasting tasks\n");
println!("Leave-One-Out:");
println!(" • Use for: Very small datasets (n < 50)");
println!(" • Pros: Maximum training data per fold");
println!(" • Cons: Computationally expensive");
println!(" • Recommended: Only when n is very small\n");
println!("=== Best Practices ===");
println!("1. Use stratified K-fold for classification (maintains class balance)");
println!("2. Use time series split for temporal data (prevents data leakage)");
println!("3. Use 5-10 folds for K-fold (good bias-variance tradeoff)");
println!("4. Always shuffle data (except time series) to reduce sampling bias");
println!("5. Report mean ± std dev for transparency");
println!("6. Use same CV strategy when comparing models");
}