LearningMethod

Trait LearningMethod 

Source
pub trait LearningMethod<LC: LearnerComponentTypes> {
    type PreparedDataloaders;
    type PreparedModel;

    // Required methods
    fn prepare_dataloaders(
        &self,
        dataloader_train: TrainLoader<LC>,
        dataloader_valid: ValidLoader<LC>,
    ) -> Self::PreparedDataloaders;
    fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel;
    fn learn(
        &self,
        model: Self::PreparedModel,
        dataloaders: Self::PreparedDataloaders,
        starting_epoch: usize,
        components: LearnerComponents<LC>,
    ) -> (LC::Model, LC::EventProcessor);

    // Provided method
    fn fit(
        &self,
        learner: Learner<LC>,
        dataloader_train: TrainLoader<LC>,
        dataloader_valid: ValidLoader<LC>,
    ) -> TrainingResult<LC::InnerModel> { ... }
}
Expand description

Provides the fit function for any learning strategy

Required Associated Types§

Source

type PreparedDataloaders

The dataloaders after being prepared for this trainin strategy

(eg: splitting for multiple devices)

Source

type PreparedModel

The model after being prepared for this training strategy

The prepared model will be correctly initialized on the proper device for training.

Required Methods§

Source

fn prepare_dataloaders( &self, dataloader_train: TrainLoader<LC>, dataloader_valid: ValidLoader<LC>, ) -> Self::PreparedDataloaders

Prepare the dataloaders for this strategy. The output will be used in the learn function

Source

fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel

Prepare the model for this training strategy. The output will be used in the learn function

Source

fn learn( &self, model: Self::PreparedModel, dataloaders: Self::PreparedDataloaders, starting_epoch: usize, components: LearnerComponents<LC>, ) -> (LC::Model, LC::EventProcessor)

Training loop for this strategy

Provided Methods§

Source

fn fit( &self, learner: Learner<LC>, dataloader_train: TrainLoader<LC>, dataloader_valid: ValidLoader<LC>, ) -> TrainingResult<LC::InnerModel>

Fit the learner’s model with this strategy.

Implementors§