burn_train/
components.rs

1use crate::{
2    TrainStep, ValidStep,
3    checkpoint::{Checkpointer, CheckpointingStrategy},
4    metric::{ItemLazy, processor::EventProcessorTraining},
5};
6use burn_core::{
7    module::{AutodiffModule, Module},
8    tensor::backend::AutodiffBackend,
9};
10use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
11use std::marker::PhantomData;
12
13/// All components necessary to train a model grouped in one trait.
14pub trait LearnerComponentTypes {
15    /// The backend in used for the training.
16    type Backend: AutodiffBackend;
17    /// The learning rate scheduler used for the training.
18    type LrScheduler: LrScheduler;
19    /// The model to train.
20    type Model: AutodiffModule<Self::Backend, InnerModule = Self::InnerModel>
21        + TrainStep<
22            <Self::LearningData as LearningData>::TrainInput,
23            <Self::LearningData as LearningData>::TrainOutput,
24        > + core::fmt::Display
25        + 'static;
26    /// The non-autodiff type of the model, should implement ValidationStep
27    type InnerModel: ValidStep<
28            <Self::LearningData as LearningData>::ValidInput,
29            <Self::LearningData as LearningData>::ValidOutput,
30        >;
31    /// The optimizer used for the training.
32    type Optimizer: Optimizer<Self::Model, Self::Backend>;
33    /// The checkpointer used for the model.
34    type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record, Self::Backend>;
35    /// The checkpointer used for the optimizer.
36    type CheckpointerOptimizer: Checkpointer<
37            <Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
38            Self::Backend,
39        > + Send;
40    /// The checkpointer used for the scheduler.
41    type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record<Self::Backend>, Self::Backend>;
42    /// Processes events happening during training and valid.
43    type EventProcessor: EventProcessorTraining<
44            ItemTrain = <Self::LearningData as LearningData>::TrainOutput,
45            ItemValid = <Self::LearningData as LearningData>::ValidOutput,
46        > + 'static;
47    /// The strategy to save and delete checkpoints.
48    type CheckpointerStrategy: CheckpointingStrategy;
49    /// The data used to perform training and validation.
50    type LearningData: LearningData;
51}
52
53/// Concrete type that implements [training components trait](TrainingComponents).
54pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD> {
55    _backend: PhantomData<B>,
56    _lr_scheduler: PhantomData<LR>,
57    _model: PhantomData<M>,
58    _optimizer: PhantomData<O>,
59    _checkpointer_model: PhantomData<CM>,
60    _checkpointer_optim: PhantomData<CO>,
61    _checkpointer_scheduler: PhantomData<CS>,
62    _event_processor: PhantomData<EP>,
63    _strategy: S,
64    _learning_data: PhantomData<LD>,
65}
66
67impl<B, LR, M, O, CM, CO, CS, EP, S, LD> LearnerComponentTypes
68    for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD>
69where
70    B: AutodiffBackend,
71    LR: LrScheduler,
72    M: AutodiffModule<B>
73        + TrainStep<LD::TrainInput, LD::TrainOutput>
74        + core::fmt::Display
75        + 'static,
76    M::InnerModule: ValidStep<LD::ValidInput, LD::ValidOutput>,
77    O: Optimizer<M, B>,
78    CM: Checkpointer<M::Record, B>,
79    CO: Checkpointer<O::Record, B>,
80    CS: Checkpointer<LR::Record<B>, B>,
81    EP: EventProcessorTraining<ItemTrain = LD::TrainOutput, ItemValid = LD::ValidOutput> + 'static,
82    S: CheckpointingStrategy,
83    LD: LearningData,
84{
85    type Backend = B;
86    type LrScheduler = LR;
87    type Model = M;
88    type InnerModel = M::InnerModule;
89    type Optimizer = O;
90    type CheckpointerModel = CM;
91    type CheckpointerOptimizer = CO;
92    type CheckpointerLrScheduler = CS;
93    type EventProcessor = EP;
94    type CheckpointerStrategy = S;
95    type LearningData = LD;
96}
97
98/// The training backend.
99pub type TrainBackend<LC> = <LC as LearnerComponentTypes>::Backend;
100
101/// The validation backend.
102pub type ValidBackend<LC> =
103    <<LC as LearnerComponentTypes>::Backend as AutodiffBackend>::InnerBackend;
104
105/// Type for training input
106pub(crate) type InputTrain<LC> =
107    <<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainInput;
108
109/// Type for validation input
110pub(crate) type InputValid<LC> =
111    <<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidInput;
112
113/// Type for training output
114pub(crate) type OutputTrain<LC> =
115    <<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainOutput;
116
117/// Type for validation output
118#[allow(unused)]
119pub(crate) type OutputValid<LC> =
120    <<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidOutput;
121
122/// Regroups types of input and outputs for training and validation
123pub trait LearningData {
124    /// Type of input to the training stop
125    type TrainInput: Send + 'static;
126    /// Type of input to the validation step
127    type ValidInput: Send + 'static;
128    /// Type of output of the training step
129    type TrainOutput: ItemLazy + 'static;
130    /// Type of output of the validation step
131    type ValidOutput: ItemLazy + 'static;
132}
133
134/// Concrete type that implements [training data trait](TrainingData).
135pub struct LearningDataMarker<TI, VI, TO, VO> {
136    _phantom_data: PhantomData<(TI, VI, TO, VO)>,
137}
138
139impl<TI, VI, TO, VO> LearningData for LearningDataMarker<TI, VI, TO, VO>
140where
141    TI: Send + 'static,
142    VI: Send + 'static,
143    TO: ItemLazy + 'static,
144    VO: ItemLazy + 'static,
145{
146    type TrainInput = TI;
147    type ValidInput = VI;
148    type TrainOutput = TO;
149    type ValidOutput = VO;
150}