use crate::components::{
InputTrain, InputValid, LearnerComponentTypes, TrainBackend, ValidBackend,
};
#[cfg(feature = "ddp")]
use crate::ddp::DdpLearningStrategy;
use crate::multi::MultiDeviceLearningStrategy;
use crate::renderer::MetricsRenderer;
use crate::single::SingleDeviceLearningStrategy;
use crate::{Learner, LearningMethod, LearningStrategy};
use burn_core::data::dataloader::DataLoader;
use burn_core::module::AutodiffModule;
use burn_core::tensor::backend::AutodiffBackend;
use burn_optim::{GradientsParams, Optimizer};
use std::sync::Arc;
pub struct TrainOutput<TO> {
pub grads: GradientsParams,
pub item: TO,
}
impl<TO> TrainOutput<TO> {
pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
module: &M,
grads: B::Gradients,
item: TO,
) -> Self {
let grads = GradientsParams::from_grads(grads, module);
Self { grads, item }
}
}
pub trait TrainStep<TI, TO> {
fn step(&self, item: TI) -> TrainOutput<TO>;
fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
where
B: AutodiffBackend,
O: Optimizer<Self, B>,
Self: AutodiffModule<B>,
{
optim.step(lr, self, grads)
}
}
pub trait ValidStep<VI, VO> {
fn step(&self, item: VI) -> VO;
}
pub(crate) type TrainLoader<LC> = Arc<dyn DataLoader<TrainBackend<LC>, InputTrain<LC>>>;
pub(crate) type ValidLoader<LC> = Arc<dyn DataLoader<ValidBackend<LC>, InputValid<LC>>>;
pub struct TrainingResult<M> {
pub model: M,
pub renderer: Box<dyn MetricsRenderer>,
}
impl<LC: LearnerComponentTypes + Send + 'static> Learner<LC> {
pub fn fit(
self,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
) -> TrainingResult<LC::InnerModel> {
log::info!("Fitting the model:\n {}", self.model);
match &self.learning_strategy {
LearningStrategy::SingleDevice(device) => {
let single_device = SingleDeviceLearningStrategy::new(device.clone());
single_device.fit(self, dataloader_train, dataloader_valid)
}
LearningStrategy::MultiDeviceNaive(devices) => {
let multi_device = MultiDeviceLearningStrategy::new(devices.clone());
multi_device.fit(self, dataloader_train, dataloader_valid)
}
#[cfg(feature = "ddp")]
LearningStrategy::DistributedDataParallel { devices, config } => {
let ddp = DdpLearningStrategy::new(devices.clone(), config.clone());
ddp.fit(self, dataloader_train, dataloader_valid)
}
}
}
}