burn_train/
components.rs

1use crate::{
2    TrainStep, ValidStep,
3    checkpoint::{Checkpointer, CheckpointingStrategy},
4    metric::{ItemLazy, processor::EventProcessorTraining},
5};
6use burn_core::{
7    module::{AutodiffModule, Module},
8    tensor::backend::AutodiffBackend,
9};
10use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
11use std::marker::PhantomData;
12
13/// All components necessary to train a model grouped in one trait.
14pub trait LearnerComponentTypes {
15    /// The backend in used for the training.
16    type Backend: AutodiffBackend;
17    /// The learning rate scheduler used for the training.
18    type LrScheduler: LrScheduler;
19    /// The model to train.
20    type Model: AutodiffModule<Self::Backend, InnerModule = Self::InnerModel>
21        + TrainStep<
22            <Self::LearningData as LearningData>::TrainInput,
23            <Self::LearningData as LearningData>::TrainOutput,
24        > + core::fmt::Display
25        + 'static;
26    /// The non-autodiff type of the model, should implement ValidationStep
27    type InnerModel: ValidStep<
28            <Self::LearningData as LearningData>::ValidInput,
29            <Self::LearningData as LearningData>::ValidOutput,
30        >;
31    /// The optimizer used for the training.
32    type Optimizer: Optimizer<Self::Model, Self::Backend>;
33    /// The checkpointer used for the model.
34    type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record, Self::Backend>;
35    /// The checkpointer used for the optimizer.
36    type CheckpointerOptimizer: Checkpointer<
37            <Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
38            Self::Backend,
39        > + Send;
40    /// The checkpointer used for the scheduler.
41    type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record<Self::Backend>, Self::Backend>;
42    type EventProcessor: EventProcessorTraining<
43            ItemTrain = <Self::LearningData as LearningData>::TrainOutput,
44            ItemValid = <Self::LearningData as LearningData>::ValidOutput,
45        > + 'static;
46    /// The strategy to save and delete checkpoints.
47    type CheckpointerStrategy: CheckpointingStrategy;
48    /// The data used to perform training and validation.
49    type LearningData: LearningData;
50}
51
52/// Concrete type that implements [training components trait](TrainingComponents).
53pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD> {
54    _backend: PhantomData<B>,
55    _lr_scheduler: PhantomData<LR>,
56    _model: PhantomData<M>,
57    _optimizer: PhantomData<O>,
58    _checkpointer_model: PhantomData<CM>,
59    _checkpointer_optim: PhantomData<CO>,
60    _checkpointer_scheduler: PhantomData<CS>,
61    _event_processor: PhantomData<EP>,
62    _strategy: S,
63    _learning_data: PhantomData<LD>,
64}
65
66impl<B, LR, M, O, CM, CO, CS, EP, S, LD> LearnerComponentTypes
67    for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD>
68where
69    B: AutodiffBackend,
70    LR: LrScheduler,
71    M: AutodiffModule<B>
72        + TrainStep<LD::TrainInput, LD::TrainOutput>
73        + core::fmt::Display
74        + 'static,
75    M::InnerModule: ValidStep<LD::ValidInput, LD::ValidOutput>,
76    O: Optimizer<M, B>,
77    CM: Checkpointer<M::Record, B>,
78    CO: Checkpointer<O::Record, B>,
79    CS: Checkpointer<LR::Record<B>, B>,
80    EP: EventProcessorTraining<ItemTrain = LD::TrainOutput, ItemValid = LD::ValidOutput> + 'static,
81    S: CheckpointingStrategy,
82    LD: LearningData,
83{
84    type Backend = B;
85    type LrScheduler = LR;
86    type Model = M;
87    type InnerModel = M::InnerModule;
88    type Optimizer = O;
89    type CheckpointerModel = CM;
90    type CheckpointerOptimizer = CO;
91    type CheckpointerLrScheduler = CS;
92    type EventProcessor = EP;
93    type CheckpointerStrategy = S;
94    type LearningData = LD;
95}
96
97/// The training backend.
98pub type TrainBackend<LC> = <LC as LearnerComponentTypes>::Backend;
99
100/// The validation backend.
101pub type ValidBackend<LC> =
102    <<LC as LearnerComponentTypes>::Backend as AutodiffBackend>::InnerBackend;
103
104/// Type for training input
105pub(crate) type InputTrain<LC> =
106    <<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainInput;
107
108/// Type for validation input
109pub(crate) type InputValid<LC> =
110    <<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidInput;
111
112/// Type for training output
113pub(crate) type OutputTrain<LC> =
114    <<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainOutput;
115
116/// Type for validation output
117#[allow(unused)]
118pub(crate) type OutputValid<LC> =
119    <<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidOutput;
120
121/// Regroups types of input and outputs for training and validation
122pub trait LearningData {
123    /// Type of input to the training stop
124    type TrainInput: Send + 'static;
125    /// Type of input to the validation step
126    type ValidInput: Send + 'static;
127    /// Type of output of the training step
128    type TrainOutput: ItemLazy + 'static;
129    /// Type of output of the validation step
130    type ValidOutput: ItemLazy + 'static;
131}
132
133/// Concrete type that implements [training data trait](TrainingData).
134pub struct LearningDataMarker<TI, VI, TO, VO> {
135    _phantom_data: PhantomData<(TI, VI, TO, VO)>,
136}
137
138impl<TI, VI, TO, VO> LearningData for LearningDataMarker<TI, VI, TO, VO>
139where
140    TI: Send + 'static,
141    VI: Send + 'static,
142    TO: ItemLazy + 'static,
143    VO: ItemLazy + 'static,
144{
145    type TrainInput = TI;
146    type ValidInput = VI;
147    type TrainOutput = TO;
148    type ValidOutput = VO;
149}