flodl/distributed/ddp.rs
1//! Training entry points for flodl.
2//!
3//! The primary entry point is [`Trainer`]. It works transparently on 1 or
4//! N GPUs - single-device training has zero DDP overhead. Reach for
5//! [`Trainer`] by default; drop to [`Ddp`] only when you need explicit
6//! multi-GPU control.
7//!
8//! **Default** ([`Trainer::setup()`], [`Trainer::builder()`]): user-owned or
9//! framework-owned training loop, transparent single/multi-GPU. Same API in
10//! both cases.
11//!
12//! **Explicit multi-GPU** ([`Ddp::wrap()`]): manual control over gradient
13//! sync and parameter broadcast for advanced patterns (GAN, RL, progressive).
14//!
15//! # Setup mode (user owns the loop)
16//!
17//! ```ignore
18//! Trainer::setup(&model, |dev| build_model(dev), |p| Adam::new(p, 0.001))?;
19//!
20//! // Training loop is identical for 1 or N GPUs:
21//! for (x, y) in &train_loader {
22//! let out = model.forward(&x)?;
23//! let loss = cross_entropy_loss(&out, &y)?;
24//! loss.backward()?;
25//! model.step()?;
26//! }
27//! ```
28//!
29//! # Builder mode (framework owns the loop)
30//!
31//! ```ignore
32//! let handle = Trainer::builder(model_factory, optim_factory, train_fn)
33//! .dataset(dataset)
34//! .batch_size(32)
35//! .num_epochs(10)
36//! .run()?;
37//!
38//! let state = handle.join()?;
39//! ```
40//!
41//! # Manual DDP
42//!
43//! ```ignore
44//! let ddp = Ddp::wrap(&[&model0, &model1], &devices)?;
45//! ddp.sync_params()?;
46//! // ... custom forward/backward ...
47//! ddp.all_reduce_gradients()?;
48//! ```
49
50use crate::autograd::Variable;
51use crate::graph::Graph;
52use crate::nn::{Buffer, Module, Optimizer, Parameter};
53use super::cuda_event::CudaEvent;
54use super::nccl::{NcclComms, ReduceOp};
55use super::ddp_run::{DdpBuilder, DdpHandle};
56pub use super::el_che::ElChe;
57use crate::tensor::{Device, Result, Tensor, TensorError};
58
59
60/// Shared lock for serializing NCCL communicator creation across test modules.
61/// NCCL init is a collective operation that deadlocks if two tests try to
62/// create communicators simultaneously.
63#[cfg(test)]
64pub(crate) static NCCL_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
65
66/// Default number of steps before the first rebalance.
67pub(crate) const DEFAULT_CALIBRATION_STEPS: usize = 10;
68
69/// How often to re-evaluate chunk ratios after calibration.
70pub(crate) const DEFAULT_REBALANCE_INTERVAL: usize = 50;
71
72/// EMA smoothing factor for throughput tracking (higher = more reactive).
73const EMA_ALPHA: f64 = 0.3;
74
75/// Minimum ratio any device can receive (prevents starving a GPU entirely).
76const MIN_CHUNK_RATIO: f64 = 0.05;
77
78// ---------------------------------------------------------------------------
79// Internal distributed state (held by Graph)
80// ---------------------------------------------------------------------------
81
82/// Internal distributed state held by Graph when `distribute()` is called.
83pub(crate) struct DistributedState {
84 /// Model replicas for ranks 1..N (rank 0 is the Graph itself).
85 pub replicas: Vec<Box<dyn Module>>,
86 /// NCCL communicators (one per device).
87 pub comms: NcclComms,
88 /// All devices including rank 0.
89 pub devices: Vec<Device>,
90 /// Per-replica optimizers indexed by rank (including rank 0).
91 pub optimizers: Vec<Box<dyn Optimizer>>,
92 /// Chunk ratios for auto-balancing (sum = 1.0). Default: equal.
93 pub chunk_ratios: Vec<f64>,
94 /// Parameters matched across replicas: param_groups\[param_idx\]\[rank\].
95 pub param_groups: Vec<Vec<Variable>>,
96 /// Buffers matched across replicas: buffer_groups\[buf_idx\]\[rank\].
97 pub buffer_groups: Vec<Vec<Buffer>>,
98
99 // -- Auto-balancer state --
100
101 /// Per-rank forward timing events from last forward pass: (start, end).
102 /// Set by forward_distributed(), read by step().
103 pub last_timing: Option<Vec<(CudaEvent, CudaEvent)>>,
104 /// Shard sizes from last forward pass (for throughput calculation).
105 pub last_shard_sizes: Vec<i64>,
106 /// EMA throughput per rank (samples/ms). Zero until first measurement.
107 pub ema_throughput: Vec<f64>,
108 /// Number of completed training steps.
109 pub step_count: usize,
110 /// Steps of equal-split calibration before first rebalance.
111 pub calibration_steps: usize,
112 /// Steps between ratio recalculations after calibration.
113 pub rebalance_interval: usize,
114
115 // -- El Che cadence (heterogeneous DDP) --
116
117 /// El Che cadence strategy. When Some, Graph uses per-device multi-batch
118 /// forward instead of per-batch scatter. When None, existing scatter path.
119 pub el_che: Option<ElChe>,
120 /// Per-rank batch counts from the last El Che forward pass.
121 /// Set by forward_distributed_el_che(), read by step().
122 pub last_el_che_counts: Vec<usize>,
123 /// Wall-clock time at end of last El Che AllReduce.
124 pub last_el_che_sync: Option<std::time::Instant>,
125 /// Maximum gradient norm for per-rank clipping in El Che mode.
126 pub max_grad_norm: Option<f64>,
127 /// Optional system timeline for high-frequency profiling.
128 pub timeline: Option<std::sync::Arc<crate::monitor::Timeline>>,
129}
130
131impl DistributedState {
132 /// AllReduce-average gradients across all replicas.
133 pub fn all_reduce_gradients(&self) -> Result<()> {
134 for group in &self.param_groups {
135 // Skip frozen parameters (no gradient on rank 0)
136 if group[0].grad().is_none() {
137 continue;
138 }
139 let grads: Vec<Tensor> = group
140 .iter()
141 .map(|v| v.grad().expect("gradient missing on replica"))
142 .collect();
143 let refs: Vec<&Tensor> = grads.iter().collect();
144 self.comms.all_reduce(&refs, ReduceOp::Avg)?;
145 }
146 Ok(())
147 }
148
149 /// Broadcast buffers from rank 0 to all replicas (BatchNorm stats etc).
150 pub fn sync_buffers(&self) -> Result<()> {
151 for group in &self.buffer_groups {
152 let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
153 let refs: Vec<&Tensor> = tensors.iter().collect();
154 self.comms.broadcast(&refs, 0)?;
155 }
156 Ok(())
157 }
158
159 /// Broadcast parameters and buffers from rank 0 to all replicas.
160 pub fn sync_params(&self) -> Result<()> {
161 for group in &self.param_groups {
162 let tensors: Vec<Tensor> = group.iter().map(|v| v.data()).collect();
163 let refs: Vec<&Tensor> = tensors.iter().collect();
164 self.comms.broadcast(&refs, 0)?;
165 }
166 self.sync_buffers()
167 }
168
169 /// Compute shard sizes from chunk ratios, guaranteeing they sum to batch_size.
170 pub fn compute_shard_sizes(&self, batch_size: i64) -> Vec<i64> {
171 let n = self.devices.len();
172 let mut sizes = Vec::with_capacity(n);
173 let mut remaining = batch_size;
174
175 for i in 0..n {
176 if i == n - 1 {
177 // Last device gets whatever is left
178 sizes.push(remaining);
179 } else {
180 let s = (batch_size as f64 * self.chunk_ratios[i]).round() as i64;
181 let s = s.max(1).min(remaining - (n - i - 1) as i64); // leave at least 1 per remaining device
182 sizes.push(s);
183 remaining -= s;
184 }
185 }
186
187 sizes
188 }
189
190 /// Number of devices.
191 pub fn world_size(&self) -> usize {
192 self.devices.len()
193 }
194
195 /// Whether chunk ratios are meaningfully unequal (need weighted gradients).
196 pub fn is_balanced(&self) -> bool {
197 let first = self.chunk_ratios[0];
198 self.chunk_ratios.iter().all(|r| (r - first).abs() < 1e-6)
199 }
200
201 /// AllReduce gradients with weighted averaging for unequal shard sizes.
202 ///
203 /// Each replica's gradient is scaled by `(shard_size / batch_size)` before
204 /// AllReduce Sum, which produces the correct mean gradient regardless of
205 /// how the batch was split.
206 pub fn weighted_all_reduce_gradients(&self, batch_size: i64) -> Result<()> {
207 for group in &self.param_groups {
208 if group[0].grad().is_none() {
209 continue;
210 }
211 let grads: Vec<Tensor> = group
212 .iter()
213 .enumerate()
214 .map(|(rank, v)| {
215 let g = v.grad().expect("gradient missing on replica");
216 let weight = self.last_shard_sizes[rank] as f64 / batch_size as f64;
217 g.mul_scalar_(weight).ok();
218 g
219 })
220 .collect();
221 let refs: Vec<&Tensor> = grads.iter().collect();
222 self.comms.all_reduce(&refs, ReduceOp::Sum)?;
223 }
224 Ok(())
225 }
226
227 /// Read timing from last forward pass, update EMA throughput, and
228 /// rebalance chunk ratios if it's time.
229 ///
230 /// Called from Graph::step() after gradient sync. Returns true if
231 /// chunk ratios were updated this step.
232 pub fn update_balance(&mut self) -> Result<bool> {
233 self.step_count += 1;
234
235 // Read timing events (set by forward_distributed)
236 if let Some(timing) = self.last_timing.take() {
237 for (rank, (start, end)) in timing.iter().enumerate() {
238 let ms = CudaEvent::elapsed_time(start, end)?;
239 if ms > 0.0 && self.last_shard_sizes[rank] > 0 {
240 let throughput = self.last_shard_sizes[rank] as f64 / ms as f64;
241 if self.ema_throughput[rank] == 0.0 {
242 // First measurement: initialize directly
243 self.ema_throughput[rank] = throughput;
244 } else {
245 self.ema_throughput[rank] =
246 EMA_ALPHA * throughput + (1.0 - EMA_ALPHA) * self.ema_throughput[rank];
247 }
248 }
249 }
250 }
251
252 // Check if it's time to rebalance
253 let should_rebalance = if self.step_count == self.calibration_steps {
254 true
255 } else if self.step_count > self.calibration_steps {
256 (self.step_count - self.calibration_steps) % self.rebalance_interval == 0
257 } else {
258 false
259 };
260
261 if should_rebalance {
262 self.rebalance();
263 return Ok(true);
264 }
265
266 Ok(false)
267 }
268
269 /// Recompute chunk_ratios proportional to EMA throughput.
270 fn rebalance(&mut self) {
271 let total: f64 = self.ema_throughput.iter().sum();
272 if total <= 0.0 {
273 return; // no data yet
274 }
275
276 let n = self.devices.len();
277 let min_total = MIN_CHUNK_RATIO * n as f64;
278
279 // Compute raw proportional ratios
280 let mut ratios: Vec<f64> = self.ema_throughput.iter().map(|t| t / total).collect();
281
282 // Clamp: no device below MIN_CHUNK_RATIO
283 let mut deficit = 0.0;
284 let mut unclamped = 0;
285 for r in &mut ratios {
286 if *r < MIN_CHUNK_RATIO {
287 deficit += MIN_CHUNK_RATIO - *r;
288 *r = MIN_CHUNK_RATIO;
289 } else {
290 unclamped += 1;
291 }
292 }
293
294 // Redistribute deficit from unclamped devices proportionally
295 if deficit > 0.0 && unclamped > 0 {
296 let unclamped_total: f64 = ratios
297 .iter()
298 .filter(|&&r| r > MIN_CHUNK_RATIO + 1e-9)
299 .sum();
300 if unclamped_total > min_total {
301 for r in &mut ratios {
302 if *r > MIN_CHUNK_RATIO + 1e-9 {
303 *r -= deficit * (*r / unclamped_total);
304 *r = r.max(MIN_CHUNK_RATIO);
305 }
306 }
307 }
308 }
309
310 // Normalize to sum exactly to 1.0
311 let sum: f64 = ratios.iter().sum();
312 if sum > 0.0 {
313 for r in &mut ratios {
314 *r /= sum;
315 }
316 }
317
318 self.chunk_ratios = ratios;
319 }
320
321 /// Configure El Che cadence from a [`DdpConfig`].
322 ///
323 /// Creates an internal ElChe when enabled (max_anchor != Some(0)),
324 /// seeds chunk_ratios from speed_hint if provided.
325 pub(crate) fn configure_el_che(&mut self, config: &DdpConfig) {
326 let n = self.devices.len();
327 if n < 2 {
328 return;
329 }
330
331 // max_anchor = Some(0) → disabled (traditional DDP)
332 if config.max_anchor == Some(0) {
333 self.el_che = None;
334 return;
335 }
336
337 // Build ElChe with sensible defaults
338 let anchor = 10; // initial anchor, auto-tunes from timing
339 let mut el_che = ElChe::new(n, anchor);
340
341 if let Some(target) = config.overhead_target {
342 el_che = el_che.with_overhead_target(target);
343 }
344 if let Some(max) = config.max_anchor {
345 el_che = el_che.with_max_anchor(max);
346 }
347 if let Some((slow_rank, ratio)) = config.speed_hint {
348 el_che = el_che.with_speed_ratio(slow_rank, ratio);
349 // Also seed chunk_ratios for the existing auto-balancer
350 self.apply_speed_hint(slow_rank, ratio);
351 }
352
353 self.el_che = Some(el_che);
354 self.max_grad_norm = config.max_grad_norm;
355 }
356
357 /// Seed chunk_ratios from a speed hint.
358 fn apply_speed_hint(&mut self, slow_rank: usize, ratio: f64) {
359 let n = self.devices.len();
360 if slow_rank >= n {
361 return;
362 }
363 let ratio = ratio.max(1.0);
364 let mut weights = vec![ratio; n];
365 weights[slow_rank] = 1.0;
366 let total: f64 = weights.iter().sum();
367 self.chunk_ratios = weights.iter().map(|w| w / total).collect();
368 }
369}
370
371// ---------------------------------------------------------------------------
372// Manual DDP coordinator
373// ---------------------------------------------------------------------------
374
375/// Manual DDP coordinator for multi-GPU gradient sync.
376///
377/// For complex training patterns (GAN, RL, progressive) where transparent
378/// Graph-level DDP doesn't fit. Provides explicit control over parameter
379/// broadcast and gradient averaging.
380///
381/// For standard training, use [`crate::graph::Graph::distribute`] instead.
382pub struct Ddp {
383 comms: NcclComms,
384 devices: Vec<Device>,
385 param_groups: Vec<Vec<Variable>>,
386 buffer_groups: Vec<Vec<Buffer>>,
387}
388
389impl Ddp {
390 /// Wrap pre-created model replicas for manual DDP control.
391 ///
392 /// Models must have identical architecture (same parameter count/shapes).
393 /// Each model should already reside on its target device.
394 pub fn wrap(models: &[&dyn Module], devices: &[Device]) -> Result<Self> {
395 if models.len() < 2 {
396 return Err(TensorError::new("Ddp::wrap requires at least 2 models"));
397 }
398 if models.len() != devices.len() {
399 return Err(TensorError::new(
400 "Ddp::wrap: model count must match device count",
401 ));
402 }
403
404 let comms = NcclComms::new(devices)?;
405
406 // Match parameters across models
407 let all_params: Vec<Vec<Parameter>> =
408 models.iter().map(|m| m.parameters()).collect();
409 let n_params = all_params[0].len();
410 for (rank, params) in all_params.iter().enumerate().skip(1) {
411 if params.len() != n_params {
412 return Err(TensorError::new(&format!(
413 "Ddp: replica {} has {} parameters, expected {}",
414 rank,
415 params.len(),
416 n_params
417 )));
418 }
419 }
420
421 let mut param_groups = Vec::with_capacity(n_params);
422 for pi in 0..n_params {
423 let group: Vec<Variable> =
424 all_params.iter().map(|p| p[pi].variable.clone()).collect();
425 param_groups.push(group);
426 }
427
428 // Match buffers
429 let all_buffers: Vec<Vec<Buffer>> =
430 models.iter().map(|m| m.buffers()).collect();
431 let n_buffers = all_buffers[0].len();
432 let mut buffer_groups = Vec::with_capacity(n_buffers);
433 for bi in 0..n_buffers {
434 let group: Vec<Buffer> =
435 all_buffers.iter().map(|b| b[bi].clone()).collect();
436 buffer_groups.push(group);
437 }
438
439 Ok(Ddp {
440 comms,
441 devices: devices.to_vec(),
442 param_groups,
443 buffer_groups,
444 })
445 }
446
447 /// Broadcast all parameters and buffers from rank 0 to all replicas.
448 pub fn sync_params(&self) -> Result<()> {
449 for group in &self.param_groups {
450 let tensors: Vec<Tensor> = group.iter().map(|v| v.data()).collect();
451 let refs: Vec<&Tensor> = tensors.iter().collect();
452 self.comms.broadcast(&refs, 0)?;
453 }
454 for group in &self.buffer_groups {
455 let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
456 let refs: Vec<&Tensor> = tensors.iter().collect();
457 self.comms.broadcast(&refs, 0)?;
458 }
459 Ok(())
460 }
461
462 /// AllReduce-average gradients across all replicas.
463 /// Call after backward(), before optimizer.step().
464 pub fn all_reduce_gradients(&self) -> Result<()> {
465 for group in &self.param_groups {
466 if group[0].grad().is_none() {
467 continue;
468 }
469 let grads: Vec<Tensor> = group
470 .iter()
471 .map(|v| v.grad().expect("gradient missing on replica"))
472 .collect();
473 let refs: Vec<&Tensor> = grads.iter().collect();
474 self.comms.all_reduce(&refs, ReduceOp::Avg)?;
475 }
476 Ok(())
477 }
478
479 /// Broadcast buffers from rank 0 (BatchNorm running stats etc).
480 pub fn sync_buffers(&self) -> Result<()> {
481 for group in &self.buffer_groups {
482 let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
483 let refs: Vec<&Tensor> = tensors.iter().collect();
484 self.comms.broadcast(&refs, 0)?;
485 }
486 Ok(())
487 }
488
489 /// AllReduce gradients weighted by per-device batch contribution.
490 ///
491 /// For heterogeneous DDP where devices process different numbers of
492 /// batches per sync step. Each replica's gradient is scaled by
493 /// `(batch_counts[rank] / total)` before AllReduce Sum, producing
494 /// the correct mean gradient.
495 ///
496 /// Use with [`ElChe::batch_counts`] for automatic weighting
497 /// (see [`ElChe`] for the full heterogeneous DDP strategy):
498 ///
499 /// ```ignore
500 /// ddp.weighted_all_reduce_gradients(cadence.batch_counts())?;
501 /// ```
502 pub fn weighted_all_reduce_gradients(&self, batch_counts: &[usize]) -> Result<()> {
503 if batch_counts.len() != self.devices.len() {
504 return Err(TensorError::new(&format!(
505 "weighted_all_reduce: batch_counts len ({}) != device count ({})",
506 batch_counts.len(),
507 self.devices.len(),
508 )));
509 }
510 let total: usize = batch_counts.iter().sum();
511 if total == 0 {
512 return Err(TensorError::new("weighted_all_reduce: total batch count is 0"));
513 }
514 for group in &self.param_groups {
515 if group[0].grad().is_none() {
516 continue;
517 }
518 let grads: Vec<Tensor> = group
519 .iter()
520 .enumerate()
521 .map(|(rank, v)| {
522 let g = v.grad().expect("gradient missing on replica");
523 let weight = batch_counts[rank] as f64 / total as f64;
524 g.mul_scalar_(weight).ok();
525 g
526 })
527 .collect();
528 let refs: Vec<&Tensor> = grads.iter().collect();
529 self.comms.all_reduce(&refs, ReduceOp::Sum)?;
530 }
531 Ok(())
532 }
533
534 /// Number of devices.
535 pub fn world_size(&self) -> usize {
536 self.devices.len()
537 }
538
539 /// Devices in use.
540 pub fn devices(&self) -> &[Device] {
541 &self.devices
542 }
543
544 // --- Deprecated aliases: use Trainer:: as the primary entry point ---
545
546 /// Deprecated: use [`Trainer::setup()`] instead.
547 ///
548 /// [`Trainer`] is now the primary training entry point and carries the
549 /// same behavior for 1 or N GPUs. [`Ddp`] remains for explicit
550 /// multi-GPU control via [`Ddp::wrap`].
551 #[deprecated(note = "use Trainer::setup() - same behavior. Ddp::setup will be removed in a future release.")]
552 pub fn setup<F, M, G, O>(
553 model: &Graph,
554 builder: F,
555 optimizer: G,
556 ) -> Result<()>
557 where
558 F: Fn(Device) -> Result<M>,
559 M: Module + 'static,
560 G: Fn(&[Parameter]) -> O,
561 O: Optimizer + 'static,
562 {
563 Trainer::setup(model, builder, optimizer)
564 }
565
566 /// Deprecated: use [`Trainer::setup_with()`] instead.
567 #[deprecated(note = "use Trainer::setup_with() - same behavior. Ddp::setup_with will be removed in a future release.")]
568 pub fn setup_with<F, M, G, O>(
569 model: &Graph,
570 builder: F,
571 optimizer: G,
572 config: DdpConfig,
573 ) -> Result<()>
574 where
575 F: Fn(Device) -> Result<M>,
576 M: Module + 'static,
577 G: Fn(&[Parameter]) -> O,
578 O: Optimizer + 'static,
579 {
580 Trainer::setup_with(model, builder, optimizer, config)
581 }
582
583 /// Deprecated: renamed to [`Trainer::setup()`].
584 #[deprecated(since = "0.3.0", note = "Renamed to Trainer::setup()")]
585 pub fn auto<F, M, G, O>(
586 model: &Graph,
587 builder: F,
588 optimizer: G,
589 ) -> Result<()>
590 where
591 F: Fn(Device) -> Result<M>,
592 M: Module + 'static,
593 G: Fn(&[Parameter]) -> O,
594 O: Optimizer + 'static,
595 {
596 Trainer::setup(model, builder, optimizer)
597 }
598
599 /// Deprecated: renamed to [`Trainer::setup_with()`].
600 #[deprecated(since = "0.3.0", note = "Renamed to Trainer::setup_with()")]
601 pub fn auto_with<F, M, G, O>(
602 model: &Graph,
603 builder: F,
604 optimizer: G,
605 config: DdpConfig,
606 ) -> Result<()>
607 where
608 F: Fn(Device) -> Result<M>,
609 M: Module + 'static,
610 G: Fn(&[Parameter]) -> O,
611 O: Optimizer + 'static,
612 {
613 Trainer::setup_with(model, builder, optimizer, config)
614 }
615
616 // -------------------------------------------------------------------
617 // Deprecated builder entry: use Trainer::builder instead
618 // -------------------------------------------------------------------
619
620 /// Deprecated: use [`Trainer::builder()`] instead.
621 ///
622 /// [`Trainer`] is the primary training entry point and works
623 /// transparently for single-GPU and multi-GPU. This alias is retained
624 /// for backwards compatibility and will be removed in a future release.
625 #[deprecated(note = "use Trainer::builder() - same behavior. Ddp::builder will be removed in a future release.")]
626 pub fn builder<F, M, G, O, T>(
627 model_factory: F,
628 optim_factory: G,
629 train_fn: T,
630 ) -> DdpBuilder<F, M, G, O, T>
631 where
632 F: Fn(Device) -> Result<M> + Send + Sync + 'static,
633 M: Module + 'static,
634 G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
635 O: Optimizer + 'static,
636 T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
637 {
638 Trainer::builder(model_factory, optim_factory, train_fn)
639 }
640
641 /// Detect whether the current CUDA setup has different GPU models.
642 fn is_heterogeneous() -> bool {
643 use crate::tensor::{cuda_available, cuda_device_count, cuda_device_name_idx};
644 if !cuda_available() || cuda_device_count() < 2 {
645 return false;
646 }
647 let n = cuda_device_count();
648 let names: Vec<Option<String>> = (0..n)
649 .map(cuda_device_name_idx)
650 .collect();
651 names.windows(2).any(|w| w[0] != w[1])
652 }
653
654 /// Print a diagnostic summary of detected CUDA devices to stderr.
655 fn print_device_summary() {
656 use crate::tensor::{
657 cuda_available, cuda_device_count,
658 cuda_device_name_idx, cuda_memory_info_idx,
659 };
660 use crate::monitor::format_bytes;
661
662 if !cuda_available() || cuda_device_count() == 0 {
663 crate::verbose!(" ddp: no CUDA available | CPU mode");
664 return;
665 }
666
667 let n = cuda_device_count();
668 let mut names = Vec::with_capacity(n as usize);
669 let mut parts = Vec::with_capacity(n as usize);
670
671 for i in 0..n {
672 let raw_name = cuda_device_name_idx(i)
673 .unwrap_or_else(|| format!("CUDA({})", i));
674 let short = raw_name
675 .strip_prefix("NVIDIA ")
676 .unwrap_or(&raw_name)
677 .to_string();
678 let vram = cuda_memory_info_idx(i)
679 .map(|(_, total)| format!(" ({})", format_bytes(total)))
680 .unwrap_or_default();
681 parts.push(format!("{}{}", short, vram));
682 names.push(raw_name);
683 }
684
685 let heterogeneous = names.windows(2).any(|w| w[0] != w[1]);
686
687 if n == 1 {
688 crate::verbose!(" ddp: 1 GPU | {} | single-device mode", parts[0]);
689 } else if heterogeneous {
690 crate::verbose!(
691 " ddp: {} GPUs (heterogeneous) | {}",
692 n,
693 parts.join(" | "),
694 );
695 } else {
696 crate::verbose!(" ddp: {} GPUs | {}", n, parts.join(" | "));
697 }
698 }
699}
700
701// ---------------------------------------------------------------------------
702// Trainer: primary training entry point
703// ---------------------------------------------------------------------------
704
705/// Primary entry point for training in flodl.
706///
707/// `Trainer` is the default API for training a model, whether you have one
708/// GPU, many GPUs, or no GPU at all. The training loop is identical in all
709/// cases: [`Trainer::setup`] (or [`Trainer::builder`]) configures the model,
710/// detects the hardware, and enables distributed training automatically when
711/// multiple CUDA devices are available. On a single GPU or CPU it's a no-op
712/// wrapper with zero DDP overhead.
713///
714/// For explicit multi-GPU control (manual gradient sync, custom replica
715/// wrapping) use [`Ddp`] directly. [`Ddp::wrap`] remains the entry point for
716/// advanced patterns (GAN, RL, progressive).
717///
718/// # Setup mode (user owns the loop)
719///
720/// ```ignore
721/// Trainer::setup(&model, |dev| build_model(dev), |p| Adam::new(p, 0.001))?;
722///
723/// for (x, y) in &train_loader {
724/// let out = model.forward(&x)?;
725/// let loss = cross_entropy_loss(&out, &y)?;
726/// loss.backward()?;
727/// model.step()?;
728/// }
729/// ```
730///
731/// # Builder mode (framework owns the loop)
732///
733/// ```ignore
734/// let handle = Trainer::builder(model_factory, optim_factory, train_fn)
735/// .dataset(dataset)
736/// .batch_size(32)
737/// .num_epochs(10)
738/// .run()?;
739///
740/// let state = handle.join()?;
741/// ```
742pub struct Trainer;
743
744impl Trainer {
745 /// One-call setup: auto-detect GPUs, distribute the model, set the
746 /// optimizer, and enable training mode.
747 ///
748 /// - **Multi-GPU** (2+ usable CUDA devices): replicates via
749 /// [`Graph::distribute`], creates per-replica optimizers, enables training.
750 /// - **Single-GPU / CPU**: sets optimizer and training mode only (no DDP
751 /// overhead).
752 ///
753 /// Always prints a diagnostic summary to stderr showing detected hardware.
754 ///
755 /// ```ignore
756 /// Trainer::setup(&model, |dev| build_model(dev), |p| Adam::new(p, 0.001))?;
757 ///
758 /// for batch in model.epoch(epoch).activate() {
759 /// let out = model.forward_batch(&batch?)?;
760 /// loss.backward()?;
761 /// model.step()?;
762 /// }
763 /// ```
764 pub fn setup<F, M, G, O>(
765 model: &Graph,
766 builder: F,
767 optimizer: G,
768 ) -> Result<()>
769 where
770 F: Fn(Device) -> Result<M>,
771 M: Module + 'static,
772 G: Fn(&[Parameter]) -> O,
773 O: Optimizer + 'static,
774 {
775 Ddp::print_device_summary();
776 model.distribute(builder)?;
777 model.set_optimizer(optimizer);
778 model.set_training(true);
779
780 // Auto-enable El Che for heterogeneous GPU setups
781 if Ddp::is_heterogeneous() {
782 model.configure_el_che(&DdpConfig::new());
783 }
784
785 Ok(())
786 }
787
788 /// One-call setup with explicit configuration.
789 ///
790 /// Like [`setup()`](Self::setup) but accepts a [`DdpConfig`] for
791 /// controlling El Che cadence, speed hints, and overhead targets.
792 ///
793 /// ```ignore
794 /// Trainer::setup_with(&model, builder, optimizer,
795 /// DdpConfig::new().speed_hint(1, 2.3))?;
796 /// ```
797 pub fn setup_with<F, M, G, O>(
798 model: &Graph,
799 builder: F,
800 optimizer: G,
801 config: DdpConfig,
802 ) -> Result<()>
803 where
804 F: Fn(Device) -> Result<M>,
805 M: Module + 'static,
806 G: Fn(&[Parameter]) -> O,
807 O: Optimizer + 'static,
808 {
809 Ddp::print_device_summary();
810 model.distribute(builder)?;
811 model.set_optimizer(optimizer);
812 model.set_training(true);
813 model.configure_el_che(&config);
814 // Pass timeline to distributed state for event injection in step().
815 if let Some(tl) = config.timeline {
816 if let Some(ref mut state) = *model.distributed.borrow_mut() {
817 state.timeline = Some(tl);
818 }
819 }
820 Ok(())
821 }
822
823 /// Create a builder for framework-managed training.
824 ///
825 /// The framework owns the training loop, data pipeline, and epoch
826 /// management. On multi-GPU hardware, each device gets its own model
827 /// replica and optimizer, and a coordinator triggers periodic
828 /// parameter averaging based on the configured [`ApplyPolicy`] and
829 /// [`AverageBackend`]. On a single GPU, training runs on the main
830 /// thread with no coordination - the API is identical in both cases.
831 ///
832 /// Returns a [`DdpBuilder`] for fluent configuration. Call `.run()` to
833 /// spawn training, then `.join()` on the returned [`DdpHandle`] to
834 /// block until completion.
835 ///
836 /// [`ApplyPolicy`]: crate::distributed::ApplyPolicy
837 /// [`AverageBackend`]: crate::distributed::AverageBackend
838 ///
839 /// # Example
840 ///
841 /// ```ignore
842 /// use flodl::*;
843 ///
844 /// let handle = Trainer::builder(
845 /// |dev| model_factory(dev),
846 /// |params| Adam::new(params, 0.001),
847 /// |model, batch| { /* forward + loss */ },
848 /// )
849 /// .dataset(dataset)
850 /// .batch_size(32)
851 /// .num_epochs(10)
852 /// .policy(ApplyPolicy::Cadence)
853 /// .backend(AverageBackend::Nccl)
854 /// .run()?;
855 ///
856 /// let state = handle.join()?;
857 /// ```
858 pub fn builder<F, M, G, O, T>(
859 model_factory: F,
860 optim_factory: G,
861 train_fn: T,
862 ) -> DdpBuilder<F, M, G, O, T>
863 where
864 F: Fn(Device) -> Result<M> + Send + Sync + 'static,
865 M: Module + 'static,
866 G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
867 O: Optimizer + 'static,
868 T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
869 {
870 DdpHandle::new_builder(model_factory, optim_factory, train_fn)
871 }
872
873 /// One-call setup for a task-head wrapper (e.g. `flodl-hf`'s
874 /// `BertForSequenceClassification`). The wrapper must implement
875 /// [`HasGraph`] so `Trainer` can reach the underlying [`Graph`].
876 ///
877 /// Semantics match [`Trainer::setup`] exactly; the only difference is
878 /// that `head_factory` builds a fresh wrapper (not a bare `Graph`) on
879 /// each replica device. Useful when the training-loop code holds onto
880 /// the wrapper's richer surface (`compute_loss`, `predict`, attached
881 /// tokenizer) but still wants transparent 1-or-N-GPU DDP.
882 ///
883 /// ```ignore
884 /// let head = DistilBertForSequenceClassification::from_pretrained(repo)?;
885 /// let config = head.config().clone();
886 /// let num_labels = head.labels().len() as i64;
887 ///
888 /// Trainer::setup_head(
889 /// &head,
890 /// move |dev| DistilBertForSequenceClassification::on_device(&config, num_labels, dev),
891 /// |p| Adam::new(p, 5e-5),
892 /// )?;
893 ///
894 /// for (enc, labels) in &batches {
895 /// let loss = head.compute_loss(&enc, &labels)?;
896 /// loss.backward()?;
897 /// head.graph().step()?;
898 /// }
899 /// ```
900 pub fn setup_head<H, F, G, O>(
901 head: &H,
902 head_factory: F,
903 optimizer: G,
904 ) -> Result<()>
905 where
906 H: HasGraph,
907 F: Fn(Device) -> Result<H> + 'static,
908 H: 'static,
909 G: Fn(&[Parameter]) -> O,
910 O: Optimizer + 'static,
911 {
912 Ddp::print_device_summary();
913 let graph = head.graph();
914 graph.distribute(move |dev| head_factory(dev).map(|h| HeadReplica { head: h }))?;
915 graph.set_optimizer(optimizer);
916 graph.set_training(true);
917
918 if Ddp::is_heterogeneous() {
919 graph.configure_el_che(&DdpConfig::new());
920 }
921
922 Ok(())
923 }
924
925 /// Task-head variant of [`Trainer::setup_with`]. Same behaviour as
926 /// [`Trainer::setup_head`] but takes an explicit [`DdpConfig`].
927 pub fn setup_head_with<H, F, G, O>(
928 head: &H,
929 head_factory: F,
930 optimizer: G,
931 config: DdpConfig,
932 ) -> Result<()>
933 where
934 H: HasGraph,
935 F: Fn(Device) -> Result<H> + 'static,
936 H: 'static,
937 G: Fn(&[Parameter]) -> O,
938 O: Optimizer + 'static,
939 {
940 Ddp::print_device_summary();
941 let graph = head.graph();
942 graph.distribute(move |dev| head_factory(dev).map(|h| HeadReplica { head: h }))?;
943 graph.set_optimizer(optimizer);
944 graph.set_training(true);
945 graph.configure_el_che(&config);
946 if let Some(tl) = config.timeline {
947 if let Some(ref mut state) = *graph.distributed.borrow_mut() {
948 state.timeline = Some(tl);
949 }
950 }
951 Ok(())
952 }
953}
954
955// ---------------------------------------------------------------------------
956// HasGraph trait: lets wrapper types plug into Trainer::setup_head
957// ---------------------------------------------------------------------------
958
959/// A wrapper type that exposes an inner [`Graph`].
960///
961/// Implement on any wrapper around a `Graph` that should participate in
962/// [`Trainer::setup_head`] or other graph-aware DDP machinery. The
963/// reference returned must outlive `&self` and point at the same graph
964/// used for the wrapper's forward / loss calls.
965///
966/// [`Graph`] implements this trivially (returns `self`) so bare-graph
967/// callers can pass a `&Graph` wherever `&impl HasGraph` is accepted.
968///
969/// ```ignore
970/// impl HasGraph for BertForSequenceClassification {
971/// fn graph(&self) -> &Graph { &self.graph }
972/// }
973/// ```
974pub trait HasGraph {
975 /// Borrow the inner training graph.
976 fn graph(&self) -> &Graph;
977}
978
979impl HasGraph for Graph {
980 fn graph(&self) -> &Graph { self }
981}
982
983/// Internal Module adapter used by [`Trainer::setup_head`] to feed a
984/// `HasGraph` replica through [`Graph::distribute`].
985///
986/// `distribute` boxes each replica as `Box<dyn Module>`. Task-head
987/// wrappers don't implement `Module` directly (their true forward is
988/// multi-input via [`Graph::forward_multi`], which doesn't fit the
989/// single-Variable `Module::forward` signature). `HeadReplica` delegates
990/// every Module method through to the inner graph and overrides
991/// [`Module::as_graph`] so DDP's multi-input replica paths downcast
992/// cleanly rather than hitting the single-input fallback.
993struct HeadReplica<H: HasGraph + 'static> {
994 head: H,
995}
996
997impl<H: HasGraph + 'static> Module for HeadReplica<H> {
998 fn forward(&self, input: &Variable) -> Result<Variable> {
999 // Single-input fallback. Task-head DDP paths reach
1000 // forward_multi via `as_graph()` below, so this is only
1001 // exercised on single-input replica paths (e.g. the scatter
1002 // forward in `forward_distributed_scatter`). For multi-input
1003 // heads that path is never triggered because the user calls
1004 // the head's own `compute_loss` / `forward_encoded`, which
1005 // route through `Graph::forward_multi` directly.
1006 self.head.graph().forward(input)
1007 }
1008 fn parameters(&self) -> Vec<Parameter> { self.head.graph().parameters() }
1009 fn buffers(&self) -> Vec<Buffer> { self.head.graph().buffers() }
1010 fn name(&self) -> &str { "head_replica" }
1011 fn set_training(&self, training: bool) { self.head.graph().set_training(training); }
1012 fn as_graph(&self) -> Option<&Graph> { Some(self.head.graph()) }
1013}
1014
1015// ---------------------------------------------------------------------------
1016// DDP configuration
1017// ---------------------------------------------------------------------------
1018
1019/// Configuration for [`Trainer::setup_with()`].
1020///
1021/// Controls El Che cadence behavior for heterogeneous multi-GPU training.
1022/// Use [`DdpConfig::new()`] for defaults or build with method chaining.
1023///
1024/// ```ignore
1025/// Trainer::setup_with(&model, builder, optimizer,
1026/// DdpConfig::new()
1027/// .speed_hint(1, 2.3) // rank 1 is slow, 2.3x ratio
1028/// .overhead_target(0.08) // tune to 8% overhead
1029/// )?;
1030/// ```
1031#[derive(Debug, Clone)]
1032pub struct DdpConfig {
1033 /// Initial speed ratio hint: (slow_rank, fast_to_slow_ratio).
1034 /// Applied before the first timing measurement.
1035 pub speed_hint: Option<(usize, f64)>,
1036 /// AllReduce overhead target for anchor auto-tune (default: 0.10).
1037 pub overhead_target: Option<f64>,
1038 /// Max batches on slow device before AllReduce.
1039 /// - `None` = auto (El Che decides, default).
1040 /// - `Some(0)` = disabled (traditional per-batch DDP, no El Che).
1041 /// - `Some(n)` = fixed anchor at n.
1042 pub max_anchor: Option<usize>,
1043 /// Maximum gradient norm for per-rank clipping in El Che mode.
1044 ///
1045 /// When set, each rank's accumulated gradients are clipped (L2 norm)
1046 /// before the normalize-by-count and weighted AllReduce steps. This
1047 /// ensures replica gradients (which the caller cannot reach) are bounded
1048 /// identically to rank 0.
1049 ///
1050 /// Standard DDP does not need this because the caller clips rank 0's
1051 /// gradients and AllReduce averages them.
1052 pub max_grad_norm: Option<f64>,
1053 /// Optional system timeline for high-frequency profiling.
1054 pub timeline: Option<std::sync::Arc<crate::monitor::Timeline>>,
1055}
1056
1057impl DdpConfig {
1058 /// Default configuration: El Che auto-enabled for heterogeneous GPUs.
1059 pub fn new() -> Self {
1060 DdpConfig {
1061 speed_hint: None,
1062 overhead_target: None,
1063 max_anchor: None,
1064 max_grad_norm: None,
1065 timeline: None,
1066 }
1067 }
1068
1069 /// Set initial speed ratio hint.
1070 ///
1071 /// `slow_rank`: which device is slowest.
1072 /// `ratio`: how many times faster the fastest device is (e.g., 2.3).
1073 ///
1074 /// After the first AllReduce, El Che discovers actual speeds and
1075 /// self-corrects even a wrong guess.
1076 pub fn speed_hint(mut self, slow_rank: usize, ratio: f64) -> Self {
1077 self.speed_hint = Some((slow_rank, ratio));
1078 self
1079 }
1080
1081 /// Set AllReduce overhead target (fraction of compute time).
1082 ///
1083 /// Default: 0.10 (10%). Lower values = fewer AllReduces = more
1084 /// gradient accumulation. El Che auto-tunes the anchor to stay
1085 /// below this target.
1086 pub fn overhead_target(mut self, target: f64) -> Self {
1087 self.overhead_target = Some(target.clamp(0.01, 0.50));
1088 self
1089 }
1090
1091 /// Set max batches on slow device before AllReduce.
1092 ///
1093 /// - `None` (default): El Che auto-tunes from overhead measurement.
1094 /// - `Some(0)`: disable El Che entirely (traditional per-batch sync).
1095 /// - `Some(n)`: fixed anchor at n (fast device gets proportionally more).
1096 pub fn max_anchor(mut self, max: Option<usize>) -> Self {
1097 self.max_anchor = max;
1098 self
1099 }
1100
1101 /// Set maximum gradient norm for per-rank clipping in El Che mode.
1102 ///
1103 /// When set, each rank's accumulated gradients are clipped to this L2
1104 /// norm before normalize-by-count and AllReduce. Essential for
1105 /// heterogeneous DDP where replica gradients are otherwise unreachable
1106 /// by the caller.
1107 pub fn max_grad_norm(mut self, max_norm: f64) -> Self {
1108 self.max_grad_norm = Some(max_norm);
1109 self
1110 }
1111
1112 /// Attach a system timeline for high-frequency profiling.
1113 pub fn timeline(mut self, tl: std::sync::Arc<crate::monitor::Timeline>) -> Self {
1114 self.timeline = Some(tl);
1115 self
1116 }
1117}
1118
1119impl Default for DdpConfig {
1120 fn default() -> Self {
1121 Self::new()
1122 }
1123}
1124
1125#[cfg(test)]
1126#[path = "ddp_tests.rs"]
1127mod tests;