hermes_llm/
distributed.rs1use anyhow::Result;
7use candle_core::Tensor;
8
9#[cfg(feature = "nccl")]
10use cudarc::driver::safe::{CudaContext, CudaStream};
11#[cfg(feature = "nccl")]
12use cudarc::nccl::safe::{Comm, Id};
13#[cfg(feature = "nccl")]
14use std::sync::Arc;
15
16#[derive(Debug, Clone)]
18pub struct DistributedConfig {
19 pub world_size: usize,
21 pub rank: usize,
23 pub comm_file: String,
25}
26
27impl Default for DistributedConfig {
28 fn default() -> Self {
29 Self {
30 world_size: 1,
31 rank: 0,
32 comm_file: "nccl_id.txt".to_string(),
33 }
34 }
35}
36
37impl DistributedConfig {
38 pub fn is_distributed(&self) -> bool {
39 self.world_size > 1
40 }
41
42 pub fn is_main_process(&self) -> bool {
43 self.rank == 0
44 }
45}
46
47#[cfg(feature = "nccl")]
49pub struct NcclCommunicator {
50 comm: Comm,
51 stream: Arc<CudaStream>,
52 rank: usize,
53 world_size: usize,
54}
55
56#[cfg(feature = "nccl")]
57impl NcclCommunicator {
58 pub fn new(config: &DistributedConfig) -> Result<Self> {
63 use std::io::Write;
64
65 let comm_file = std::path::PathBuf::from(&config.comm_file);
66
67 let id = if config.rank == 0 {
69 if comm_file.exists() {
71 std::fs::remove_file(&comm_file)?;
72 }
73
74 let id = Id::new().map_err(|e| anyhow::anyhow!("Failed to create NCCL ID: {:?}", e))?;
75
76 let tmp_file = comm_file.with_extension("tmp");
78 let mut file = std::fs::File::create(&tmp_file)?;
79 file.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
80 std::fs::rename(&tmp_file, &comm_file)?;
81
82 tracing::info!("Rank 0: Created NCCL ID and wrote to {:?}", comm_file);
83 id
84 } else {
85 tracing::info!("Rank {}: Waiting for NCCL ID file...", config.rank);
87 while !comm_file.exists() {
88 std::thread::sleep(std::time::Duration::from_millis(100));
89 }
90 std::thread::sleep(std::time::Duration::from_millis(100));
92
93 let data = std::fs::read(&comm_file)?;
94 let internal: [i8; 128] = data
95 .into_iter()
96 .map(|i| i as i8)
97 .collect::<Vec<_>>()
98 .try_into()
99 .map_err(|_| anyhow::anyhow!("Invalid NCCL ID file"))?;
100
101 let id = Id::uninit(internal);
102 tracing::info!("Rank {}: Read NCCL ID from {:?}", config.rank, comm_file);
103 id
104 };
105
106 let gpu_ordinal = 0;
109 let ctx = CudaContext::new(gpu_ordinal).map_err(|e| {
110 anyhow::anyhow!("Failed to create CUDA context {}: {:?}", gpu_ordinal, e)
111 })?;
112 let stream = ctx.default_stream();
113
114 let comm = Comm::from_rank(stream.clone(), config.rank, config.world_size, id)
116 .map_err(|e| anyhow::anyhow!("Failed to create NCCL communicator: {:?}", e.0))?;
117
118 tracing::info!("Rank {}: NCCL communicator initialized", config.rank);
119
120 if config.rank == 0 {
123 std::thread::sleep(std::time::Duration::from_millis(500));
124 if comm_file.exists() {
125 let _ = std::fs::remove_file(&comm_file);
126 }
127 }
128
129 Ok(Self {
130 comm,
131 stream,
132 rank: config.rank,
133 world_size: config.world_size,
134 })
135 }
136
137 pub fn all_reduce_avg(&self, tensor: &Tensor) -> Result<Tensor> {
139 let reduced = self.all_reduce_sum(tensor)?;
142 let avg = reduced.affine(1.0 / self.world_size as f64, 0.0)?;
143 Ok(avg)
144 }
145
146 pub fn all_reduce_sum(&self, tensor: &Tensor) -> Result<Tensor> {
148 use cudarc::nccl::safe::ReduceOp;
149
150 let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
152 let len = data.len();
153
154 let gpu_data = self
156 .stream
157 .clone_htod(&data)
158 .map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?;
159
160 let mut gpu_output = self
162 .stream
163 .alloc_zeros::<f32>(len)
164 .map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
165
166 self.comm
168 .all_reduce(&gpu_data, &mut gpu_output, &ReduceOp::Sum)
169 .map_err(|e| anyhow::anyhow!("NCCL all-reduce failed: {:?}", e.0))?;
170
171 self.stream
173 .synchronize()
174 .map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
175
176 let output = self
178 .stream
179 .clone_dtoh(&gpu_output)
180 .map_err(|e| anyhow::anyhow!("Failed to copy data from GPU: {:?}", e))?;
181
182 let result = Tensor::from_vec(output, tensor.shape(), tensor.device())?;
184 Ok(result)
185 }
186
187 pub fn broadcast(&self, tensor: &Tensor) -> Result<Tensor> {
189 let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
190 let len = data.len();
191
192 let gpu_data = if self.rank == 0 {
194 Some(
195 self.stream
196 .clone_htod(&data)
197 .map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?,
198 )
199 } else {
200 None
201 };
202
203 let mut gpu_output = self
205 .stream
206 .alloc_zeros::<f32>(len)
207 .map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
208
209 self.comm
210 .broadcast(gpu_data.as_ref(), &mut gpu_output, 0)
211 .map_err(|e| anyhow::anyhow!("NCCL broadcast failed: {:?}", e.0))?;
212
213 self.stream
215 .synchronize()
216 .map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
217
218 let output = self
220 .stream
221 .clone_dtoh(&gpu_output)
222 .map_err(|e| anyhow::anyhow!("Failed to copy data from GPU: {:?}", e))?;
223
224 let result = Tensor::from_vec(output, tensor.shape(), tensor.device())?;
225 Ok(result)
226 }
227
228 pub fn barrier(&self) -> Result<()> {
230 use cudarc::nccl::safe::ReduceOp;
231
232 let dummy = [0.0f32];
234 let gpu_dummy = self
235 .stream
236 .clone_htod(&dummy)
237 .map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?;
238 let mut gpu_output = self
239 .stream
240 .alloc_zeros::<f32>(1)
241 .map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
242
243 self.comm
244 .all_reduce(&gpu_dummy, &mut gpu_output, &ReduceOp::Sum)
245 .map_err(|e| anyhow::anyhow!("NCCL barrier failed: {:?}", e.0))?;
246
247 self.stream
249 .synchronize()
250 .map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
251 Ok(())
252 }
253
254 pub fn rank(&self) -> usize {
255 self.rank
256 }
257
258 pub fn world_size(&self) -> usize {
259 self.world_size
260 }
261
262 pub fn finalize(self) -> Result<()> {
264 self.stream
266 .synchronize()
267 .map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
268 Ok(())
270 }
271}
272
273#[cfg(not(feature = "nccl"))]
275pub struct NcclCommunicator {
276 rank: usize,
277 world_size: usize,
278}
279
280#[cfg(not(feature = "nccl"))]
281impl NcclCommunicator {
282 pub fn new(_config: &DistributedConfig) -> Result<Self> {
283 anyhow::bail!("NCCL support not enabled. Build with --features nccl")
284 }
285
286 pub fn all_reduce_avg(&self, tensor: &Tensor) -> Result<Tensor> {
287 Ok(tensor.clone())
288 }
289
290 pub fn all_reduce_sum(&self, tensor: &Tensor) -> Result<Tensor> {
291 Ok(tensor.clone())
292 }
293
294 pub fn broadcast(&self, tensor: &Tensor) -> Result<Tensor> {
295 Ok(tensor.clone())
296 }
297
298 pub fn barrier(&self) -> Result<()> {
299 Ok(())
300 }
301
302 pub fn rank(&self) -> usize {
303 self.rank
304 }
305
306 pub fn world_size(&self) -> usize {
307 self.world_size
308 }
309
310 pub fn finalize(self) -> Result<()> {
311 Ok(())
312 }
313}
314
315fn sync_vars(
318 var_map: &candle_nn::VarMap,
319 op: impl FnOnce(&Tensor) -> Result<Tensor>,
320) -> Result<()> {
321 use candle_core::Shape;
322
323 let vars: Vec<candle_core::Var> = var_map.all_vars();
324 if vars.is_empty() {
325 return Ok(());
326 }
327
328 let mut shapes: Vec<Shape> = Vec::with_capacity(vars.len());
330 let mut flat_data: Vec<f32> = Vec::new();
331
332 for var in &vars {
333 let tensor = var.as_tensor();
334 shapes.push(tensor.shape().clone());
335 let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
336 flat_data.extend(data);
337 }
338
339 let device = vars[0].as_tensor().device();
340 let len = flat_data.len();
341 let flat_tensor = Tensor::from_vec(flat_data, len, device)?;
342 let synced = op(&flat_tensor)?;
343 let synced_data: Vec<f32> = synced.to_vec1()?;
344
345 let mut offset = 0;
347 for (var, shape) in vars.iter().zip(shapes.iter()) {
348 let size = shape.elem_count();
349 let data = &synced_data[offset..offset + size];
350 let tensor = Tensor::from_vec(data.to_vec(), shape.dims(), device)?;
351 var.set(&tensor)?;
352 offset += size;
353 }
354
355 Ok(())
356}
357
358pub fn sync_model(var_map: &candle_nn::VarMap, comm: &NcclCommunicator) -> Result<()> {
361 sync_vars(var_map, |t| comm.broadcast(t))
362}
363
364pub fn sync_gradients(var_map: &candle_nn::VarMap, comm: &NcclCommunicator) -> Result<()> {
367 sync_vars(var_map, |t| comm.all_reduce_avg(t))
368}