concision_neural/model/
trainer.rs1use crate::Model;
7use concision_data::{DatasetBase, IntoDataset, Records};
8
9pub struct Trainer<'a, M, T, R>
10where
11 M: Model<T>,
12 R: Records,
13{
14 pub(crate) dataset: DatasetBase<R::Inputs, R::Targets>,
16 pub(crate) model: &'a mut M,
17 pub(crate) loss: T,
19}
20
21impl<'a, M, T, R> Trainer<'a, M, T, R>
22where
23 M: Model<T>,
24 R: Records,
25{
26 pub fn new(model: &'a mut M, dataset: R) -> Self
27 where
28 R: IntoDataset<R::Inputs, R::Targets>,
29 T: Default,
30 {
31 Self {
32 dataset: dataset.into_dataset(),
33 model,
34 loss: T::default(),
35 }
36 }
37 pub const fn loss(&self) -> &T {
39 &self.loss
40 }
41 pub fn loss_mut(&mut self) -> &mut T {
43 &mut self.loss
44 }
45 pub const fn dataset(&self) -> &DatasetBase<R::Inputs, R::Targets> {
47 &self.dataset
48 }
49 pub fn dataset_mut(&mut self) -> &mut DatasetBase<R::Inputs, R::Targets> {
51 &mut self.dataset
52 }
53 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
54 pub fn begin(&self) -> &Self {
55 todo!("Define a generic training loop...")
56 }
57}
58
59impl<'a, M, T, R> core::ops::Deref for Trainer<'a, M, T, R>
60where
61 M: Model<T>,
62 R: Records,
63{
64 type Target = M;
65
66 fn deref(&self) -> &Self::Target {
67 self.model
68 }
69}
70impl<'a, M, T, R> core::ops::DerefMut for Trainer<'a, M, T, R>
71where
72 M: Model<T>,
73 R: Records,
74{
75 fn deref_mut(&mut self) -> &mut Self::Target {
76 self.model
77 }
78}