eenn 0.1.0

A hybrid neural-symbolic constraint solver with cognitive reasoning capabilities
Documentation
//! Training Infrastructure for Neural Networks
//!
//! Provides training loops, data handling, and evaluation utilities

use crate::loss::LossFunction;
use crate::nn::TrainableNeuron;
use crate::optimizer::Optimizer;
use std::time::Instant;

/// Training data point
#[derive(Clone, Debug)]
pub struct TrainingExample {
    pub input: f32,
    pub target: f32,
}

/// Dataset for training
#[derive(Clone)]
pub struct Dataset {
    pub examples: Vec<TrainingExample>,
}

impl Dataset {
    pub fn new(examples: Vec<TrainingExample>) -> Self {
        Self { examples }
    }

    pub fn len(&self) -> usize {
        self.examples.len()
    }

    pub fn is_empty(&self) -> bool {
        self.examples.is_empty()
    }

    pub fn shuffle(&mut self) {
        use rand::seq::SliceRandom;
        let mut rng = rand::rng();
        self.examples.shuffle(&mut rng);
    }

    pub fn split(&self, train_ratio: f32) -> (Dataset, Dataset) {
        let split_idx = (self.examples.len() as f32 * train_ratio) as usize;
        let train_examples = self.examples[..split_idx].to_vec();
        let val_examples = self.examples[split_idx..].to_vec();

        (Dataset::new(train_examples), Dataset::new(val_examples))
    }
}

/// Training configuration
#[derive(Debug, Clone)]
pub struct TrainingConfig {
    pub epochs: usize,
    pub batch_size: usize,
    pub learning_rate: f32,
    pub validation_split: f32,
    pub early_stopping_patience: Option<usize>,
    pub verbose: bool,
}

impl Default for TrainingConfig {
    fn default() -> Self {
        Self {
            epochs: 100,
            batch_size: 32,
            learning_rate: 0.01,
            validation_split: 0.2,
            early_stopping_patience: Some(10),
            verbose: true,
        }
    }
}

/// Training metrics for one epoch
#[derive(Clone, Debug)]
pub struct EpochMetrics {
    pub epoch: usize,
    pub train_loss: f32,
    pub train_accuracy: f32,
    pub val_loss: f32,
    pub val_accuracy: f32,
    pub duration_ms: f32,
}

/// Training history
#[derive(Clone, Debug)]
pub struct TrainingHistory {
    pub metrics: Vec<EpochMetrics>,
}

impl TrainingHistory {
    pub fn new() -> Self {
        Self {
            metrics: Vec::new(),
        }
    }

    pub fn add_epoch(&mut self, metrics: EpochMetrics) {
        self.metrics.push(metrics);
    }

    pub fn best_val_loss(&self) -> Option<f32> {
        self.metrics
            .iter()
            .map(|m| m.val_loss)
            .fold(None, |acc, loss| match acc {
                Some(best) => Some(best.min(loss)),
                None => Some(loss),
            })
    }

    pub fn best_val_accuracy(&self) -> Option<f32> {
        self.metrics
            .iter()
            .map(|m| m.val_accuracy)
            .fold(None, |acc, acc_val| match acc {
                Some(best) => Some(best.max(acc_val)),
                None => Some(acc_val),
            })
    }
}

/// Neural network trainer
pub struct Trainer {
    config: TrainingConfig,
}

impl Trainer {
    pub fn new(config: TrainingConfig) -> Self {
        Self { config }
    }

