use ndarray::{Array2, arr2};
use rust_lstm::{
LSTMNetwork, create_basic_trainer, TrainingConfig, EarlyStoppingConfig, EarlyStoppingMetric,
MSELoss, Adam
};
fn main() {
println!("Early Stopping Demonstration");
println!("================================\n");
let (train_data, val_data) = generate_overfitting_data();
println!("Generated {} training sequences and {} validation sequences",
train_data.len(), val_data.len());
demonstrate_validation_early_stopping(&train_data, &val_data);
demonstrate_train_loss_early_stopping(&train_data, &val_data);
demonstrate_no_weight_restoration(&train_data, &val_data);
demonstrate_custom_patience(&train_data, &val_data);
}
fn demonstrate_validation_early_stopping(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
) {
println!("1. VALIDATION LOSS EARLY STOPPING");
println!("==================================");
let network = LSTMNetwork::new(1, 8, 1);
let early_stopping_config = EarlyStoppingConfig {
patience: 5,
min_delta: 1e-4,
restore_best_weights: true,
monitor: EarlyStoppingMetric::ValidationLoss,
};
let training_config = TrainingConfig {
epochs: 100, print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};
let mut trainer = create_basic_trainer(network, 0.01)
.with_config(training_config);
println!("Training with validation loss monitoring (patience=5)...");
trainer.train(train_data, Some(val_data));
if let Some(final_metrics) = trainer.get_latest_metrics() {
println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n",
final_metrics.epoch,
final_metrics.train_loss,
final_metrics.validation_loss.unwrap_or(0.0));
}
}
fn demonstrate_train_loss_early_stopping(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
) {
println!("2. TRAINING LOSS EARLY STOPPING");
println!("===============================");
let network = LSTMNetwork::new(1, 8, 1);
let early_stopping_config = EarlyStoppingConfig {
patience: 8,
min_delta: 1e-5,
restore_best_weights: true,
monitor: EarlyStoppingMetric::TrainLoss,
};
let training_config = TrainingConfig {
epochs: 100,
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};
let mut trainer = create_basic_trainer(network, 0.01)
.with_config(training_config);
println!("Training with training loss monitoring (patience=8)...");
trainer.train(train_data, Some(val_data));
if let Some(final_metrics) = trainer.get_latest_metrics() {
println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n",
final_metrics.epoch,
final_metrics.train_loss,
final_metrics.validation_loss.unwrap_or(0.0));
}
}
fn demonstrate_no_weight_restoration(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
) {
println!("3. EARLY STOPPING WITHOUT WEIGHT RESTORATION");
println!("=============================================");
let network = LSTMNetwork::new(1, 8, 1);
let early_stopping_config = EarlyStoppingConfig {
patience: 5,
min_delta: 1e-4,
restore_best_weights: false, monitor: EarlyStoppingMetric::ValidationLoss,
};
let training_config = TrainingConfig {
epochs: 100,
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};
let mut trainer = create_basic_trainer(network, 0.01)
.with_config(training_config);
println!("Training without weight restoration...");
trainer.train(train_data, Some(val_data));
if let Some(final_metrics) = trainer.get_latest_metrics() {
println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}",
final_metrics.epoch,
final_metrics.train_loss,
final_metrics.validation_loss.unwrap_or(0.0));
println!("Note: Weights are from the last epoch, not the best epoch\n");
}
}
fn demonstrate_custom_patience(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
) {
println!("4. EARLY STOPPING WITH HIGH PATIENCE");
println!("====================================");
let network = LSTMNetwork::new(1, 8, 1);
let early_stopping_config = EarlyStoppingConfig {
patience: 15, min_delta: 1e-6, restore_best_weights: true,
monitor: EarlyStoppingMetric::ValidationLoss,
};
let training_config = TrainingConfig {
epochs: 100,
print_every: 2,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};
let mut trainer = create_basic_trainer(network, 0.01)
.with_config(training_config);
println!("Training with high patience (patience=15)...");
trainer.train(train_data, Some(val_data));
if let Some(final_metrics) = trainer.get_latest_metrics() {
println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n",
final_metrics.epoch,
final_metrics.train_loss,
final_metrics.validation_loss.unwrap_or(0.0));
}
}
fn generate_overfitting_data() -> (Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)>, Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)>) {
let mut train_data = Vec::new();
let mut val_data = Vec::new();
for i in 0..20 {
let mut inputs = Vec::new();
let mut targets = Vec::new();
let phase = i as f64 * 0.1;
for t in 0..10 {
let x = (t as f64 * 0.3 + phase).sin();
let y = ((t + 1) as f64 * 0.3 + phase).sin();
inputs.push(arr2(&[[x]]));
targets.push(arr2(&[[y]]));
}
train_data.push((inputs, targets));
}
for i in 0..5 {
let mut inputs = Vec::new();
let mut targets = Vec::new();
let phase = (i as f64 + 100.0) * 0.1; for t in 0..10 {
let x = (t as f64 * 0.3 + phase).sin();
let y = ((t + 1) as f64 * 0.3 + phase).sin();
inputs.push(arr2(&[[x]]));
targets.push(arr2(&[[y]]));
}
val_data.push((inputs, targets));
}
(train_data, val_data)
}