concision_neural/train/
trainer.rs1use crate::Model;
7use cnc::data::{Dataset, IntoDataset, Records};
8
9pub struct Trainer<'a, M, T, R>
10where
11 M: Model<T>,
12 R: Records,
13{
14 pub(crate) dataset: Dataset<R::Inputs, R::Targets>,
16 pub(crate) model: &'a mut M,
17 pub(crate) loss: T,
19}
20
21impl<'a, M, T, R> Trainer<'a, M, T, R>
22where
23 M: Model<T>,
24 R: Records,
25{
26 pub fn new(model: &'a mut M, dataset: R) -> Self
27 where
28 R: IntoDataset<R::Inputs, R::Targets>,
29 T: Default,
30 {
31 Self {
32 dataset: dataset.into_dataset(),
33 model,
34 loss: T::default(),
35 }
36 }
37 pub const fn loss(&self) -> &T {
39 &self.loss
40 }
41 pub fn loss_mut(&mut self) -> &mut T {
43 &mut self.loss
44 }
45 pub const fn dataset(&self) -> &Dataset<R::Inputs, R::Targets> {
47 &self.dataset
48 }
49 pub fn dataset_mut(&mut self) -> &mut Dataset<R::Inputs, R::Targets> {
51 &mut self.dataset
52 }
53 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
54 pub fn begin(&self) -> &Self {
55 todo!("Define a generic training loop...")
56 }
57}