concision_neural/traits/
train.rs

1/*
2    Appellation: train <module>
3    Contrib: @FL03
4*/
5use crate::error::{NeuralResult, TrainingError};
6
7/// This trait defines the training process for the network
8pub trait Train<X, Y> {
9    type Output;
10
11    fn train(&mut self, input: &X, target: &Y) -> NeuralResult<Self::Output>;
12
13    fn train_for(&mut self, input: &X, target: &Y, epochs: usize) -> NeuralResult<Self::Output> {
14        let mut output = None;
15
16        for _ in 0..epochs {
17            output = match self.train(input, target) {
18                Ok(o) => Some(o),
19                Err(e) => {
20                    #[cfg(feature = "tracing")]
21                    tracing::error!("Training failed: {e}");
22                    return Err(e);
23                }
24            }
25        }
26        output.ok_or_else(|| TrainingError::TrainingFailed.into())
27    }
28}