pub trait LearningMethod<LC: LearnerComponentTypes> {
type PreparedDataloaders;
type PreparedModel;
// Required methods
fn prepare_dataloaders(
&self,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
) -> Self::PreparedDataloaders;
fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel;
fn learn(
&self,
model: Self::PreparedModel,
dataloaders: Self::PreparedDataloaders,
starting_epoch: usize,
components: LearnerComponents<LC>,
) -> (LC::Model, LC::EventProcessor);
// Provided method
fn fit(
&self,
learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
) -> TrainingResult<LC::InnerModel> { ... }
}Expand description
Provides the fit function for any learning strategy
Required Associated Types§
Sourcetype PreparedDataloaders
type PreparedDataloaders
The dataloaders after being prepared for this trainin strategy
(eg: splitting for multiple devices)
Sourcetype PreparedModel
type PreparedModel
The model after being prepared for this training strategy
The prepared model will be correctly initialized on the proper device for training.
Required Methods§
Sourcefn prepare_dataloaders(
&self,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
) -> Self::PreparedDataloaders
fn prepare_dataloaders( &self, dataloader_train: TrainLoader<LC>, dataloader_valid: ValidLoader<LC>, ) -> Self::PreparedDataloaders
Prepare the dataloaders for this strategy. The output will be used in the learn function
Sourcefn prepare_model(&self, model: LC::Model) -> Self::PreparedModel
fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel
Prepare the model for this training strategy. The output will be used in the learn function
Sourcefn learn(
&self,
model: Self::PreparedModel,
dataloaders: Self::PreparedDataloaders,
starting_epoch: usize,
components: LearnerComponents<LC>,
) -> (LC::Model, LC::EventProcessor)
fn learn( &self, model: Self::PreparedModel, dataloaders: Self::PreparedDataloaders, starting_epoch: usize, components: LearnerComponents<LC>, ) -> (LC::Model, LC::EventProcessor)
Training loop for this strategy
Provided Methods§
Sourcefn fit(
&self,
learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
) -> TrainingResult<LC::InnerModel>
fn fit( &self, learner: Learner<LC>, dataloader_train: TrainLoader<LC>, dataloader_valid: ValidLoader<LC>, ) -> TrainingResult<LC::InnerModel>
Fit the learner’s model with this strategy.