1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11 master_ip: Option<String>,
12 pub master_port: u16,
13 pub port: u16,
14 pub right_port: u16,
15 right_ip: Option<String>,
16 pub rank: usize,
17 pub world_size: usize,
18}
19
20impl RingConfig {
21 pub fn load() -> Self {
23 let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24 let config: RingConfig = serde_json::from_reader(
25 &File::open(config_json).expect("Could not access Ring config JSON"),
26 )
27 .expect("Invalid JSON config");
28
29 if config.master_ip.is_none() && !config.is_master_rank() {
30 panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31 }
32 config
33 }
34
35 pub fn is_master_rank(&self) -> bool {
36 self.rank == 0
37 }
38
39 pub fn master_ip(&self) -> String {
40 self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41 }
42
43 pub fn right_ip(&self) -> String {
44 self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45 }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49 fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53 fn wait(&self) -> Result<()> {
54 Barrier::wait(self);
55 Ok(())
56 }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60 #[cfg(all(feature = "cuda", feature = "ring"))]
61 {
62 let config = RingConfig::load();
63 Ok(config.world_size)
64 }
65 #[cfg(all(not(feature = "cuda"), feature = "ring"))]
66 {
67 let config = RingConfig::load();
68 Ok(config.world_size)
69 }
70
71 #[cfg(all(feature = "cuda", feature = "nccl"))]
72 {
73 if let Ok(x) = std::env::var("MISTRALRS_MN_LOCAL_WORLD_SIZE") {
75 use std::str::FromStr;
76 Ok(usize::from_str(&x).expect("Not a number for MISTRALRS_MN_LOCAL_WORLD_SIZE!"))
77 } else {
78 use candle_core::cuda::WrapErr;
79 candle_core::cuda::cudarc::driver::result::device::get_count()
80 .w()
81 .map(|x| x as usize)
82 }
83 }
84
85 #[cfg(all(not(feature = "ring"), not(feature = "nccl")))]
86 Ok(1)
87}
88
89pub fn use_nccl() -> bool {
90 (std::env::var("MISTRALRS_NO_NCCL").is_err()
91 || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
92 && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
93}
94
95pub fn use_ring() -> bool {
96 cfg!(feature = "ring") && std::env::var("RING_CONFIG").is_ok()
97}
98
99#[derive(Debug)]
101pub enum Comm {
102 #[cfg(all(feature = "cuda", feature = "nccl"))]
103 Nccl(nccl::NcclComm),
104 #[cfg(feature = "ring")]
105 Ring(ring::RingComm),
106 Dummy(dummy::DummyComm),
107}
108
109impl Comm {
110 pub fn from_device(
111 id: Id,
112 dev: &candle_core::Device,
113 rank: usize,
114 world_size: usize,
115 ) -> Result<Self> {
116 #[cfg(all(feature = "cuda", feature = "nccl"))]
117 if use_nccl() {
118 return Ok(Self::Nccl(nccl::NcclComm::from_device(
119 id, dev, rank, world_size,
120 )?));
121 }
122
123 #[cfg(feature = "ring")]
124 {
125 return Ok(Self::Ring(ring::RingComm::from_device(
126 id, dev, rank, world_size,
127 )?));
128 }
129
130 #[allow(unreachable_code)]
131 Ok(Self::Dummy(dummy::DummyComm::from_device(
132 id, dev, rank, world_size,
133 )?))
134 }
135
136 pub fn rank(&self) -> usize {
137 match self {
138 #[cfg(all(feature = "cuda", feature = "nccl"))]
139 Self::Nccl(comm) => comm.rank(),
140 #[cfg(feature = "ring")]
141 Self::Ring(comm) => comm.rank(),
142 Self::Dummy(comm) => comm.rank(),
143 }
144 }
145
146 pub fn world_size(&self) -> usize {
147 match self {
148 #[cfg(all(feature = "cuda", feature = "nccl"))]
149 Self::Nccl(comm) => comm.world_size(),
150 #[cfg(feature = "ring")]
151 Self::Ring(comm) => comm.world_size(),
152 Self::Dummy(comm) => comm.world_size(),
153 }
154 }
155}
156
157#[derive(Debug, Clone, Copy)]
159pub enum Id {
160 #[cfg(all(feature = "cuda", feature = "nccl"))]
161 Nccl(cudarc::nccl::Id),
162 Dummy,
163}
164
165impl Id {
166 pub fn new() -> Self {
167 #[cfg(all(feature = "cuda", feature = "nccl"))]
168 if use_nccl() {
169 let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
170 return Self::Nccl(id);
171 }
172 Self::Dummy
173 }
174
175 pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
176 #[cfg(all(feature = "cuda", feature = "nccl"))]
177 if use_nccl() {
178 return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
179 }
180 Self::Dummy
181 }
182
183 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
184 match self {
185 #[cfg(all(feature = "cuda", feature = "nccl"))]
186 Self::Nccl(id) => id.internal(),
187 Self::Dummy => {
188 static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
189 &ZEROED_ID
190 }
191 }
192 }
193}
194
195impl Default for Id {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[cfg(all(feature = "cuda", feature = "nccl"))]
202use candle_core::cuda::cudarc;
203
204#[cfg(all(feature = "cuda", feature = "nccl"))]
206mod nccl {
207 use candle_core::{cuda::cudarc, Device, Result};
208
209 #[derive(Debug)]
210 pub struct NcclComm {
211 comm: cudarc::nccl::Comm,
212 }
213
214 impl NcclComm {
215 pub fn from_device(
216 id: super::Id,
217 dev: &Device,
218 rank: usize,
219 world_size: usize,
220 ) -> Result<Self> {
221 if !super::use_nccl() {
222 candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
223 }
224 if !world_size.is_power_of_two() {
225 candle_core::bail!(
226 "NCCL backend requires world_size to be a power of 2, got {}",
227 world_size
228 );
229 }
230 let stream = dev.as_cuda_device()?.cuda_stream();
231 let device_ordinal = stream.context().ordinal();
232 if rank != device_ordinal {
233 candle_core::bail!(
234 "NCCL rank {} must match device ordinal, but device ordinal is {}. \
235 Ensure GPUs are visible in the correct order (check CUDA_VISIBLE_DEVICES).",
236 rank,
237 device_ordinal
238 );
239 }
240 let nccl_id = match id {
241 super::Id::Nccl(id) => id,
242 _ => candle_core::bail!("Expected NCCL Id variant for NCCL Comm initialization"),
243 };
244 tracing::info!(
245 "Initializing NCCL communicator: rank={}, world_size={}, device={}",
246 rank,
247 world_size,
248 device_ordinal
249 );
250 let comm = cudarc::nccl::Comm::from_rank(stream, rank, world_size, nccl_id)
251 .map_err(|e| candle_core::Error::debug(e.0))?;
252 Ok(Self { comm })
253 }
254
255 pub fn rank(&self) -> usize {
256 self.comm.rank()
257 }
258
259 pub fn world_size(&self) -> usize {
260 self.comm.world_size()
261 }
262
263 pub fn inner(&self) -> &cudarc::nccl::Comm {
264 &self.comm
265 }
266 }
267
268 unsafe impl Sync for NcclComm {}
270 unsafe impl Send for NcclComm {}
271}
272
273#[cfg(feature = "ring")]
275mod ring {
276 use super::RingConfig;
277 use candle_core::{Device, Result};
278
279 #[derive(Debug)]
280 pub struct RingComm {
281 config: RingConfig,
282 }
283
284 impl RingComm {
285 pub fn from_device(
286 _id: super::Id,
287 _dev: &Device,
288 _rank: usize,
289 _world_size: usize,
290 ) -> Result<Self> {
291 let config = RingConfig::load();
292 if config.world_size < 2 {
294 candle_core::bail!(
295 "Ring backend requires world_size >= 2, got {}",
296 config.world_size
297 );
298 }
299 if config.rank >= config.world_size {
300 candle_core::bail!(
301 "Ring backend invalid config: rank {} >= world_size {}",
302 config.rank,
303 config.world_size
304 );
305 }
306 if !config.world_size.is_power_of_two() {
307 candle_core::bail!(
308 "Ring backend requires world_size to be a power of 2, got {}",
309 config.world_size
310 );
311 }
312 Ok(Self { config })
313 }
314
315 pub fn rank(&self) -> usize {
316 self.config.rank
317 }
318
319 pub fn world_size(&self) -> usize {
320 self.config.world_size
321 }
322
323 pub fn config(&self) -> &RingConfig {
324 &self.config
325 }
326 }
327}
328
329mod dummy {
331 use candle_core::{Device, Result};
332
333 #[derive(Debug)]
334 pub struct DummyComm;
335
336 impl DummyComm {
337 pub fn from_device(
338 _id: super::Id,
339 _dev: &Device,
340 _rank: usize,
341 _world_size: usize,
342 ) -> Result<Self> {
343 Ok(Self)
344 }
345
346 pub fn rank(&self) -> usize {
347 0
348 }
349
350 pub fn world_size(&self) -> usize {
351 1
352 }
353 }
354}
355
356#[derive(Clone, Debug)]
358pub struct SumAllReduce {
359 #[cfg(all(feature = "cuda", feature = "nccl"))]
360 nccl: Option<nccl_ops::SumAllReduce>,
361 #[cfg(feature = "ring")]
362 ring: Option<ring_ops::SumAllReduce>,
363 dummy: Option<dummy_ops::SumAllReduce>,
364}
365
366impl SumAllReduce {
367 pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
368 match &**comm {
369 #[cfg(all(feature = "cuda", feature = "nccl"))]
370 Comm::Nccl(_) => Self {
371 #[cfg(all(feature = "cuda", feature = "nccl"))]
372 nccl: Some(nccl_ops::SumAllReduce::new(comm)),
373 #[cfg(feature = "ring")]
374 ring: None,
375 dummy: None,
376 },
377 #[cfg(feature = "ring")]
378 Comm::Ring(_) => Self {
379 #[cfg(all(feature = "cuda", feature = "nccl"))]
380 nccl: None,
381 #[cfg(feature = "ring")]
382 ring: Some(ring_ops::SumAllReduce::new(comm)),
383 dummy: None,
384 },
385 Comm::Dummy(_) => Self {
386 #[cfg(all(feature = "cuda", feature = "nccl"))]
387 nccl: None,
388 #[cfg(feature = "ring")]
389 ring: None,
390 dummy: Some(dummy_ops::SumAllReduce::new(comm)),
391 },
392 }
393 }
394
395 pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
396 #[cfg(all(feature = "cuda", feature = "nccl"))]
397 if let Some(ref nccl) = self.nccl {
398 return nccl.sum_all_reduce(xs);
399 }
400 #[cfg(feature = "ring")]
401 if let Some(ref ring) = self.ring {
402 return ring.sum_all_reduce(xs);
403 }
404 if let Some(ref dummy) = self.dummy {
405 return dummy.sum_all_reduce(xs);
406 }
407 candle_core::bail!("No valid SumAllReduce implementation available")
408 }
409}
410
411#[derive(Clone, Debug)]
412pub struct AllGather {
413 #[cfg(all(feature = "cuda", feature = "nccl"))]
414 nccl: Option<nccl_ops::AllGather>,
415 #[cfg(feature = "ring")]
416 ring: Option<ring_ops::AllGather>,
417 dummy: Option<dummy_ops::AllGather>,
418}
419
420impl AllGather {
421 pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
422 match &**comm {
423 #[cfg(all(feature = "cuda", feature = "nccl"))]
424 Comm::Nccl(_) => Self {
425 #[cfg(all(feature = "cuda", feature = "nccl"))]
426 nccl: Some(nccl_ops::AllGather::new(comm, dim)),
427 #[cfg(feature = "ring")]
428 ring: None,
429 dummy: None,
430 },
431 #[cfg(feature = "ring")]
432 Comm::Ring(_) => Self {
433 #[cfg(all(feature = "cuda", feature = "nccl"))]
434 nccl: None,
435 #[cfg(feature = "ring")]
436 ring: Some(ring_ops::AllGather::new(comm, dim)),
437 dummy: None,
438 },
439 Comm::Dummy(_) => Self {
440 #[cfg(all(feature = "cuda", feature = "nccl"))]
441 nccl: None,
442 #[cfg(feature = "ring")]
443 ring: None,
444 dummy: Some(dummy_ops::AllGather::new(comm, dim)),
445 },
446 }
447 }
448
449 pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
450 #[cfg(all(feature = "cuda", feature = "nccl"))]
451 if let Some(ref nccl) = self.nccl {
452 return nccl.all_gather(xs);
453 }
454 #[cfg(feature = "ring")]
455 if let Some(ref ring) = self.ring {
456 return ring.all_gather(xs);
457 }
458 if let Some(ref dummy) = self.dummy {
459 return dummy.all_gather(xs);
460 }
461 candle_core::bail!("No valid AllGather implementation available")
462 }
463}
464
465#[cfg(all(feature = "cuda", feature = "nccl"))]
467mod nccl_ops {
468 use std::{fmt::Debug, sync::Arc};
469
470 use candle_core::{
471 backend::BackendStorage, cuda::cudarc, CpuStorage, CustomOp1, DType, Layout, Result, Shape,
472 Tensor,
473 };
474
475 #[derive(Clone, Debug)]
476 pub struct SumAllReduce {
477 comm: Arc<super::Comm>,
478 }
479
480 impl SumAllReduce {
481 pub fn new(comm: &Arc<super::Comm>) -> Self {
482 Self { comm: comm.clone() }
483 }
484 }
485
486 impl SumAllReduce {
487 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
488 xs.apply_op1_no_bwd(self)
489 }
490 }
491
492 impl CustomOp1 for SumAllReduce {
493 fn name(&self) -> &'static str {
494 "SumAllReduce"
495 }
496
497 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
498 candle_core::bail!("SumAllReduce is never used on cpu")
499 }
500
501 fn cuda_fwd(
502 &self,
503 s: &candle_core::CudaStorage,
504 l: &Layout,
505 ) -> Result<(candle_core::CudaStorage, Shape)> {
506 use cudarc::nccl::ReduceOp;
507 use half::{bf16, f16};
508
509 let elem_count = l.shape().elem_count();
510 let dev = s.device().clone();
511
512 match self.comm.as_ref() {
513 super::Comm::Nccl(nccl_comm) => {
514 let dst = match s.dtype() {
515 DType::BF16 => {
516 let s = s.as_cuda_slice::<bf16>()?;
517 let s = match l.contiguous_offsets() {
518 Some((0, l)) if l == s.len() => s,
519 Some(_) | None => candle_core::bail!("input has to be contiguous"),
520 };
521 if elem_count == 0 {
522 candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
523 }
524 let device_ordinal = dev.cuda_stream().context().ordinal();
525 if device_ordinal != nccl_comm.rank() {
526 candle_core::bail!(
527 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
528 Ensure each rank uses the correct GPU.",
529 device_ordinal,
530 nccl_comm.rank()
531 );
532 }
533 tracing::debug!(
534 "NCCL all_reduce (BF16): rank={}, device={}, elem_count={}",
535 nccl_comm.rank(),
536 device_ordinal,
537 elem_count
538 );
539 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
540 nccl_comm
541 .inner()
542 .all_reduce(s, &mut dst, &ReduceOp::Sum)
543 .map_err(candle_core::Error::debug)?;
544 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
545 }
546 DType::F16 => {
547 let s = s.as_cuda_slice::<f16>()?;
548 let s = match l.contiguous_offsets() {
549 Some((0, l)) if l == s.len() => s,
550 Some(_) | None => candle_core::bail!("input has to be contiguous"),
551 };
552 if elem_count == 0 {
553 candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
554 }
555 let device_ordinal = dev.cuda_stream().context().ordinal();
556 if device_ordinal != nccl_comm.rank() {
557 candle_core::bail!(
558 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
559 Ensure each rank uses the correct GPU.",
560 device_ordinal,
561 nccl_comm.rank()
562 );
563 }
564 tracing::debug!(
565 "NCCL all_reduce (F16): rank={}, device={}, elem_count={}",
566 nccl_comm.rank(),
567 device_ordinal,
568 elem_count
569 );
570 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
571 nccl_comm
572 .inner()
573 .all_reduce(s, &mut dst, &ReduceOp::Sum)
574 .map_err(candle_core::Error::debug)?;
575 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
576 }
577 DType::F32 => {
578 let s = s.as_cuda_slice::<f32>()?;
579 let s = match l.contiguous_offsets() {
580 Some((0, l)) if l == s.len() => s,
581 Some(_) | None => candle_core::bail!("input has to be contiguous"),
582 };
583 if elem_count == 0 {
584 candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
585 }
586 let device_ordinal = dev.cuda_stream().context().ordinal();
587 if device_ordinal != nccl_comm.rank() {
588 candle_core::bail!(
589 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
590 Ensure each rank uses the correct GPU.",
591 device_ordinal,
592 nccl_comm.rank()
593 );
594 }
595 tracing::debug!(
596 "NCCL all_reduce (F32): rank={}, device={}, elem_count={}",
597 nccl_comm.rank(),
598 device_ordinal,
599 elem_count
600 );
601 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
602 nccl_comm
603 .inner()
604 .all_reduce(s, &mut dst, &ReduceOp::Sum)
605 .map_err(candle_core::Error::debug)?;
606 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
607 }
608 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
609 };
610 Ok((dst, l.shape().clone()))
611 }
612 _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
613 }
614 }
615 }
616
617 #[derive(Clone, Debug)]
618 pub struct AllGather {
619 comm: Arc<super::Comm>,
620 dim: usize,
621 }
622
623 impl AllGather {
624 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
625 Self {
626 comm: comm.clone(),
627 dim,
628 }
629 }
630 }
631
632 impl AllGather {
633 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
634 xs.apply_op1_no_bwd(self)
635 }
636 }
637
638 impl CustomOp1 for AllGather {
639 fn name(&self) -> &'static str {
640 "AllGather"
641 }
642
643 fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
644 candle_core::bail!("AllGather is never used on cpu")
645 }
646
647 fn cuda_fwd(
648 &self,
649 s: &candle_core::CudaStorage,
650 l: &Layout,
651 ) -> Result<(candle_core::CudaStorage, Shape)> {
652 use half::{bf16, f16};
653
654 let mut out_shape = l.shape().dims().to_vec();
655 out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
656 let out_shape = Shape::from(out_shape);
657
658 let elem_count = out_shape.elem_count();
659 let dev = s.device().clone();
660
661 match self.comm.as_ref() {
662 super::Comm::Nccl(nccl_comm) => {
663 let dst = match s.dtype() {
664 DType::BF16 => {
665 let s = s.as_cuda_slice::<bf16>()?;
666 let s = match l.contiguous_offsets() {
667 Some((0, l)) if l == s.len() => s,
668 Some(_) | None => candle_core::bail!("input has to be contiguous"),
669 };
670 if elem_count == 0 {
671 candle_core::bail!("NCCL all_gather: elem_count must be > 0");
672 }
673 let device_ordinal = dev.cuda_stream().context().ordinal();
674 if device_ordinal != nccl_comm.rank() {
675 candle_core::bail!(
676 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
677 Ensure each rank uses the correct GPU.",
678 device_ordinal,
679 nccl_comm.rank()
680 );
681 }
682 tracing::debug!(
683 "NCCL all_gather (BF16): rank={}, device={}, elem_count={}",
684 nccl_comm.rank(),
685 device_ordinal,
686 elem_count
687 );
688 let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
689 nccl_comm
690 .inner()
691 .all_gather(s, &mut dst)
692 .map_err(candle_core::Error::debug)?;
693 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
694 }
695 DType::F16 => {
696 let s = s.as_cuda_slice::<f16>()?;
697 let s = match l.contiguous_offsets() {
698 Some((0, l)) if l == s.len() => s,
699 Some(_) | None => candle_core::bail!("input has to be contiguous"),
700 };
701 if elem_count == 0 {
702 candle_core::bail!("NCCL all_gather: elem_count must be > 0");
703 }
704 let device_ordinal = dev.cuda_stream().context().ordinal();
705 if device_ordinal != nccl_comm.rank() {
706 candle_core::bail!(
707 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
708 Ensure each rank uses the correct GPU.",
709 device_ordinal,
710 nccl_comm.rank()
711 );
712 }
713 tracing::debug!(
714 "NCCL all_gather (F16): rank={}, device={}, elem_count={}",
715 nccl_comm.rank(),
716 device_ordinal,
717 elem_count
718 );
719 let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
720 nccl_comm
721 .inner()
722 .all_gather(s, &mut dst)
723 .map_err(candle_core::Error::debug)?;
724 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
725 }
726 DType::F32 => {
727 let s = s.as_cuda_slice::<f32>()?;
728 let s = match l.contiguous_offsets() {
729 Some((0, l)) if l == s.len() => s,
730 Some(_) | None => candle_core::bail!("input has to be contiguous"),
731 };
732 if elem_count == 0 {
733 candle_core::bail!("NCCL all_gather: elem_count must be > 0");
734 }
735 let device_ordinal = dev.cuda_stream().context().ordinal();
736 if device_ordinal != nccl_comm.rank() {
737 candle_core::bail!(
738 "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
739 Ensure each rank uses the correct GPU.",
740 device_ordinal,
741 nccl_comm.rank()
742 );
743 }
744 tracing::debug!(
745 "NCCL all_gather (F32): rank={}, device={}, elem_count={}",
746 nccl_comm.rank(),
747 device_ordinal,
748 elem_count
749 );
750 let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
751 nccl_comm
752 .inner()
753 .all_gather(s, &mut dst)
754 .map_err(candle_core::Error::debug)?;
755 candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
756 }
757 dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
758 };
759 Ok((dst, out_shape))
760 }
761 _ => candle_core::bail!("AllGather requires NCCL backend"),
762 }
763 }
764 }
765}
766
767#[cfg(feature = "ring")]
769mod ring_ops {
770 use std::{
771 collections::HashMap,
772 sync::{Arc, Mutex, OnceLock},
773 time::{Duration, Instant},
774 };
775
776 use std::io::{Read, Write};
777 use std::net::{TcpListener, TcpStream};
778
779 type SharedTcpStream = Arc<Mutex<TcpStream>>;
781 type LeftRight = (SharedTcpStream, SharedTcpStream);
782
783 use candle_core::{
784 backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
785 };
786
787 use super::RingConfig;
788
789 static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
791
792 fn get_ring_streams(config: &RingConfig) -> LeftRight {
793 LEFT_RIGHT_STREAMS
794 .get_or_init(|| {
795 let cur_port = config.port;
796
797 let right_ip = config.right_ip();
798 let right_port = config.right_port;
799
800 let left_listener =
801 TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
802
803 let start = Instant::now();
804 let right = loop {
806 match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
807 Ok(s) => break s,
808 Err(_) if start.elapsed() > Duration::from_secs(10) => {
809 panic!("Failed to connect to right node due to 10-second timeout");
810 }
811 Err(_) => continue,
812 }
813 };
814
815 let (left, _) = left_listener.accept().expect("accept left neighbour");
817
818 left.set_nodelay(true).unwrap();
819 left.set_nonblocking(false).unwrap();
820 right.set_nodelay(true).unwrap();
821 right.set_nonblocking(false).unwrap();
822
823 (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
824 })
825 .clone()
826 }
827
828 #[derive(Clone, Debug)]
829 pub struct SumAllReduce {
830 left: SharedTcpStream,
831 right: SharedTcpStream,
832 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
833 }
834
835 impl SumAllReduce {
836 pub fn new(comm: &Arc<super::Comm>) -> Self {
837 match &**comm {
838 super::Comm::Ring(ring_comm) => {
839 let (left, right) = get_ring_streams(ring_comm.config());
840 Self {
841 left,
842 right,
843 buffers: Arc::new(Mutex::new(HashMap::new())),
844 }
845 }
846 _ => panic!("SumAllReduce requires Ring backend"),
847 }
848 }
849
850 fn run<T: WithDType + Copy>(
851 &self,
852 x: &[T],
853 dims: &[usize],
854 device: &Device,
855 ) -> Result<Tensor> {
856 let nbytes = x.len() * std::mem::size_of_val(x);
857
858 let right = self.right.clone();
861 let left = self.left.clone();
862
863 let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
865
866 let mut buffers_guard = self.buffers.lock().map_err(|e| {
868 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
869 })?;
870 let recv_buf = buffers_guard
871 .entry(nbytes)
872 .or_insert_with(|| vec![0u8; nbytes]);
873
874 let mut right_guard = right.lock().map_err(|e| {
876 candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
877 })?;
878 let mut left_guard = left.lock().map_err(|e| {
879 candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
880 })?;
881
882 if nbytes <= 8 * 1024 {
887 right_guard
889 .write_all(data_bytes)
890 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
891
892 left_guard
893 .read_exact(recv_buf)
894 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
895 } else {
896 const CHUNK_SIZE: usize = 64 * 1024; let mut offset = 0;
899
900 while offset < nbytes {
901 let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
902
903 right_guard
905 .write_all(&data_bytes[offset..offset + len])
906 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
907
908 left_guard
910 .read_exact(&mut recv_buf[offset..offset + len])
911 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
912
913 offset += len;
914 }
915 }
916
917 drop(left_guard);
918 drop(right_guard);
919
920 let received: &[T] =
923 unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
924
925 Tensor::from_slice(received, dims, device)
926 }
927
928 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
929 let storage = xs.storage_and_layout().0;
930 let cpu_storage = match &*storage {
931 Storage::Cpu(storage) => storage,
932 Storage::Cuda(storage) => &storage.to_cpu_storage()?,
933 Storage::Metal(storage) => &storage.to_cpu_storage()?,
934 };
935
936 let delta = match cpu_storage {
937 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
938 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
939 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
940 _ => candle_core::bail!("Unsupported dtype for ring backend"),
941 };
942
943 xs + delta
944 }
945 }
946
947 #[derive(Clone, Debug)]
948 pub struct AllGather {
949 left: SharedTcpStream,
950 right: SharedTcpStream,
951 buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
952 dim: usize,
953 world_size: usize,
954 rank: usize,
955 }
956
957 impl AllGather {
958 pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
959 match &**comm {
960 super::Comm::Ring(ring_comm) => {
961 let (left, right) = get_ring_streams(ring_comm.config());
962 Self {
963 left,
964 right,
965 buffers: Arc::new(Mutex::new(HashMap::new())),
966 dim,
967 world_size: ring_comm.world_size(),
968 rank: ring_comm.rank(),
969 }
970 }
971 _ => panic!("AllGather requires Ring backend"),
972 }
973 }
974
975 fn run<T: WithDType + Copy + Default>(
976 &self,
977 x: &[T],
978 dims: &[usize],
979 device: &Device,
980 ) -> Result<Tensor> {
981 if self.dim >= dims.len() {
983 candle_core::bail!(
984 "AllGather: invalid dimension {} for tensor of rank {}",
985 self.dim,
986 dims.len()
987 );
988 }
989 let elem_cnt = x.len();
990 let nbytes = elem_cnt * std::mem::size_of_val(x);
991
992 let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
994
995 let start = self.rank * elem_cnt;
997 out[start..start + elem_cnt].copy_from_slice(x);
998
999 let right = self.right.clone();
1000 let left = self.left.clone();
1001 let mut send_piece: &[T] = x;
1002
1003 for step in 0..(self.world_size - 1) {
1004 let bytes =
1006 unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
1007 {
1008 let mut rg = right.lock().map_err(|e| {
1009 candle_core::Error::msg(format!(
1010 "Failed to lock right stream mutex: {:?}",
1011 e
1012 ))
1013 })?;
1014 rg.write_all(bytes)
1015 .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
1016 }
1017
1018 let mut bg = self.buffers.lock().map_err(|e| {
1020 candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
1021 })?;
1022 let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
1023 {
1024 let mut lg = left.lock().map_err(|e| {
1025 candle_core::Error::msg(format!(
1026 "Failed to lock left stream mutex: {:?}",
1027 e
1028 ))
1029 })?;
1030 lg.read_exact(buf)
1031 .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
1032 }
1033 let recv_piece: &[T] =
1034 unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
1035
1036 let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
1038 let dst = src_rank * elem_cnt;
1039 out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
1040
1041 send_piece = recv_piece;
1043 }
1044
1045 let mut out_dims = dims.to_vec();
1046 out_dims[self.dim] *= self.world_size;
1047 Tensor::from_slice(&out, out_dims, device)
1048 }
1049
1050 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
1051 let storage = xs.storage_and_layout().0;
1052 let cpu_storage = match &*storage {
1053 Storage::Cpu(s) => s,
1054 Storage::Cuda(s) => &s.to_cpu_storage()?,
1055 Storage::Metal(s) => &s.to_cpu_storage()?,
1056 };
1057
1058 match cpu_storage {
1059 CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1060 CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1061 CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1062 _ => candle_core::bail!("Unsupported dtype for ring backend"),
1063 }
1064 }
1065 }
1066}
1067
1068mod dummy_ops {
1070 use candle_core::{Result, Tensor};
1071 use std::sync::Arc;
1072
1073 #[derive(Clone, Debug)]
1074 pub struct SumAllReduce;
1075
1076 impl SumAllReduce {
1077 pub fn new(_comm: &Arc<super::Comm>) -> Self {
1078 Self
1079 }
1080
1081 pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
1082 Ok(xs.clone())
1083 }
1084 }
1085
1086 #[derive(Clone, Debug)]
1087 pub struct AllGather;
1088
1089 impl AllGather {
1090 pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
1091 Self
1092 }
1093
1094 pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
1095 Ok(xs.clone())
1096 }
1097 }
1098}