use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc;
use std::sync::Arc;
use crate::autograd::Variable;
use crate::data::BatchDataSet;
use crate::distributed::nccl;
use crate::nn::{Module, Optimizer, Parameter};
use crate::tensor::{Device, Result, Tensor, TensorError};
use super::{
ApplyPolicy, AverageBackend, DdpRunConfig, CheckpointFn, EpochFn,
TrainedState, TimingMsg, WorkerConfig,
};
use super::worker::GpuWorker;
use super::coordinator::Coordinator;
pub struct DdpHandle {
worker_handles: Vec<std::thread::JoinHandle<Result<()>>>,
coordinator_handle: Option<std::thread::JoinHandle<Result<TrainedState>>>,
devices: Vec<Device>,
shutdown: Arc<AtomicBool>,
nccl_abort_handles: Vec<Arc<nccl::NcclAbortHandle>>,
final_state: Option<TrainedState>,
metrics_rx: Option<mpsc::Receiver<super::EpochMetrics>>,
architecture_svg: Option<String>,
graph_label: Option<String>,
graph_hash: Option<String>,
training_meta: Option<serde_json::Value>,
}
impl DdpHandle {
#[allow(clippy::too_many_arguments)]
#[deprecated(since = "0.3.0", note = "Use Ddp::builder() instead")]
pub fn auto<F, M, G, O, T>(
model_factory: F,
optim_factory: G,
train_fn: T,
dataset: Arc<dyn BatchDataSet>,
batch_size: usize,
num_epochs: usize,
policy: ApplyPolicy,
backend: AverageBackend,
) -> Result<Self>
where
F: Fn(Device) -> Result<M> + Send + Sync + 'static,
M: Module + 'static,
G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
{
#[allow(deprecated)]
Self::auto_with(
model_factory, optim_factory, train_fn,
dataset, batch_size, num_epochs,
policy, backend, DdpRunConfig::new(),
)
}
#[allow(clippy::too_many_arguments)]
#[deprecated(since = "0.3.0", note = "Use Ddp::builder() instead")]
pub fn auto_with<F, M, G, O, T>(
model_factory: F,
optim_factory: G,
train_fn: T,
dataset: Arc<dyn BatchDataSet>,
batch_size: usize,
num_epochs: usize,
policy: ApplyPolicy,
backend: AverageBackend,
config: DdpRunConfig,
) -> Result<Self>
where
F: Fn(Device) -> Result<M> + Send + Sync + 'static,
M: Module + 'static,
G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
{
Self::launch(
model_factory, optim_factory, train_fn,
dataset, batch_size, num_epochs,
policy, backend, config, None, None, None,
)
}
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
fn launch<F, M, G, O, T>(
model_factory: F,
optim_factory: G,
train_fn: T,
dataset: Arc<dyn BatchDataSet>,
batch_size: usize,
num_epochs: usize,
policy: ApplyPolicy,
backend: AverageBackend,
config: DdpRunConfig,
checkpoint_fn: Option<CheckpointFn<M>>,
epoch_fn: Option<EpochFn<M>>,
scheduler_fn: Option<Box<dyn Fn(usize) -> Arc<dyn crate::nn::Scheduler> + Send + Sync>>,
) -> Result<Self>
where
F: Fn(Device) -> Result<M> + Send + Sync + 'static,
M: Module + 'static,
G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
{
use std::sync::atomic::{AtomicBool, Ordering};
let devices = crate::tensor::usable_cuda_devices();
if devices.len() < 2 {
let dev = devices.first().copied().unwrap_or(Device::CPU);
let scheduler = scheduler_fn.map(|f| f(1));
return Self::run_single(
&model_factory, &optim_factory, &train_fn,
dataset, batch_size, num_epochs, dev,
checkpoint_fn.as_ref().cloned(),
config.checkpoint_every,
epoch_fn,
config.max_grad_norm,
scheduler,
);
}
Self::print_summary(&devices, &policy, &backend);
let tmp_model = model_factory(devices[0])?;
let initial_params: Vec<Tensor> = tmp_model.parameters().iter()
.map(|p| p.variable.data().to_device(Device::CPU).and_then(|t| t.pin_memory()))
.collect::<Result<Vec<_>>>()?;
let initial_buffers: Vec<Tensor> = tmp_model.buffers().iter()
.map(|b| b.get().to_device(Device::CPU).and_then(|t| t.pin_memory()))
.collect::<Result<Vec<_>>>()?;
let graph_ref = tmp_model.as_graph();
let architecture_svg = graph_ref
.and_then(|g| g.svg(None).ok())
.map(|bytes| String::from_utf8_lossy(&bytes).into_owned());
let graph_label = graph_ref.and_then(|g| g.label().map(|s| s.to_string()));
let graph_hash = graph_ref.map(|g| g.structural_hash().to_string());
drop(tmp_model);
let world_size = devices.len();
let total_samples = dataset.len();
let lr_scale_factor = if world_size > 1 && config.lr_scale_ratio > 0.0 {
let factor = 1.0 + (world_size as f64 - 1.0) * config.lr_scale_ratio;
crate::verbose!(
" ddp: LR scaled by {factor:.2}x (ratio={:.2}, world_size={world_size}). \
Adjust with .lr_scale_ratio()",
config.lr_scale_ratio,
);
factor
} else {
1.0
};
let progressive = config.progressive_dispatch
.unwrap_or(!matches!(policy, ApplyPolicy::Sync));
let training_meta = Some(Self::build_training_meta(
&devices, &policy, &backend, batch_size, num_epochs,
total_samples, progressive, &config,
));
let (timing_tx_main, timing_rx) = mpsc::channel();
let (metrics_tx_main, metrics_rx) = mpsc::channel();
let (param_tx_main, param_rx) = mpsc::channel();
let mut coord_control_txs = Vec::new();
let mut worker_control_rxs = Vec::new();
let mut worker_final_txs = Vec::new();
let mut coord_final_rxs = Vec::new();
for _ in 0..world_size {
let (tx, rx) = mpsc::channel();
coord_control_txs.push(tx);
worker_control_rxs.push(rx);
let (ftx, frx) = mpsc::channel();
worker_final_txs.push(ftx);
coord_final_rxs.push(frx);
}
let (mut rank_comms, nccl_abort_handles): (Vec<Option<_>>, Vec<_>) =
if backend == AverageBackend::Nccl {
let group = nccl::NcclComms::new(&devices)?;
let comms = group.split()?;
let aborts = comms.iter().map(|c| c.abort_handle()).collect();
(comms.into_iter().map(Some).collect(), aborts)
} else {
((0..world_size).map(|_| None).collect(), Vec::new())
};
let anchor = config.anchor.unwrap_or(10);
let mut el_che = crate::distributed::ddp::ElChe::new(world_size, anchor);
if let Some(target) = config.overhead_target {
el_che = el_che.with_overhead_target(target);
}
if let Some(max) = config.max_anchor {
el_che = el_che.with_max_anchor(max);
}
if let Some(diff) = config.max_batch_diff {
el_che = el_che.with_max_batch_diff(diff);
}
let (epoch_metrics_tx, epoch_metrics_rx) = mpsc::channel();
let coord_device_indices: Vec<u8> = devices.iter().map(|d| match d {
Device::CUDA(idx) => *idx,
_ => 0,
}).collect();
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_coord = shutdown.clone();
let div_threshold = config.divergence_threshold;
let no_div_guard = config.no_divergence_guard;
let ckpt_every = config.checkpoint_every;
let snap_timeout = config.snapshot_timeout_secs;
let partition_ratios = config.partition_ratios.clone();
let max_grad_norm = config.max_grad_norm;
let timeline = config.timeline.clone();
let coord_timeline = timeline.clone();
let coord_batch_size = batch_size;
let seed: u64 = 42;
let coordinator_handle = std::thread::Builder::new()
.name("ddp-coordinator".into())
.spawn(move || -> Result<TrainedState> {
let mut builder = Coordinator::builder(
timing_rx, metrics_rx, param_rx,
coord_final_rxs,
coord_control_txs,
policy, backend,
world_size, total_samples, el_che,
)
.snapshot_timeout_secs(snap_timeout)
.epoch_metrics_tx(epoch_metrics_tx)
.device_indices(coord_device_indices)
.num_epochs(num_epochs)
.partition_ratios(partition_ratios)
.progressive(progressive)
.batch_size(coord_batch_size)
.timeline(coord_timeline.clone())
.max_overshoot(config.max_overshoot);
if let Some(dt) = div_threshold {
builder = builder.divergence_threshold(dt);
}
if no_div_guard {
builder = builder.no_divergence_guard();
}
if let Some(n) = ckpt_every {
builder = builder.checkpoint_every(n);
}
let mut coord = builder.build();
coord.send_all_plans(0);
let poll_timeout = std::time::Duration::from_micros(100);
let mut loop_tick: u64 = 0;
let mut last_state_dump = std::time::Instant::now();
let loop_err = loop {
loop_tick += 1;
if shutdown_coord.load(Ordering::Relaxed) {
crate::verbose!(" ddp: coordinator exit: shutdown flag set (worker error?)");
break None;
}
if !coord.drain_timing_blocking(poll_timeout) {
crate::verbose!(" ddp: coordinator exit: all timing channels disconnected");
break None;
}
if coord.active_count == 0 {
crate::verbose!(" ddp: coordinator exit: all workers exited");
break None;
}
if coord.all_epochs_done() {
break None;
}
if last_state_dump.elapsed().as_secs() >= 2 {
last_state_dump = std::time::Instant::now();
coord.debug_state_dump(loop_tick);
}
coord.check_throttle();
if let Err(e) = coord.poll_cpu_averaging() {
shutdown_coord.store(true, Ordering::Relaxed);
break Some(e);
}
for m in coord.drain_metrics() {
crate::verbose!(
" ddp: rank {} epoch {} | loss={:.4} batches={} time={:.0}ms",
m.rank, m.epoch, m.avg_loss, m.batches_processed, m.epoch_ms
);
}
if coord.should_average() {
coord.drain_timing();
if coord.should_average() {
if let Err(e) = coord.trigger_averaging() {
shutdown_coord.store(true, Ordering::Relaxed);
break Some(e);
}
}
}
};
coord.drain_avg_state();
coord.shutdown_workers();
match coord.collect_final_state() {
Some(state) => Ok(state),
None => match loop_err {
Some(e) => Err(e),
None => Err(TensorError::new(
"coordinator: no final snapshots received from workers"
)),
},
}
})
.map_err(|e| TensorError::new(&format!("failed to spawn coordinator: {e}")))?;
let scheduler = scheduler_fn.map(|f| f(world_size));
let model_factory = Arc::new(model_factory);
let optim_factory = Arc::new(optim_factory);
let train_fn = Arc::new(train_fn);
let mut worker_handles = Vec::new();
for (rank, control_rx) in worker_control_rxs.into_iter().enumerate() {
let device = devices[rank];
let mf = model_factory.clone();
let of = optim_factory.clone();
let tf = train_fn.clone();
let ds = dataset.clone();
let params = initial_params.clone();
let buffers = initial_buffers.clone();
let t_tx = timing_tx_main.clone();
let t_tx_err = timing_tx_main.clone();
let m_tx = metrics_tx_main.clone();
let p_tx = param_tx_main.clone();
let fp_tx = worker_final_txs.remove(0);
let ckpt_fn = checkpoint_fn.clone();
let epoch_fn_w = epoch_fn.clone();
let scheduler_w = scheduler.clone();
let shutdown_w = shutdown.clone();
let worker_nccl = rank_comms[rank].take();
let worker_tl = timeline.clone();
let lr_scale = lr_scale_factor;
let config = WorkerConfig {
rank,
world_size,
device,
initial_params: params,
initial_buffers: buffers,
total_samples,
batch_size,
seed,
max_grad_norm,
timeline: worker_tl,
policy,
};
let handle = std::thread::Builder::new()
.name(format!("ddp-gpu-{rank}"))
.spawn(move || {
if let Device::CUDA(idx) = device {
crate::tensor::set_current_cuda_device(idx);
}
let result = (|| -> Result<()> {
let mut worker = GpuWorker::new(
&config,
|dev| (*mf)(dev),
|params| (*of)(params),
ds,
worker_nccl,
ckpt_fn,
t_tx,
m_tx,
p_tx,
fp_tx,
control_rx,
)?;
if lr_scale > 1.0 {
if scheduler_w.is_some() {
worker.set_lr_scale(lr_scale);
} else {
worker.scale_lr(lr_scale);
}
}
if let Some(ref sched) = scheduler_w {
worker.set_scheduler(Arc::clone(sched));
}
worker.current_epoch = usize::MAX; loop {
if shutdown_w.load(Ordering::Relaxed) {
break;
}
let plan = match worker.wait_for_epoch_plan()? {
Some(p) => p,
None => break, };
if plan.epoch != worker.current_epoch {
worker.current_epoch = plan.epoch;
if let Some(ref f) = epoch_fn_w {
f(plan.epoch, &mut worker);
}
}
if worker.run_epoch_plan(&plan, &*tf)? {
break; }
}
worker.abort_nccl();
worker.send_final_snapshot();
worker.report_exiting();
worker.drain_until_shutdown();
Ok(())
})();
if let Err(ref e) = result {
eprintln!(" ddp: worker {rank} error: {e}");
let _ = t_tx_err.send(TimingMsg::Exiting { rank });
shutdown_w.store(true, Ordering::Relaxed);
}
result
})
.map_err(|e| TensorError::new(&format!("failed to spawn worker {rank}: {e}")))?;
worker_handles.push(handle);
}
drop(timing_tx_main);
drop(metrics_tx_main);
drop(param_tx_main);
Ok(DdpHandle {
worker_handles,
coordinator_handle: Some(coordinator_handle),
devices: devices.to_vec(),
shutdown,
nccl_abort_handles,
final_state: None,
metrics_rx: Some(epoch_metrics_rx),
architecture_svg,
graph_label,
graph_hash,
training_meta,
})
}
#[allow(clippy::too_many_arguments)]
fn run_single<F, M, G, O, T>(
model_factory: &F,
optim_factory: &G,
train_fn: &T,
dataset: Arc<dyn BatchDataSet>,
batch_size: usize,
num_epochs: usize,
device: Device,
checkpoint_fn: Option<CheckpointFn<M>>,
checkpoint_every: Option<usize>,
epoch_fn: Option<EpochFn<M>>,
max_grad_norm: Option<f64>,
scheduler: Option<Arc<dyn crate::nn::Scheduler>>,
) -> Result<Self>
where
F: Fn(Device) -> Result<M>,
M: Module + 'static,
G: Fn(&[Parameter]) -> O,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable>,
{
use std::sync::atomic::AtomicBool;
crate::verbose!(" ddp: single device ({device:?}) | no coordination");
let total_samples = dataset.len();
let tmp_model = model_factory(device)?;
let initial_params: Vec<Tensor> = tmp_model.parameters().iter()
.map(|p| p.variable.data())
.collect();
let initial_buffers: Vec<Tensor> = tmp_model.buffers().iter()
.map(|b| b.get())
.collect();
let graph_ref = tmp_model.as_graph();
let architecture_svg = graph_ref
.and_then(|g| g.svg(None).ok())
.map(|bytes| String::from_utf8_lossy(&bytes).into_owned());
let graph_label = graph_ref.and_then(|g| g.label().map(|s| s.to_string()));
let graph_hash = graph_ref.map(|g| g.structural_hash().to_string());
drop(tmp_model);
let training_meta = Some(serde_json::json!({
"gpus": 1,
"device": format!("{device:?}"),
"batch_size": batch_size,
"num_epochs": num_epochs,
"total_samples": total_samples,
"mode": "single-gpu fallback",
}));
let config = WorkerConfig {
rank: 0,
world_size: 1,
device,
initial_params,
initial_buffers,
total_samples,
batch_size,
seed: 42,
max_grad_norm,
timeline: None,
policy: ApplyPolicy::Sync, };
let ((timing_tx, metrics_tx, param_tx, final_param_tx, control_rx), _channels) =
GpuWorker::<M>::channels();
let mut worker = GpuWorker::new(
&config,
model_factory,
optim_factory,
dataset,
None, checkpoint_fn.clone(),
timing_tx,
metrics_tx,
param_tx,
final_param_tx,
control_rx,
)?;
if let Some(sched) = scheduler {
worker.set_scheduler(sched);
}
for epoch in 0..num_epochs {
worker.current_epoch = epoch;
if let Some(ref f) = epoch_fn {
f(epoch, &mut worker);
}
let plan = super::EpochPlan {
epoch,
partition_offset: 0,
partition_size: total_samples,
};
worker.run_epoch_plan(&plan, train_fn)?;
if let (Some(every), Some(f)) = (checkpoint_every, &checkpoint_fn) {
if every > 0 && (epoch + 1) % every == 0 {
if let Err(e) = f((epoch + 1) as u64, worker.model()) {
eprintln!(" ddp: checkpoint failed (epoch {}): {e}", epoch + 1);
}
}
}
}
let snap = worker.snapshot_params();
let final_state = TrainedState {
params: snap.params.iter()
.map(|t| t.to_device(Device::CPU))
.collect::<Result<Vec<_>>>()?,
buffers: snap.buffers.iter()
.map(|t| t.to_device(Device::CPU))
.collect::<Result<Vec<_>>>()?,
};
Ok(DdpHandle {
worker_handles: Vec::new(),
coordinator_handle: None,
devices: vec![device],
shutdown: Arc::new(AtomicBool::new(true)),
nccl_abort_handles: Vec::new(),
final_state: Some(final_state),
metrics_rx: None,
architecture_svg,
graph_label,
graph_hash,
training_meta,
})
}
pub fn world_size(&self) -> usize {
self.devices.len()
}
pub fn devices(&self) -> &[Device] {
&self.devices
}
pub fn architecture_svg(&self) -> Option<&str> {
self.architecture_svg.as_deref()
}
pub fn setup_monitor(&self, monitor: &mut crate::monitor::Monitor) {
if let Some(svg) = &self.architecture_svg {
monitor.set_svg(svg);
}
monitor.set_identity(
self.graph_label.as_deref(),
self.graph_hash.as_deref(),
);
if let Some(meta) = &self.training_meta {
monitor.set_metadata(meta.clone());
}
}
pub fn poll_metrics(&self) -> Vec<super::EpochMetrics> {
match &self.metrics_rx {
Some(rx) => {
let mut out = Vec::new();
while let Ok(m) = rx.try_recv() {
out.push(m);
}
out
}
None => Vec::new(),
}
}
pub fn next_metrics(&self) -> Option<super::EpochMetrics> {
self.metrics_rx.as_ref().and_then(|rx| rx.recv().ok())
}
fn abort_nccl(&self) {
for h in &self.nccl_abort_handles {
let _ = h.abort();
}
}
pub fn join(mut self) -> Result<TrainedState> {
if let Some(state) = self.final_state.take() {
return Ok(state);
}
let mut first_err: Option<TensorError> = None;
let handles: Vec<_> = self.worker_handles.drain(..).collect();
for h in handles {
match h.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => {
self.shutdown.store(true, Ordering::Relaxed);
self.abort_nccl();
if first_err.is_none() {
first_err = Some(e);
}
}
Err(_) => {
self.shutdown.store(true, Ordering::Relaxed);
self.abort_nccl();
if first_err.is_none() {
first_err = Some(TensorError::new("worker thread panicked"));
}
}
}
}
self.shutdown.store(true, Ordering::Relaxed);
if let Some(h) = self.coordinator_handle.take() {
match h.join() {
Ok(Ok(state)) => {
if let Some(ref e) = first_err {
eprintln!(" ddp: WARNING: training state recovered but worker error occurred: {e}");
}
return Ok(state);
}
Ok(Err(e)) if first_err.is_none() => first_err = Some(e),
Err(_) if first_err.is_none() => {
first_err = Some(TensorError::new("coordinator thread panicked"));
}
_ => {}
}
}
Err(first_err.unwrap_or_else(|| TensorError::new("join: no trained state available")))
}
fn print_summary(devices: &[Device], policy: &ApplyPolicy, backend: &AverageBackend) {
use crate::tensor::{cuda_device_name_idx, cuda_memory_info_idx};
use crate::monitor::format_bytes;
let mut parts = Vec::with_capacity(devices.len());
let mut names = Vec::with_capacity(devices.len());
for &dev in devices {
if let Device::CUDA(idx) = dev {
let raw_name = cuda_device_name_idx(idx as i32)
.unwrap_or_else(|| format!("CUDA({})", idx));
let short = raw_name
.strip_prefix("NVIDIA ")
.unwrap_or(&raw_name)
.to_string();
let vram = cuda_memory_info_idx(idx as i32)
.ok()
.map(|(_, total)| format!(" ({})", format_bytes(total)))
.unwrap_or_default();
parts.push(format!("{}{}", short, vram));
names.push(raw_name);
}
}
let heterogeneous = names.windows(2).any(|w| w[0] != w[1]);
let mode = if heterogeneous { "heterogeneous" } else { "homogeneous" };
let policy_str = match policy {
ApplyPolicy::Sync => "sync",
ApplyPolicy::Cadence => "cadence",
ApplyPolicy::Async => "async",
};
let backend_str = match backend {
AverageBackend::Nccl => "nccl",
AverageBackend::Cpu => "cpu",
};
crate::verbose!(
" ddp: {} GPUs ({}) | {} | policy={} backend={}",
devices.len(), mode, parts.join(" | "), policy_str, backend_str,
);
}
#[allow(clippy::too_many_arguments)]
fn build_training_meta(
devices: &[Device],
policy: &ApplyPolicy,
backend: &AverageBackend,
batch_size: usize,
num_epochs: usize,
total_samples: usize,
progressive: bool,
config: &DdpRunConfig,
) -> serde_json::Value {
use crate::tensor::cuda_device_name_idx;
let gpu_names: Vec<String> = devices.iter().map(|d| {
if let Device::CUDA(idx) = d {
cuda_device_name_idx(*idx as i32)
.unwrap_or_else(|| format!("CUDA({})", idx))
} else {
format!("{d:?}")
}
}).collect();
let policy_str = match policy {
ApplyPolicy::Sync => "sync",
ApplyPolicy::Cadence => "cadence",
ApplyPolicy::Async => "async",
};
let backend_str = match backend {
AverageBackend::Nccl => "nccl",
AverageBackend::Cpu => "cpu",
};
let mut meta = serde_json::json!({
"gpus": devices.len(),
"gpu_names": gpu_names,
"policy": policy_str,
"backend": backend_str,
"batch_size": batch_size,
"num_epochs": num_epochs,
"total_samples": total_samples,
"progressive_dispatch": progressive,
});
if let Some(anchor) = config.anchor {
meta["anchor"] = serde_json::json!(anchor);
}
if let Some(target) = config.overhead_target {
meta["overhead_target"] = serde_json::json!(target);
}
if let Some(max) = config.max_anchor {
meta["max_anchor"] = serde_json::json!(max);
}
if let Some(diff) = config.max_batch_diff {
meta["max_batch_diff"] = serde_json::json!(diff);
}
if let Some(overshoot) = config.max_overshoot {
meta["max_overshoot"] = serde_json::json!(overshoot);
}
if let Some(dt) = config.divergence_threshold {
meta["divergence_threshold"] = serde_json::json!(dt);
}
meta
}
}
#[allow(clippy::type_complexity)]
pub struct DdpBuilder<F, M, G, O, T>
where
F: Fn(Device) -> Result<M> + Send + Sync + 'static,
M: Module + 'static,
G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
{
model_factory: F,
optim_factory: G,
train_fn: T,
dataset: Option<Arc<dyn BatchDataSet>>,
batch_size: Option<usize>,
num_epochs: Option<usize>,
policy: ApplyPolicy,
backend: AverageBackend,
config: DdpRunConfig,
checkpoint_fn: Option<CheckpointFn<M>>,
epoch_fn: Option<EpochFn<M>>,
scheduler_fn: Option<Box<dyn Fn(usize) -> Arc<dyn crate::nn::Scheduler> + Send + Sync>>,
_phantom: PhantomData<(M, O)>,
}
impl<F, M, G, O, T> DdpBuilder<F, M, G, O, T>
where
F: Fn(Device) -> Result<M> + Send + Sync + 'static,
M: Module + 'static,
G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
{
pub fn dataset(mut self, dataset: Arc<dyn BatchDataSet>) -> Self {
self.dataset = Some(dataset);
self
}
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = Some(size);
self
}
pub fn num_epochs(mut self, n: usize) -> Self {
self.num_epochs = Some(n);
self
}
pub fn policy(mut self, policy: ApplyPolicy) -> Self {
self.policy = policy;
self
}
pub fn backend(mut self, backend: AverageBackend) -> Self {
self.backend = backend;
self
}
pub fn overhead_target(mut self, target: f64) -> Self {
self.config = self.config.with_overhead_target(target);
self
}
pub fn max_anchor(mut self, max: usize) -> Self {
self.config = self.config.with_max_anchor(max);
self
}
pub fn anchor(mut self, anchor: usize) -> Self {
self.config = self.config.with_anchor(anchor);
self
}
pub fn divergence_threshold(mut self, threshold: f64) -> Self {
self.config = self.config.with_divergence_threshold(threshold);
self
}
pub fn no_divergence_guard(mut self) -> Self {
self.config = self.config.with_no_divergence_guard();
self
}
pub fn max_batch_diff(mut self, max: usize) -> Self {
self.config = self.config.with_max_batch_diff(max);
self
}
pub fn max_overshoot(mut self, max: usize) -> Self {
self.config = self.config.with_max_overshoot(max);
self
}
pub fn checkpoint_every(mut self, n: usize) -> Self {
self.config = self.config.with_checkpoint_every(n);
self
}
pub fn progressive_dispatch(mut self, enabled: bool) -> Self {
self.config = self.config.with_progressive_dispatch(enabled);
self
}
pub fn max_grad_norm(mut self, max_norm: f64) -> Self {
self.config = self.config.with_max_grad_norm(max_norm);
self
}
pub fn timeline(mut self, tl: std::sync::Arc<crate::monitor::Timeline>) -> Self {
self.config = self.config.with_timeline(tl);
self
}
pub fn lr_scale_ratio(mut self, ratio: f64) -> Self {
self.config = self.config.with_lr_scale_ratio(ratio);
self
}
pub fn checkpoint_fn<C>(mut self, f: C) -> Self
where
C: Fn(u64, &M) -> Result<()> + Send + Sync + 'static,
{
self.checkpoint_fn = Some(Arc::new(f));
self
}
pub fn scheduler<S>(mut self, factory: S) -> Self
where
S: Fn(usize) -> Arc<dyn crate::nn::Scheduler> + Send + Sync + 'static,
{
self.scheduler_fn = Some(Box::new(factory));
self
}
pub fn epoch_fn<E>(mut self, f: E) -> Self
where
E: Fn(usize, &mut GpuWorker<M>) + Send + Sync + 'static,
{
self.epoch_fn = Some(Arc::new(f));
self
}
pub fn run(self) -> Result<DdpHandle> {
let dataset = self.dataset.expect("DdpBuilder: dataset is required");
let batch_size = self.batch_size.expect("DdpBuilder: batch_size is required");
let num_epochs = self.num_epochs.expect("DdpBuilder: num_epochs is required");
DdpHandle::launch(
self.model_factory,
self.optim_factory,
self.train_fn,
dataset,
batch_size,
num_epochs,
self.policy,
self.backend,
self.config,
self.checkpoint_fn,
self.epoch_fn,
self.scheduler_fn,
)
}
}
impl DdpHandle {
#[deprecated(since = "0.3.0", note = "Use Ddp::builder() instead")]
pub fn builder<F, M, G, O, T>(
model_factory: F,
optim_factory: G,
train_fn: T,
) -> DdpBuilder<F, M, G, O, T>
where
F: Fn(Device) -> Result<M> + Send + Sync + 'static,
M: Module + 'static,
G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
{
Self::new_builder(model_factory, optim_factory, train_fn)
}
pub(crate) fn new_builder<F, M, G, O, T>(
model_factory: F,
optim_factory: G,
train_fn: T,
) -> DdpBuilder<F, M, G, O, T>
where
F: Fn(Device) -> Result<M> + Send + Sync + 'static,
M: Module + 'static,
G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
O: Optimizer + 'static,
T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
{
DdpBuilder {
model_factory,
optim_factory,
train_fn,
dataset: None,
batch_size: None,
num_epochs: None,
policy: ApplyPolicy::Cadence,
backend: AverageBackend::Nccl,
config: DdpRunConfig::new(),
checkpoint_fn: None,
epoch_fn: None,
scheduler_fn: None,
_phantom: PhantomData,
}
}
}
impl Drop for DdpHandle {
fn drop(&mut self) {
self.shutdown.store(true, std::sync::atomic::Ordering::Relaxed);
self.abort_nccl();
}
}