1use crate::{
2 TrainStep, ValidStep,
3 checkpoint::{Checkpointer, CheckpointingStrategy},
4 metric::{ItemLazy, processor::EventProcessorTraining},
5};
6use burn_core::{
7 module::{AutodiffModule, Module},
8 tensor::backend::AutodiffBackend,
9};
10use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
11use std::marker::PhantomData;
12
13pub trait LearnerComponentTypes {
15 type Backend: AutodiffBackend;
17 type LrScheduler: LrScheduler;
19 type Model: AutodiffModule<Self::Backend, InnerModule = Self::InnerModel>
21 + TrainStep<
22 <Self::LearningData as LearningData>::TrainInput,
23 <Self::LearningData as LearningData>::TrainOutput,
24 > + core::fmt::Display
25 + 'static;
26 type InnerModel: ValidStep<
28 <Self::LearningData as LearningData>::ValidInput,
29 <Self::LearningData as LearningData>::ValidOutput,
30 >;
31 type Optimizer: Optimizer<Self::Model, Self::Backend>;
33 type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record, Self::Backend>;
35 type CheckpointerOptimizer: Checkpointer<
37 <Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
38 Self::Backend,
39 > + Send;
40 type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record<Self::Backend>, Self::Backend>;
42 type EventProcessor: EventProcessorTraining<
44 ItemTrain = <Self::LearningData as LearningData>::TrainOutput,
45 ItemValid = <Self::LearningData as LearningData>::ValidOutput,
46 > + 'static;
47 type CheckpointerStrategy: CheckpointingStrategy;
49 type LearningData: LearningData;
51}
52
53pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD> {
55 _backend: PhantomData<B>,
56 _lr_scheduler: PhantomData<LR>,
57 _model: PhantomData<M>,
58 _optimizer: PhantomData<O>,
59 _checkpointer_model: PhantomData<CM>,
60 _checkpointer_optim: PhantomData<CO>,
61 _checkpointer_scheduler: PhantomData<CS>,
62 _event_processor: PhantomData<EP>,
63 _strategy: S,
64 _learning_data: PhantomData<LD>,
65}
66
67impl<B, LR, M, O, CM, CO, CS, EP, S, LD> LearnerComponentTypes
68 for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD>
69where
70 B: AutodiffBackend,
71 LR: LrScheduler,
72 M: AutodiffModule<B>
73 + TrainStep<LD::TrainInput, LD::TrainOutput>
74 + core::fmt::Display
75 + 'static,
76 M::InnerModule: ValidStep<LD::ValidInput, LD::ValidOutput>,
77 O: Optimizer<M, B>,
78 CM: Checkpointer<M::Record, B>,
79 CO: Checkpointer<O::Record, B>,
80 CS: Checkpointer<LR::Record<B>, B>,
81 EP: EventProcessorTraining<ItemTrain = LD::TrainOutput, ItemValid = LD::ValidOutput> + 'static,
82 S: CheckpointingStrategy,
83 LD: LearningData,
84{
85 type Backend = B;
86 type LrScheduler = LR;
87 type Model = M;
88 type InnerModel = M::InnerModule;
89 type Optimizer = O;
90 type CheckpointerModel = CM;
91 type CheckpointerOptimizer = CO;
92 type CheckpointerLrScheduler = CS;
93 type EventProcessor = EP;
94 type CheckpointerStrategy = S;
95 type LearningData = LD;
96}
97
98pub type TrainBackend<LC> = <LC as LearnerComponentTypes>::Backend;
100
101pub type ValidBackend<LC> =
103 <<LC as LearnerComponentTypes>::Backend as AutodiffBackend>::InnerBackend;
104
105pub(crate) type InputTrain<LC> =
107 <<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainInput;
108
109pub(crate) type InputValid<LC> =
111 <<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidInput;
112
113pub(crate) type OutputTrain<LC> =
115 <<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainOutput;
116
117#[allow(unused)]
119pub(crate) type OutputValid<LC> =
120 <<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidOutput;
121
122pub trait LearningData {
124 type TrainInput: Send + 'static;
126 type ValidInput: Send + 'static;
128 type TrainOutput: ItemLazy + 'static;
130 type ValidOutput: ItemLazy + 'static;
132}
133
134pub struct LearningDataMarker<TI, VI, TO, VO> {
136 _phantom_data: PhantomData<(TI, VI, TO, VO)>,
137}
138
139impl<TI, VI, TO, VO> LearningData for LearningDataMarker<TI, VI, TO, VO>
140where
141 TI: Send + 'static,
142 VI: Send + 'static,
143 TO: ItemLazy + 'static,
144 VO: ItemLazy + 'static,
145{
146 type TrainInput = TI;
147 type ValidInput = VI;
148 type TrainOutput = TO;
149 type ValidOutput = VO;
150}