Skip to main content

hermes_llm/
distributed.rs

1//! Distributed training support using NCCL
2//!
3//! This module provides multi-GPU training capabilities using NVIDIA's NCCL library
4//! for efficient gradient synchronization across GPUs.
5
6use anyhow::Result;
7use candle_core::Tensor;
8
9#[cfg(feature = "nccl")]
10use cudarc::driver::safe::{CudaContext, CudaStream};
11#[cfg(feature = "nccl")]
12use cudarc::nccl::safe::{Comm, Id};
13#[cfg(feature = "nccl")]
14use std::sync::Arc;
15
16/// Distributed configuration for multi-GPU training
17#[derive(Debug, Clone)]
18pub struct DistributedConfig {
19    /// Total number of GPUs/processes
20    pub world_size: usize,
21    /// This process's rank (0 to world_size-1)
22    pub rank: usize,
23    /// Path to communication file for NCCL ID exchange
24    pub comm_file: String,
25}
26
27impl Default for DistributedConfig {
28    fn default() -> Self {
29        Self {
30            world_size: 1,
31            rank: 0,
32            comm_file: "nccl_id.txt".to_string(),
33        }
34    }
35}
36
37impl DistributedConfig {
38    pub fn is_distributed(&self) -> bool {
39        self.world_size > 1
40    }
41
42    pub fn is_main_process(&self) -> bool {
43        self.rank == 0
44    }
45}
46
47/// NCCL Communicator wrapper for gradient synchronization
48#[cfg(feature = "nccl")]
49pub struct NcclCommunicator {
50    comm: Comm,
51    stream: Arc<CudaStream>,
52    rank: usize,
53    world_size: usize,
54}
55
56#[cfg(feature = "nccl")]
57impl NcclCommunicator {
58    /// Initialize NCCL communicator
59    ///
60    /// Rank 0 creates the NCCL ID and writes it to a file.
61    /// Other ranks wait for the file and read the ID.
62    pub fn new(config: &DistributedConfig) -> Result<Self> {
63        use std::io::Write;
64
65        let comm_file = std::path::PathBuf::from(&config.comm_file);
66
67        // Rank 0 creates the ID, others wait for it
68        let id = if config.rank == 0 {
69            // Clean up any existing comm file
70            if comm_file.exists() {
71                std::fs::remove_file(&comm_file)?;
72            }
73
74            let id = Id::new().map_err(|e| anyhow::anyhow!("Failed to create NCCL ID: {:?}", e))?;
75
76            // Write ID to temporary file then rename (atomic)
77            let tmp_file = comm_file.with_extension("tmp");
78            let mut file = std::fs::File::create(&tmp_file)?;
79            file.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
80            std::fs::rename(&tmp_file, &comm_file)?;
81
82            tracing::info!("Rank 0: Created NCCL ID and wrote to {:?}", comm_file);
83            id
84        } else {
85            // Wait for rank 0 to create the file
86            tracing::info!("Rank {}: Waiting for NCCL ID file...", config.rank);
87            while !comm_file.exists() {
88                std::thread::sleep(std::time::Duration::from_millis(100));
89            }
90            // Small delay to ensure file is fully written
91            std::thread::sleep(std::time::Duration::from_millis(100));
92
93            let data = std::fs::read(&comm_file)?;
94            let internal: [i8; 128] = data
95                .into_iter()
96                .map(|i| i as i8)
97                .collect::<Vec<_>>()
98                .try_into()
99                .map_err(|_| anyhow::anyhow!("Invalid NCCL ID file"))?;
100
101            let id = Id::uninit(internal);
102            tracing::info!("Rank {}: Read NCCL ID from {:?}", config.rank, comm_file);
103            id
104        };
105
106        // Create CUDA context and stream for this rank
107        // When CUDA_VISIBLE_DEVICES is set, the visible GPU is always device 0
108        let gpu_ordinal = 0;
109        let ctx = CudaContext::new(gpu_ordinal).map_err(|e| {
110            anyhow::anyhow!("Failed to create CUDA context {}: {:?}", gpu_ordinal, e)
111        })?;
112        let stream = ctx.default_stream();
113
114        // Create NCCL communicator
115        let comm = Comm::from_rank(stream.clone(), config.rank, config.world_size, id)
116            .map_err(|e| anyhow::anyhow!("Failed to create NCCL communicator: {:?}", e.0))?;
117
118        tracing::info!("Rank {}: NCCL communicator initialized", config.rank);
119
120        // Clean up comm file after all ranks have created communicator
121        // Rank 0 waits a bit then cleans up (other ranks already have the ID)
122        if config.rank == 0 {
123            std::thread::sleep(std::time::Duration::from_millis(500));
124            if comm_file.exists() {
125                let _ = std::fs::remove_file(&comm_file);
126            }
127        }
128
129        Ok(Self {
130            comm,
131            stream,
132            rank: config.rank,
133            world_size: config.world_size,
134        })
135    }
136
137    /// All-reduce a tensor (sum across all ranks, then divide by world_size for average)
138    pub fn all_reduce_avg(&self, tensor: &Tensor) -> Result<Tensor> {
139        // For now, we use the synchronous all-reduce
140        // The tensor must be on a CUDA device
141        let reduced = self.all_reduce_sum(tensor)?;
142        let avg = reduced.affine(1.0 / self.world_size as f64, 0.0)?;
143        Ok(avg)
144    }
145
146    /// All-reduce a tensor (sum across all ranks)
147    pub fn all_reduce_sum(&self, tensor: &Tensor) -> Result<Tensor> {
148        use cudarc::nccl::safe::ReduceOp;
149
150        // Get the data from tensor
151        let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
152        let len = data.len();
153
154        // Copy data to GPU using stream
155        let gpu_data = self
156            .stream
157            .clone_htod(&data)
158            .map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?;
159
160        // Allocate output buffer on GPU
161        let mut gpu_output = self
162            .stream
163            .alloc_zeros::<f32>(len)
164            .map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
165
166        // Perform NCCL all-reduce on GPU buffers
167        self.comm
168            .all_reduce(&gpu_data, &mut gpu_output, &ReduceOp::Sum)
169            .map_err(|e| anyhow::anyhow!("NCCL all-reduce failed: {:?}", e.0))?;
170
171        // Synchronize stream before reading results
172        self.stream
173            .synchronize()
174            .map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
175
176        // Copy result back to CPU
177        let output = self
178            .stream
179            .clone_dtoh(&gpu_output)
180            .map_err(|e| anyhow::anyhow!("Failed to copy data from GPU: {:?}", e))?;
181
182        // Convert back to tensor
183        let result = Tensor::from_vec(output, tensor.shape(), tensor.device())?;
184        Ok(result)
185    }
186
187    /// Broadcast a tensor from rank 0 to all other ranks
188    pub fn broadcast(&self, tensor: &Tensor) -> Result<Tensor> {
189        let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
190        let len = data.len();
191
192        // Copy data to GPU (only rank 0 has meaningful data)
193        let gpu_data = if self.rank == 0 {
194            Some(
195                self.stream
196                    .clone_htod(&data)
197                    .map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?,
198            )
199        } else {
200            None
201        };
202
203        // Allocate output buffer on GPU
204        let mut gpu_output = self
205            .stream
206            .alloc_zeros::<f32>(len)
207            .map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
208
209        self.comm
210            .broadcast(gpu_data.as_ref(), &mut gpu_output, 0)
211            .map_err(|e| anyhow::anyhow!("NCCL broadcast failed: {:?}", e.0))?;
212
213        // Synchronize stream before reading results
214        self.stream
215            .synchronize()
216            .map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
217
218        // Copy result back to CPU
219        let output = self
220            .stream
221            .clone_dtoh(&gpu_output)
222            .map_err(|e| anyhow::anyhow!("Failed to copy data from GPU: {:?}", e))?;
223
224        let result = Tensor::from_vec(output, tensor.shape(), tensor.device())?;
225        Ok(result)
226    }
227
228    /// Synchronize all ranks (barrier)
229    pub fn barrier(&self) -> Result<()> {
230        use cudarc::nccl::safe::ReduceOp;
231
232        // Use a small all-reduce as a barrier
233        let dummy = [0.0f32];
234        let gpu_dummy = self
235            .stream
236            .clone_htod(&dummy)
237            .map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?;
238        let mut gpu_output = self
239            .stream
240            .alloc_zeros::<f32>(1)
241            .map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
242
243        self.comm
244            .all_reduce(&gpu_dummy, &mut gpu_output, &ReduceOp::Sum)
245            .map_err(|e| anyhow::anyhow!("NCCL barrier failed: {:?}", e.0))?;
246
247        // Synchronize to ensure barrier completes
248        self.stream
249            .synchronize()
250            .map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
251        Ok(())
252    }
253
254    pub fn rank(&self) -> usize {
255        self.rank
256    }
257
258    pub fn world_size(&self) -> usize {
259        self.world_size
260    }
261
262    /// Finalize the NCCL communicator - sync stream before drop
263    pub fn finalize(self) -> Result<()> {
264        // Sync stream to ensure all NCCL operations complete before Drop
265        self.stream
266            .synchronize()
267            .map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
268        // comm will be dropped here
269        Ok(())
270    }
271}
272
273/// Stub communicator for non-NCCL builds
274#[cfg(not(feature = "nccl"))]
275pub struct NcclCommunicator {
276    rank: usize,
277    world_size: usize,
278}
279
280#[cfg(not(feature = "nccl"))]
281impl NcclCommunicator {
282    pub fn new(_config: &DistributedConfig) -> Result<Self> {
283        anyhow::bail!("NCCL support not enabled. Build with --features nccl")
284    }
285
286    pub fn all_reduce_avg(&self, tensor: &Tensor) -> Result<Tensor> {
287        Ok(tensor.clone())
288    }
289
290    pub fn all_reduce_sum(&self, tensor: &Tensor) -> Result<Tensor> {
291        Ok(tensor.clone())
292    }
293
294    pub fn broadcast(&self, tensor: &Tensor) -> Result<Tensor> {
295        Ok(tensor.clone())
296    }
297
298    pub fn barrier(&self) -> Result<()> {
299        Ok(())
300    }
301
302    pub fn rank(&self) -> usize {
303        self.rank
304    }
305
306    pub fn world_size(&self) -> usize {
307        self.world_size
308    }
309
310    pub fn finalize(self) -> Result<()> {
311        Ok(())
312    }
313}
314
315/// Flatten all variables, apply a collective operation, and unflatten back.
316/// Shared helper for sync_model (broadcast) and sync_gradients (all-reduce).
317fn sync_vars(
318    var_map: &candle_nn::VarMap,
319    op: impl FnOnce(&Tensor) -> Result<Tensor>,
320) -> Result<()> {
321    use candle_core::Shape;
322
323    let vars: Vec<candle_core::Var> = var_map.all_vars();
324    if vars.is_empty() {
325        return Ok(());
326    }
327
328    // Collect shapes and flatten all tensors into one
329    let mut shapes: Vec<Shape> = Vec::with_capacity(vars.len());
330    let mut flat_data: Vec<f32> = Vec::new();
331
332    for var in &vars {
333        let tensor = var.as_tensor();
334        shapes.push(tensor.shape().clone());
335        let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
336        flat_data.extend(data);
337    }
338
339    let device = vars[0].as_tensor().device();
340    let len = flat_data.len();
341    let flat_tensor = Tensor::from_vec(flat_data, len, device)?;
342    let synced = op(&flat_tensor)?;
343    let synced_data: Vec<f32> = synced.to_vec1()?;
344
345    // Unflatten and update each variable
346    let mut offset = 0;
347    for (var, shape) in vars.iter().zip(shapes.iter()) {
348        let size = shape.elem_count();
349        let data = &synced_data[offset..offset + size];
350        let tensor = Tensor::from_vec(data.to_vec(), shape.dims(), device)?;
351        var.set(&tensor)?;
352        offset += size;
353    }
354
355    Ok(())
356}
357
358/// Broadcast model weights from rank 0 to all other ranks
359/// Must be called after model initialization to ensure all ranks have identical weights
360pub fn sync_model(var_map: &candle_nn::VarMap, comm: &NcclCommunicator) -> Result<()> {
361    sync_vars(var_map, |t| comm.broadcast(t))
362}
363
364/// Synchronize gradients across all ranks using NCCL
365/// Flattens all gradients into a single tensor for efficient all-reduce
366pub fn sync_gradients(var_map: &candle_nn::VarMap, comm: &NcclCommunicator) -> Result<()> {
367    sync_vars(var_map, |t| comm.all_reduce_avg(t))
368}