Skip to main content

flodl/distributed/
ddp.rs

1//! Training entry points for flodl.
2//!
3//! The primary entry point is [`Trainer`]. It works transparently on 1 or
4//! N GPUs - single-device training has zero DDP overhead. Reach for
5//! [`Trainer`] by default; drop to [`Ddp`] only when you need explicit
6//! multi-GPU control.
7//!
8//! **Default** ([`Trainer::setup()`], [`Trainer::builder()`]): user-owned or
9//! framework-owned training loop, transparent single/multi-GPU. Same API in
10//! both cases.
11//!
12//! **Explicit multi-GPU** ([`Ddp::wrap()`]): manual control over gradient
13//! sync and parameter broadcast for advanced patterns (GAN, RL, progressive).
14//!
15//! # Setup mode (user owns the loop)
16//!
17//! ```ignore
18//! Trainer::setup(&model, |dev| build_model(dev), |p| Adam::new(p, 0.001))?;
19//!
20//! // Training loop is identical for 1 or N GPUs:
21//! for (x, y) in &train_loader {
22//!     let out = model.forward(&x)?;
23//!     let loss = cross_entropy_loss(&out, &y)?;
24//!     loss.backward()?;
25//!     model.step()?;
26//! }
27//! ```
28//!
29//! # Builder mode (framework owns the loop)
30//!
31//! ```ignore
32//! let handle = Trainer::builder(model_factory, optim_factory, train_fn)
33//!     .dataset(dataset)
34//!     .batch_size(32)
35//!     .num_epochs(10)
36//!     .run()?;
37//!
38//! let state = handle.join()?;
39//! ```
40//!
41//! # Manual DDP
42//!
43//! ```ignore
44//! let ddp = Ddp::wrap(&[&model0, &model1], &devices)?;
45//! ddp.sync_params()?;
46//! // ... custom forward/backward ...
47//! ddp.all_reduce_gradients()?;
48//! ```
49
50use crate::autograd::Variable;
51use crate::graph::Graph;
52use crate::nn::{Buffer, Module, Optimizer, Parameter};
53use super::cuda_event::CudaEvent;
54use super::nccl::{NcclComms, ReduceOp};
55use super::ddp_run::{DdpBuilder, DdpHandle};
56pub use super::el_che::ElChe;
57use crate::tensor::{Device, Result, Tensor, TensorError};
58
59
60/// Shared lock for serializing NCCL communicator creation across test modules.
61/// NCCL init is a collective operation that deadlocks if two tests try to
62/// create communicators simultaneously.
63#[cfg(test)]
64pub(crate) static NCCL_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
65
66/// Default number of steps before the first rebalance.
67pub(crate) const DEFAULT_CALIBRATION_STEPS: usize = 10;
68
69/// How often to re-evaluate chunk ratios after calibration.
70pub(crate) const DEFAULT_REBALANCE_INTERVAL: usize = 50;
71
72/// EMA smoothing factor for throughput tracking (higher = more reactive).
73const EMA_ALPHA: f64 = 0.3;
74
75/// Minimum ratio any device can receive (prevents starving a GPU entirely).
76const MIN_CHUNK_RATIO: f64 = 0.05;
77
78// ---------------------------------------------------------------------------
79// Internal distributed state (held by Graph)
80// ---------------------------------------------------------------------------
81
82/// Internal distributed state held by Graph when `distribute()` is called.
83pub(crate) struct DistributedState {
84    /// Model replicas for ranks 1..N (rank 0 is the Graph itself).
85    pub replicas: Vec<Box<dyn Module>>,
86    /// NCCL communicators (one per device).
87    pub comms: NcclComms,
88    /// All devices including rank 0.
89    pub devices: Vec<Device>,
90    /// Per-replica optimizers indexed by rank (including rank 0).
91    pub optimizers: Vec<Box<dyn Optimizer>>,
92    /// Chunk ratios for auto-balancing (sum = 1.0). Default: equal.
93    pub chunk_ratios: Vec<f64>,
94    /// Parameters matched across replicas: param_groups\[param_idx\]\[rank\].
95    pub param_groups: Vec<Vec<Variable>>,
96    /// Buffers matched across replicas: buffer_groups\[buf_idx\]\[rank\].
97    pub buffer_groups: Vec<Vec<Buffer>>,
98
99    // -- Auto-balancer state --
100
101    /// Per-rank forward timing events from last forward pass: (start, end).
102    /// Set by forward_distributed(), read by step().
103    pub last_timing: Option<Vec<(CudaEvent, CudaEvent)>>,
104    /// Shard sizes from last forward pass (for throughput calculation).
105    pub last_shard_sizes: Vec<i64>,
106    /// EMA throughput per rank (samples/ms). Zero until first measurement.
107    pub ema_throughput: Vec<f64>,
108    /// Number of completed training steps.
109    pub step_count: usize,
110    /// Steps of equal-split calibration before first rebalance.
111    pub calibration_steps: usize,
112    /// Steps between ratio recalculations after calibration.
113    pub rebalance_interval: usize,
114
115    // -- El Che cadence (heterogeneous DDP) --
116
117    /// El Che cadence strategy. When Some, Graph uses per-device multi-batch
118    /// forward instead of per-batch scatter. When None, existing scatter path.
119    pub el_che: Option<ElChe>,
120    /// Per-rank batch counts from the last El Che forward pass.
121    /// Set by forward_distributed_el_che(), read by step().
122    pub last_el_che_counts: Vec<usize>,
123    /// Wall-clock time at end of last El Che AllReduce.
124    pub last_el_che_sync: Option<std::time::Instant>,
125    /// Maximum gradient norm for per-rank clipping in El Che mode.
126    pub max_grad_norm: Option<f64>,
127    /// Optional system timeline for high-frequency profiling.
128    pub timeline: Option<std::sync::Arc<crate::monitor::Timeline>>,
129}
130
131impl DistributedState {
132    /// AllReduce-average gradients across all replicas.
133    pub fn all_reduce_gradients(&self) -> Result<()> {
134        for group in &self.param_groups {
135            // Skip frozen parameters (no gradient on rank 0)
136            if group[0].grad().is_none() {
137                continue;
138            }
139            let grads: Vec<Tensor> = group
140                .iter()
141                .map(|v| v.grad().expect("gradient missing on replica"))
142                .collect();
143            let refs: Vec<&Tensor> = grads.iter().collect();
144            self.comms.all_reduce(&refs, ReduceOp::Avg)?;
145        }
146        Ok(())
147    }
148
149    /// Broadcast buffers from rank 0 to all replicas (BatchNorm stats etc).
150    pub fn sync_buffers(&self) -> Result<()> {
151        for group in &self.buffer_groups {
152            let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
153            let refs: Vec<&Tensor> = tensors.iter().collect();
154            self.comms.broadcast(&refs, 0)?;
155        }
156        Ok(())
157    }
158
159    /// Broadcast parameters and buffers from rank 0 to all replicas.
160    pub fn sync_params(&self) -> Result<()> {
161        for group in &self.param_groups {
162            let tensors: Vec<Tensor> = group.iter().map(|v| v.data()).collect();
163            let refs: Vec<&Tensor> = tensors.iter().collect();
164            self.comms.broadcast(&refs, 0)?;
165        }
166        self.sync_buffers()
167    }
168
169    /// Compute shard sizes from chunk ratios, guaranteeing they sum to batch_size.
170    pub fn compute_shard_sizes(&self, batch_size: i64) -> Vec<i64> {
171        let n = self.devices.len();
172        let mut sizes = Vec::with_capacity(n);
173        let mut remaining = batch_size;
174
175        for i in 0..n {
176            if i == n - 1 {
177                // Last device gets whatever is left
178                sizes.push(remaining);
179            } else {
180                let s = (batch_size as f64 * self.chunk_ratios[i]).round() as i64;
181                let s = s.max(1).min(remaining - (n - i - 1) as i64); // leave at least 1 per remaining device
182                sizes.push(s);
183                remaining -= s;
184            }
185        }
186
187        sizes
188    }
189
190    /// Number of devices.
191    pub fn world_size(&self) -> usize {
192        self.devices.len()
193    }
194
195    /// Whether chunk ratios are meaningfully unequal (need weighted gradients).
196    pub fn is_balanced(&self) -> bool {
197        let first = self.chunk_ratios[0];
198        self.chunk_ratios.iter().all(|r| (r - first).abs() < 1e-6)
199    }
200
201    /// AllReduce gradients with weighted averaging for unequal shard sizes.
202    ///
203    /// Each replica's gradient is scaled by `(shard_size / batch_size)` before
204    /// AllReduce Sum, which produces the correct mean gradient regardless of
205    /// how the batch was split.
206    pub fn weighted_all_reduce_gradients(&self, batch_size: i64) -> Result<()> {
207        for group in &self.param_groups {
208            if group[0].grad().is_none() {
209                continue;
210            }
211            let grads: Vec<Tensor> = group
212                .iter()
213                .enumerate()
214                .map(|(rank, v)| {
215                    let g = v.grad().expect("gradient missing on replica");
216                    let weight = self.last_shard_sizes[rank] as f64 / batch_size as f64;
217                    g.mul_scalar_(weight).ok();
218                    g
219                })
220                .collect();
221            let refs: Vec<&Tensor> = grads.iter().collect();
222            self.comms.all_reduce(&refs, ReduceOp::Sum)?;
223        }
224        Ok(())
225    }
226
227    /// Read timing from last forward pass, update EMA throughput, and
228    /// rebalance chunk ratios if it's time.
229    ///
230    /// Called from Graph::step() after gradient sync. Returns true if
231    /// chunk ratios were updated this step.
232    pub fn update_balance(&mut self) -> Result<bool> {
233        self.step_count += 1;
234
235        // Read timing events (set by forward_distributed)
236        if let Some(timing) = self.last_timing.take() {
237            for (rank, (start, end)) in timing.iter().enumerate() {
238                let ms = CudaEvent::elapsed_time(start, end)?;
239                if ms > 0.0 && self.last_shard_sizes[rank] > 0 {
240                    let throughput = self.last_shard_sizes[rank] as f64 / ms as f64;
241                    if self.ema_throughput[rank] == 0.0 {
242                        // First measurement: initialize directly
243                        self.ema_throughput[rank] = throughput;
244                    } else {
245                        self.ema_throughput[rank] =
246                            EMA_ALPHA * throughput + (1.0 - EMA_ALPHA) * self.ema_throughput[rank];
247                    }
248                }
249            }
250        }
251
252        // Check if it's time to rebalance
253        let should_rebalance = if self.step_count == self.calibration_steps {
254            true
255        } else if self.step_count > self.calibration_steps {
256            (self.step_count - self.calibration_steps) % self.rebalance_interval == 0
257        } else {
258            false
259        };
260
261        if should_rebalance {
262            self.rebalance();
263            return Ok(true);
264        }
265
266        Ok(false)
267    }
268
269    /// Recompute chunk_ratios proportional to EMA throughput.
270    fn rebalance(&mut self) {
271        let total: f64 = self.ema_throughput.iter().sum();
272        if total <= 0.0 {
273            return; // no data yet
274        }
275
276        let n = self.devices.len();
277        let min_total = MIN_CHUNK_RATIO * n as f64;
278
279        // Compute raw proportional ratios
280        let mut ratios: Vec<f64> = self.ema_throughput.iter().map(|t| t / total).collect();
281
282        // Clamp: no device below MIN_CHUNK_RATIO
283        let mut deficit = 0.0;
284        let mut unclamped = 0;
285        for r in &mut ratios {
286            if *r < MIN_CHUNK_RATIO {
287                deficit += MIN_CHUNK_RATIO - *r;
288                *r = MIN_CHUNK_RATIO;
289            } else {
290                unclamped += 1;
291            }
292        }
293
294        // Redistribute deficit from unclamped devices proportionally
295        if deficit > 0.0 && unclamped > 0 {
296            let unclamped_total: f64 = ratios
297                .iter()
298                .filter(|&&r| r > MIN_CHUNK_RATIO + 1e-9)
299                .sum();
300            if unclamped_total > min_total {
301                for r in &mut ratios {
302                    if *r > MIN_CHUNK_RATIO + 1e-9 {
303                        *r -= deficit * (*r / unclamped_total);
304                        *r = r.max(MIN_CHUNK_RATIO);
305                    }
306                }
307            }
308        }
309
310        // Normalize to sum exactly to 1.0
311        let sum: f64 = ratios.iter().sum();
312        if sum > 0.0 {
313            for r in &mut ratios {
314                *r /= sum;
315            }
316        }
317
318        self.chunk_ratios = ratios;
319    }
320
321    /// Configure El Che cadence from a [`DdpConfig`].
322    ///
323    /// Creates an internal ElChe when enabled (max_anchor != Some(0)),
324    /// seeds chunk_ratios from speed_hint if provided.
325    pub(crate) fn configure_el_che(&mut self, config: &DdpConfig) {
326        let n = self.devices.len();
327        if n < 2 {
328            return;
329        }
330
331        // max_anchor = Some(0) → disabled (traditional DDP)
332        if config.max_anchor == Some(0) {
333            self.el_che = None;
334            return;
335        }
336
337        // Build ElChe with sensible defaults
338        let anchor = 10; // initial anchor, auto-tunes from timing
339        let mut el_che = ElChe::new(n, anchor);
340
341        if let Some(target) = config.overhead_target {
342            el_che = el_che.with_overhead_target(target);
343        }
344        if let Some(max) = config.max_anchor {
345            el_che = el_che.with_max_anchor(max);
346        }
347        if let Some((slow_rank, ratio)) = config.speed_hint {
348            el_che = el_che.with_speed_ratio(slow_rank, ratio);
349            // Also seed chunk_ratios for the existing auto-balancer
350            self.apply_speed_hint(slow_rank, ratio);
351        }
352
353        self.el_che = Some(el_che);
354        self.max_grad_norm = config.max_grad_norm;
355    }
356
357    /// Seed chunk_ratios from a speed hint.
358    fn apply_speed_hint(&mut self, slow_rank: usize, ratio: f64) {
359        let n = self.devices.len();
360        if slow_rank >= n {
361            return;
362        }
363        let ratio = ratio.max(1.0);
364        let mut weights = vec![ratio; n];
365        weights[slow_rank] = 1.0;
366        let total: f64 = weights.iter().sum();
367        self.chunk_ratios = weights.iter().map(|w| w / total).collect();
368    }
369}
370
371// ---------------------------------------------------------------------------
372// Manual DDP coordinator
373// ---------------------------------------------------------------------------
374
375/// Manual DDP coordinator for multi-GPU gradient sync.
376///
377/// For complex training patterns (GAN, RL, progressive) where transparent
378/// Graph-level DDP doesn't fit. Provides explicit control over parameter
379/// broadcast and gradient averaging.
380///
381/// For standard training, use [`crate::graph::Graph::distribute`] instead.
382pub struct Ddp {
383    comms: NcclComms,
384    devices: Vec<Device>,
385    param_groups: Vec<Vec<Variable>>,
386    buffer_groups: Vec<Vec<Buffer>>,
387}
388
389impl Ddp {
390    /// Wrap pre-created model replicas for manual DDP control.
391    ///
392    /// Models must have identical architecture (same parameter count/shapes).
393    /// Each model should already reside on its target device.
394    pub fn wrap(models: &[&dyn Module], devices: &[Device]) -> Result<Self> {
395        if models.len() < 2 {
396            return Err(TensorError::new("Ddp::wrap requires at least 2 models"));
397        }
398        if models.len() != devices.len() {
399            return Err(TensorError::new(
400                "Ddp::wrap: model count must match device count",
401            ));
402        }
403
404        let comms = NcclComms::new(devices)?;
405
406        // Match parameters across models
407        let all_params: Vec<Vec<Parameter>> =
408            models.iter().map(|m| m.parameters()).collect();
409        let n_params = all_params[0].len();
410        for (rank, params) in all_params.iter().enumerate().skip(1) {
411            if params.len() != n_params {
412                return Err(TensorError::new(&format!(
413                    "Ddp: replica {} has {} parameters, expected {}",
414                    rank,
415                    params.len(),
416                    n_params
417                )));
418            }
419        }
420
421        let mut param_groups = Vec::with_capacity(n_params);
422        for pi in 0..n_params {
423            let group: Vec<Variable> =
424                all_params.iter().map(|p| p[pi].variable.clone()).collect();
425            param_groups.push(group);
426        }
427
428        // Match buffers
429        let all_buffers: Vec<Vec<Buffer>> =
430            models.iter().map(|m| m.buffers()).collect();
431        let n_buffers = all_buffers[0].len();
432        let mut buffer_groups = Vec::with_capacity(n_buffers);
433        for bi in 0..n_buffers {
434            let group: Vec<Buffer> =
435                all_buffers.iter().map(|b| b[bi].clone()).collect();
436            buffer_groups.push(group);
437        }
438
439        Ok(Ddp {
440            comms,
441            devices: devices.to_vec(),
442            param_groups,
443            buffer_groups,
444        })
445    }
446
447    /// Broadcast all parameters and buffers from rank 0 to all replicas.
448    pub fn sync_params(&self) -> Result<()> {
449        for group in &self.param_groups {
450            let tensors: Vec<Tensor> = group.iter().map(|v| v.data()).collect();
451            let refs: Vec<&Tensor> = tensors.iter().collect();
452            self.comms.broadcast(&refs, 0)?;
453        }
454        for group in &self.buffer_groups {
455            let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
456            let refs: Vec<&Tensor> = tensors.iter().collect();
457            self.comms.broadcast(&refs, 0)?;
458        }
459        Ok(())
460    }
461
462    /// AllReduce-average gradients across all replicas.
463    /// Call after backward(), before optimizer.step().
464    pub fn all_reduce_gradients(&self) -> Result<()> {
465        for group in &self.param_groups {
466            if group[0].grad().is_none() {
467                continue;
468            }
469            let grads: Vec<Tensor> = group
470                .iter()
471                .map(|v| v.grad().expect("gradient missing on replica"))
472                .collect();
473            let refs: Vec<&Tensor> = grads.iter().collect();
474            self.comms.all_reduce(&refs, ReduceOp::Avg)?;
475        }
476        Ok(())
477    }
478
479    /// Broadcast buffers from rank 0 (BatchNorm running stats etc).
480    pub fn sync_buffers(&self) -> Result<()> {
481        for group in &self.buffer_groups {
482            let tensors: Vec<Tensor> = group.iter().map(|b| b.get()).collect();
483            let refs: Vec<&Tensor> = tensors.iter().collect();
484            self.comms.broadcast(&refs, 0)?;
485        }
486        Ok(())
487    }
488
489    /// AllReduce gradients weighted by per-device batch contribution.
490    ///
491    /// For heterogeneous DDP where devices process different numbers of
492    /// batches per sync step. Each replica's gradient is scaled by
493    /// `(batch_counts[rank] / total)` before AllReduce Sum, producing
494    /// the correct mean gradient.
495    ///
496    /// Use with [`ElChe::batch_counts`] for automatic weighting
497    /// (see [`ElChe`] for the full heterogeneous DDP strategy):
498    ///
499    /// ```ignore
500    /// ddp.weighted_all_reduce_gradients(cadence.batch_counts())?;
501    /// ```
502    pub fn weighted_all_reduce_gradients(&self, batch_counts: &[usize]) -> Result<()> {
503        if batch_counts.len() != self.devices.len() {
504            return Err(TensorError::new(&format!(
505                "weighted_all_reduce: batch_counts len ({}) != device count ({})",
506                batch_counts.len(),
507                self.devices.len(),
508            )));
509        }
510        let total: usize = batch_counts.iter().sum();
511        if total == 0 {
512            return Err(TensorError::new("weighted_all_reduce: total batch count is 0"));
513        }
514        for group in &self.param_groups {
515            if group[0].grad().is_none() {
516                continue;
517            }
518            let grads: Vec<Tensor> = group
519                .iter()
520                .enumerate()
521                .map(|(rank, v)| {
522                    let g = v.grad().expect("gradient missing on replica");
523                    let weight = batch_counts[rank] as f64 / total as f64;
524                    g.mul_scalar_(weight).ok();
525                    g
526                })
527                .collect();
528            let refs: Vec<&Tensor> = grads.iter().collect();
529            self.comms.all_reduce(&refs, ReduceOp::Sum)?;
530        }
531        Ok(())
532    }
533
534    /// Number of devices.
535    pub fn world_size(&self) -> usize {
536        self.devices.len()
537    }
538
539    /// Devices in use.
540    pub fn devices(&self) -> &[Device] {
541        &self.devices
542    }
543
544    // --- Deprecated aliases: use Trainer:: as the primary entry point ---
545
546    /// Deprecated: use [`Trainer::setup()`] instead.
547    ///
548    /// [`Trainer`] is now the primary training entry point and carries the
549    /// same behavior for 1 or N GPUs. [`Ddp`] remains for explicit
550    /// multi-GPU control via [`Ddp::wrap`].
551    #[deprecated(note = "use Trainer::setup() - same behavior. Ddp::setup will be removed in a future release.")]
552    pub fn setup<F, M, G, O>(
553        model: &Graph,
554        builder: F,
555        optimizer: G,
556    ) -> Result<()>
557    where
558        F: Fn(Device) -> Result<M>,
559        M: Module + 'static,
560        G: Fn(&[Parameter]) -> O,
561        O: Optimizer + 'static,
562    {
563        Trainer::setup(model, builder, optimizer)
564    }
565
566    /// Deprecated: use [`Trainer::setup_with()`] instead.
567    #[deprecated(note = "use Trainer::setup_with() - same behavior. Ddp::setup_with will be removed in a future release.")]
568    pub fn setup_with<F, M, G, O>(
569        model: &Graph,
570        builder: F,
571        optimizer: G,
572        config: DdpConfig,
573    ) -> Result<()>
574    where
575        F: Fn(Device) -> Result<M>,
576        M: Module + 'static,
577        G: Fn(&[Parameter]) -> O,
578        O: Optimizer + 'static,
579    {
580        Trainer::setup_with(model, builder, optimizer, config)
581    }
582
583    /// Deprecated: renamed to [`Trainer::setup()`].
584    #[deprecated(since = "0.3.0", note = "Renamed to Trainer::setup()")]
585    pub fn auto<F, M, G, O>(
586        model: &Graph,
587        builder: F,
588        optimizer: G,
589    ) -> Result<()>
590    where
591        F: Fn(Device) -> Result<M>,
592        M: Module + 'static,
593        G: Fn(&[Parameter]) -> O,
594        O: Optimizer + 'static,
595    {
596        Trainer::setup(model, builder, optimizer)
597    }
598
599    /// Deprecated: renamed to [`Trainer::setup_with()`].
600    #[deprecated(since = "0.3.0", note = "Renamed to Trainer::setup_with()")]
601    pub fn auto_with<F, M, G, O>(
602        model: &Graph,
603        builder: F,
604        optimizer: G,
605        config: DdpConfig,
606    ) -> Result<()>
607    where
608        F: Fn(Device) -> Result<M>,
609        M: Module + 'static,
610        G: Fn(&[Parameter]) -> O,
611        O: Optimizer + 'static,
612    {
613        Trainer::setup_with(model, builder, optimizer, config)
614    }
615
616    // -------------------------------------------------------------------
617    // Deprecated builder entry: use Trainer::builder instead
618    // -------------------------------------------------------------------
619
620    /// Deprecated: use [`Trainer::builder()`] instead.
621    ///
622    /// [`Trainer`] is the primary training entry point and works
623    /// transparently for single-GPU and multi-GPU. This alias is retained
624    /// for backwards compatibility and will be removed in a future release.
625    #[deprecated(note = "use Trainer::builder() - same behavior. Ddp::builder will be removed in a future release.")]
626    pub fn builder<F, M, G, O, T>(
627        model_factory: F,
628        optim_factory: G,
629        train_fn: T,
630    ) -> DdpBuilder<F, M, G, O, T>
631    where
632        F: Fn(Device) -> Result<M> + Send + Sync + 'static,
633        M: Module + 'static,
634        G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
635        O: Optimizer + 'static,
636        T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
637    {
638        Trainer::builder(model_factory, optim_factory, train_fn)
639    }
640
641    /// Detect whether the current CUDA setup has different GPU models.
642    fn is_heterogeneous() -> bool {
643        use crate::tensor::{cuda_available, cuda_device_count, cuda_device_name_idx};
644        if !cuda_available() || cuda_device_count() < 2 {
645            return false;
646        }
647        let n = cuda_device_count();
648        let names: Vec<Option<String>> = (0..n)
649            .map(cuda_device_name_idx)
650            .collect();
651        names.windows(2).any(|w| w[0] != w[1])
652    }
653
654    /// Print a diagnostic summary of detected CUDA devices to stderr.
655    fn print_device_summary() {
656        use crate::tensor::{
657            cuda_available, cuda_device_count,
658            cuda_device_name_idx, cuda_memory_info_idx,
659        };
660        use crate::monitor::format_bytes;
661
662        if !cuda_available() || cuda_device_count() == 0 {
663            crate::verbose!("  ddp: no CUDA available | CPU mode");
664            return;
665        }
666
667        let n = cuda_device_count();
668        let mut names = Vec::with_capacity(n as usize);
669        let mut parts = Vec::with_capacity(n as usize);
670
671        for i in 0..n {
672            let raw_name = cuda_device_name_idx(i)
673                .unwrap_or_else(|| format!("CUDA({})", i));
674            let short = raw_name
675                .strip_prefix("NVIDIA ")
676                .unwrap_or(&raw_name)
677                .to_string();
678            let vram = cuda_memory_info_idx(i)
679                .map(|(_, total)| format!(" ({})", format_bytes(total)))
680                .unwrap_or_default();
681            parts.push(format!("{}{}", short, vram));
682            names.push(raw_name);
683        }
684
685        let heterogeneous = names.windows(2).any(|w| w[0] != w[1]);
686
687        if n == 1 {
688            crate::verbose!("  ddp: 1 GPU | {} | single-device mode", parts[0]);
689        } else if heterogeneous {
690            crate::verbose!(
691                "  ddp: {} GPUs (heterogeneous) | {}",
692                n,
693                parts.join(" | "),
694            );
695        } else {
696            crate::verbose!("  ddp: {} GPUs | {}", n, parts.join(" | "));
697        }
698    }
699}
700
701// ---------------------------------------------------------------------------
702// Trainer: primary training entry point
703// ---------------------------------------------------------------------------
704
705/// Primary entry point for training in flodl.
706///
707/// `Trainer` is the default API for training a model, whether you have one
708/// GPU, many GPUs, or no GPU at all. The training loop is identical in all
709/// cases: [`Trainer::setup`] (or [`Trainer::builder`]) configures the model,
710/// detects the hardware, and enables distributed training automatically when
711/// multiple CUDA devices are available. On a single GPU or CPU it's a no-op
712/// wrapper with zero DDP overhead.
713///
714/// For explicit multi-GPU control (manual gradient sync, custom replica
715/// wrapping) use [`Ddp`] directly. [`Ddp::wrap`] remains the entry point for
716/// advanced patterns (GAN, RL, progressive).
717///
718/// # Setup mode (user owns the loop)
719///
720/// ```ignore
721/// Trainer::setup(&model, |dev| build_model(dev), |p| Adam::new(p, 0.001))?;
722///
723/// for (x, y) in &train_loader {
724///     let out = model.forward(&x)?;
725///     let loss = cross_entropy_loss(&out, &y)?;
726///     loss.backward()?;
727///     model.step()?;
728/// }
729/// ```
730///
731/// # Builder mode (framework owns the loop)
732///
733/// ```ignore
734/// let handle = Trainer::builder(model_factory, optim_factory, train_fn)
735///     .dataset(dataset)
736///     .batch_size(32)
737///     .num_epochs(10)
738///     .run()?;
739///
740/// let state = handle.join()?;
741/// ```
742pub struct Trainer;
743
744impl Trainer {
745    /// One-call setup: auto-detect GPUs, distribute the model, set the
746    /// optimizer, and enable training mode.
747    ///
748    /// - **Multi-GPU** (2+ usable CUDA devices): replicates via
749    ///   [`Graph::distribute`], creates per-replica optimizers, enables training.
750    /// - **Single-GPU / CPU**: sets optimizer and training mode only (no DDP
751    ///   overhead).
752    ///
753    /// Always prints a diagnostic summary to stderr showing detected hardware.
754    ///
755    /// ```ignore
756    /// Trainer::setup(&model, |dev| build_model(dev), |p| Adam::new(p, 0.001))?;
757    ///
758    /// for batch in model.epoch(epoch).activate() {
759    ///     let out = model.forward_batch(&batch?)?;
760    ///     loss.backward()?;
761    ///     model.step()?;
762    /// }
763    /// ```
764    pub fn setup<F, M, G, O>(
765        model: &Graph,
766        builder: F,
767        optimizer: G,
768    ) -> Result<()>
769    where
770        F: Fn(Device) -> Result<M>,
771        M: Module + 'static,
772        G: Fn(&[Parameter]) -> O,
773        O: Optimizer + 'static,
774    {
775        Ddp::print_device_summary();
776        model.distribute(builder)?;
777        model.set_optimizer(optimizer);
778        model.set_training(true);
779
780        // Auto-enable El Che for heterogeneous GPU setups
781        if Ddp::is_heterogeneous() {
782            model.configure_el_che(&DdpConfig::new());
783        }
784
785        Ok(())
786    }
787
788    /// One-call setup with explicit configuration.
789    ///
790    /// Like [`setup()`](Self::setup) but accepts a [`DdpConfig`] for
791    /// controlling El Che cadence, speed hints, and overhead targets.
792    ///
793    /// ```ignore
794    /// Trainer::setup_with(&model, builder, optimizer,
795    ///     DdpConfig::new().speed_hint(1, 2.3))?;
796    /// ```
797    pub fn setup_with<F, M, G, O>(
798        model: &Graph,
799        builder: F,
800        optimizer: G,
801        config: DdpConfig,
802    ) -> Result<()>
803    where
804        F: Fn(Device) -> Result<M>,
805        M: Module + 'static,
806        G: Fn(&[Parameter]) -> O,
807        O: Optimizer + 'static,
808    {
809        Ddp::print_device_summary();
810        model.distribute(builder)?;
811        model.set_optimizer(optimizer);
812        model.set_training(true);
813        model.configure_el_che(&config);
814        // Pass timeline to distributed state for event injection in step().
815        if let Some(tl) = config.timeline {
816            if let Some(ref mut state) = *model.distributed.borrow_mut() {
817                state.timeline = Some(tl);
818            }
819        }
820        Ok(())
821    }
822
823    /// Create a builder for framework-managed training.
824    ///
825    /// The framework owns the training loop, data pipeline, and epoch
826    /// management. On multi-GPU hardware, each device gets its own model
827    /// replica and optimizer, and a coordinator triggers periodic
828    /// parameter averaging based on the configured [`ApplyPolicy`] and
829    /// [`AverageBackend`]. On a single GPU, training runs on the main
830    /// thread with no coordination - the API is identical in both cases.
831    ///
832    /// Returns a [`DdpBuilder`] for fluent configuration. Call `.run()` to
833    /// spawn training, then `.join()` on the returned [`DdpHandle`] to
834    /// block until completion.
835    ///
836    /// [`ApplyPolicy`]: crate::distributed::ApplyPolicy
837    /// [`AverageBackend`]: crate::distributed::AverageBackend
838    ///
839    /// # Example
840    ///
841    /// ```ignore
842    /// use flodl::*;
843    ///
844    /// let handle = Trainer::builder(
845    ///     |dev| model_factory(dev),
846    ///     |params| Adam::new(params, 0.001),
847    ///     |model, batch| { /* forward + loss */ },
848    /// )
849    /// .dataset(dataset)
850    /// .batch_size(32)
851    /// .num_epochs(10)
852    /// .policy(ApplyPolicy::Cadence)
853    /// .backend(AverageBackend::Nccl)
854    /// .run()?;
855    ///
856    /// let state = handle.join()?;
857    /// ```
858    pub fn builder<F, M, G, O, T>(
859        model_factory: F,
860        optim_factory: G,
861        train_fn: T,
862    ) -> DdpBuilder<F, M, G, O, T>
863    where
864        F: Fn(Device) -> Result<M> + Send + Sync + 'static,
865        M: Module + 'static,
866        G: Fn(&[Parameter]) -> O + Send + Sync + 'static,
867        O: Optimizer + 'static,
868        T: Fn(&M, &[Tensor]) -> Result<Variable> + Send + Sync + 'static,
869    {
870        DdpHandle::new_builder(model_factory, optim_factory, train_fn)
871    }
872
873    /// One-call setup for a task-head wrapper (e.g. `flodl-hf`'s
874    /// `BertForSequenceClassification`). The wrapper must implement
875    /// [`HasGraph`] so `Trainer` can reach the underlying [`Graph`].
876    ///
877    /// Semantics match [`Trainer::setup`] exactly; the only difference is
878    /// that `head_factory` builds a fresh wrapper (not a bare `Graph`) on
879    /// each replica device. Useful when the training-loop code holds onto
880    /// the wrapper's richer surface (`compute_loss`, `predict`, attached
881    /// tokenizer) but still wants transparent 1-or-N-GPU DDP.
882    ///
883    /// ```ignore
884    /// let head = DistilBertForSequenceClassification::from_pretrained(repo)?;
885    /// let config = head.config().clone();
886    /// let num_labels = head.labels().len() as i64;
887    ///
888    /// Trainer::setup_head(
889    ///     &head,
890    ///     move |dev| DistilBertForSequenceClassification::on_device(&config, num_labels, dev),
891    ///     |p| Adam::new(p, 5e-5),
892    /// )?;
893    ///
894    /// for (enc, labels) in &batches {
895    ///     let loss = head.compute_loss(&enc, &labels)?;
896    ///     loss.backward()?;
897    ///     head.graph().step()?;
898    /// }
899    /// ```
900    pub fn setup_head<H, F, G, O>(
901        head: &H,
902        head_factory: F,
903        optimizer: G,
904    ) -> Result<()>
905    where
906        H: HasGraph,
907        F: Fn(Device) -> Result<H> + 'static,
908        H: 'static,
909        G: Fn(&[Parameter]) -> O,
910        O: Optimizer + 'static,
911    {
912        Ddp::print_device_summary();
913        let graph = head.graph();
914        graph.distribute(move |dev| head_factory(dev).map(|h| HeadReplica { head: h }))?;
915        graph.set_optimizer(optimizer);
916        graph.set_training(true);
917
918        if Ddp::is_heterogeneous() {
919            graph.configure_el_che(&DdpConfig::new());
920        }
921
922        Ok(())
923    }
924
925    /// Task-head variant of [`Trainer::setup_with`]. Same behaviour as
926    /// [`Trainer::setup_head`] but takes an explicit [`DdpConfig`].
927    pub fn setup_head_with<H, F, G, O>(
928        head: &H,
929        head_factory: F,
930        optimizer: G,
931        config: DdpConfig,
932    ) -> Result<()>
933    where
934        H: HasGraph,
935        F: Fn(Device) -> Result<H> + 'static,
936        H: 'static,
937        G: Fn(&[Parameter]) -> O,
938        O: Optimizer + 'static,
939    {
940        Ddp::print_device_summary();
941        let graph = head.graph();
942        graph.distribute(move |dev| head_factory(dev).map(|h| HeadReplica { head: h }))?;
943        graph.set_optimizer(optimizer);
944        graph.set_training(true);
945        graph.configure_el_che(&config);
946        if let Some(tl) = config.timeline {
947            if let Some(ref mut state) = *graph.distributed.borrow_mut() {
948                state.timeline = Some(tl);
949            }
950        }
951        Ok(())
952    }
953}
954
955// ---------------------------------------------------------------------------
956// HasGraph trait: lets wrapper types plug into Trainer::setup_head
957// ---------------------------------------------------------------------------
958
959/// A wrapper type that exposes an inner [`Graph`].
960///
961/// Implement on any wrapper around a `Graph` that should participate in
962/// [`Trainer::setup_head`] or other graph-aware DDP machinery. The
963/// reference returned must outlive `&self` and point at the same graph
964/// used for the wrapper's forward / loss calls.
965///
966/// [`Graph`] implements this trivially (returns `self`) so bare-graph
967/// callers can pass a `&Graph` wherever `&impl HasGraph` is accepted.
968///
969/// ```ignore
970/// impl HasGraph for BertForSequenceClassification {
971///     fn graph(&self) -> &Graph { &self.graph }
972/// }
973/// ```
974pub trait HasGraph {
975    /// Borrow the inner training graph.
976    fn graph(&self) -> &Graph;
977}
978
979impl HasGraph for Graph {
980    fn graph(&self) -> &Graph { self }
981}
982
983/// Internal Module adapter used by [`Trainer::setup_head`] to feed a
984/// `HasGraph` replica through [`Graph::distribute`].
985///
986/// `distribute` boxes each replica as `Box<dyn Module>`. Task-head
987/// wrappers don't implement `Module` directly (their true forward is
988/// multi-input via [`Graph::forward_multi`], which doesn't fit the
989/// single-Variable `Module::forward` signature). `HeadReplica` delegates
990/// every Module method through to the inner graph and overrides
991/// [`Module::as_graph`] so DDP's multi-input replica paths downcast
992/// cleanly rather than hitting the single-input fallback.
993struct HeadReplica<H: HasGraph + 'static> {
994    head: H,
995}
996
997impl<H: HasGraph + 'static> Module for HeadReplica<H> {
998    fn forward(&self, input: &Variable) -> Result<Variable> {
999        // Single-input fallback. Task-head DDP paths reach
1000        // forward_multi via `as_graph()` below, so this is only
1001        // exercised on single-input replica paths (e.g. the scatter
1002        // forward in `forward_distributed_scatter`). For multi-input
1003        // heads that path is never triggered because the user calls
1004        // the head's own `compute_loss` / `forward_encoded`, which
1005        // route through `Graph::forward_multi` directly.
1006        self.head.graph().forward(input)
1007    }
1008    fn parameters(&self) -> Vec<Parameter> { self.head.graph().parameters() }
1009    fn buffers(&self) -> Vec<Buffer> { self.head.graph().buffers() }
1010    fn name(&self) -> &str { "head_replica" }
1011    fn set_training(&self, training: bool) { self.head.graph().set_training(training); }
1012    fn as_graph(&self) -> Option<&Graph> { Some(self.head.graph()) }
1013}
1014
1015// ---------------------------------------------------------------------------
1016// DDP configuration
1017// ---------------------------------------------------------------------------
1018
1019/// Configuration for [`Trainer::setup_with()`].
1020///
1021/// Controls El Che cadence behavior for heterogeneous multi-GPU training.
1022/// Use [`DdpConfig::new()`] for defaults or build with method chaining.
1023///
1024/// ```ignore
1025/// Trainer::setup_with(&model, builder, optimizer,
1026///     DdpConfig::new()
1027///         .speed_hint(1, 2.3)     // rank 1 is slow, 2.3x ratio
1028///         .overhead_target(0.08)  // tune to 8% overhead
1029/// )?;
1030/// ```
1031#[derive(Debug, Clone)]
1032pub struct DdpConfig {
1033    /// Initial speed ratio hint: (slow_rank, fast_to_slow_ratio).
1034    /// Applied before the first timing measurement.
1035    pub speed_hint: Option<(usize, f64)>,
1036    /// AllReduce overhead target for anchor auto-tune (default: 0.10).
1037    pub overhead_target: Option<f64>,
1038    /// Max batches on slow device before AllReduce.
1039    /// - `None` = auto (El Che decides, default).
1040    /// - `Some(0)` = disabled (traditional per-batch DDP, no El Che).
1041    /// - `Some(n)` = fixed anchor at n.
1042    pub max_anchor: Option<usize>,
1043    /// Maximum gradient norm for per-rank clipping in El Che mode.
1044    ///
1045    /// When set, each rank's accumulated gradients are clipped (L2 norm)
1046    /// before the normalize-by-count and weighted AllReduce steps. This
1047    /// ensures replica gradients (which the caller cannot reach) are bounded
1048    /// identically to rank 0.
1049    ///
1050    /// Standard DDP does not need this because the caller clips rank 0's
1051    /// gradients and AllReduce averages them.
1052    pub max_grad_norm: Option<f64>,
1053    /// Optional system timeline for high-frequency profiling.
1054    pub timeline: Option<std::sync::Arc<crate::monitor::Timeline>>,
1055}
1056
1057impl DdpConfig {
1058    /// Default configuration: El Che auto-enabled for heterogeneous GPUs.
1059    pub fn new() -> Self {
1060        DdpConfig {
1061            speed_hint: None,
1062            overhead_target: None,
1063            max_anchor: None,
1064            max_grad_norm: None,
1065            timeline: None,
1066        }
1067    }
1068
1069    /// Set initial speed ratio hint.
1070    ///
1071    /// `slow_rank`: which device is slowest.
1072    /// `ratio`: how many times faster the fastest device is (e.g., 2.3).
1073    ///
1074    /// After the first AllReduce, El Che discovers actual speeds and
1075    /// self-corrects even a wrong guess.
1076    pub fn speed_hint(mut self, slow_rank: usize, ratio: f64) -> Self {
1077        self.speed_hint = Some((slow_rank, ratio));
1078        self
1079    }
1080
1081    /// Set AllReduce overhead target (fraction of compute time).
1082    ///
1083    /// Default: 0.10 (10%). Lower values = fewer AllReduces = more
1084    /// gradient accumulation. El Che auto-tunes the anchor to stay
1085    /// below this target.
1086    pub fn overhead_target(mut self, target: f64) -> Self {
1087        self.overhead_target = Some(target.clamp(0.01, 0.50));
1088        self
1089    }
1090
1091    /// Set max batches on slow device before AllReduce.
1092    ///
1093    /// - `None` (default): El Che auto-tunes from overhead measurement.
1094    /// - `Some(0)`: disable El Che entirely (traditional per-batch sync).
1095    /// - `Some(n)`: fixed anchor at n (fast device gets proportionally more).
1096    pub fn max_anchor(mut self, max: Option<usize>) -> Self {
1097        self.max_anchor = max;
1098        self
1099    }
1100
1101    /// Set maximum gradient norm for per-rank clipping in El Che mode.
1102    ///
1103    /// When set, each rank's accumulated gradients are clipped to this L2
1104    /// norm before normalize-by-count and AllReduce. Essential for
1105    /// heterogeneous DDP where replica gradients are otherwise unreachable
1106    /// by the caller.
1107    pub fn max_grad_norm(mut self, max_norm: f64) -> Self {
1108        self.max_grad_norm = Some(max_norm);
1109        self
1110    }
1111
1112    /// Attach a system timeline for high-frequency profiling.
1113    pub fn timeline(mut self, tl: std::sync::Arc<crate::monitor::Timeline>) -> Self {
1114        self.timeline = Some(tl);
1115        self
1116    }
1117}
1118
1119impl Default for DdpConfig {
1120    fn default() -> Self {
1121        Self::new()
1122    }
1123}
1124
1125#[cfg(test)]
1126#[path = "ddp_tests.rs"]
1127mod tests;