meuron 0.4.0

Meuron is a modular neural network library written in rust for training simple neural networks.
Documentation
use crate::NeuralNetwork;
use crate::backend::Backend;
use crate::cost::Cost;
use crate::layer::Layer;
use crate::optimizer::Optimizer;
use ndarray::RemoveAxis;
use serde::{Deserialize, Serialize};

pub trait TrainCallback {
    fn on_epoch(&mut self, epoch: usize, total: usize, loss: f32, val_loss: Option<f32>) -> bool;
}

pub struct PrintCallback;

impl TrainCallback for PrintCallback {
    fn on_epoch(&mut self, epoch: usize, total: usize, loss: f32, val_loss: Option<f32>) -> bool {
        print!("Epoch {epoch}/{total}: loss = {loss:.6}");
        if let Some(v) = val_loss {
            print!(" val_loss = {v:.6}");
        }
        println!();
        true
    }
}

impl<F> TrainCallback for F
where
    F: FnMut(usize, usize, f32, Option<f32>) -> bool,
{
    fn on_epoch(&mut self, epoch: usize, total: usize, loss: f32, val_loss: Option<f32>) -> bool {
        self(epoch, total, loss, val_loss)
    }
}

pub struct TrainOptions {
    pub epochs: usize,
    pub batch_size: usize,
    pub validation_split: f32,
    pub callback: Option<Box<dyn TrainCallback>>,
}

impl Default for TrainOptions {
    fn default() -> Self {
        Self {
            epochs: 10,
            batch_size: 32,
            validation_split: 0.0,
            callback: None,
        }
    }
}

impl TrainOptions {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn epochs(mut self, epochs: usize) -> Self {
        self.epochs = epochs;
        self
    }

    pub fn batch_size(mut self, batch_size: usize) -> Self {
        self.batch_size = batch_size;
        self
    }

    pub fn validation_split(mut self, validation_split: f32) -> Self {
        self.validation_split = validation_split;
        self
    }

    pub fn callback<C>(mut self, callback: C) -> Self
    where
        C: TrainCallback + 'static,
    {
        self.callback = Some(Box::new(callback));
        self
    }
}

impl<L, C, B> NeuralNetwork<L, C, B>
where
    B: Backend,
    L: Layer<B> + Serialize + for<'de> Deserialize<'de>,
    C: Cost<B>,
{
    pub fn train<O, I, T>(
        &mut self,
        inputs: I,
        targets: T,
        mut optimizer: O,
        options: TrainOptions,
    ) -> Vec<f32>
    where
        O: Optimizer<B>,
        I: Into<B::Tensor<L::Input>>,
        T: Into<B::Tensor<L::Output>>,
        L::Input: RemoveAxis,
        L::Output: RemoveAxis,
    {
        use ndarray::{Axis, Slice};
        use rand::rng;
        use rand::seq::SliceRandom;

        let TrainOptions {
            epochs,
            batch_size,
            validation_split,
            callback,
        } = options;

        let mut callback: Box<dyn TrainCallback> =
            callback.unwrap_or_else(|| Box::new(PrintCallback));

        let all_inputs = B::to_array(&inputs.into());
        let all_targets = B::to_array(&targets.into());

        let total = all_inputs.shape()[0];
        let val_n = (total as f32 * validation_split) as usize;
        let train_n = total - val_n;

        let train_inputs = all_inputs
            .slice_axis(Axis(0), Slice::from(..train_n))
            .to_owned();
        let train_targets = all_targets
            .slice_axis(Axis(0), Slice::from(..train_n))
            .to_owned();

        let val_data = (val_n > 0).then(|| {
            (
                all_inputs
                    .slice_axis(Axis(0), Slice::from(train_n..))
                    .to_owned(),
                all_targets
                    .slice_axis(Axis(0), Slice::from(train_n..))
                    .to_owned(),
            )
        });

        let mut losses = Vec::with_capacity(epochs);

        for epoch in 0..epochs {
            let mut indices: Vec<usize> = (0..train_n).collect();
            indices.shuffle(&mut rng());

            let mut total_loss = 0.0_f32;
            let mut batch_count = 0;

            for batch_start in (0..train_n).step_by(batch_size) {
                let batch_end = (batch_start + batch_size).min(train_n);
                let batch_indices = &indices[batch_start..batch_end];

                let batch_input = B::from_array(train_inputs.select(Axis(0), batch_indices));
                let batch_target = B::from_array(train_targets.select(Axis(0), batch_indices));

                let output = self.forward(batch_input);
                total_loss += self.cost.loss(&output, &batch_target);
                let grad = self.cost.gradient(&output, &batch_target);
                self.backward(grad);
                self.layers.update(&mut optimizer);

                B::flush();
                batch_count += 1;
            }

            let loss = total_loss / batch_count as f32;
            losses.push(loss);

            let val_loss = val_data.as_ref().map(|(vi, vt)| {
                let out = self.forward(B::from_array(vi.clone()));
                self.cost.loss(&out, &B::from_array(vt.clone()))
            });

            if !callback.on_epoch(epoch + 1, epochs, loss, val_loss) {
                break;
            }
        }

        losses
    }
}