use std::sync::Arc;
#[cfg(feature = "ddp")]
use burn_collective::CollectiveConfig;
use burn_core::{module::AutodiffModule, prelude::Backend};
use crate::{
EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,
LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,
TrainingModel, ValidLoader,
components::LearningComponentsTypes,
metric::{
processor::{EventProcessorTraining, LearnerEvent},
store::EventStoreClient,
},
};
type LearnerDevice<LC> = <<LC as LearningComponentsTypes>::Backend as Backend>::Device;
pub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;
#[derive(Clone, Copy, Debug)]
pub enum MultiDeviceOptim {
OptimMainDevice,
OptimSharded,
}
#[derive(Clone)]
pub enum TrainingStrategy<LC: LearningComponentsTypes> {
SingleDevice(LearnerDevice<LC>),
MultiDevice(Vec<LearnerDevice<LC>>, MultiDeviceOptim),
Custom(CustomLearningStrategy<LC>),
#[cfg(feature = "ddp")]
DistributedDataParallel {
devices: Vec<LearnerDevice<LC>>,
config: CollectiveConfig,
},
}
#[cfg(feature = "ddp")]
pub fn ddp<LC: LearningComponentsTypes>(
devices: Vec<LearnerDevice<LC>>,
config: CollectiveConfig,
) -> TrainingStrategy<LC> {
TrainingStrategy::DistributedDataParallel { devices, config }
}
impl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {
fn default() -> Self {
Self::SingleDevice(Default::default())
}
}
pub struct TrainingComponents<LC: LearningComponentsTypes> {
pub num_epochs: usize,
pub checkpoint: Option<usize>,
pub checkpointer: Option<LearningCheckpointer<LC>>,
pub grad_accumulation: Option<usize>,
pub interrupter: Interrupter,
pub early_stopping: Option<EarlyStoppingStrategyRef>,
pub event_processor: SupervisedTrainingEventProcessor<LC>,
pub event_store: Arc<EventStoreClient>,
pub summary: Option<LearnerSummaryConfig>,
}
pub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {
fn train(
&self,
mut learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
mut training_components: TrainingComponents<LC>,
) -> LearningResult<InferenceModel<LC>> {
let starting_epoch = match training_components.checkpoint {
Some(checkpoint) => {
if let Some(checkpointer) = &mut training_components.checkpointer {
learner =
checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);
}
checkpoint + 1
}
None => 1,
};
let summary_config = training_components.summary.clone();
training_components
.event_processor
.process_train(LearnerEvent::Start);
let (model, mut event_processor) = self.fit(
training_components,
learner,
dataloader_train,
dataloader_valid,
starting_epoch,
);
let summary = summary_config.and_then(|summary| {
summary
.init()
.map(|summary| summary.with_model(model.to_string()))
.ok()
});
event_processor.process_train(LearnerEvent::End(summary));
let model = model.valid();
let renderer = event_processor.renderer();
LearningResult::<InferenceModel<LC>> { model, renderer }
}
fn fit(
&self,
training_components: TrainingComponents<LC>,
learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
starting_epoch: usize,
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);
}