burn_train/
components.rs

1use crate::{InferenceStep, TrainStep};
2use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend};
3use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
4use std::marker::PhantomData;
5
6/// Components used for a model to learn, grouped in one trait.
7pub trait LearningComponentsTypes {
8    /// The backend used for training.
9    type Backend: AutodiffBackend;
10    /// The learning rate scheduler used for training.
11    type LrScheduler: LrScheduler + 'static;
12    /// The model to train.
13    type TrainingModel: TrainStep
14        + AutodiffModule<Self::Backend, InnerModule = Self::InferenceModel>
15        + core::fmt::Display
16        + 'static;
17    /// The non-autodiff type of the model.
18    type InferenceModel: InferenceStep;
19    /// The optimizer used for training.
20    type Optimizer: Optimizer<Self::TrainingModel, Self::Backend> + 'static;
21}
22
23/// Concrete type that implements the [LearningComponentsTypes](LearningComponentsTypes) trait.
24pub struct LearningComponentsMarker<B, LR, M, O> {
25    _backend: PhantomData<B>,
26    _lr_scheduler: PhantomData<LR>,
27    _model: PhantomData<M>,
28    _optimizer: PhantomData<O>,
29}
30
31impl<B, LR, M, O> LearningComponentsTypes for LearningComponentsMarker<B, LR, M, O>
32where
33    B: AutodiffBackend,
34    LR: LrScheduler + 'static,
35    M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
36    M::InnerModule: InferenceStep,
37    O: Optimizer<M, B> + 'static,
38{
39    type Backend = B;
40    type LrScheduler = LR;
41    type TrainingModel = M;
42    type InferenceModel = M::InnerModule;
43    type Optimizer = O;
44}
45
46/// The training backend.
47pub type TrainingBackend<LC> = <LC as LearningComponentsTypes>::Backend;
48/// The inference backend.
49pub(crate) type InferenceBackend<LC> =
50    <<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend;
51/// The model used for training.
52pub type TrainingModel<LC> = <LC as LearningComponentsTypes>::TrainingModel;
53/// The non-autodiff model.
54pub(crate) type InferenceModel<LC> = <LC as LearningComponentsTypes>::InferenceModel;
55/// Type for training input.
56pub(crate) type TrainingModelInput<LC> =
57    <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input;
58/// Type for inference input.
59pub(crate) type InferenceModelInput<LC> =
60    <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input;
61/// Type for training output.
62pub(crate) type TrainingModelOutput<LC> =
63    <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output;
64/// Type for inference output.
65pub(crate) type InferenceModelOutput<LC> =
66    <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output;