use core::panic;
use std::sync::{Arc, Mutex};
use burn_collective::CollectiveConfig;
use burn_core::tensor::Device;
use crate::ddp::worker::DdpWorker;
use crate::metric::store::EventStoreClient;
use crate::{
EarlyStoppingStrategyRef, Interrupter, Learner, LearningComponentsTypes,
SupervisedLearningStrategy, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend,
TrainingComponents, TrainingModel, ValidLoader,
};
use burn_core::data::dataloader::split::split_dataloader;
#[derive(Clone)]
pub(crate) struct WorkerComponents {
pub num_epochs: usize,
pub grad_accumulation: Option<usize>,
pub interrupter: Interrupter,
pub early_stopping: Option<EarlyStoppingStrategyRef>,
pub event_store: Arc<EventStoreClient>,
}
pub struct DdpTrainingStrategy<LC: LearningComponentsTypes> {
devices: Vec<Device<TrainingBackend<LC>>>,
config: CollectiveConfig,
}
impl<LC: LearningComponentsTypes> DdpTrainingStrategy<LC> {
pub fn new(devices: Vec<Device<TrainingBackend<LC>>>, config: CollectiveConfig) -> Self {
let config = config.with_num_devices(devices.len());
Self { devices, config }
}
}
impl<LC: LearningComponentsTypes + Send + 'static> SupervisedLearningStrategy<LC>
for DdpTrainingStrategy<LC>
{
fn fit(
&self,
training_components: TrainingComponents<LC>,
learner: Learner<LC>,
dataloader_train: TrainLoader<LC>,
dataloader_valid: ValidLoader<LC>,
starting_epoch: usize,
) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {
let main_device = self.devices.first().unwrap();
let mut dataloaders_train = split_dataloader(dataloader_train, &self.devices);
let dataloader_valid = dataloader_valid.to_device(main_device);
let main_device = self.devices[0].clone();
let peer_count = self.devices.len();
let event_processor = Arc::new(Mutex::new(training_components.event_processor));
let interrupter = training_components.interrupter;
let worker_components = WorkerComponents {
num_epochs: training_components.num_epochs,
grad_accumulation: training_components.grad_accumulation,
interrupter: interrupter.clone(),
early_stopping: training_components.early_stopping,
event_store: training_components.event_store,
};
let main_handle = DdpWorker::<LC>::start(
0.into(),
main_device,
learner.clone(),
event_processor.clone(),
worker_components.clone(),
training_components.checkpointer,
dataloaders_train.remove(0),
Some(dataloader_valid),
self.config.clone(),
starting_epoch,
peer_count,
true,
);
let mut peer_id = 1;
let mut secondary_workers = vec![];
for device in &self.devices[1..] {
let handle = DdpWorker::<LC>::start(
peer_id.into(),
device.clone(),
learner.clone(),
event_processor.clone(),
worker_components.clone(),
None,
dataloaders_train.remove(0),
None,
self.config.clone(),
starting_epoch,
peer_count,
false,
);
peer_id += 1;
secondary_workers.push(handle);
}
for worker in secondary_workers {
worker
.join()
.expect("Distributed data parallel worker failed");
}
let model = main_handle
.join()
.expect("Distributed data parallel main worker failed");
if interrupter.should_stop() {
let reason = interrupter
.get_message()
.unwrap_or(String::from("Reason unknown"));
log::info!("Training interrupted: {reason}");
}
let Ok(event_processor) = Arc::try_unwrap(event_processor) else {
panic!("Event processor still held!");
};
let Ok(event_processor) = event_processor.into_inner() else {
panic!("Event processor lock poisoned");
};
(model, event_processor)
}
}