Skip to main content

entrenar/train/transformer_trainer/
distributed_trainer.rs

1//! Distributed CUDA trainer for data-parallel pretraining.
2//!
3//! Wraps `CudaTransformerTrainer` with per-block gradient accumulation
4//! and AllReduce communication for multi-GPU / multi-node DDP.
5//!
6//! # Architecture
7//!
8//! ```text
9//! DistributedCudaTrainer
10//! ├── trainer: CudaTransformerTrainer     (local GPU training)
11//! ├── comm: DistributedComm               (local channels or TCP)
12//! └── dist_config: DistributedTrainConfig
13//! ```
14//!
15//! # Training Step (DDP)
16//!
17//! 1. Each worker runs forward+backward on its data shard (accumulate_only=true)
18//! 2. Pre-average local gradients (divide by accumulated_count)
19//! 3. Per-block AllReduce (reverse order):
20//!    - Send block gradients to coordinator
21//!    - Receive averaged gradients
22//!    - Overwrite local accum
23//! 4. AllReduce non-block: LM head, final norm, embedding
24//! 5. Upload averaged gradients to GPU, run optimizer step
25//!
26//! # Contract
27//!
28//! C-DDP-001: After AllReduce + optimizer step, all workers hold identical weights.
29
30#[cfg(feature = "cuda")]
31use std::sync::mpsc;
32
33#[cfg(feature = "cuda")]
34use super::config::DistributedTrainConfig;
35#[cfg(feature = "cuda")]
36use super::grad_accumulator::BlockGradientSet;
37
38/// Communication backend for distributed training.
39#[cfg(feature = "cuda")]
40pub enum DistributedComm {
41    /// Single-machine multi-GPU via crossbeam channels.
42    ///
43    /// Each worker has a send/recv pair for each other worker.
44    /// AllReduce is done by sending gradients to rank 0,
45    /// averaging, and broadcasting back.
46    Local {
47        /// Send gradient to coordinator
48        tx: mpsc::Sender<GradientMessage>,
49        /// Receive averaged gradient from coordinator
50        rx: mpsc::Receiver<GradientMessage>,
51    },
52    /// Multi-node via TCP using the existing WorkerClient/GradientServer.
53    Remote {
54        /// TCP client for gradient exchange
55        client: crate::finetune::WorkerClient,
56    },
57}
58
59/// Message types for local (channel-based) gradient exchange.
60#[cfg(feature = "cuda")]
61#[derive(Debug)]
62pub enum GradientMessage {
63    /// Per-block gradient from a worker
64    BlockGradient { block_idx: usize, gradients: Vec<f32>, component_sizes: Vec<u32> },
65    /// Averaged per-block gradient from coordinator
66    AveragedBlockGradient { block_idx: usize, gradients: Vec<f32>, component_sizes: Vec<u32> },
67    /// Non-block gradient (LM head, final norm, embedding)
68    NonBlockGradient { component: u8, gradients: Vec<f32> },
69    /// Averaged non-block gradient
70    AveragedNonBlockGradient { component: u8, gradients: Vec<f32> },
71    /// Synchronization barrier
72    Barrier,
73}
74
75/// Distributed CUDA trainer for data-parallel pretraining.
76///
77/// Wraps a single-GPU `CudaTransformerTrainer` with communication
78/// and gradient averaging logic. The actual CUDA operations remain
79/// in the underlying trainer — this layer only handles:
80///
81/// 1. Running forward+backward in accumulate-only mode
82/// 2. Pre-averaging local gradients
83/// 3. AllReducing gradients across workers
84/// 4. Applying averaged gradients via the underlying trainer
85///
86/// # Safety
87///
88/// C-STREAMSYNC-001 applies: stream.synchronize() before all D2H transfers
89/// is handled by the underlying CudaTransformerTrainer.
90#[cfg(feature = "cuda")]
91pub struct DistributedCudaTrainer {
92    /// Underlying single-GPU trainer
93    trainer: super::cuda_trainer::CudaTransformerTrainer,
94    /// Communication backend
95    comm: DistributedComm,
96    /// Distributed configuration
97    dist_config: DistributedTrainConfig,
98    /// Current training step
99    step: usize,
100}
101
102#[cfg(feature = "cuda")]
103impl DistributedCudaTrainer {
104    /// Create a new distributed trainer.
105    ///
106    /// Ensures the underlying trainer has gradient accumulation buffers
107    /// (required for DDP even with accumulation_steps=1, since gradients
108    /// must be downloaded to CPU for AllReduce).
109    ///
110    /// # Arguments
111    /// * `trainer` - Pre-initialized single-GPU trainer (with ensure_grad_accum called)
112    /// * `comm` - Communication backend (local channels or TCP)
113    /// * `dist_config` - Distributed training configuration
114    pub fn new(
115        mut trainer: super::cuda_trainer::CudaTransformerTrainer,
116        comm: DistributedComm,
117        dist_config: DistributedTrainConfig,
118    ) -> Self {
119        // DDP always needs grad accum buffers for CPU-side AllReduce
120        trainer.ensure_grad_accum();
121
122        Self { trainer, comm, dist_config, step: 0 }
123    }
124
125    /// DDP training step: forward+backward → AllReduce → optimizer.
126    ///
127    /// 1. Local forward+backward (accumulate_only=true → grads to CPU accum)
128    /// 2. Pre-average local gradients
129    /// 3. AllReduce all gradient components across workers
130    /// 4. Apply averaged gradients (upload to GPU + optimizer step)
131    ///
132    /// Returns average loss for this worker's batch.
133    pub fn train_batch(&mut self, batch: &super::batch::LMBatch) -> f32 {
134        // 1. Local forward+backward (accumulate only)
135        let loss = self.trainer.forward_backward_batch(batch);
136
137        // 2-3. Pre-average and AllReduce
138        let step = self.step as u64;
139        Self::allreduce_impl(step, &self.comm, &mut self.trainer);
140
141        // 4. Apply averaged gradients
142        self.trainer.apply_ddp_gradients();
143
144        self.step += 1;
145        loss
146    }
147
148    /// Pre-average local gradients, then AllReduce across workers.
149    ///
150    /// Separated as a static method to satisfy the borrow checker:
151    /// `comm` and `trainer` are disjoint fields borrowed independently.
152    fn allreduce_impl(
153        step: u64,
154        comm: &DistributedComm,
155        trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
156    ) {
157        // Phase 0: Pre-average local gradients before AllReduce.
158        // Each worker divides by its local accumulated_count so the coordinator
159        // averages per-sample means (not raw sums). This ensures C-DDP-001 even
160        // if workers process different numbers of valid sequences.
161        let local_count = {
162            let accum = trainer.grad_accum_mut().unwrap();
163            let count = accum.accumulated_count;
164            accum.average(); // divides block + non-block grads by count
165            count
166        };
167        // Average embedding grad separately (lives in CPU model, not in accum)
168        if local_count > 1 {
169            if let Some(mut eg) = trainer.embed_grad_vec() {
170                let inv = 1.0 / local_count as f32;
171                for g in &mut eg {
172                    *g *= inv;
173                }
174                trainer.set_embed_grad(eg);
175            }
176        }
177
178        // Phase 1-3: AllReduce via configured transport
179        match comm {
180            DistributedComm::Remote { client } => {
181                Self::allreduce_remote(step, client, trainer);
182            }
183            DistributedComm::Local { tx, rx } => {
184                Self::allreduce_local(step, tx, rx, trainer);
185            }
186        }
187    }
188
189    /// AllReduce via TCP (multi-process DDP).
190    fn allreduce_remote(
191        step: u64,
192        client: &crate::finetune::WorkerClient,
193        trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
194    ) {
195        // Phase 1: Per-block AllReduce (reverse order matches backward pass)
196        {
197            let accum = trainer.grad_accum_mut().unwrap();
198            let num_blocks = accum.num_blocks();
199            for block_idx in (0..num_blocks).rev() {
200                let flat = accum.block_grads[block_idx].flatten();
201                let sizes = accum.block_grads[block_idx].component_sizes_u32();
202                client
203                    .send_block_gradient(step, block_idx as u32, num_blocks as u32, flat, sizes)
204                    .expect("block gradient send failed");
205                let avg = client.receive_averaged_block().expect("block gradient receive failed");
206                accum.block_grads[block_idx] =
207                    BlockGradientSet::from_flat(&avg.gradients, &avg.component_sizes);
208            }
209        }
210
211        // Phase 2: Non-block AllReduce (LM head + final norm)
212        {
213            let accum = trainer.grad_accum_mut().unwrap();
214
215            // LM head (component=0)
216            let lm_grad = accum.lm_head_grad.clone();
217            client.send_non_block_gradient(step, 0, lm_grad).expect("lm_head gradient send failed");
218            let avg = client.receive_averaged_non_block().expect("lm_head gradient receive failed");
219            accum.lm_head_grad = avg.gradients;
220
221            // Final norm (component=1)
222            let norm_grad = accum.final_norm_grad.clone();
223            client
224                .send_non_block_gradient(step, 1, norm_grad)
225                .expect("final_norm gradient send failed");
226            let avg =
227                client.receive_averaged_non_block().expect("final_norm gradient receive failed");
228            accum.final_norm_grad = avg.gradients;
229
230            // Prevent re-averaging in gpu_optimizer_from_accum
231            accum.accumulated_count = 1;
232        }
233
234        // Phase 3: Embedding AllReduce (CPU gradient, separate from accum)
235        {
236            let embed_grad = trainer.embed_grad_vec().unwrap_or_default();
237            client
238                .send_non_block_gradient(step, 2, embed_grad)
239                .expect("embedding gradient send failed");
240            let avg =
241                client.receive_averaged_non_block().expect("embedding gradient receive failed");
242            trainer.set_embed_grad(avg.gradients);
243        }
244    }
245
246    /// AllReduce via mpsc channels (single-machine multi-GPU).
247    fn allreduce_local(
248        step: u64,
249        tx: &mpsc::Sender<GradientMessage>,
250        rx: &mpsc::Receiver<GradientMessage>,
251        trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
252    ) {
253        let _ = step; // used for logging in future
254
255        // Phase 1: Per-block AllReduce via channels
256        {
257            let accum = trainer.grad_accum_mut().unwrap();
258            let num_blocks = accum.num_blocks();
259            for block_idx in (0..num_blocks).rev() {
260                let flat = accum.block_grads[block_idx].flatten();
261                let sizes = accum.block_grads[block_idx].component_sizes_u32();
262                tx.send(GradientMessage::BlockGradient {
263                    block_idx,
264                    gradients: flat,
265                    component_sizes: sizes,
266                })
267                .expect("channel send failed");
268
269                match rx.recv().expect("channel recv failed") {
270                    GradientMessage::AveragedBlockGradient {
271                        gradients, component_sizes, ..
272                    } => {
273                        accum.block_grads[block_idx] =
274                            BlockGradientSet::from_flat(&gradients, &component_sizes);
275                    }
276                    other => panic!("expected AveragedBlockGradient, got {other:?}"),
277                }
278            }
279        }
280
281        // Phase 2: Non-block AllReduce via channels
282        {
283            let accum = trainer.grad_accum_mut().unwrap();
284
285            // LM head
286            let lm_grad = accum.lm_head_grad.clone();
287            tx.send(GradientMessage::NonBlockGradient { component: 0, gradients: lm_grad })
288                .expect("channel send failed");
289            match rx.recv().expect("channel recv failed") {
290                GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
291                    accum.lm_head_grad = gradients;
292                }
293                other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
294            }
295
296            // Final norm
297            let norm_grad = accum.final_norm_grad.clone();
298            tx.send(GradientMessage::NonBlockGradient { component: 1, gradients: norm_grad })
299                .expect("channel send failed");
300            match rx.recv().expect("channel recv failed") {
301                GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
302                    accum.final_norm_grad = gradients;
303                }
304                other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
305            }
306
307            accum.accumulated_count = 1;
308        }
309
310        // Phase 3: Embedding AllReduce
311        {
312            let embed_grad = trainer.embed_grad_vec().unwrap_or_default();
313            tx.send(GradientMessage::NonBlockGradient { component: 2, gradients: embed_grad })
314                .expect("channel send failed");
315            match rx.recv().expect("channel recv failed") {
316                GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
317                    trainer.set_embed_grad(gradients);
318                }
319                other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
320            }
321        }
322    }
323
324    /// Get the distributed configuration.
325    pub fn dist_config(&self) -> &DistributedTrainConfig {
326        &self.dist_config
327    }
328
329    /// Get the current step.
330    pub fn step(&self) -> usize {
331        self.step
332    }
333
334    /// Get a reference to the underlying trainer.
335    pub fn trainer(&self) -> &super::cuda_trainer::CudaTransformerTrainer {
336        &self.trainer
337    }
338
339    /// Get a mutable reference to the underlying trainer.
340    pub fn trainer_mut(&mut self) -> &mut super::cuda_trainer::CudaTransformerTrainer {
341        &mut self.trainer
342    }
343
344    /// Check if this worker is the coordinator (rank 0).
345    pub fn is_coordinator(&self) -> bool {
346        self.dist_config.rank == 0
347    }
348
349    /// Get world size.
350    pub fn world_size(&self) -> usize {
351        self.dist_config.world_size
352    }
353
354    /// Get rank.
355    pub fn rank(&self) -> usize {
356        self.dist_config.rank
357    }
358
359    /// Check if max_steps has been reached.
360    pub fn reached_max_steps(&self) -> bool {
361        self.trainer.reached_max_steps()
362    }
363}
364
365/// Create a local communication pair for single-machine multi-GPU training.
366///
367/// Returns (coordinator_comm, worker_comms) where worker_comms[i] is for worker i.
368/// The coordinator aggregates gradients and broadcasts averages.
369#[cfg(feature = "cuda")]
370#[allow(dead_code)]
371pub fn create_local_comm_pair() -> (
372    (mpsc::Sender<GradientMessage>, mpsc::Receiver<GradientMessage>),
373    (mpsc::Sender<GradientMessage>, mpsc::Receiver<GradientMessage>),
374) {
375    let (tx_to_coord, rx_at_coord) = mpsc::channel();
376    let (tx_to_worker, rx_at_worker) = mpsc::channel();
377    ((tx_to_worker, rx_at_coord), (tx_to_coord, rx_at_worker))
378}
379
380/// Shard batch indices across workers by interleaving.
381///
382/// Worker N gets batches N, N+world_size, N+2*world_size, ...
383/// This ensures disjoint+complete coverage of the dataset.
384pub fn shard_batches(num_batches: usize, rank: usize, world_size: usize) -> Vec<usize> {
385    (rank..num_batches).step_by(world_size).collect()
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_module_compiles() {
394        assert!(true);
395    }
396
397    #[test]
398    fn test_data_sharding_by_rank() {
399        // 10 batches, 2 workers
400        let shard0 = shard_batches(10, 0, 2);
401        let shard1 = shard_batches(10, 1, 2);
402
403        // Worker 0 gets even indices
404        assert_eq!(shard0, vec![0, 2, 4, 6, 8]);
405        // Worker 1 gets odd indices
406        assert_eq!(shard1, vec![1, 3, 5, 7, 9]);
407
408        // Disjoint
409        for idx in &shard0 {
410            assert!(!shard1.contains(idx));
411        }
412        // Complete
413        let mut all: Vec<usize> = shard0.iter().chain(shard1.iter()).copied().collect();
414        all.sort_unstable();
415        assert_eq!(all, (0..10).collect::<Vec<_>>());
416    }
417
418    #[test]
419    fn test_data_sharding_uneven() {
420        // 7 batches, 3 workers
421        let shard0 = shard_batches(7, 0, 3);
422        let shard1 = shard_batches(7, 1, 3);
423        let shard2 = shard_batches(7, 2, 3);
424
425        assert_eq!(shard0, vec![0, 3, 6]);
426        assert_eq!(shard1, vec![1, 4]);
427        assert_eq!(shard2, vec![2, 5]);
428
429        let mut all: Vec<usize> =
430            shard0.iter().chain(shard1.iter()).chain(shard2.iter()).copied().collect();
431        all.sort_unstable();
432        assert_eq!(all, (0..7).collect::<Vec<_>>());
433    }
434
435    #[test]
436    fn test_data_sharding_single_worker() {
437        let shard = shard_batches(5, 0, 1);
438        assert_eq!(shard, vec![0, 1, 2, 3, 4]);
439    }
440
441    #[test]
442    fn test_data_sharding_more_workers_than_batches() {
443        let shard0 = shard_batches(2, 0, 4);
444        let shard1 = shard_batches(2, 1, 4);
445        let shard2 = shard_batches(2, 2, 4);
446        let shard3 = shard_batches(2, 3, 4);
447
448        assert_eq!(shard0, vec![0]);
449        assert_eq!(shard1, vec![1]);
450        assert!(shard2.is_empty());
451        assert!(shard3.is_empty());
452    }
453}