burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use burn_tensor::backend::ADBackend;

use super::Learner;
use crate::data::dataloader::DataLoader;
use crate::module::ADModule;
use crate::optim::Optimizer;
use crate::train::LearnerItem;
use std::sync::Arc;

#[derive(new)]
pub struct TrainOutput<TO, G> {
    grads: G,
    item: TO,
}

pub trait TrainStep<TI, TO, G> {
    fn step(&self, item: TI) -> TrainOutput<TO, G>;
}

pub trait ValidStep<VI, VO> {
    fn step(&self, item: VI) -> VO;
}

impl<M, O, TO, VO> Learner<M, O, TO, VO>
where
    VO: Send + Sync + 'static,
    TO: Send + Sync + 'static,
    M: ADModule,
    O: Optimizer<Backend = M::Backend>,
{
    pub fn fit<TI, VI>(
        mut self,
        dataloader_train: Arc<dyn DataLoader<TI>>,
        dataloader_valid: Arc<dyn DataLoader<VI>>,
    ) -> M
    where
        M: TrainStep<TI, TO, <M::ADBackend as ADBackend>::Gradients>,
        M::InnerModule: ValidStep<VI, VO>,
    {
        log::info!("Fitting {}", self.model.to_string());

        let starting_epoch = match self.checkpoint {
            Some(checkpoint) => {
                self.load_checkpoint(checkpoint);
                checkpoint
            }
            None => 1,
        };

        for epoch in starting_epoch..self.num_epochs + 1 {
            self.train_step(&dataloader_train, epoch);
            self.valid_step(&dataloader_valid, epoch);
            self.checkpoint(epoch);
        }

        self.model
    }

    fn train_step<TI>(&mut self, dataloader_train: &Arc<dyn DataLoader<TI>>, epoch: usize)
    where
        M: TrainStep<TI, TO, <M::ADBackend as ADBackend>::Gradients>,
    {
        log::info!("Executing training step for epoch {}", epoch);

        let mut iterator = dataloader_train.iter();
        let mut iteration = 0;

        while let Some(item) = iterator.next() {
            let progress = iterator.progress();
            iteration += 1;

            let item = self.model.step(item);
            self.model.update_params(&item.grads, &mut self.optim);

            self.callback.on_train_item(LearnerItem::new(
                item.item,
                progress,
                epoch,
                self.num_epochs,
                iteration,
            ));
        }
        self.callback.on_train_end_epoch(epoch);
    }

    fn valid_step<VI>(&mut self, dataloader_valid: &Arc<dyn DataLoader<VI>>, epoch: usize)
    where
        M::InnerModule: ValidStep<VI, VO>,
    {
        log::info!("Executing validation step for epoch {}", epoch);

        let model = self.model.inner();

        let mut iterator = dataloader_valid.iter();
        let mut iteration = 0;

        while let Some(item) = iterator.next() {
            let progress = iterator.progress();
            iteration += 1;

            let item = model.step(item);
            self.callback.on_valid_item(LearnerItem::new(
                item,
                progress,
                epoch,
                self.num_epochs,
                iteration,
            ));
        }
        self.callback.on_valid_end_epoch(epoch);
    }
}