#[cfg(feature = "cuda")]
use std::sync::mpsc;
#[cfg(feature = "cuda")]
use super::config::DistributedTrainConfig;
#[cfg(feature = "cuda")]
use super::grad_accumulator::BlockGradientSet;
#[cfg(feature = "cuda")]
pub enum DistributedComm {
Local {
tx: mpsc::Sender<GradientMessage>,
rx: mpsc::Receiver<GradientMessage>,
},
Remote {
client: crate::finetune::WorkerClient,
},
}
#[cfg(feature = "cuda")]
#[derive(Debug)]
pub enum GradientMessage {
BlockGradient { block_idx: usize, gradients: Vec<f32>, component_sizes: Vec<u32> },
AveragedBlockGradient { block_idx: usize, gradients: Vec<f32>, component_sizes: Vec<u32> },
NonBlockGradient { component: u8, gradients: Vec<f32> },
AveragedNonBlockGradient { component: u8, gradients: Vec<f32> },
Barrier,
}
#[cfg(feature = "cuda")]
pub struct DistributedCudaTrainer {
trainer: super::cuda_trainer::CudaTransformerTrainer,
comm: DistributedComm,
dist_config: DistributedTrainConfig,
step: usize,
}
#[cfg(feature = "cuda")]
impl DistributedCudaTrainer {
pub fn new(
mut trainer: super::cuda_trainer::CudaTransformerTrainer,
comm: DistributedComm,
dist_config: DistributedTrainConfig,
) -> Self {
trainer.ensure_grad_accum();
Self { trainer, comm, dist_config, step: 0 }
}
pub fn train_batch(&mut self, batch: &super::batch::LMBatch) -> f32 {
let loss = self.trainer.forward_backward_batch(batch);
let step = self.step as u64;
Self::allreduce_impl(step, &self.comm, &mut self.trainer);
self.trainer.apply_ddp_gradients();
self.step += 1;
loss
}
fn allreduce_impl(
step: u64,
comm: &DistributedComm,
trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
) {
let local_count = {
let accum = trainer.grad_accum_mut().unwrap();
let count = accum.accumulated_count;
accum.average(); count
};
if local_count > 1 {
if let Some(mut eg) = trainer.embed_grad_vec() {
let inv = 1.0 / local_count as f32;
for g in &mut eg {
*g *= inv;
}
trainer.set_embed_grad(eg);
}
}
match comm {
DistributedComm::Remote { client } => {
Self::allreduce_remote(step, client, trainer);
}
DistributedComm::Local { tx, rx } => {
Self::allreduce_local(step, tx, rx, trainer);
}
}
}
fn allreduce_remote(
step: u64,
client: &crate::finetune::WorkerClient,
trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
) {
{
let accum = trainer.grad_accum_mut().unwrap();
let num_blocks = accum.num_blocks();
for block_idx in (0..num_blocks).rev() {
let flat = accum.block_grads[block_idx].flatten();
let sizes = accum.block_grads[block_idx].component_sizes_u32();
client
.send_block_gradient(step, block_idx as u32, num_blocks as u32, flat, sizes)
.expect("block gradient send failed");
let avg = client.receive_averaged_block().expect("block gradient receive failed");
accum.block_grads[block_idx] =
BlockGradientSet::from_flat(&avg.gradients, &avg.component_sizes);
}
}
{
let accum = trainer.grad_accum_mut().unwrap();
let lm_grad = accum.lm_head_grad.clone();
client.send_non_block_gradient(step, 0, lm_grad).expect("lm_head gradient send failed");
let avg = client.receive_averaged_non_block().expect("lm_head gradient receive failed");
accum.lm_head_grad = avg.gradients;
let norm_grad = accum.final_norm_grad.clone();
client
.send_non_block_gradient(step, 1, norm_grad)
.expect("final_norm gradient send failed");
let avg =
client.receive_averaged_non_block().expect("final_norm gradient receive failed");
accum.final_norm_grad = avg.gradients;
accum.accumulated_count = 1;
}
{
let embed_grad = trainer.embed_grad_vec().unwrap_or_default();
client
.send_non_block_gradient(step, 2, embed_grad)
.expect("embedding gradient send failed");
let avg =
client.receive_averaged_non_block().expect("embedding gradient receive failed");
trainer.set_embed_grad(avg.gradients);
}
}
fn allreduce_local(
step: u64,
tx: &mpsc::Sender<GradientMessage>,
rx: &mpsc::Receiver<GradientMessage>,
trainer: &mut super::cuda_trainer::CudaTransformerTrainer,
) {
let _ = step;
{
let accum = trainer.grad_accum_mut().unwrap();
let num_blocks = accum.num_blocks();
for block_idx in (0..num_blocks).rev() {
let flat = accum.block_grads[block_idx].flatten();
let sizes = accum.block_grads[block_idx].component_sizes_u32();
tx.send(GradientMessage::BlockGradient {
block_idx,
gradients: flat,
component_sizes: sizes,
})
.expect("channel send failed");
match rx.recv().expect("channel recv failed") {
GradientMessage::AveragedBlockGradient {
gradients, component_sizes, ..
} => {
accum.block_grads[block_idx] =
BlockGradientSet::from_flat(&gradients, &component_sizes);
}
other => panic!("expected AveragedBlockGradient, got {other:?}"),
}
}
}
{
let accum = trainer.grad_accum_mut().unwrap();
let lm_grad = accum.lm_head_grad.clone();
tx.send(GradientMessage::NonBlockGradient { component: 0, gradients: lm_grad })
.expect("channel send failed");
match rx.recv().expect("channel recv failed") {
GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
accum.lm_head_grad = gradients;
}
other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
}
let norm_grad = accum.final_norm_grad.clone();
tx.send(GradientMessage::NonBlockGradient { component: 1, gradients: norm_grad })
.expect("channel send failed");
match rx.recv().expect("channel recv failed") {
GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
accum.final_norm_grad = gradients;
}
other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
}
accum.accumulated_count = 1;
}
{
let embed_grad = trainer.embed_grad_vec().unwrap_or_default();
tx.send(GradientMessage::NonBlockGradient { component: 2, gradients: embed_grad })
.expect("channel send failed");
match rx.recv().expect("channel recv failed") {
GradientMessage::AveragedNonBlockGradient { gradients, .. } => {
trainer.set_embed_grad(gradients);
}
other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
}
}
}
pub fn dist_config(&self) -> &DistributedTrainConfig {
&self.dist_config
}
pub fn step(&self) -> usize {
self.step
}
pub fn trainer(&self) -> &super::cuda_trainer::CudaTransformerTrainer {
&self.trainer
}
pub fn trainer_mut(&mut self) -> &mut super::cuda_trainer::CudaTransformerTrainer {
&mut self.trainer
}
pub fn is_coordinator(&self) -> bool {
self.dist_config.rank == 0
}
pub fn world_size(&self) -> usize {
self.dist_config.world_size
}
pub fn rank(&self) -> usize {
self.dist_config.rank
}
pub fn reached_max_steps(&self) -> bool {
self.trainer.reached_max_steps()
}
}
#[cfg(feature = "cuda")]
#[allow(dead_code)]
pub fn create_local_comm_pair() -> (
(mpsc::Sender<GradientMessage>, mpsc::Receiver<GradientMessage>),
(mpsc::Sender<GradientMessage>, mpsc::Receiver<GradientMessage>),
) {
let (tx_to_coord, rx_at_coord) = mpsc::channel();
let (tx_to_worker, rx_at_worker) = mpsc::channel();
((tx_to_worker, rx_at_coord), (tx_to_coord, rx_at_worker))
}
pub fn shard_batches(num_batches: usize, rank: usize, world_size: usize) -> Vec<usize> {
(rank..num_batches).step_by(world_size).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_module_compiles() {
assert!(true);
}
#[test]
fn test_data_sharding_by_rank() {
let shard0 = shard_batches(10, 0, 2);
let shard1 = shard_batches(10, 1, 2);
assert_eq!(shard0, vec![0, 2, 4, 6, 8]);
assert_eq!(shard1, vec![1, 3, 5, 7, 9]);
for idx in &shard0 {
assert!(!shard1.contains(idx));
}
let mut all: Vec<usize> = shard0.iter().chain(shard1.iter()).copied().collect();
all.sort_unstable();
assert_eq!(all, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_data_sharding_uneven() {
let shard0 = shard_batches(7, 0, 3);
let shard1 = shard_batches(7, 1, 3);
let shard2 = shard_batches(7, 2, 3);
assert_eq!(shard0, vec![0, 3, 6]);
assert_eq!(shard1, vec![1, 4]);
assert_eq!(shard2, vec![2, 5]);
let mut all: Vec<usize> =
shard0.iter().chain(shard1.iter()).chain(shard2.iter()).copied().collect();
all.sort_unstable();
assert_eq!(all, (0..7).collect::<Vec<_>>());
}
#[test]
fn test_data_sharding_single_worker() {
let shard = shard_batches(5, 0, 1);
assert_eq!(shard, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_data_sharding_more_workers_than_batches() {
let shard0 = shard_batches(2, 0, 4);
let shard1 = shard_batches(2, 1, 4);
let shard2 = shard_batches(2, 2, 4);
let shard3 = shard_batches(2, 3, 4);
assert_eq!(shard0, vec![0]);
assert_eq!(shard1, vec![1]);
assert!(shard2.is_empty());
assert!(shard3.is_empty());
}
}