1use crate::{
2 TrainStep, ValidStep,
3 checkpoint::{Checkpointer, CheckpointingStrategy},
4 metric::{ItemLazy, processor::EventProcessorTraining},
5};
6use burn_core::{
7 module::{AutodiffModule, Module},
8 tensor::backend::AutodiffBackend,
9};
10use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
11use std::marker::PhantomData;
12
13pub trait LearnerComponentTypes {
15 type Backend: AutodiffBackend;
17 type LrScheduler: LrScheduler;
19 type Model: AutodiffModule<Self::Backend, InnerModule = Self::InnerModel>
21 + TrainStep<
22 <Self::LearningData as LearningData>::TrainInput,
23 <Self::LearningData as LearningData>::TrainOutput,
24 > + core::fmt::Display
25 + 'static;
26 type InnerModel: ValidStep<
28 <Self::LearningData as LearningData>::ValidInput,
29 <Self::LearningData as LearningData>::ValidOutput,
30 >;
31 type Optimizer: Optimizer<Self::Model, Self::Backend>;
33 type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record, Self::Backend>;
35 type CheckpointerOptimizer: Checkpointer<
37 <Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
38 Self::Backend,
39 > + Send;
40 type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record<Self::Backend>, Self::Backend>;
42 type EventProcessor: EventProcessorTraining<
43 ItemTrain = <Self::LearningData as LearningData>::TrainOutput,
44 ItemValid = <Self::LearningData as LearningData>::ValidOutput,
45 > + 'static;
46 type CheckpointerStrategy: CheckpointingStrategy;
48 type LearningData: LearningData;
50}
51
52pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD> {
54 _backend: PhantomData<B>,
55 _lr_scheduler: PhantomData<LR>,
56 _model: PhantomData<M>,
57 _optimizer: PhantomData<O>,
58 _checkpointer_model: PhantomData<CM>,
59 _checkpointer_optim: PhantomData<CO>,
60 _checkpointer_scheduler: PhantomData<CS>,
61 _event_processor: PhantomData<EP>,
62 _strategy: S,
63 _learning_data: PhantomData<LD>,
64}
65
66impl<B, LR, M, O, CM, CO, CS, EP, S, LD> LearnerComponentTypes
67 for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD>
68where
69 B: AutodiffBackend,
70 LR: LrScheduler,
71 M: AutodiffModule<B>
72 + TrainStep<LD::TrainInput, LD::TrainOutput>
73 + core::fmt::Display
74 + 'static,
75 M::InnerModule: ValidStep<LD::ValidInput, LD::ValidOutput>,
76 O: Optimizer<M, B>,
77 CM: Checkpointer<M::Record, B>,
78 CO: Checkpointer<O::Record, B>,
79 CS: Checkpointer<LR::Record<B>, B>,
80 EP: EventProcessorTraining<ItemTrain = LD::TrainOutput, ItemValid = LD::ValidOutput> + 'static,
81 S: CheckpointingStrategy,
82 LD: LearningData,
83{
84 type Backend = B;
85 type LrScheduler = LR;
86 type Model = M;
87 type InnerModel = M::InnerModule;
88 type Optimizer = O;
89 type CheckpointerModel = CM;
90 type CheckpointerOptimizer = CO;
91 type CheckpointerLrScheduler = CS;
92 type EventProcessor = EP;
93 type CheckpointerStrategy = S;
94 type LearningData = LD;
95}
96
97pub type TrainBackend<LC> = <LC as LearnerComponentTypes>::Backend;
99
100pub type ValidBackend<LC> =
102 <<LC as LearnerComponentTypes>::Backend as AutodiffBackend>::InnerBackend;
103
104pub(crate) type InputTrain<LC> =
106 <<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainInput;
107
108pub(crate) type InputValid<LC> =
110 <<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidInput;
111
112pub(crate) type OutputTrain<LC> =
114 <<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainOutput;
115
116#[allow(unused)]
118pub(crate) type OutputValid<LC> =
119 <<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidOutput;
120
121pub trait LearningData {
123 type TrainInput: Send + 'static;
125 type ValidInput: Send + 'static;
127 type TrainOutput: ItemLazy + 'static;
129 type ValidOutput: ItemLazy + 'static;
131}
132
133pub struct LearningDataMarker<TI, VI, TO, VO> {
135 _phantom_data: PhantomData<(TI, VI, TO, VO)>,
136}
137
138impl<TI, VI, TO, VO> LearningData for LearningDataMarker<TI, VI, TO, VO>
139where
140 TI: Send + 'static,
141 VI: Send + 'static,
142 TO: ItemLazy + 'static,
143 VO: ItemLazy + 'static,
144{
145 type TrainInput = TI;
146 type ValidInput = VI;
147 type TrainOutput = TO;
148 type ValidOutput = VO;
149}