Skip to main content

cudarc/nccl/
safe.rs

1use super::{result, sys};
2use crate::driver::{CudaContext, CudaStream, DevicePtr, DevicePtrMut};
3use std::{mem::MaybeUninit, sync::Arc, vec, vec::Vec};
4
5pub use result::{group_end, group_start};
6
7#[derive(Debug)]
8pub struct Comm {
9    comm: sys::ncclComm_t,
10    stream: Arc<CudaStream>,
11    rank: usize,
12    world_size: usize,
13}
14
15#[derive(Debug, Clone, Copy)]
16pub struct Id {
17    id: sys::ncclUniqueId,
18}
19
20impl Id {
21    pub fn new() -> Result<Self, result::NcclError> {
22        let id = result::get_uniqueid()?;
23        Ok(Self { id })
24    }
25
26    pub fn uninit(internal: [::core::ffi::c_char; 128usize]) -> Self {
27        let id = sys::ncclUniqueId { internal };
28        Self { id }
29    }
30
31    pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
32        &self.id.internal
33    }
34}
35
36pub enum ReduceOp {
37    Sum,
38    Prod,
39    Max,
40    Min,
41    Avg,
42}
43
44fn convert_to_nccl_reduce_op(op: &ReduceOp) -> sys::ncclRedOp_t {
45    match op {
46        ReduceOp::Sum => sys::ncclRedOp_t::ncclSum,
47        ReduceOp::Prod => sys::ncclRedOp_t::ncclProd,
48        ReduceOp::Max => sys::ncclRedOp_t::ncclMax,
49        ReduceOp::Min => sys::ncclRedOp_t::ncclMin,
50        ReduceOp::Avg => sys::ncclRedOp_t::ncclAvg,
51    }
52}
53
54impl Drop for Comm {
55    fn drop(&mut self) {
56        // TODO(thenerdstation): Shoule we instead do finalize then destory?
57        unsafe {
58            result::comm_abort(self.comm).expect("Error when aborting Comm.");
59        }
60    }
61}
62
63pub trait NcclType {
64    fn as_nccl_type() -> sys::ncclDataType_t;
65}
66
67macro_rules! define_nccl_type {
68    ($t:ty, $nccl_type:expr) => {
69        impl NcclType for $t {
70            fn as_nccl_type() -> sys::ncclDataType_t {
71                $nccl_type
72            }
73        }
74    };
75}
76
77define_nccl_type!(f32, sys::ncclDataType_t::ncclFloat32);
78define_nccl_type!(f64, sys::ncclDataType_t::ncclFloat64);
79define_nccl_type!(i8, sys::ncclDataType_t::ncclInt8);
80define_nccl_type!(i32, sys::ncclDataType_t::ncclInt32);
81define_nccl_type!(i64, sys::ncclDataType_t::ncclInt64);
82define_nccl_type!(u8, sys::ncclDataType_t::ncclUint8);
83define_nccl_type!(u32, sys::ncclDataType_t::ncclUint32);
84define_nccl_type!(u64, sys::ncclDataType_t::ncclUint64);
85define_nccl_type!(char, sys::ncclDataType_t::ncclUint8);
86#[cfg(feature = "f16")]
87define_nccl_type!(half::f16, sys::ncclDataType_t::ncclFloat16);
88#[cfg(feature = "f16")]
89define_nccl_type!(half::bf16, sys::ncclDataType_t::ncclBfloat16);
90impl Comm {
91    /// Primitive to create new communication link on a single thread.
92    /// WARNING: You are likely to get limited throughput using a single core
93    /// to control multiple GPUs (see issue #169).
94    /// ```
95    /// # use cudarc::driver::safe::{CudaDevice};
96    /// # use cudarc::nccl::safe::{Comm, ReduceOp, group_start, group_end};
97    /// let n = 2;
98    /// let n_devices = CudaDevice::count().unwrap() as usize;
99    /// let devices : Vec<_> = (0..n_devices).flat_map(CudaDevice::new).collect();
100    /// let comms = Comm::from_devices(devices).unwrap();
101    /// group_start().unwrap();
102    /// (0..n_devices).map(|i| {
103    ///     let comm = &comms[i];
104    ///     let dev = comm.device();
105    ///     let slice = dev.htod_copy(vec![(i + 1) as f32 * 1.0; n]).unwrap();
106    ///     let mut slice_receive = dev.alloc_zeros::<f32>(n).unwrap();
107    ///     comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum)
108    ///         .unwrap();
109    /// });
110    /// group_start().unwrap();
111    /// ```
112    pub fn from_devices(streams: Vec<Arc<CudaStream>>) -> Result<Vec<Self>, result::NcclError> {
113        let n_streams = streams.len();
114        let mut comms = vec![std::ptr::null_mut(); n_streams];
115        let ordinals: Vec<_> = streams
116            .iter()
117            .map(|d| d.context().ordinal() as i32)
118            .collect();
119        unsafe {
120            result::comm_init_all(comms.as_mut_ptr(), n_streams as i32, ordinals.as_ptr())?;
121        }
122
123        let comms: Vec<Self> = comms
124            .into_iter()
125            .zip(streams.iter().cloned())
126            .enumerate()
127            .map(|(rank, (comm, stream))| Self {
128                comm,
129                stream,
130                rank,
131                world_size: n_streams,
132            })
133            .collect();
134
135        Ok(comms)
136    }
137
138    pub fn stream(&self) -> Arc<CudaStream> {
139        self.stream.clone()
140    }
141
142    pub fn context(&self) -> &Arc<CudaContext> {
143        self.stream.context()
144    }
145
146    pub fn ordinal(&self) -> usize {
147        self.stream.ctx.ordinal
148    }
149
150    pub fn rank(&self) -> usize {
151        self.rank
152    }
153
154    pub fn world_size(&self) -> usize {
155        self.world_size
156    }
157
158    /// Primitive to create new communication link on each process (threads are possible but not
159    /// recommended).
160    ///
161    /// WARNING: If using threads, uou are likely to get limited throughput using a single core
162    /// to control multiple GPUs. Cuda drivers effectively use a global mutex thrashing
163    /// performance on multi threaded multi GPU (see issue #169).
164    /// ```
165    /// # use cudarc::driver::safe::{CudaDevice};
166    /// # use cudarc::nccl::safe::{Comm, Id, ReduceOp};
167    /// let n = 2;
168    /// let n_devices = 1; // This is to simplify this example.
169    /// // Spawn this only on rank 0
170    /// let id = Id::new().unwrap();
171    /// // Send id.internal() to other ranks
172    /// // let id = Id::uninit(id.internal().clone()); on other ranks
173    ///
174    /// let rank = 0;
175    /// let dev = CudaDevice::new(rank).unwrap();
176    /// let comm = Comm::from_rank(dev.clone(), rank, n_devices, id).unwrap();
177    /// let slice = dev.htod_copy(vec![(rank + 1) as f32 * 1.0; n]).unwrap();
178    /// let mut slice_receive = dev.alloc_zeros::<f32>(n).unwrap();
179    /// comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum)
180    ///     .unwrap();
181    /// let out = dev.dtoh_sync_copy(&slice_receive).unwrap();
182    /// assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]);
183    /// ```
184    pub fn from_rank(
185        stream: Arc<CudaStream>,
186        rank: usize,
187        world_size: usize,
188        id: Id,
189    ) -> Result<Self, result::NcclError> {
190        let mut comm = MaybeUninit::uninit();
191
192        let comm = unsafe {
193            result::comm_init_rank(
194                comm.as_mut_ptr(),
195                world_size
196                    .try_into()
197                    .expect("World_size cannot be casted to i32"),
198                id.id,
199                rank.try_into().expect("Rank cannot be cast to i32"),
200            )?;
201            comm.assume_init()
202        };
203        Ok(Self {
204            comm,
205            stream,
206            rank,
207            world_size,
208        })
209    }
210}
211
212impl Comm {
213    /// Send data to one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend)
214    pub fn send<S: DevicePtr<T>, T: NcclType>(
215        &self,
216        data: &S,
217        peer: i32,
218    ) -> Result<(), result::NcclError> {
219        let (src, _record_src) = data.device_ptr(&self.stream);
220        unsafe {
221            result::send(
222                src as _,
223                data.len(),
224                T::as_nccl_type(),
225                peer,
226                self.comm,
227                self.stream.cu_stream as _,
228            )
229        }?;
230        Ok(())
231    }
232
233    /// Receive data from one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv)
234    pub fn recv<R: DevicePtrMut<T>, T: NcclType>(
235        &self,
236        buff: &mut R,
237        peer: i32,
238    ) -> Result<result::NcclStatus, result::NcclError> {
239        let count = buff.len();
240        let (dst, _record_dst) = buff.device_ptr_mut(&self.stream);
241        unsafe {
242            result::recv(
243                dst as _,
244                count,
245                T::as_nccl_type(),
246                peer,
247                self.comm,
248                self.stream.cu_stream as _,
249            )
250        }
251    }
252
253    /// Broadcasts a value from `root` rank to every other ranks `recvbuff`.
254    /// sendbuff is ignored on ranks other than `root`, so you can pass `None`
255    /// on non-root ranks.
256    ///
257    /// sendbuff must be Some on root rank!
258    ///
259    /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast)
260    pub fn broadcast<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
261        &self,
262        sendbuff: Option<&S>,
263        recvbuff: &mut R,
264        root: i32,
265    ) -> Result<result::NcclStatus, result::NcclError> {
266        debug_assert!(sendbuff.is_some() || self.rank != root as usize);
267        let count = recvbuff.len();
268        let (src, _record_src) = sendbuff.map(|b| b.device_ptr(&self.stream)).unzip();
269        let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
270        unsafe {
271            result::broadcast(
272                src.map(|ptr| ptr as _).unwrap_or(std::ptr::null()),
273                dst as _,
274                count,
275                T::as_nccl_type(),
276                root,
277                self.comm,
278                self.stream.cu_stream as _,
279            )
280        }
281    }
282
283    /// In place version of [Comm::broadcast()].
284    /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast)
285    pub fn broadcast_in_place<R: DevicePtrMut<T>, T: NcclType>(
286        &self,
287        recvbuff: &mut R,
288        root: i32,
289    ) -> Result<result::NcclStatus, result::NcclError> {
290        let count = recvbuff.len();
291        let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
292        unsafe {
293            result::broadcast(
294                dst as _,
295                dst as _,
296                count,
297                T::as_nccl_type(),
298                root,
299                self.comm,
300                self.stream.cu_stream as _,
301            )
302        }
303    }
304
305    /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather)
306    pub fn all_gather<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
307        &self,
308        sendbuff: &S,
309        recvbuff: &mut R,
310    ) -> Result<result::NcclStatus, result::NcclError> {
311        let (src, _record_src) = sendbuff.device_ptr(&self.stream);
312        let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
313        unsafe {
314            result::all_gather(
315                src as _,
316                dst as _,
317                sendbuff.len(),
318                T::as_nccl_type(),
319                self.comm,
320                self.stream.cu_stream as _,
321            )
322        }
323    }
324
325    /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce)
326    pub fn all_reduce<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
327        &self,
328        sendbuff: &S,
329        recvbuff: &mut R,
330        reduce_op: &ReduceOp,
331    ) -> Result<result::NcclStatus, result::NcclError> {
332        let (src, _record_src) = sendbuff.device_ptr(&self.stream);
333        let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
334        unsafe {
335            result::all_reduce(
336                src as _,
337                dst as _,
338                sendbuff.len(),
339                T::as_nccl_type(),
340                convert_to_nccl_reduce_op(reduce_op),
341                self.comm,
342                self.stream.cu_stream as _,
343            )
344        }
345    }
346
347    /// In place version of [Comm::all_reduce()].
348    /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce)
349    pub fn all_reduce_in_place<R: DevicePtrMut<T>, T: NcclType>(
350        &self,
351        buff: &mut R,
352        reduce_op: &ReduceOp,
353    ) -> Result<result::NcclStatus, result::NcclError> {
354        let count = buff.len();
355        let (dst, _record_dst) = buff.device_ptr_mut(&self.stream);
356        unsafe {
357            result::all_reduce(
358                dst as _,
359                dst as _,
360                count,
361                T::as_nccl_type(),
362                convert_to_nccl_reduce_op(reduce_op),
363                self.comm,
364                self.stream.cu_stream as _,
365            )
366        }
367    }
368
369    /// Reduces the sendbuff from all ranks into the recvbuff on the
370    /// `root` rank.
371    ///
372    /// recvbuff must be Some on the root rank!
373    ///
374    /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce)
375    pub fn reduce<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
376        &self,
377        sendbuff: &S,
378        recvbuff: Option<&mut R>,
379        reduce_op: &ReduceOp,
380        root: i32,
381    ) -> Result<result::NcclStatus, result::NcclError> {
382        debug_assert!(recvbuff.is_some() || self.rank != root as usize);
383
384        let (src, _record_src) = sendbuff.device_ptr(&self.stream);
385        let (dst, _record_dst) = recvbuff.map(|b| b.device_ptr_mut(&self.stream)).unzip();
386        unsafe {
387            result::reduce(
388                src as _,
389                dst.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()),
390                sendbuff.len(),
391                T::as_nccl_type(),
392                convert_to_nccl_reduce_op(reduce_op),
393                root,
394                self.comm,
395                self.stream.cu_stream as _,
396            )
397        }
398    }
399
400    /// In place version of [Comm::reduce()].
401    /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce)
402    pub fn reduce_in_place<R: DevicePtrMut<T>, T: NcclType>(
403        &self,
404        recvbuff: &mut R,
405        reduce_op: &ReduceOp,
406        root: i32,
407    ) -> Result<result::NcclStatus, result::NcclError> {
408        let count = recvbuff.len();
409        let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
410        unsafe {
411            result::reduce(
412                dst as _,
413                dst as _,
414                count,
415                T::as_nccl_type(),
416                convert_to_nccl_reduce_op(reduce_op),
417                root,
418                self.comm,
419                self.stream.cu_stream as _,
420            )
421        }
422    }
423
424    /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter)
425    pub fn reduce_scatter<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
426        &self,
427        sendbuff: &S,
428        recvbuff: &mut R,
429        reduce_op: &ReduceOp,
430    ) -> Result<result::NcclStatus, result::NcclError> {
431        let count = recvbuff.len();
432        let (src, _record_src) = sendbuff.device_ptr(&self.stream);
433        let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
434        unsafe {
435            result::reduce_scatter(
436                src as _,
437                dst as _,
438                count,
439                T::as_nccl_type(),
440                convert_to_nccl_reduce_op(reduce_op),
441                self.comm,
442                self.stream.cu_stream as _,
443            )
444        }
445    }
446}
447
448#[macro_export]
449macro_rules! group {
450    ($x:block) => {
451        unsafe {
452            result::group_start().unwrap();
453        }
454        $x
455        unsafe {
456            result::group_end().unwrap();
457        }
458    };
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    #[cfg(feature = "no-std")]
465    use no_std_compat::println;
466
467    #[test]
468    fn test_all_reduce() {
469        let n = 2;
470        let n_devices = CudaContext::device_count().unwrap() as usize;
471        let id = Id::new().unwrap();
472        let threads: Vec<_> = (0..n_devices)
473            .map(|i| {
474                println!("III {i}");
475                std::thread::spawn(move || {
476                    println!("Within thread {i}");
477                    let ctx = CudaContext::new(i).unwrap();
478                    let stream = ctx.default_stream();
479                    let comm = Comm::from_rank(stream.clone(), i, n_devices, id).unwrap();
480                    let slice = stream.clone_htod(&vec![(i + 1) as f32 * 1.0; n]).unwrap();
481                    let mut slice_receive = stream.alloc_zeros::<f32>(n).unwrap();
482                    comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum)
483                        .unwrap();
484
485                    let out = stream.clone_dtoh(&slice_receive).unwrap();
486
487                    assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]);
488                })
489            })
490            .collect();
491        for t in threads {
492            t.join().unwrap()
493        }
494    }
495
496    #[test]
497    fn test_all_reduce_views() {
498        let n = 2;
499        let n_devices = CudaContext::device_count().unwrap() as usize;
500        let id = Id::new().unwrap();
501        let threads: Vec<_> = (0..n_devices)
502            .map(|i| {
503                println!("III {i}");
504                std::thread::spawn(move || {
505                    println!("Within thread {i}");
506                    let ctx = CudaContext::new(i).unwrap();
507                    let stream = ctx.default_stream();
508                    let comm = Comm::from_rank(stream.clone(), i, n_devices, id).unwrap();
509                    let slice = stream.clone_htod(&vec![(i + 1) as f32 * 1.0; n]).unwrap();
510                    let mut slice_receive = stream.alloc_zeros::<f32>(n).unwrap();
511                    let slice_view = slice.slice(..);
512                    let mut slice_receive_view = slice_receive.slice_mut(..);
513
514                    comm.all_reduce(&slice_view, &mut slice_receive_view, &ReduceOp::Sum)
515                        .unwrap();
516
517                    let out = stream.clone_dtoh(&slice_receive).unwrap();
518
519                    assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]);
520                })
521            })
522            .collect();
523        for t in threads {
524            t.join().unwrap()
525        }
526    }
527}