1use crate::{InferenceStep, TrainStep};
2use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend};
3use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
4use std::marker::PhantomData;
5
6pub trait LearningComponentsTypes {
8 type Backend: AutodiffBackend;
10 type LrScheduler: LrScheduler + 'static;
12 type TrainingModel: TrainStep
14 + AutodiffModule<Self::Backend, InnerModule = Self::InferenceModel>
15 + core::fmt::Display
16 + 'static;
17 type InferenceModel: InferenceStep;
19 type Optimizer: Optimizer<Self::TrainingModel, Self::Backend> + 'static;
21}
22
23pub struct LearningComponentsMarker<B, LR, M, O> {
25 _backend: PhantomData<B>,
26 _lr_scheduler: PhantomData<LR>,
27 _model: PhantomData<M>,
28 _optimizer: PhantomData<O>,
29}
30
31impl<B, LR, M, O> LearningComponentsTypes for LearningComponentsMarker<B, LR, M, O>
32where
33 B: AutodiffBackend,
34 LR: LrScheduler + 'static,
35 M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
36 M::InnerModule: InferenceStep,
37 O: Optimizer<M, B> + 'static,
38{
39 type Backend = B;
40 type LrScheduler = LR;
41 type TrainingModel = M;
42 type InferenceModel = M::InnerModule;
43 type Optimizer = O;
44}
45
46pub type TrainingBackend<LC> = <LC as LearningComponentsTypes>::Backend;
48pub(crate) type InferenceBackend<LC> =
50 <<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend;
51pub type TrainingModel<LC> = <LC as LearningComponentsTypes>::TrainingModel;
53pub(crate) type InferenceModel<LC> = <LC as LearningComponentsTypes>::InferenceModel;
55pub(crate) type TrainingModelInput<LC> =
57 <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input;
58pub(crate) type InferenceModelInput<LC> =
60 <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input;
61pub(crate) type TrainingModelOutput<LC> =
63 <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output;
64pub(crate) type InferenceModelOutput<LC> =
66 <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output;