concision_neural/train/
trainer.rs

1/*
2    Appellation: trainer <module>
3    Contrib: @FL03
4*/
5
6use crate::Model;
7use cnc::data::{Dataset, IntoDataset, Records};
8
9pub struct Trainer<'a, M, T, R>
10where
11    M: Model<T>,
12    R: Records,
13{
14    /// the training dataset
15    pub(crate) dataset: Dataset<R::Inputs, R::Targets>,
16    pub(crate) model: &'a mut M,
17    /// the accumulated loss
18    pub(crate) loss: T,
19}
20
21impl<'a, M, T, R> Trainer<'a, M, T, R>
22where
23    M: Model<T>,
24    R: Records,
25{
26    pub fn new(model: &'a mut M, dataset: R) -> Self
27    where
28        R: IntoDataset<R::Inputs, R::Targets>,
29        T: Default,
30    {
31        Self {
32            dataset: dataset.into_dataset(),
33            model,
34            loss: T::default(),
35        }
36    }
37    /// returns an immutable reference to the total loss
38    pub const fn loss(&self) -> &T {
39        &self.loss
40    }
41    /// returns a mutable reference to the total loss
42    pub fn loss_mut(&mut self) -> &mut T {
43        &mut self.loss
44    }
45    /// returns an immutable reference to the training session's dataset
46    pub const fn dataset(&self) -> &Dataset<R::Inputs, R::Targets> {
47        &self.dataset
48    }
49    /// returns a mutable reference to the training session's dataset
50    pub fn dataset_mut(&mut self) -> &mut Dataset<R::Inputs, R::Targets> {
51        &mut self.dataset
52    }
53    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
54    pub fn begin(&self) -> &Self {
55        todo!("Define a generic training loop...")
56    }
57}