entrenar/train/transformer_trainer/
distributed_trainer.rs1#[cfg(feature = "cuda")]
31use std::sync::mpsc;
32
33#[cfg(feature = "cuda")]
34use super::config::DistributedTrainConfig;
35#[cfg(feature = "cuda")]
36use super::grad_accumulator::BlockGradientSet;
37
38#[cfg(feature = "cuda")]
40pub enum DistributedComm {
41 Local {
47 tx: mpsc::Sender<GradientMessage>,
49 rx: mpsc::Receiver<GradientMessage>,
51 },
52 Remote {
54 client: crate::finetune::WorkerClient,
56 },
57}
58
59#[cfg(feature = "cuda")]
61#[derive(Debug)]
62pub enum GradientMessage {
63 BlockGradient { block_idx: usize, gradients: Vec<f32>, component_sizes: Vec<u32> },
65 AveragedBlockGradient { block_idx: usize, gradients: Vec<f32>, component_sizes: Vec<u32> },
67 NonBlockGradient { component: u8, gradients: Vec<f32> },
69 AveragedNonBlockGradient { component: u8, gradients: Vec<f32> },
71 Barrier,
73}
74
75#[cfg(feature = "cuda")]
91pub struct DistributedCudaTrainer {
92 trainer: super::cuda_trainer::CudaTransformerTrainer,
94 comm: DistributedComm,
96 dist_config: DistributedTrainConfig,
98 step: usize,
100}
101
102#[cfg(feature = "cuda")]
103impl DistributedCudaTrainer {
104 pub fn new(
115 mut trainer: super::cuda_trainer::CudaTransformerTrainer,
116 comm: DistributedComm,
117 dist_config: DistributedTrainConfig,
118 ) -> Self {
119 trainer.ensure_grad_accum();
121
122 Self { trainer, comm, dist_config, step: 0 }
123 }
124
125 pub fn train_batch(&mut self, batch: &super::batch::LMBatch) -> f32 {
134 let loss = self.trainer.forward_backward_batch(batch);
136
137 let step = self.step as u64;
139 Self::allreduce_impl(step, &self.comm, &mut self.trainer);
140
141 self.trainer.apply_ddp_gradients();
143
144 self.step += 1;
145 loss
146 }
147
148 fn allreduce_impl(
153 step: u64,
154 comm: &DistributedComm,
155 trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
156 ) {
157 let local_count = {
162 let accum = trainer.grad_accum_mut().unwrap();
163 let count = accum.accumulated_count;
164 accum.average(); count
166 };
167 if local_count > 1 {
169 if let Some(mut eg) = trainer.embed_grad_vec() {
170 let inv = 1.0 / local_count as f32;
171 for g in &mut eg {
172 *g *= inv;
173 }
174 trainer.set_embed_grad(eg);
175 }
176 }
177
178 match comm {
180 DistributedComm::Remote { client } => {
181 Self::allreduce_remote(step, client, trainer);
182 }
183 DistributedComm::Local { tx, rx } => {
184 Self::allreduce_local(step, tx, rx, trainer);
185 }
186 }
187 }
188
189 fn allreduce_remote(
191 step: u64,
192 client: &crate::finetune::WorkerClient,
193 trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
194 ) {
195 {
197 let accum = trainer.grad_accum_mut().unwrap();
198 let num_blocks = accum.num_blocks();
199 for block_idx in (0..num_blocks).rev() {
200 let flat = accum.block_grads[block_idx].flatten();
201 let sizes = accum.block_grads[block_idx].component_sizes_u32();
202 client
203 .send_block_gradient(step, block_idx as u32, num_blocks as u32, flat, sizes)
204 .expect("block gradient send failed");
205 let avg = client.receive_averaged_block().expect("block gradient receive failed");
206 accum.block_grads[block_idx] =
207 BlockGradientSet::from_flat(&avg.gradients, &avg.component_sizes);
208 }
209 }
210
211 {
213 let accum = trainer.grad_accum_mut().unwrap();
214
215 let lm_grad = accum.lm_head_grad.clone();
217 client.send_non_block_gradient(step, 0, lm_grad).expect("lm_head gradient send failed");
218 let avg = client.receive_averaged_non_block().expect("lm_head gradient receive failed");
219 accum.lm_head_grad = avg.gradients;
220
221 let norm_grad = accum.final_norm_grad.clone();
223 client
224 .send_non_block_gradient(step, 1, norm_grad)
225 .expect("final_norm gradient send failed");
226 let avg =
227 client.receive_averaged_non_block().expect("final_norm gradient receive failed");
228 accum.final_norm_grad = avg.gradients;
229
230 accum.accumulated_count = 1;
232 }
233
234 {
236 let embed_grad = trainer.embed_grad_vec().unwrap_or_default();
237 client
238 .send_non_block_gradient(step, 2, embed_grad)
239 .expect("embedding gradient send failed");
240 let avg =
241 client.receive_averaged_non_block().expect("embedding gradient receive failed");
242 trainer.set_embed_grad(avg.gradients);
243 }
244 }
245
246 fn allreduce_local(
248 step: u64,
249 tx: &mpsc::Sender<GradientMessage>,
250 rx: &mpsc::Receiver<GradientMessage>,
251 trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
252 ) {
253 let _ = step; {
257 let accum = trainer.grad_accum_mut().unwrap();
258 let num_blocks = accum.num_blocks();
259 for block_idx in (0..num_blocks).rev() {
260 let flat = accum.block_grads[block_idx].flatten();
261 let sizes = accum.block_grads[block_idx].component_sizes_u32();
262 tx.send(GradientMessage::BlockGradient {
263 block_idx,
264 gradients: flat,
265 component_sizes: sizes,
266 })
267 .expect("channel send failed");
268
269 match rx.recv().expect("channel recv failed") {
270 GradientMessage::AveragedBlockGradient {
271 gradients, component_sizes, ..
272 } => {
273 accum.block_grads[block_idx] =
274 BlockGradientSet::from_flat(&gradients, &component_sizes);
275 }
276 other => panic!("expected AveragedBlockGradient, got {other:?}"),
277 }
278 }
279 }
280
281 {
283 let accum = trainer.grad_accum_mut().unwrap();
284
285 let lm_grad = accum.lm_head_grad.clone();
287 tx.send(GradientMessage::NonBlockGradient { component: 0, gradients: lm_grad })
288 .expect("channel send failed");
289 match rx.recv().expect("channel recv failed") {
290 GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
291 accum.lm_head_grad = gradients;
292 }
293 other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
294 }
295
296 let norm_grad = accum.final_norm_grad.clone();
298 tx.send(GradientMessage::NonBlockGradient { component: 1, gradients: norm_grad })
299 .expect("channel send failed");
300 match rx.recv().expect("channel recv failed") {
301 GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
302 accum.final_norm_grad = gradients;
303 }
304 other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
305 }
306
307 accum.accumulated_count = 1;
308 }
309
310 {
312 let embed_grad = trainer.embed_grad_vec().unwrap_or_default();
313 tx.send(GradientMessage::NonBlockGradient { component: 2, gradients: embed_grad })
314 .expect("channel send failed");
315 match rx.recv().expect("channel recv failed") {
316 GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
317 trainer.set_embed_grad(gradients);
318 }
319 other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
320 }
321 }
322 }
323
324 pub fn dist_config(&self) -> &DistributedTrainConfig {
326 &self.dist_config
327 }
328
329 pub fn step(&self) -> usize {
331 self.step
332 }
333
334 pub fn trainer(&self) -> &super::cuda_trainer::CudaTransformerTrainer {
336 &self.trainer
337 }
338
339 pub fn trainer_mut(&mut self) -> &mut super::cuda_trainer::CudaTransformerTrainer {
341 &mut self.trainer
342 }
343
344 pub fn is_coordinator(&self) -> bool {
346 self.dist_config.rank == 0
347 }
348
349 pub fn world_size(&self) -> usize {
351 self.dist_config.world_size
352 }
353
354 pub fn rank(&self) -> usize {
356 self.dist_config.rank
357 }
358
359 pub fn reached_max_steps(&self) -> bool {
361 self.trainer.reached_max_steps()
362 }
363}
364
365#[cfg(feature = "cuda")]
370#[allow(dead_code)]
371pub fn create_local_comm_pair() -> (
372 (mpsc::Sender<GradientMessage>, mpsc::Receiver<GradientMessage>),
373 (mpsc::Sender<GradientMessage>, mpsc::Receiver<GradientMessage>),
374) {
375 let (tx_to_coord, rx_at_coord) = mpsc::channel();
376 let (tx_to_worker, rx_at_worker) = mpsc::channel();
377 ((tx_to_worker, rx_at_coord), (tx_to_coord, rx_at_worker))
378}
379
380pub fn shard_batches(num_batches: usize, rank: usize, world_size: usize) -> Vec<usize> {
385 (rank..num_batches).step_by(world_size).collect()
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_module_compiles() {
394 assert!(true);
395 }
396
397 #[test]
398 fn test_data_sharding_by_rank() {
399 let shard0 = shard_batches(10, 0, 2);
401 let shard1 = shard_batches(10, 1, 2);
402
403 assert_eq!(shard0, vec![0, 2, 4, 6, 8]);
405 assert_eq!(shard1, vec![1, 3, 5, 7, 9]);
407
408 for idx in &shard0 {
410 assert!(!shard1.contains(idx));
411 }
412 let mut all: Vec<usize> = shard0.iter().chain(shard1.iter()).copied().collect();
414 all.sort_unstable();
415 assert_eq!(all, (0..10).collect::<Vec<_>>());
416 }
417
418 #[test]
419 fn test_data_sharding_uneven() {
420 let shard0 = shard_batches(7, 0, 3);
422 let shard1 = shard_batches(7, 1, 3);
423 let shard2 = shard_batches(7, 2, 3);
424
425 assert_eq!(shard0, vec![0, 3, 6]);
426 assert_eq!(shard1, vec![1, 4]);
427 assert_eq!(shard2, vec![2, 5]);
428
429 let mut all: Vec<usize> =
430 shard0.iter().chain(shard1.iter()).chain(shard2.iter()).copied().collect();
431 all.sort_unstable();
432 assert_eq!(all, (0..7).collect::<Vec<_>>());
433 }
434
435 #[test]
436 fn test_data_sharding_single_worker() {
437 let shard = shard_batches(5, 0, 1);
438 assert_eq!(shard, vec![0, 1, 2, 3, 4]);
439 }
440
441 #[test]
442 fn test_data_sharding_more_workers_than_batches() {
443 let shard0 = shard_batches(2, 0, 4);
444 let shard1 = shard_batches(2, 1, 4);
445 let shard2 = shard_batches(2, 2, 4);
446 let shard3 = shard_batches(2, 3, 4);
447
448 assert_eq!(shard0, vec![0]);
449 assert_eq!(shard1, vec![1]);
450 assert!(shard2.is_empty());
451 assert!(shard3.is_empty());
452 }
453}