concision_neural/train/traits/
train.rs

1/*
2    Appellation: train <module>
3    Contrib: @FL03
4*/
5use crate::train::error::TrainingError;
6
7use crate::error::ModelResult;
8
9/// This trait defines the training process for the network
10pub trait Train<X, Y> {
11    type Output;
12
13    fn train(&mut self, input: &X, target: &Y) -> ModelResult<Self::Output>;
14
15    fn train_for(&mut self, input: &X, target: &Y, epochs: usize) -> ModelResult<Self::Output> {
16        let mut output = None;
17
18        for _ in 0..epochs {
19            output = match self.train(input, target) {
20                Ok(o) => Some(o),
21                Err(e) => {
22                    #[cfg(feature = "tracing")]
23                    tracing::error!("Training failed: {e}");
24                    return Err(e);
25                }
26            }
27        }
28        output.ok_or_else(|| TrainingError::TrainingFailed.into())
29    }
30}