Skip to main content

mistralrs_quant/distributed/
mod.rs

1use std::{fmt::Debug, fs::File, sync::Barrier};
2
3use candle_core::Result;
4pub mod layers;
5pub mod socket;
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Deserialize, Serialize)]
10pub struct RingConfig {
11    master_ip: Option<String>,
12    pub master_port: u16,
13    pub port: u16,
14    pub right_port: u16,
15    right_ip: Option<String>,
16    pub rank: usize,
17    pub world_size: usize,
18}
19
20impl RingConfig {
21    /// Loads the ring backend config from a path at `RING_CONFIG`
22    pub fn load() -> Self {
23        let config_json = std::env::var("RING_CONFIG").expect("RING_CONFIG must be set");
24        let config: RingConfig = serde_json::from_reader(
25            &File::open(config_json).expect("Could not access Ring config JSON"),
26        )
27        .expect("Invalid JSON config");
28
29        if config.master_ip.is_none() && !config.is_master_rank() {
30            panic!("Invalid Ring config. Non-master ranks (rank != 0) must specify master_ip.");
31        }
32        config
33    }
34
35    pub fn is_master_rank(&self) -> bool {
36        self.rank == 0
37    }
38
39    pub fn master_ip(&self) -> String {
40        self.master_ip.clone().unwrap_or("0.0.0.0".to_string())
41    }
42
43    pub fn right_ip(&self) -> String {
44        self.right_ip.clone().unwrap_or("0.0.0.0".to_string())
45    }
46}
47
48pub trait BarrierLike: Debug + Send + Sync {
49    fn wait(&self) -> Result<()>;
50}
51
52impl BarrierLike for Barrier {
53    fn wait(&self) -> Result<()> {
54        Barrier::wait(self);
55        Ok(())
56    }
57}
58
59pub fn get_global_tp_size_from_devices() -> Result<usize> {
60    #[cfg(all(feature = "cuda", feature = "ring"))]
61    {
62        let config = RingConfig::load();
63        Ok(config.world_size)
64    }
65    #[cfg(all(not(feature = "cuda"), feature = "ring"))]
66    {
67        let config = RingConfig::load();
68        Ok(config.world_size)
69    }
70
71    #[cfg(all(feature = "cuda", feature = "nccl"))]
72    {
73        // In case we have manual set of TP size
74        if let Ok(x) = std::env::var("MISTRALRS_MN_LOCAL_WORLD_SIZE") {
75            use std::str::FromStr;
76            Ok(usize::from_str(&x).expect("Not a number for MISTRALRS_MN_LOCAL_WORLD_SIZE!"))
77        } else {
78            use candle_core::cuda::WrapErr;
79            candle_core::cuda::cudarc::driver::result::device::get_count()
80                .w()
81                .map(|x| x as usize)
82        }
83    }
84
85    #[cfg(all(not(feature = "ring"), not(feature = "nccl")))]
86    Ok(1)
87}
88
89pub fn use_nccl() -> bool {
90    (std::env::var("MISTRALRS_NO_NCCL").is_err()
91        || std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
92        && (cfg!(feature = "nccl") && cfg!(feature = "cuda"))
93}
94
95pub fn use_ring() -> bool {
96    cfg!(feature = "ring") && std::env::var("RING_CONFIG").is_ok()
97}
98
99// Unified Comm enum
100#[derive(Debug)]
101pub enum Comm {
102    #[cfg(all(feature = "cuda", feature = "nccl"))]
103    Nccl(nccl::NcclComm),
104    #[cfg(feature = "ring")]
105    Ring(ring::RingComm),
106    Dummy(dummy::DummyComm),
107}
108
109impl Comm {
110    pub fn from_device(
111        id: Id,
112        dev: &candle_core::Device,
113        rank: usize,
114        world_size: usize,
115    ) -> Result<Self> {
116        #[cfg(all(feature = "cuda", feature = "nccl"))]
117        if use_nccl() {
118            return Ok(Self::Nccl(nccl::NcclComm::from_device(
119                id, dev, rank, world_size,
120            )?));
121        }
122
123        #[cfg(feature = "ring")]
124        {
125            return Ok(Self::Ring(ring::RingComm::from_device(
126                id, dev, rank, world_size,
127            )?));
128        }
129
130        #[allow(unreachable_code)]
131        Ok(Self::Dummy(dummy::DummyComm::from_device(
132            id, dev, rank, world_size,
133        )?))
134    }
135
136    pub fn rank(&self) -> usize {
137        match self {
138            #[cfg(all(feature = "cuda", feature = "nccl"))]
139            Self::Nccl(comm) => comm.rank(),
140            #[cfg(feature = "ring")]
141            Self::Ring(comm) => comm.rank(),
142            Self::Dummy(comm) => comm.rank(),
143        }
144    }
145
146    pub fn world_size(&self) -> usize {
147        match self {
148            #[cfg(all(feature = "cuda", feature = "nccl"))]
149            Self::Nccl(comm) => comm.world_size(),
150            #[cfg(feature = "ring")]
151            Self::Ring(comm) => comm.world_size(),
152            Self::Dummy(comm) => comm.world_size(),
153        }
154    }
155}
156
157// Unified Id enum
158#[derive(Debug, Clone, Copy)]
159pub enum Id {
160    #[cfg(all(feature = "cuda", feature = "nccl"))]
161    Nccl(cudarc::nccl::Id),
162    Dummy,
163}
164
165impl Id {
166    pub fn new() -> Self {
167        #[cfg(all(feature = "cuda", feature = "nccl"))]
168        if use_nccl() {
169            let id = cudarc::nccl::Id::new().expect("Failed to create `Id`.");
170            return Self::Nccl(id);
171        }
172        Self::Dummy
173    }
174
175    pub fn uninit(_internal: [::core::ffi::c_char; 128usize]) -> Self {
176        #[cfg(all(feature = "cuda", feature = "nccl"))]
177        if use_nccl() {
178            return Self::Nccl(cudarc::nccl::Id::uninit(_internal));
179        }
180        Self::Dummy
181    }
182
183    pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
184        match self {
185            #[cfg(all(feature = "cuda", feature = "nccl"))]
186            Self::Nccl(id) => id.internal(),
187            Self::Dummy => {
188                static ZEROED_ID: [::core::ffi::c_char; 128] = [0; 128];
189                &ZEROED_ID
190            }
191        }
192    }
193}
194
195impl Default for Id {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201#[cfg(all(feature = "cuda", feature = "nccl"))]
202use candle_core::cuda::cudarc;
203
204// NCCL backend implementation
205#[cfg(all(feature = "cuda", feature = "nccl"))]
206mod nccl {
207    use candle_core::{cuda::cudarc, Device, Result};
208
209    #[derive(Debug)]
210    pub struct NcclComm {
211        comm: cudarc::nccl::Comm,
212    }
213
214    impl NcclComm {
215        pub fn from_device(
216            id: super::Id,
217            dev: &Device,
218            rank: usize,
219            world_size: usize,
220        ) -> Result<Self> {
221            if !super::use_nccl() {
222                candle_core::bail!("NCCL is disabled but NCCL Comm was requested");
223            }
224            if !world_size.is_power_of_two() {
225                candle_core::bail!(
226                    "NCCL backend requires world_size to be a power of 2, got {}",
227                    world_size
228                );
229            }
230            let stream = dev.as_cuda_device()?.cuda_stream();
231            let device_ordinal = stream.context().ordinal();
232            if rank != device_ordinal {
233                candle_core::bail!(
234                    "NCCL rank {} must match device ordinal, but device ordinal is {}. \
235                     Ensure GPUs are visible in the correct order (check CUDA_VISIBLE_DEVICES).",
236                    rank,
237                    device_ordinal
238                );
239            }
240            let nccl_id = match id {
241                super::Id::Nccl(id) => id,
242                _ => candle_core::bail!("Expected NCCL Id variant for NCCL Comm initialization"),
243            };
244            tracing::info!(
245                "Initializing NCCL communicator: rank={}, world_size={}, device={}",
246                rank,
247                world_size,
248                device_ordinal
249            );
250            let comm = cudarc::nccl::Comm::from_rank(stream, rank, world_size, nccl_id)
251                .map_err(|e| candle_core::Error::debug(e.0))?;
252            Ok(Self { comm })
253        }
254
255        pub fn rank(&self) -> usize {
256            self.comm.rank()
257        }
258
259        pub fn world_size(&self) -> usize {
260            self.comm.world_size()
261        }
262
263        pub fn inner(&self) -> &cudarc::nccl::Comm {
264            &self.comm
265        }
266    }
267
268    /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
269    unsafe impl Sync for NcclComm {}
270    unsafe impl Send for NcclComm {}
271}
272
273// Ring backend implementation
274#[cfg(feature = "ring")]
275mod ring {
276    use super::RingConfig;
277    use candle_core::{Device, Result};
278
279    #[derive(Debug)]
280    pub struct RingComm {
281        config: RingConfig,
282    }
283
284    impl RingComm {
285        pub fn from_device(
286            _id: super::Id,
287            _dev: &Device,
288            _rank: usize,
289            _world_size: usize,
290        ) -> Result<Self> {
291            let config = RingConfig::load();
292            // Validate ring configuration
293            if config.world_size < 2 {
294                candle_core::bail!(
295                    "Ring backend requires world_size >= 2, got {}",
296                    config.world_size
297                );
298            }
299            if config.rank >= config.world_size {
300                candle_core::bail!(
301                    "Ring backend invalid config: rank {} >= world_size {}",
302                    config.rank,
303                    config.world_size
304                );
305            }
306            if !config.world_size.is_power_of_two() {
307                candle_core::bail!(
308                    "Ring backend requires world_size to be a power of 2, got {}",
309                    config.world_size
310                );
311            }
312            Ok(Self { config })
313        }
314
315        pub fn rank(&self) -> usize {
316            self.config.rank
317        }
318
319        pub fn world_size(&self) -> usize {
320            self.config.world_size
321        }
322
323        pub fn config(&self) -> &RingConfig {
324            &self.config
325        }
326    }
327}
328
329// Dummy backend implementation
330mod dummy {
331    use candle_core::{Device, Result};
332
333    #[derive(Debug)]
334    pub struct DummyComm;
335
336    impl DummyComm {
337        pub fn from_device(
338            _id: super::Id,
339            _dev: &Device,
340            _rank: usize,
341            _world_size: usize,
342        ) -> Result<Self> {
343            Ok(Self)
344        }
345
346        pub fn rank(&self) -> usize {
347            0
348        }
349
350        pub fn world_size(&self) -> usize {
351            1
352        }
353    }
354}
355
356// Unified operations that work with the Comm enum
357#[derive(Clone, Debug)]
358pub struct SumAllReduce {
359    #[cfg(all(feature = "cuda", feature = "nccl"))]
360    nccl: Option<nccl_ops::SumAllReduce>,
361    #[cfg(feature = "ring")]
362    ring: Option<ring_ops::SumAllReduce>,
363    dummy: Option<dummy_ops::SumAllReduce>,
364}
365
366impl SumAllReduce {
367    pub fn new(comm: &std::sync::Arc<Comm>) -> Self {
368        match &**comm {
369            #[cfg(all(feature = "cuda", feature = "nccl"))]
370            Comm::Nccl(_) => Self {
371                #[cfg(all(feature = "cuda", feature = "nccl"))]
372                nccl: Some(nccl_ops::SumAllReduce::new(comm)),
373                #[cfg(feature = "ring")]
374                ring: None,
375                dummy: None,
376            },
377            #[cfg(feature = "ring")]
378            Comm::Ring(_) => Self {
379                #[cfg(all(feature = "cuda", feature = "nccl"))]
380                nccl: None,
381                #[cfg(feature = "ring")]
382                ring: Some(ring_ops::SumAllReduce::new(comm)),
383                dummy: None,
384            },
385            Comm::Dummy(_) => Self {
386                #[cfg(all(feature = "cuda", feature = "nccl"))]
387                nccl: None,
388                #[cfg(feature = "ring")]
389                ring: None,
390                dummy: Some(dummy_ops::SumAllReduce::new(comm)),
391            },
392        }
393    }
394
395    pub fn sum_all_reduce(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
396        #[cfg(all(feature = "cuda", feature = "nccl"))]
397        if let Some(ref nccl) = self.nccl {
398            return nccl.sum_all_reduce(xs);
399        }
400        #[cfg(feature = "ring")]
401        if let Some(ref ring) = self.ring {
402            return ring.sum_all_reduce(xs);
403        }
404        if let Some(ref dummy) = self.dummy {
405            return dummy.sum_all_reduce(xs);
406        }
407        candle_core::bail!("No valid SumAllReduce implementation available")
408    }
409}
410
411#[derive(Clone, Debug)]
412pub struct AllGather {
413    #[cfg(all(feature = "cuda", feature = "nccl"))]
414    nccl: Option<nccl_ops::AllGather>,
415    #[cfg(feature = "ring")]
416    ring: Option<ring_ops::AllGather>,
417    dummy: Option<dummy_ops::AllGather>,
418}
419
420impl AllGather {
421    pub fn new(comm: &std::sync::Arc<Comm>, dim: usize) -> Self {
422        match &**comm {
423            #[cfg(all(feature = "cuda", feature = "nccl"))]
424            Comm::Nccl(_) => Self {
425                #[cfg(all(feature = "cuda", feature = "nccl"))]
426                nccl: Some(nccl_ops::AllGather::new(comm, dim)),
427                #[cfg(feature = "ring")]
428                ring: None,
429                dummy: None,
430            },
431            #[cfg(feature = "ring")]
432            Comm::Ring(_) => Self {
433                #[cfg(all(feature = "cuda", feature = "nccl"))]
434                nccl: None,
435                #[cfg(feature = "ring")]
436                ring: Some(ring_ops::AllGather::new(comm, dim)),
437                dummy: None,
438            },
439            Comm::Dummy(_) => Self {
440                #[cfg(all(feature = "cuda", feature = "nccl"))]
441                nccl: None,
442                #[cfg(feature = "ring")]
443                ring: None,
444                dummy: Some(dummy_ops::AllGather::new(comm, dim)),
445            },
446        }
447    }
448
449    pub fn all_gather(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
450        #[cfg(all(feature = "cuda", feature = "nccl"))]
451        if let Some(ref nccl) = self.nccl {
452            return nccl.all_gather(xs);
453        }
454        #[cfg(feature = "ring")]
455        if let Some(ref ring) = self.ring {
456            return ring.all_gather(xs);
457        }
458        if let Some(ref dummy) = self.dummy {
459            return dummy.all_gather(xs);
460        }
461        candle_core::bail!("No valid AllGather implementation available")
462    }
463}
464
465// Implementation modules for each backend
466#[cfg(all(feature = "cuda", feature = "nccl"))]
467mod nccl_ops {
468    use std::{fmt::Debug, sync::Arc};
469
470    use candle_core::{
471        backend::BackendStorage, cuda::cudarc, CpuStorage, CustomOp1, DType, Layout, Result, Shape,
472        Tensor,
473    };
474
475    #[derive(Clone, Debug)]
476    pub struct SumAllReduce {
477        comm: Arc<super::Comm>,
478    }
479
480    impl SumAllReduce {
481        pub fn new(comm: &Arc<super::Comm>) -> Self {
482            Self { comm: comm.clone() }
483        }
484    }
485
486    impl SumAllReduce {
487        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
488            xs.apply_op1_no_bwd(self)
489        }
490    }
491
492    impl CustomOp1 for SumAllReduce {
493        fn name(&self) -> &'static str {
494            "SumAllReduce"
495        }
496
497        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
498            candle_core::bail!("SumAllReduce is never used on cpu")
499        }
500
501        fn cuda_fwd(
502            &self,
503            s: &candle_core::CudaStorage,
504            l: &Layout,
505        ) -> Result<(candle_core::CudaStorage, Shape)> {
506            use cudarc::nccl::ReduceOp;
507            use half::{bf16, f16};
508
509            let elem_count = l.shape().elem_count();
510            let dev = s.device().clone();
511
512            match self.comm.as_ref() {
513                super::Comm::Nccl(nccl_comm) => {
514                    let dst = match s.dtype() {
515                        DType::BF16 => {
516                            let s = s.as_cuda_slice::<bf16>()?;
517                            let s = match l.contiguous_offsets() {
518                                Some((0, l)) if l == s.len() => s,
519                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
520                            };
521                            if elem_count == 0 {
522                                candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
523                            }
524                            let device_ordinal = dev.cuda_stream().context().ordinal();
525                            if device_ordinal != nccl_comm.rank() {
526                                candle_core::bail!(
527                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
528                                     Ensure each rank uses the correct GPU.",
529                                    device_ordinal,
530                                    nccl_comm.rank()
531                                );
532                            }
533                            tracing::debug!(
534                                "NCCL all_reduce (BF16): rank={}, device={}, elem_count={}",
535                                nccl_comm.rank(),
536                                device_ordinal,
537                                elem_count
538                            );
539                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
540                            nccl_comm
541                                .inner()
542                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
543                                .map_err(candle_core::Error::debug)?;
544                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
545                        }
546                        DType::F16 => {
547                            let s = s.as_cuda_slice::<f16>()?;
548                            let s = match l.contiguous_offsets() {
549                                Some((0, l)) if l == s.len() => s,
550                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
551                            };
552                            if elem_count == 0 {
553                                candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
554                            }
555                            let device_ordinal = dev.cuda_stream().context().ordinal();
556                            if device_ordinal != nccl_comm.rank() {
557                                candle_core::bail!(
558                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
559                                     Ensure each rank uses the correct GPU.",
560                                    device_ordinal,
561                                    nccl_comm.rank()
562                                );
563                            }
564                            tracing::debug!(
565                                "NCCL all_reduce (F16): rank={}, device={}, elem_count={}",
566                                nccl_comm.rank(),
567                                device_ordinal,
568                                elem_count
569                            );
570                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
571                            nccl_comm
572                                .inner()
573                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
574                                .map_err(candle_core::Error::debug)?;
575                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
576                        }
577                        DType::F32 => {
578                            let s = s.as_cuda_slice::<f32>()?;
579                            let s = match l.contiguous_offsets() {
580                                Some((0, l)) if l == s.len() => s,
581                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
582                            };
583                            if elem_count == 0 {
584                                candle_core::bail!("NCCL all_reduce: elem_count must be > 0");
585                            }
586                            let device_ordinal = dev.cuda_stream().context().ordinal();
587                            if device_ordinal != nccl_comm.rank() {
588                                candle_core::bail!(
589                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
590                                     Ensure each rank uses the correct GPU.",
591                                    device_ordinal,
592                                    nccl_comm.rank()
593                                );
594                            }
595                            tracing::debug!(
596                                "NCCL all_reduce (F32): rank={}, device={}, elem_count={}",
597                                nccl_comm.rank(),
598                                device_ordinal,
599                                elem_count
600                            );
601                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
602                            nccl_comm
603                                .inner()
604                                .all_reduce(s, &mut dst, &ReduceOp::Sum)
605                                .map_err(candle_core::Error::debug)?;
606                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
607                        }
608                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
609                    };
610                    Ok((dst, l.shape().clone()))
611                }
612                _ => candle_core::bail!("SumAllReduce requires NCCL backend"),
613            }
614        }
615    }
616
617    #[derive(Clone, Debug)]
618    pub struct AllGather {
619        comm: Arc<super::Comm>,
620        dim: usize,
621    }
622
623    impl AllGather {
624        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
625            Self {
626                comm: comm.clone(),
627                dim,
628            }
629        }
630    }
631
632    impl AllGather {
633        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
634            xs.apply_op1_no_bwd(self)
635        }
636    }
637
638    impl CustomOp1 for AllGather {
639        fn name(&self) -> &'static str {
640            "AllGather"
641        }
642
643        fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
644            candle_core::bail!("AllGather is never used on cpu")
645        }
646
647        fn cuda_fwd(
648            &self,
649            s: &candle_core::CudaStorage,
650            l: &Layout,
651        ) -> Result<(candle_core::CudaStorage, Shape)> {
652            use half::{bf16, f16};
653
654            let mut out_shape = l.shape().dims().to_vec();
655            out_shape[self.dim] = out_shape[self.dim] * self.comm.world_size();
656            let out_shape = Shape::from(out_shape);
657
658            let elem_count = out_shape.elem_count();
659            let dev = s.device().clone();
660
661            match self.comm.as_ref() {
662                super::Comm::Nccl(nccl_comm) => {
663                    let dst = match s.dtype() {
664                        DType::BF16 => {
665                            let s = s.as_cuda_slice::<bf16>()?;
666                            let s = match l.contiguous_offsets() {
667                                Some((0, l)) if l == s.len() => s,
668                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
669                            };
670                            if elem_count == 0 {
671                                candle_core::bail!("NCCL all_gather: elem_count must be > 0");
672                            }
673                            let device_ordinal = dev.cuda_stream().context().ordinal();
674                            if device_ordinal != nccl_comm.rank() {
675                                candle_core::bail!(
676                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
677                                     Ensure each rank uses the correct GPU.",
678                                    device_ordinal,
679                                    nccl_comm.rank()
680                                );
681                            }
682                            tracing::debug!(
683                                "NCCL all_gather (BF16): rank={}, device={}, elem_count={}",
684                                nccl_comm.rank(),
685                                device_ordinal,
686                                elem_count
687                            );
688                            let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
689                            nccl_comm
690                                .inner()
691                                .all_gather(s, &mut dst)
692                                .map_err(candle_core::Error::debug)?;
693                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
694                        }
695                        DType::F16 => {
696                            let s = s.as_cuda_slice::<f16>()?;
697                            let s = match l.contiguous_offsets() {
698                                Some((0, l)) if l == s.len() => s,
699                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
700                            };
701                            if elem_count == 0 {
702                                candle_core::bail!("NCCL all_gather: elem_count must be > 0");
703                            }
704                            let device_ordinal = dev.cuda_stream().context().ordinal();
705                            if device_ordinal != nccl_comm.rank() {
706                                candle_core::bail!(
707                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
708                                     Ensure each rank uses the correct GPU.",
709                                    device_ordinal,
710                                    nccl_comm.rank()
711                                );
712                            }
713                            tracing::debug!(
714                                "NCCL all_gather (F16): rank={}, device={}, elem_count={}",
715                                nccl_comm.rank(),
716                                device_ordinal,
717                                elem_count
718                            );
719                            let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
720                            nccl_comm
721                                .inner()
722                                .all_gather(s, &mut dst)
723                                .map_err(candle_core::Error::debug)?;
724                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
725                        }
726                        DType::F32 => {
727                            let s = s.as_cuda_slice::<f32>()?;
728                            let s = match l.contiguous_offsets() {
729                                Some((0, l)) if l == s.len() => s,
730                                Some(_) | None => candle_core::bail!("input has to be contiguous"),
731                            };
732                            if elem_count == 0 {
733                                candle_core::bail!("NCCL all_gather: elem_count must be > 0");
734                            }
735                            let device_ordinal = dev.cuda_stream().context().ordinal();
736                            if device_ordinal != nccl_comm.rank() {
737                                candle_core::bail!(
738                                    "NCCL device mismatch: tensor on device {} but NCCL rank is {}. \
739                                     Ensure each rank uses the correct GPU.",
740                                    device_ordinal,
741                                    nccl_comm.rank()
742                                );
743                            }
744                            tracing::debug!(
745                                "NCCL all_gather (F32): rank={}, device={}, elem_count={}",
746                                nccl_comm.rank(),
747                                device_ordinal,
748                                elem_count
749                            );
750                            let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
751                            nccl_comm
752                                .inner()
753                                .all_gather(s, &mut dst)
754                                .map_err(candle_core::Error::debug)?;
755                            candle_core::CudaStorage::wrap_cuda_slice(dst, dev)
756                        }
757                        dtype => candle_core::bail!("unsupported dtype {dtype:?}"),
758                    };
759                    Ok((dst, out_shape))
760                }
761                _ => candle_core::bail!("AllGather requires NCCL backend"),
762            }
763        }
764    }
765}
766
767// Ring operations
768#[cfg(feature = "ring")]
769mod ring_ops {
770    use std::{
771        collections::HashMap,
772        sync::{Arc, Mutex, OnceLock},
773        time::{Duration, Instant},
774    };
775
776    use std::io::{Read, Write};
777    use std::net::{TcpListener, TcpStream};
778
779    // Friendly aliases to tame type complexity.
780    type SharedTcpStream = Arc<Mutex<TcpStream>>;
781    type LeftRight = (SharedTcpStream, SharedTcpStream);
782
783    use candle_core::{
784        backend::BackendStorage, CpuStorage, Device, Result, Storage, Tensor, WithDType,
785    };
786
787    use super::RingConfig;
788
789    // Lazily–initialized pair of TCP streams shared by every ring‑based collective op
790    static LEFT_RIGHT_STREAMS: OnceLock<LeftRight> = OnceLock::new();
791
792    fn get_ring_streams(config: &RingConfig) -> LeftRight {
793        LEFT_RIGHT_STREAMS
794            .get_or_init(|| {
795                let cur_port = config.port;
796
797                let right_ip = config.right_ip();
798                let right_port = config.right_port;
799
800                let left_listener =
801                    TcpListener::bind(format!("0.0.0.0:{cur_port}")).expect("bind left");
802
803                let start = Instant::now();
804                // Connect to the right neighbor using the provided IP
805                let right = loop {
806                    match TcpStream::connect(format!("{}:{}", right_ip, right_port)) {
807                        Ok(s) => break s,
808                        Err(_) if start.elapsed() > Duration::from_secs(10) => {
809                            panic!("Failed to connect to right node due to 10-second timeout");
810                        }
811                        Err(_) => continue,
812                    }
813                };
814
815                // Accept connection from the left neighbour
816                let (left, _) = left_listener.accept().expect("accept left neighbour");
817
818                left.set_nodelay(true).unwrap();
819                left.set_nonblocking(false).unwrap();
820                right.set_nodelay(true).unwrap();
821                right.set_nonblocking(false).unwrap();
822
823                (Arc::new(Mutex::new(left)), Arc::new(Mutex::new(right)))
824            })
825            .clone()
826    }
827
828    #[derive(Clone, Debug)]
829    pub struct SumAllReduce {
830        left: SharedTcpStream,
831        right: SharedTcpStream,
832        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
833    }
834
835    impl SumAllReduce {
836        pub fn new(comm: &Arc<super::Comm>) -> Self {
837            match &**comm {
838                super::Comm::Ring(ring_comm) => {
839                    let (left, right) = get_ring_streams(ring_comm.config());
840                    Self {
841                        left,
842                        right,
843                        buffers: Arc::new(Mutex::new(HashMap::new())),
844                    }
845                }
846                _ => panic!("SumAllReduce requires Ring backend"),
847            }
848        }
849
850        fn run<T: WithDType + Copy>(
851            &self,
852            x: &[T],
853            dims: &[usize],
854            device: &Device,
855        ) -> Result<Tensor> {
856            let nbytes = x.len() * std::mem::size_of_val(x);
857
858            // --- ping‑pong to overlap latency ---------------------------------------
859            // Clone the Arc references
860            let right = self.right.clone();
861            let left = self.left.clone();
862
863            // View the local slice as bytes that can be written on the wire.
864            let data_bytes = unsafe { std::slice::from_raw_parts(x.as_ptr() as *const u8, nbytes) };
865
866            // Re‑use (or allocate) a receive buffer of identical size.
867            let mut buffers_guard = self.buffers.lock().map_err(|e| {
868                candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
869            })?;
870            let recv_buf = buffers_guard
871                .entry(nbytes)
872                .or_insert_with(|| vec![0u8; nbytes]);
873
874            // Lock both sockets once to avoid per-call mutex overhead.
875            let mut right_guard = right.lock().map_err(|e| {
876                candle_core::Error::msg(format!("Failed to lock right stream mutex: {:?}", e))
877            })?;
878            let mut left_guard = left.lock().map_err(|e| {
879                candle_core::Error::msg(format!("Failed to lock left stream mutex: {:?}", e))
880            })?;
881
882            // For the typical tensor size we see (~ 6 KiB) a single
883            // write/read pair is faster than chunking because the extra
884            // system‑call and loop overhead dominates.  Only fall back to the
885            // chunked "ping‑pong" pipeline for larger transfers.
886            if nbytes <= 8 * 1024 {
887                // --- fast path: one shot ------------------------------------
888                right_guard
889                    .write_all(data_bytes)
890                    .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
891
892                left_guard
893                    .read_exact(recv_buf)
894                    .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
895            } else {
896                // --- slow path: chunked ping‑pong ---------------------------
897                const CHUNK_SIZE: usize = 64 * 1024; // 64 KiB
898                let mut offset = 0;
899
900                while offset < nbytes {
901                    let len = std::cmp::min(CHUNK_SIZE, nbytes - offset);
902
903                    // send this chunk to the right neighbour
904                    right_guard
905                        .write_all(&data_bytes[offset..offset + len])
906                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
907
908                    // receive the matching chunk from the left neighbour
909                    left_guard
910                        .read_exact(&mut recv_buf[offset..offset + len])
911                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
912
913                    offset += len;
914                }
915            }
916
917            drop(left_guard);
918            drop(right_guard);
919
920            // -------------------------------------------------------------------------
921            // Interpret the received bytes as a slice of T and add element‑wise into x
922            let received: &[T] =
923                unsafe { std::slice::from_raw_parts(recv_buf.as_ptr() as *const T, x.len()) };
924
925            Tensor::from_slice(received, dims, device)
926        }
927
928        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
929            let storage = xs.storage_and_layout().0;
930            let cpu_storage = match &*storage {
931                Storage::Cpu(storage) => storage,
932                Storage::Cuda(storage) => &storage.to_cpu_storage()?,
933                Storage::Metal(storage) => &storage.to_cpu_storage()?,
934            };
935
936            let delta = match cpu_storage {
937                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
938                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
939                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device())?,
940                _ => candle_core::bail!("Unsupported dtype for ring backend"),
941            };
942
943            xs + delta
944        }
945    }
946
947    #[derive(Clone, Debug)]
948    pub struct AllGather {
949        left: SharedTcpStream,
950        right: SharedTcpStream,
951        buffers: Arc<Mutex<HashMap<usize, Vec<u8>>>>,
952        dim: usize,
953        world_size: usize,
954        rank: usize,
955    }
956
957    impl AllGather {
958        pub fn new(comm: &Arc<super::Comm>, dim: usize) -> Self {
959            match &**comm {
960                super::Comm::Ring(ring_comm) => {
961                    let (left, right) = get_ring_streams(ring_comm.config());
962                    Self {
963                        left,
964                        right,
965                        buffers: Arc::new(Mutex::new(HashMap::new())),
966                        dim,
967                        world_size: ring_comm.world_size(),
968                        rank: ring_comm.rank(),
969                    }
970                }
971                _ => panic!("AllGather requires Ring backend"),
972            }
973        }
974
975        fn run<T: WithDType + Copy + Default>(
976            &self,
977            x: &[T],
978            dims: &[usize],
979            device: &Device,
980        ) -> Result<Tensor> {
981            // Validate gather dimension
982            if self.dim >= dims.len() {
983                candle_core::bail!(
984                    "AllGather: invalid dimension {} for tensor of rank {}",
985                    self.dim,
986                    dims.len()
987                );
988            }
989            let elem_cnt = x.len();
990            let nbytes = elem_cnt * std::mem::size_of_val(x);
991
992            // Prepare output buffer that will hold slices from every rank.
993            let mut out: Vec<T> = vec![T::default(); elem_cnt * self.world_size];
994
995            // Copy this rank's slice into its final slot.
996            let start = self.rank * elem_cnt;
997            out[start..start + elem_cnt].copy_from_slice(x);
998
999            let right = self.right.clone();
1000            let left = self.left.clone();
1001            let mut send_piece: &[T] = x;
1002
1003            for step in 0..(self.world_size - 1) {
1004                // ---------- send to the right ----------
1005                let bytes =
1006                    unsafe { std::slice::from_raw_parts(send_piece.as_ptr() as *const u8, nbytes) };
1007                {
1008                    let mut rg = right.lock().map_err(|e| {
1009                        candle_core::Error::msg(format!(
1010                            "Failed to lock right stream mutex: {:?}",
1011                            e
1012                        ))
1013                    })?;
1014                    rg.write_all(bytes)
1015                        .map_err(|e| candle_core::Error::msg(format!("write error: {:?}", e)))?;
1016                }
1017
1018                // ---------- receive from the left ----------
1019                let mut bg = self.buffers.lock().map_err(|e| {
1020                    candle_core::Error::msg(format!("Failed to lock buffers mutex: {:?}", e))
1021                })?;
1022                let buf = bg.entry(nbytes).or_insert_with(|| vec![0u8; nbytes]);
1023                {
1024                    let mut lg = left.lock().map_err(|e| {
1025                        candle_core::Error::msg(format!(
1026                            "Failed to lock left stream mutex: {:?}",
1027                            e
1028                        ))
1029                    })?;
1030                    lg.read_exact(buf)
1031                        .map_err(|e| candle_core::Error::msg(format!("read error: {:?}", e)))?;
1032                }
1033                let recv_piece: &[T] =
1034                    unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const T, elem_cnt) };
1035
1036                // Determine which global rank the received slice came from.
1037                let src_rank = (self.rank + self.world_size - step - 1) % self.world_size;
1038                let dst = src_rank * elem_cnt;
1039                out[dst..dst + elem_cnt].copy_from_slice(recv_piece);
1040
1041                // Forward that slice in the next iteration.
1042                send_piece = recv_piece;
1043            }
1044
1045            let mut out_dims = dims.to_vec();
1046            out_dims[self.dim] *= self.world_size;
1047            Tensor::from_slice(&out, out_dims, device)
1048        }
1049
1050        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
1051            let storage = xs.storage_and_layout().0;
1052            let cpu_storage = match &*storage {
1053                Storage::Cpu(s) => s,
1054                Storage::Cuda(s) => &s.to_cpu_storage()?,
1055                Storage::Metal(s) => &s.to_cpu_storage()?,
1056            };
1057
1058            match cpu_storage {
1059                CpuStorage::BF16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1060                CpuStorage::F32(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1061                CpuStorage::F16(x) => self.run(x.as_slice(), xs.dims(), xs.device()),
1062                _ => candle_core::bail!("Unsupported dtype for ring backend"),
1063            }
1064        }
1065    }
1066}
1067
1068// Dummy operations
1069mod dummy_ops {
1070    use candle_core::{Result, Tensor};
1071    use std::sync::Arc;
1072
1073    #[derive(Clone, Debug)]
1074    pub struct SumAllReduce;
1075
1076    impl SumAllReduce {
1077        pub fn new(_comm: &Arc<super::Comm>) -> Self {
1078            Self
1079        }
1080
1081        pub fn sum_all_reduce(&self, xs: &Tensor) -> Result<Tensor> {
1082            Ok(xs.clone())
1083        }
1084    }
1085
1086    #[derive(Clone, Debug)]
1087    pub struct AllGather;
1088
1089    impl AllGather {
1090        pub fn new(_comm: &Arc<super::Comm>, _dim: usize) -> Self {
1091            Self
1092        }
1093
1094        pub fn all_gather(&self, xs: &Tensor) -> Result<Tensor> {
1095            Ok(xs.clone())
1096        }
1097    }
1098}