use crate::loss::LossFunction;
use crate::nn::TrainableNeuron;
use crate::optimizer::Optimizer;
use std::time::Instant;
#[derive(Clone, Debug)]
pub struct TrainingExample {
pub input: f32,
pub target: f32,
}
#[derive(Clone)]
pub struct Dataset {
pub examples: Vec<TrainingExample>,
}
impl Dataset {
pub fn new(examples: Vec<TrainingExample>) -> Self {
Self { examples }
}
pub fn len(&self) -> usize {
self.examples.len()
}
pub fn is_empty(&self) -> bool {
self.examples.is_empty()
}
pub fn shuffle(&mut self) {
use rand::seq::SliceRandom;
let mut rng = rand::rng();
self.examples.shuffle(&mut rng);
}
pub fn split(&self, train_ratio: f32) -> (Dataset, Dataset) {
let split_idx = (self.examples.len() as f32 * train_ratio) as usize;
let train_examples = self.examples[..split_idx].to_vec();
let val_examples = self.examples[split_idx..].to_vec();
(Dataset::new(train_examples), Dataset::new(val_examples))
}
}
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub epochs: usize,
pub batch_size: usize,
pub learning_rate: f32,
pub validation_split: f32,
pub early_stopping_patience: Option<usize>,
pub verbose: bool,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
epochs: 100,
batch_size: 32,
learning_rate: 0.01,
validation_split: 0.2,
early_stopping_patience: Some(10),
verbose: true,
}
}
}
#[derive(Clone, Debug)]
pub struct EpochMetrics {
pub epoch: usize,
pub train_loss: f32,
pub train_accuracy: f32,
pub val_loss: f32,
pub val_accuracy: f32,
pub duration_ms: f32,
}
#[derive(Clone, Debug)]
pub struct TrainingHistory {
pub metrics: Vec<EpochMetrics>,
}
impl TrainingHistory {
pub fn new() -> Self {
Self {
metrics: Vec::new(),
}
}
pub fn add_epoch(&mut self, metrics: EpochMetrics) {
self.metrics.push(metrics);
}
pub fn best_val_loss(&self) -> Option<f32> {
self.metrics
.iter()
.map(|m| m.val_loss)
.fold(None, |acc, loss| match acc {
Some(best) => Some(best.min(loss)),
None => Some(loss),
})
}
pub fn best_val_accuracy(&self) -> Option<f32> {
self.metrics
.iter()
.map(|m| m.val_accuracy)
.fold(None, |acc, acc_val| match acc {
Some(best) => Some(best.max(acc_val)),
None => Some(acc_val),
})
}
}
pub struct Trainer {
config: TrainingConfig,
}
impl Trainer {
pub fn new(config: TrainingConfig) -> Self {
Self { config }
}
pub fn train<L: LossFunction, O: Optimizer>(
&self,
network: &mut TrainableNeuron,
dataset: &Dataset,
loss_fn: &L,
optimizer: &mut O,
) -> TrainingHistory {
let mut history = TrainingHistory::new();
let (train_set, val_set) = dataset.split(1.0 - self.config.validation_split);
let mut best_val_loss = f32::INFINITY;
let mut patience_counter = 0;
if self.config.verbose {
println!(
"🚀 Starting training with {} examples ({} train, {} val)",
dataset.len(),
train_set.len(),
val_set.len()
);
println!(
" Epochs: {}, Learning Rate: {:.4}",
self.config.epochs, self.config.learning_rate
);
}
for epoch in 0..self.config.epochs {
let epoch_start = Instant::now();
let (train_loss, train_accuracy) =
self.train_epoch(network, &train_set, loss_fn, optimizer);
let (val_loss, val_accuracy) = self.validate_epoch(network, &val_set, loss_fn);
let duration = epoch_start.elapsed().as_secs_f32() * 1000.0;
let metrics = EpochMetrics {
epoch: epoch + 1,
train_loss,
train_accuracy,
val_loss,
val_accuracy,
duration_ms: duration,
};
history.add_epoch(metrics.clone());
if self.config.verbose {
println!(
"Epoch {:3}: train_loss={:.4}, train_acc={:.2}%, val_loss={:.4}, val_acc={:.2}% ({:.1}ms)",
epoch + 1,
train_loss,
train_accuracy * 100.0,
val_loss,
val_accuracy * 100.0,
duration
);
}
if let Some(patience) = self.config.early_stopping_patience {
if val_loss < best_val_loss {
best_val_loss = val_loss;
patience_counter = 0;
} else {
patience_counter += 1;
if patience_counter >= patience {
if self.config.verbose {
println!(
"🛑 Early stopping at epoch {} (patience: {})",
epoch + 1,
patience
);
}
break;
}
}
}
}
if self.config.verbose {
println!("✅ Training completed!");
if let Some(best_loss) = history.best_val_loss() {
println!(" Best validation loss: {:.4}", best_loss);
}
if let Some(best_acc) = history.best_val_accuracy() {
println!(" Best validation accuracy: {:.2}%", best_acc * 100.0);
}
}
history
}
fn train_epoch<L: LossFunction, O: Optimizer>(
&self,
network: &mut TrainableNeuron,
dataset: &Dataset,
loss_fn: &L,
optimizer: &mut O,
) -> (f32, f32) {
let mut total_loss = 0.0;
let mut correct_predictions = 0;
for example in &dataset.examples {
let prediction = network.forward(example.input);
let loss = loss_fn.forward(prediction, example.target);
total_loss += loss;
let error_tolerance = 0.1;
if (prediction - example.target).abs() <= error_tolerance {
correct_predictions += 1;
}
optimizer.zero_grad(network.parameters_mut());
let _loss = network.backward(example.target);
optimizer.step(network.parameters_mut());
}
let avg_loss = total_loss / dataset.len() as f32;
let accuracy = correct_predictions as f32 / dataset.len() as f32;
(avg_loss, accuracy)
}
fn validate_epoch<L: LossFunction>(
&self,
network: &mut TrainableNeuron,
dataset: &Dataset,
loss_fn: &L,
) -> (f32, f32) {
let mut total_loss = 0.0;
let mut correct_predictions = 0;
for example in &dataset.examples {
let prediction = network.forward(example.input);
let loss = loss_fn.forward(prediction, example.target);
total_loss += loss;
let error_tolerance = 0.1;
if (prediction - example.target).abs() <= error_tolerance {
correct_predictions += 1;
}
}
let avg_loss = total_loss / dataset.len() as f32;
let accuracy = correct_predictions as f32 / dataset.len() as f32;
(avg_loss, accuracy)
}
}
pub fn evaluate_model<L: LossFunction>(
network: &mut TrainableNeuron,
dataset: &Dataset,
loss_fn: &L,
) -> (f32, f32) {
let mut total_loss = 0.0;
let mut total_error = 0.0;
for example in &dataset.examples {
let prediction = network.forward(example.input);
let loss = loss_fn.forward(prediction, example.target);
total_loss += loss;
let error = (prediction - example.target).abs();
total_error += error;
}
let avg_loss = total_loss / dataset.len() as f32;
let mae = total_error / dataset.len() as f32;
(avg_loss, mae)
}