concision_neural/model/
trainer.rs

1/*
2    Appellation: trainer <module>
3    Contrib: @FL03
4*/
5
6use crate::Model;
7use concision_data::{DatasetBase, IntoDataset, Records};
8
9pub struct Trainer<'a, M, T, R>
10where
11    M: Model<T>,
12    R: Records,
13{
14    /// the training dataset
15    pub(crate) dataset: DatasetBase<R::Inputs, R::Targets>,
16    pub(crate) model: &'a mut M,
17    /// the accumulated loss
18    pub(crate) loss: T,
19}
20
21impl<'a, M, T, R> Trainer<'a, M, T, R>
22where
23    M: Model<T>,
24    R: Records,
25{
26    pub fn new(model: &'a mut M, dataset: R) -> Self
27    where
28        R: IntoDataset<R::Inputs, R::Targets>,
29        T: Default,
30    {
31        Self {
32            dataset: dataset.into_dataset(),
33            model,
34            loss: T::default(),
35        }
36    }
37    /// returns an immutable reference to the total loss
38    pub const fn loss(&self) -> &T {
39        &self.loss
40    }
41    /// returns a mutable reference to the total loss
42    pub fn loss_mut(&mut self) -> &mut T {
43        &mut self.loss
44    }
45    /// returns an immutable reference to the training session's dataset
46    pub const fn dataset(&self) -> &DatasetBase<R::Inputs, R::Targets> {
47        &self.dataset
48    }
49    /// returns a mutable reference to the training session's dataset
50    pub fn dataset_mut(&mut self) -> &mut DatasetBase<R::Inputs, R::Targets> {
51        &mut self.dataset
52    }
53    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
54    pub fn begin(&self) -> &Self {
55        todo!("Define a generic training loop...")
56    }
57}
58
59impl<'a, M, T, R> core::ops::Deref for Trainer<'a, M, T, R>
60where
61    M: Model<T>,
62    R: Records,
63{
64    type Target = M;
65
66    fn deref(&self) -> &Self::Target {
67        self.model
68    }
69}
70impl<'a, M, T, R> core::ops::DerefMut for Trainer<'a, M, T, R>
71where
72    M: Model<T>,
73    R: Records,
74{
75    fn deref_mut(&mut self) -> &mut Self::Target {
76        self.model
77    }
78}