use crate::{
TrainStep, ValidStep,
checkpoint::{Checkpointer, CheckpointingStrategy},
metric::{ItemLazy, processor::EventProcessorTraining},
};
use burn_core::{
module::{AutodiffModule, Module},
tensor::backend::AutodiffBackend,
};
use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
use std::marker::PhantomData;
pub trait LearnerComponentTypes {
type Backend: AutodiffBackend;
type LrScheduler: LrScheduler;
type Model: AutodiffModule<Self::Backend, InnerModule = Self::InnerModel>
+ TrainStep<
<Self::LearningData as LearningData>::TrainInput,
<Self::LearningData as LearningData>::TrainOutput,
> + core::fmt::Display
+ 'static;
type InnerModel: ValidStep<
<Self::LearningData as LearningData>::ValidInput,
<Self::LearningData as LearningData>::ValidOutput,
>;
type Optimizer: Optimizer<Self::Model, Self::Backend>;
type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record, Self::Backend>;
type CheckpointerOptimizer: Checkpointer<
<Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
Self::Backend,
> + Send;
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record<Self::Backend>, Self::Backend>;
type EventProcessor: EventProcessorTraining<
ItemTrain = <Self::LearningData as LearningData>::TrainOutput,
ItemValid = <Self::LearningData as LearningData>::ValidOutput,
> + 'static;
type CheckpointerStrategy: CheckpointingStrategy;
type LearningData: LearningData;
}
pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD> {
_backend: PhantomData<B>,
_lr_scheduler: PhantomData<LR>,
_model: PhantomData<M>,
_optimizer: PhantomData<O>,
_checkpointer_model: PhantomData<CM>,
_checkpointer_optim: PhantomData<CO>,
_checkpointer_scheduler: PhantomData<CS>,
_event_processor: PhantomData<EP>,
_strategy: S,
_learning_data: PhantomData<LD>,
}
impl<B, LR, M, O, CM, CO, CS, EP, S, LD> LearnerComponentTypes
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S, LD>
where
B: AutodiffBackend,
LR: LrScheduler,
M: AutodiffModule<B>
+ TrainStep<LD::TrainInput, LD::TrainOutput>
+ core::fmt::Display
+ 'static,
M::InnerModule: ValidStep<LD::ValidInput, LD::ValidOutput>,
O: Optimizer<M, B>,
CM: Checkpointer<M::Record, B>,
CO: Checkpointer<O::Record, B>,
CS: Checkpointer<LR::Record<B>, B>,
EP: EventProcessorTraining<ItemTrain = LD::TrainOutput, ItemValid = LD::ValidOutput> + 'static,
S: CheckpointingStrategy,
LD: LearningData,
{
type Backend = B;
type LrScheduler = LR;
type Model = M;
type InnerModel = M::InnerModule;
type Optimizer = O;
type CheckpointerModel = CM;
type CheckpointerOptimizer = CO;
type CheckpointerLrScheduler = CS;
type EventProcessor = EP;
type CheckpointerStrategy = S;
type LearningData = LD;
}
pub type TrainBackend<LC> = <LC as LearnerComponentTypes>::Backend;
pub type ValidBackend<LC> =
<<LC as LearnerComponentTypes>::Backend as AutodiffBackend>::InnerBackend;
pub(crate) type InputTrain<LC> =
<<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainInput;
pub(crate) type InputValid<LC> =
<<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidInput;
pub(crate) type OutputTrain<LC> =
<<LC as LearnerComponentTypes>::LearningData as LearningData>::TrainOutput;
#[allow(unused)]
pub(crate) type OutputValid<LC> =
<<LC as LearnerComponentTypes>::LearningData as LearningData>::ValidOutput;
pub trait LearningData {
type TrainInput: Send + 'static;
type ValidInput: Send + 'static;
type TrainOutput: ItemLazy + 'static;
type ValidOutput: ItemLazy + 'static;
}
pub struct LearningDataMarker<TI, VI, TO, VO> {
_phantom_data: PhantomData<(TI, VI, TO, VO)>,
}
impl<TI, VI, TO, VO> LearningData for LearningDataMarker<TI, VI, TO, VO>
where
TI: Send + 'static,
VI: Send + 'static,
TO: ItemLazy + 'static,
VO: ItemLazy + 'static,
{
type TrainInput = TI;
type ValidInput = VI;
type TrainOutput = TO;
type ValidOutput = VO;
}