    pub fn train<L: LossFunction, O: Optimizer>(
        &self,
        network: &mut TrainableNeuron,
        dataset: &Dataset,
        loss_fn: &L,
        optimizer: &mut O,
    ) -> TrainingHistory {
        let mut history = TrainingHistory::new();

        // Split dataset
        let (train_set, val_set) = dataset.split(1.0 - self.config.validation_split);

        let mut best_val_loss = f32::INFINITY;
        let mut patience_counter = 0;

        if self.config.verbose {
            println!(
                "🚀 Starting training with {} examples ({} train, {} val)",
                dataset.len(),
                train_set.len(),
                val_set.len()
            );
            println!(
                "   Epochs: {}, Learning Rate: {:.4}",
                self.config.epochs, self.config.learning_rate
            );
        }

        for epoch in 0..self.config.epochs {
            let epoch_start = Instant::now();

            // Training phase
            let (train_loss, train_accuracy) =
                self.train_epoch(network, &train_set, loss_fn, optimizer);

            // Validation phase
            let (val_loss, val_accuracy) = self.validate_epoch(network, &val_set, loss_fn);

            let duration = epoch_start.elapsed().as_secs_f32() * 1000.0;

            let metrics = EpochMetrics {
                epoch: epoch + 1,
                train_loss,
                train_accuracy,
                val_loss,
                val_accuracy,
                duration_ms: duration,
            };

            history.add_epoch(metrics.clone());

            if self.config.verbose {
                println!(
                    "Epoch {:3}: train_loss={:.4}, train_acc={:.2}%, val_loss={:.4}, val_acc={:.2}% ({:.1}ms)",
                    epoch + 1,
                    train_loss,
                    train_accuracy * 100.0,
                    val_loss,
                    val_accuracy * 100.0,
                    duration
                );
            }

            // Early stopping
            if let Some(patience) = self.config.early_stopping_patience {
                if val_loss < best_val_loss {
                    best_val_loss = val_loss;
                    patience_counter = 0;
                } else {
                    patience_counter += 1;
                    if patience_counter >= patience {
                        if self.config.verbose {
                            println!(
                                "🛑 Early stopping at epoch {} (patience: {})",
                                epoch + 1,
                                patience
                            );
                        }
                        break;
                    }
                }
            }
        }

        if self.config.verbose {
            println!("✅ Training completed!");
            if let Some(best_loss) = history.best_val_loss() {
                println!("   Best validation loss: {:.4}", best_loss);
            }
            if let Some(best_acc) = history.best_val_accuracy() {
                println!("   Best validation accuracy: {:.2}%", best_acc * 100.0);
            }
        }

        history
    }

    fn train_epoch<L: LossFunction, O: Optimizer>(
        &self,
        network: &mut TrainableNeuron,
        dataset: &Dataset,
        loss_fn: &L,
        optimizer: &mut O,
    ) -> (f32, f32) {
        let mut total_loss = 0.0;
        let mut correct_predictions = 0;

        for example in &dataset.examples {
            // Forward pass
            let prediction = network.forward(example.input);
            let loss = loss_fn.forward(prediction, example.target);
            total_loss += loss;

            // Accuracy calculation (within 10% tolerance for regression)
            let error_tolerance = 0.1;
            if (prediction - example.target).abs() <= error_tolerance {
                correct_predictions += 1;
            }

            // Backward pass
            optimizer.zero_grad(network.parameters_mut());
            let _loss = network.backward(example.target);
            optimizer.step(network.parameters_mut());
        }

        let avg_loss = total_loss / dataset.len() as f32;
        let accuracy = correct_predictions as f32 / dataset.len() as f32;

        (avg_loss, accuracy)
    }

    fn validate_epoch<L: LossFunction>(
        &self,
        network: &mut TrainableNeuron,
        dataset: &Dataset,
        loss_fn: &L,
    ) -> (f32, f32) {
        let mut total_loss = 0.0;
        let mut correct_predictions = 0;

        for example in &dataset.examples {
            let prediction = network.forward(example.input);
            let loss = loss_fn.forward(prediction, example.target);
            total_loss += loss;

            // Accuracy calculation (within 10% tolerance)
            let error_tolerance = 0.1;
            if (prediction - example.target).abs() <= error_tolerance {
                correct_predictions += 1;
            }
        }

        let avg_loss = total_loss / dataset.len() as f32;
        let accuracy = correct_predictions as f32 / dataset.len() as f32;

        (avg_loss, accuracy)
    }
}

/// Utility functions for evaluation
pub fn evaluate_model<L: LossFunction>(
    network: &mut TrainableNeuron,
    dataset: &Dataset,
    loss_fn: &L,
) -> (f32, f32) {
    let mut total_loss = 0.0;
    let mut total_error = 0.0;

    for example in &dataset.examples {
        let prediction = network.forward(example.input);
        let loss = loss_fn.forward(prediction, example.target);
        total_loss += loss;

        let error = (prediction - example.target).abs();
        total_error += error;
    }

    let avg_loss = total_loss / dataset.len() as f32;
    let mae = total_error / dataset.len() as f32;

    (avg_loss, mae)
}