use std::sync::Arc;
#[cfg(feature = "ddp")]
use burn_core::tensor::backend::distributed::{DistributedBackend, DistributedConfig};
use burn_core::{module::AutodiffModule, prelude::Backend};
use crate::{
EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,
LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,
TrainingModel, ValidLoader,
components::LearningComponentsTypes,
metric::{
processor::{EventProcessorTraining, LearnerEvent},
store::EventStoreClient,
},
};
pub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;
#[derive(Clone, Copy, Debug)]
pub enum MultiDeviceOptim {
OptimMainDevice,
OptimSharded,
}
pub enum ExecutionStrategy<B: Backend> {
SingleDevice(B::Device),
MultiDevice(Vec<B::Device>, MultiDeviceOptim),
#[cfg(feature = "ddp")]
DistributedDataParallel {
devices: Vec<B::Device>,
runtime: Box<dyn DistributedRuntime>,
},
}
impl<B: Backend> ExecutionStrategy<B> {
pub fn main_device(&self) -> &B::Device {
match self {
ExecutionStrategy::SingleDevice(device) => device,
ExecutionStrategy::MultiDevice(devices, _optim) => &devices[0],
#[cfg(feature = "ddp")]
ExecutionStrategy::DistributedDataParallel {
devices,
runtime: _,
} => &devices[0],
}
}
pub fn single(device: B::Device) -> Self {
Self::SingleDevice(device)
}
pub fn multi(devices: Vec<B::Device>, optim: MultiDeviceOptim) -> Self {
Self::MultiDevice(devices, optim)
}
}
#[cfg(feature = "ddp")]
impl<B: DistributedBackend> ExecutionStrategy<B> {
pub fn ddp(devices: Vec<B::Device>, config: DistributedConfig) -> Self {
let session = DistributedSession::<B> {
devices: devices.clone(),
config,
};
Self::DistributedDataParallel {
devices,
runtime: Box::new(session),
}
}
}
pub enum TrainingStrategy<LC: LearningComponentsTypes> {
Default(ExecutionStrategy<LC::Backend>),
Custom(CustomLearningStrategy<LC>),
}
impl<LC: LearningComponentsTypes> From<ExecutionStrategy<LC::Backend>> for TrainingStrategy<LC> {
fn from(value: ExecutionStrategy<LC::Backend>) -> Self {
Self::Default(value)
}
}
#[cfg(feature = "ddp")]
pub trait DistributedRuntime: Send + Sync + 'static {
fn start(&self);
fn close(&self);
}
#[cfg(feature = "ddp")]
pub struct DistributedSession<B: DistributedBackend> {
devices: Vec<B::Device>,
config: DistributedConfig,
}
#[cfg(feature = "ddp")]
impl<B: DistributedBackend> DistributedRuntime for DistributedSession<B> {
fn start(&self) {
B::start_communication_server(&self.devices, self.config.clone());
}
fn close(&self) {
B::close_communication_server(&self.devices[0]);
}
}
impl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {
fn default() -> Self {
Self::Default(ExecutionStrategy::SingleDevice(Default::default()))
}
}
pub struct TrainingComponents<LC: LearningComponentsTypes> {
pub num_epochs: usize,
pub checkpoint: Option<usize>,
pub checkpointer: Option<LearningCheckpointer<LC>>,
pub grad_accumulation: Option<usize>,
pub interrupter: Interrupter,
pub early_stopping: Option<EarlyStoppingStrategyRef>,
pub event_processor: SupervisedTrainingEventProcessor<LC>,
pub event_store: Arc<EventStoreClient>,
pub summary: Option<LearnerSummaryConfig>,
}
pub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {
fn train(
&self,
mut learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
mut training_components: TrainingComponents<LC>,
) -> LearningResult<InferenceModel<LC>> {
let starting_epoch = match training_components.checkpoint {
Some(checkpoint) => {
if let Some(checkpointer) = &mut training_components.checkpointer {
learner =
checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);
}
checkpoint + 1
}
None => 1,
};
let summary_config = training_components.summary.clone();
training_components
.event_processor
.process_train(LearnerEvent::Start);
let (model, mut event_processor) = self.fit(
training_components,
learner,
dataloader_train,
dataloader_valid,
starting_epoch,
);
let summary = summary_config.and_then(|summary| {
summary
.init()
.map(|summary| summary.with_model(model.to_string()))
.ok()
});
event_processor.process_train(LearnerEvent::End(summary));
let model = model.valid();
let renderer = event_processor.renderer();
LearningResult::<InferenceModel<LC>> { model, renderer }
}
fn fit(
&self,
training_components: TrainingComponents<LC>,
learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
starting_epoch: usize,
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);
}