use crate::{InferenceStep, TrainStep};
use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend};
use burn_optim::{Optimizer, lr_scheduler::LrScheduler};
use std::marker::PhantomData;
pub trait LearningComponentsTypes {
type Backend: AutodiffBackend;
type LrScheduler: LrScheduler + 'static;
type TrainingModel: TrainStep
+ AutodiffModule<Self::Backend, InnerModule = Self::InferenceModel>
+ core::fmt::Display
+ 'static;
type InferenceModel: InferenceStep;
type Optimizer: Optimizer<Self::TrainingModel, Self::Backend> + 'static;
}
pub struct LearningComponentsMarker<B, LR, M, O> {
_backend: PhantomData<B>,
_lr_scheduler: PhantomData<LR>,
_model: PhantomData<M>,
_optimizer: PhantomData<O>,
}
impl<B, LR, M, O> LearningComponentsTypes for LearningComponentsMarker<B, LR, M, O>
where
B: AutodiffBackend,
LR: LrScheduler + 'static,
M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
M::InnerModule: InferenceStep,
O: Optimizer<M, B> + 'static,
{
type Backend = B;
type LrScheduler = LR;
type TrainingModel = M;
type InferenceModel = M::InnerModule;
type Optimizer = O;
}
pub type TrainingBackend<LC> = <LC as LearningComponentsTypes>::Backend;
pub(crate) type InferenceBackend<LC> =
<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend;
pub type TrainingModel<LC> = <LC as LearningComponentsTypes>::TrainingModel;
pub(crate) type InferenceModel<LC> = <LC as LearningComponentsTypes>::InferenceModel;
pub(crate) type TrainingModelInput<LC> =
<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input;
pub(crate) type InferenceModelInput<LC> =
<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input;
pub(crate) type TrainingModelOutput<LC> =
<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output;
pub(crate) type InferenceModelOutput<LC> =
<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output;