Skip to main content

baracuda_nccl/
lib.rs

1//! Safe Rust wrappers for NVIDIA NCCL.
2//!
3//! v0.1 covers the communicator (single-process multi-GPU via
4//! `ncclCommInitAll`, multi-process via `ncclCommInitRank` + `UniqueId`) and
5//! the `all_reduce` + `broadcast` collectives — enough for synchronous
6//! data-parallel training.
7//!
8//! NCCL is a Linux library; Windows has experimental support but no
9//! general distribution. On hosts without NCCL, [`Communicator::init_all`]
10//! returns `LoaderError::LibraryNotFound` — callers can fall back to
11//! single-device execution.
12
13#![warn(missing_debug_implementations)]
14
15use baracuda_driver::{DeviceBuffer, Stream};
16use baracuda_nccl_sys::{
17    nccl, ncclComm_t, ncclDataType_t, ncclRedOp_t, ncclResult_t, ncclUniqueId,
18};
19use baracuda_types::DeviceRepr;
20
21/// Error type for NCCL operations.
22pub type Error = baracuda_core::Error<ncclResult_t>;
23/// Result alias.
24pub type Result<T, E = Error> = core::result::Result<T, E>;
25
26#[inline]
27fn check(status: ncclResult_t) -> Result<()> {
28    Error::check(status)
29}
30
31/// Reduction operation for `all_reduce` / `reduce`.
32#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
33pub enum RedOp {
34    #[default]
35    Sum,
36    Prod,
37    Max,
38    Min,
39    /// Arithmetic mean. NCCL 2.10+.
40    Avg,
41    /// Custom op id returned by [`Communicator::create_pre_mul_sum`].
42    /// NCCL 2.11+.
43    Custom(i32),
44}
45
46impl RedOp {
47    fn raw(self) -> ncclRedOp_t {
48        match self {
49            RedOp::Sum => ncclRedOp_t::Sum,
50            RedOp::Prod => ncclRedOp_t::Prod,
51            RedOp::Max => ncclRedOp_t::Max,
52            RedOp::Min => ncclRedOp_t::Min,
53            RedOp::Avg => ncclRedOp_t::Avg,
54            RedOp::Custom(id) => ncclRedOp_t(id),
55        }
56    }
57}
58
59/// Where the scalar passed to [`Communicator::create_pre_mul_sum`] lives.
60#[derive(Copy, Clone, Debug, Eq, PartialEq)]
61pub enum ScalarResidence {
62    /// Scalar pointer is in host memory; NCCL captures the value at call time.
63    Host = 0,
64    /// Scalar pointer is in device memory; NCCL captures it at collective launch.
65    Device = 1,
66}
67
68/// Element type for NCCL buffers. Implemented by baracuda-types primitives
69/// via a sealed trait.
70pub trait NcclScalar: DeviceRepr + sealed::Sealed {
71    #[doc(hidden)]
72    fn raw() -> ncclDataType_t;
73}
74
75macro_rules! impl_nccl_scalar {
76    ($ty:ty, $variant:ident) => {
77        impl NcclScalar for $ty {
78            fn raw() -> ncclDataType_t {
79                ncclDataType_t::$variant
80            }
81        }
82        impl sealed::Sealed for $ty {}
83    };
84}
85
86impl_nccl_scalar!(i8, Int8);
87impl_nccl_scalar!(u8, Uint8);
88impl_nccl_scalar!(i32, Int32);
89impl_nccl_scalar!(u32, Uint32);
90impl_nccl_scalar!(i64, Int64);
91impl_nccl_scalar!(u64, Uint64);
92impl_nccl_scalar!(f32, Float32);
93impl_nccl_scalar!(f64, Float64);
94
95// Half-precision types from the `half` crate. Gated on `half-crate`
96// (which transitively pulls in `baracuda-types/half-crate` so the
97// `DeviceRepr` supertrait is already satisfied).
98#[cfg(feature = "half-crate")]
99impl_nccl_scalar!(half::f16, Float16);
100#[cfg(feature = "half-crate")]
101impl_nccl_scalar!(half::bf16, BFloat16);
102
103mod sealed {
104    /// Seal so only baracuda-authorized types implement `NcclScalar`.
105    /// Extra impls under feature gates are added directly on the sealed
106    /// trait in the parent module via `impl_nccl_scalar!`.
107    pub trait Sealed {}
108}
109
110#[cfg(all(test, feature = "half-crate"))]
111mod half_scalar_tests {
112    use super::*;
113
114    #[test]
115    fn half_types_are_nccl_scalars() {
116        fn require_scalar<T: NcclScalar>() -> ncclDataType_t {
117            T::raw()
118        }
119        assert_eq!(
120            require_scalar::<half::f16>(),
121            ncclDataType_t::Float16,
122            "half::f16 must map to ncclFloat16"
123        );
124        assert_eq!(
125            require_scalar::<half::bf16>(),
126            ncclDataType_t::BFloat16,
127            "half::bf16 must map to ncclBfloat16"
128        );
129    }
130}
131
132/// A 128-byte opaque identifier for establishing a multi-process NCCL
133/// communicator. One process calls [`UniqueId::new`] and distributes the
134/// bytes to all other processes via a user-provided channel (TCP, MPI, …);
135/// every process then calls [`Communicator::init_rank`] with the same id.
136#[derive(Copy, Clone, Debug)]
137pub struct UniqueId(ncclUniqueId);
138
139impl UniqueId {
140    /// Generate a fresh unique id on this process.
141    pub fn new() -> Result<Self> {
142        let n = nccl()?;
143        let cu = n.nccl_get_unique_id()?;
144        let mut id = ncclUniqueId::default();
145        check(unsafe { cu(&mut id) })?;
146        Ok(Self(id))
147    }
148
149    /// Raw 128-byte representation. Transmit over the wire as-is.
150    pub fn as_bytes(&self) -> [u8; 128] {
151        let mut out = [0u8; 128];
152        for (o, b) in out.iter_mut().zip(&self.0.internal) {
153            *o = *b as u8;
154        }
155        out
156    }
157
158    /// Rebuild from the 128 bytes received from another process.
159    pub fn from_bytes(bytes: [u8; 128]) -> Self {
160        let mut id = ncclUniqueId::default();
161        for (i, b) in id.internal.iter_mut().zip(&bytes) {
162            *i = *b as i8;
163        }
164        Self(id)
165    }
166}
167
168/// A NCCL communicator — one rank's view of a distributed group.
169pub struct Communicator {
170    handle: ncclComm_t,
171}
172
173unsafe impl Send for Communicator {}
174
175impl core::fmt::Debug for Communicator {
176    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
177        f.debug_struct("nccl::Communicator")
178            .field("handle", &self.handle)
179            .finish()
180    }
181}
182
183impl Communicator {
184    /// Initialize `ndev` communicators (one per device) in this process.
185    /// The returned vector is ordered to match `devices`.
186    ///
187    /// This is the single-process "data-parallel on local GPUs" path.
188    pub fn init_all(devices: &[i32]) -> Result<Vec<Self>> {
189        let n = nccl()?;
190        let cu = n.nccl_comm_init_all()?;
191        let ndev = devices.len() as core::ffi::c_int;
192        let mut comms = vec![core::ptr::null_mut::<core::ffi::c_void>(); devices.len()];
193        check(unsafe { cu(comms.as_mut_ptr(), ndev, devices.as_ptr()) })?;
194        Ok(comms.into_iter().map(|handle| Self { handle }).collect())
195    }
196
197    /// Initialize one rank of a multi-process communicator.
198    pub fn init_rank(nranks: i32, id: UniqueId, rank: i32) -> Result<Self> {
199        let n = nccl()?;
200        let cu = n.nccl_comm_init_rank()?;
201        let mut handle: ncclComm_t = core::ptr::null_mut();
202        check(unsafe { cu(&mut handle, nranks, id.0, rank) })?;
203        Ok(Self { handle })
204    }
205
206    /// Like [`Self::init_rank`] but takes a pointer to a configured
207    /// `ncclConfig_t`. NCCL 2.13+. Pass `core::ptr::null_mut()` for
208    /// defaults — equivalent to [`Self::init_rank`]. The struct shape
209    /// (blocking flag, CGA cluster size, splitShare, netName, …)
210    /// changes between NCCL versions, so we don't model it as a typed
211    /// Rust struct; build it through the C API or as a `[u8; N]`.
212    ///
213    /// # Safety
214    ///
215    /// `config` must be a properly-initialized `ncclConfig_t` for the
216    /// installed NCCL version, or null.
217    pub unsafe fn init_rank_config(
218        nranks: i32,
219        id: UniqueId,
220        rank: i32,
221        config: *mut core::ffi::c_void,
222    ) -> Result<Self> {
223        let n = nccl()?;
224        let cu = n.nccl_comm_init_rank_config()?;
225        let mut handle: ncclComm_t = core::ptr::null_mut();
226        check(cu(&mut handle, nranks, id.0, rank, config))?;
227        Ok(Self { handle })
228    }
229
230    /// Number of ranks in the communicator.
231    pub fn nranks(&self) -> Result<i32> {
232        let n = nccl()?;
233        let cu = n.nccl_comm_count()?;
234        let mut c: core::ffi::c_int = 0;
235        check(unsafe { cu(self.handle, &mut c) })?;
236        Ok(c)
237    }
238
239    /// Rank of this communicator within the group.
240    pub fn rank(&self) -> Result<i32> {
241        let n = nccl()?;
242        let cu = n.nccl_comm_user_rank()?;
243        let mut r: core::ffi::c_int = 0;
244        check(unsafe { cu(self.handle, &mut r) })?;
245        Ok(r)
246    }
247
248    /// Raw communicator handle. Use with care.
249    #[inline]
250    pub fn as_raw(&self) -> ncclComm_t {
251        self.handle
252    }
253}
254
255impl Drop for Communicator {
256    fn drop(&mut self) {
257        if let Ok(n) = nccl() {
258            if let Ok(cu) = n.nccl_comm_destroy() {
259                let _ = unsafe { cu(self.handle) };
260            }
261        }
262    }
263}
264
265/// All-reduce: each rank sends `send` and receives the per-element
266/// reduction (across every rank) into `recv`. In-place use (`send == recv`) is legal.
267#[allow(clippy::too_many_arguments)]
268pub fn all_reduce<T: NcclScalar>(
269    send: &DeviceBuffer<T>,
270    recv: &mut DeviceBuffer<T>,
271    count: usize,
272    op: RedOp,
273    comm: &Communicator,
274    stream: &Stream,
275) -> Result<()> {
276    assert!(send.len() >= count && recv.len() >= count);
277    let n = nccl()?;
278    let cu = n.nccl_all_reduce()?;
279    check(unsafe {
280        cu(
281            send.as_raw().0 as *const core::ffi::c_void,
282            recv.as_raw().0 as *mut core::ffi::c_void,
283            count,
284            T::raw(),
285            op.raw(),
286            comm.handle,
287            stream.as_raw() as _,
288        )
289    })
290}
291
292/// Broadcast the data at `root`'s `send` buffer to every rank's `recv` buffer.
293pub fn broadcast<T: NcclScalar>(
294    send: &DeviceBuffer<T>,
295    recv: &mut DeviceBuffer<T>,
296    count: usize,
297    root: i32,
298    comm: &Communicator,
299    stream: &Stream,
300) -> Result<()> {
301    let n = nccl()?;
302    let cu = n.nccl_broadcast()?;
303    check(unsafe {
304        cu(
305            send.as_raw().0 as *const core::ffi::c_void,
306            recv.as_raw().0 as *mut core::ffi::c_void,
307            count,
308            T::raw(),
309            root,
310            comm.handle,
311            stream.as_raw() as _,
312        )
313    })
314}
315
316/// Begin a group of collectives that must be submitted atomically (e.g.
317/// in single-process multi-GPU all-reduce).
318pub fn group_start() -> Result<()> {
319    let n = nccl()?;
320    let cu = n.nccl_group_start()?;
321    check(unsafe { cu() })
322}
323
324/// End the current collective group.
325pub fn group_end() -> Result<()> {
326    let n = nccl()?;
327    let cu = n.nccl_group_end()?;
328    check(unsafe { cu() })
329}
330
331/// NCCL library version as a packed integer (e.g. `22100` for NCCL 2.21.0).
332pub fn version() -> Result<i32> {
333    let n = nccl()?;
334    let cu = n.nccl_get_version()?;
335    let mut v: core::ffi::c_int = 0;
336    check(unsafe { cu(&mut v) })?;
337    Ok(v)
338}
339
340/// Human-readable name for a status code.
341pub fn error_string(status: ncclResult_t) -> Result<&'static str> {
342    let n = nccl()?;
343    let cu = n.nccl_get_error_string()?;
344    let p = unsafe { cu(status) };
345    if p.is_null() {
346        return Ok("unknown");
347    }
348    Ok(unsafe { core::ffi::CStr::from_ptr(p) }
349        .to_str()
350        .unwrap_or("unknown"))
351}
352
353// ---- Full collective surface ----
354
355impl Communicator {
356    /// `recvbuf = reduce(sendbuf[root])` on root only; non-root `recvbuf` is unchanged.
357    pub fn reduce<T: NcclScalar>(
358        &self,
359        sendbuf: &DeviceBuffer<T>,
360        recvbuf: &mut DeviceBuffer<T>,
361        count: usize,
362        op: RedOp,
363        root: i32,
364        stream: &Stream,
365    ) -> Result<()> {
366        let n = nccl()?;
367        let cu = n.nccl_reduce()?;
368        check(unsafe {
369            cu(
370                sendbuf.as_raw().0 as *const core::ffi::c_void,
371                recvbuf.as_raw().0 as *mut core::ffi::c_void,
372                count,
373                T::raw(),
374                op.raw(),
375                root,
376                self.handle,
377                stream.as_raw(),
378            )
379        })
380    }
381
382    /// `recvbuf[r * sendcount..] = sendbuf` from rank `r`.
383    pub fn all_gather<T: NcclScalar>(
384        &self,
385        sendbuf: &DeviceBuffer<T>,
386        recvbuf: &mut DeviceBuffer<T>,
387        sendcount: usize,
388        stream: &Stream,
389    ) -> Result<()> {
390        let n = nccl()?;
391        let cu = n.nccl_all_gather()?;
392        check(unsafe {
393            cu(
394                sendbuf.as_raw().0 as *const core::ffi::c_void,
395                recvbuf.as_raw().0 as *mut core::ffi::c_void,
396                sendcount,
397                T::raw(),
398                self.handle,
399                stream.as_raw(),
400            )
401        })
402    }
403
404    /// Combined reduce + scatter: `recvbuf = reduce(sendbuf[r * recvcount..])`
405    /// across ranks r = 0..nranks.
406    pub fn reduce_scatter<T: NcclScalar>(
407        &self,
408        sendbuf: &DeviceBuffer<T>,
409        recvbuf: &mut DeviceBuffer<T>,
410        recvcount: usize,
411        op: RedOp,
412        stream: &Stream,
413    ) -> Result<()> {
414        let n = nccl()?;
415        let cu = n.nccl_reduce_scatter()?;
416        check(unsafe {
417            cu(
418                sendbuf.as_raw().0 as *const core::ffi::c_void,
419                recvbuf.as_raw().0 as *mut core::ffi::c_void,
420                recvcount,
421                T::raw(),
422                op.raw(),
423                self.handle,
424                stream.as_raw(),
425            )
426        })
427    }
428
429    /// Point-to-point send to `peer`. Pair with [`Self::recv`] inside a
430    /// group-call bracket.
431    pub fn send<T: NcclScalar>(
432        &self,
433        sendbuf: &DeviceBuffer<T>,
434        count: usize,
435        peer: i32,
436        stream: &Stream,
437    ) -> Result<()> {
438        let n = nccl()?;
439        let cu = n.nccl_send()?;
440        check(unsafe {
441            cu(
442                sendbuf.as_raw().0 as *const core::ffi::c_void,
443                count,
444                T::raw(),
445                peer,
446                self.handle,
447                stream.as_raw(),
448            )
449        })
450    }
451
452    /// Point-to-point recv from `peer`.
453    pub fn recv<T: NcclScalar>(
454        &self,
455        recvbuf: &mut DeviceBuffer<T>,
456        count: usize,
457        peer: i32,
458        stream: &Stream,
459    ) -> Result<()> {
460        let n = nccl()?;
461        let cu = n.nccl_recv()?;
462        check(unsafe {
463            cu(
464                recvbuf.as_raw().0 as *mut core::ffi::c_void,
465                count,
466                T::raw(),
467                peer,
468                self.handle,
469                stream.as_raw(),
470            )
471        })
472    }
473
474    /// Abort all outstanding operations on this communicator. Forces
475    /// pending collectives to return with an error. Drop still destroys.
476    pub fn abort(&self) -> Result<()> {
477        let n = nccl()?;
478        let cu = n.nccl_comm_abort()?;
479        check(unsafe { cu(self.handle) })
480    }
481
482    /// Mark the communicator as done. After `finalize` you can still
483    /// call [`Communicator::get_async_error`] but no new collectives.
484    pub fn finalize(&self) -> Result<()> {
485        let n = nccl()?;
486        let cu = n.nccl_comm_finalize()?;
487        check(unsafe { cu(self.handle) })
488    }
489
490    /// Poll the communicator's async error state (non-blocking).
491    /// Returns `Ok(Success)` if there's no pending error.
492    pub fn get_async_error(&self) -> Result<ncclResult_t> {
493        let n = nccl()?;
494        let cu = n.nccl_comm_get_async_error()?;
495        let mut s = ncclResult_t::Success;
496        check(unsafe { cu(self.handle, &mut s) })?;
497        Ok(s)
498    }
499
500    /// CUDA device ordinal this communicator is bound to.
501    pub fn cuda_device(&self) -> Result<i32> {
502        let n = nccl()?;
503        let cu = n.nccl_comm_cu_device()?;
504        let mut d: core::ffi::c_int = 0;
505        check(unsafe { cu(self.handle, &mut d) })?;
506        Ok(d)
507    }
508
509    /// Split a communicator — ranks with the same `color` end up in the
510    /// same new communicator, ordered by `key`. Pass `color = -1` to
511    /// drop a rank from the new communicator.
512    pub fn split(&self, color: i32, key: i32) -> Result<Communicator> {
513        let n = nccl()?;
514        let cu = n.nccl_comm_split()?;
515        let mut new_comm: ncclComm_t = core::ptr::null_mut();
516        check(unsafe { cu(self.handle, color, key, &mut new_comm, core::ptr::null_mut()) })?;
517        Ok(Communicator { handle: new_comm })
518    }
519
520    /// Register a device buffer for zero-copy collective use. Returns an
521    /// opaque handle to pass to [`Self::deregister`] later.
522    ///
523    /// # Safety
524    ///
525    /// `dev_ptr` must be a live device-memory allocation.
526    pub unsafe fn register(
527        &self,
528        dev_ptr: *mut core::ffi::c_void,
529        size: usize,
530    ) -> Result<*mut core::ffi::c_void> {
531        let n = nccl()?;
532        let cu = n.nccl_comm_register()?;
533        let mut handle: *mut core::ffi::c_void = core::ptr::null_mut();
534        check(cu(self.handle, dev_ptr, size, &mut handle))?;
535        Ok(handle)
536    }
537
538    /// Deregister a previously-registered buffer.
539    ///
540    /// # Safety
541    ///
542    /// `handle` must come from a [`Self::register`] call on this comm.
543    pub unsafe fn deregister(&self, handle: *mut core::ffi::c_void) -> Result<()> {
544        let n = nccl()?;
545        let cu = n.nccl_comm_deregister()?;
546        check(cu(self.handle, handle))
547    }
548
549    /// Create a custom pre-multiplied-sum reduction op:
550    /// `out = sum_i (scalar * x_i)`. Use the returned [`RedOp::Custom`]
551    /// in any subsequent [`all_reduce`] / [`Communicator::reduce`] /
552    /// [`Communicator::reduce_scatter`] on this communicator.
553    /// Destroy it with [`Self::destroy_red_op`] when you're done.
554    /// NCCL 2.11+.
555    ///
556    /// # Safety
557    ///
558    /// `scalar` must point to a single value of type `T` whose
559    /// residence matches `residence` (host or device memory) and stay
560    /// valid until the next collective using this op completes.
561    pub unsafe fn create_pre_mul_sum<T: NcclScalar>(
562        &self,
563        scalar: *mut core::ffi::c_void,
564        residence: ScalarResidence,
565    ) -> Result<RedOp> {
566        let n = nccl()?;
567        let cu = n.nccl_red_op_create_pre_mul_sum()?;
568        let mut op = ncclRedOp_t(0);
569        check(cu(&mut op, scalar, T::raw(), residence as i32, self.handle))?;
570        Ok(RedOp::Custom(op.0))
571    }
572
573    /// Destroy a custom op previously returned by [`Self::create_pre_mul_sum`].
574    /// NCCL 2.11+. Calling on a built-in op (Sum/Prod/Max/Min/Avg) is a
575    /// no-op error from NCCL — guard against that yourself.
576    pub fn destroy_red_op(&self, op: RedOp) -> Result<()> {
577        let n = nccl()?;
578        let cu = n.nccl_red_op_destroy()?;
579        check(unsafe { cu(op.raw(), self.handle) })
580    }
581
582    /// Most recent error string produced on this communicator.
583    /// NCCL 2.13+. Returns `"unknown"` if the loader can't resolve
584    /// the symbol or the C library returns null.
585    pub fn last_error(&self) -> Result<&'static str> {
586        let n = nccl()?;
587        let cu = n.nccl_get_last_error()?;
588        let p = unsafe { cu(self.handle) };
589        if p.is_null() {
590            return Ok("unknown");
591        }
592        Ok(unsafe { core::ffi::CStr::from_ptr(p) }
593            .to_str()
594            .unwrap_or("unknown"))
595    }
596}
597
598/// NCCL-managed device allocation. Drop calls `ncclMemFree`.
599#[derive(Debug)]
600pub struct NcclMem {
601    ptr: *mut core::ffi::c_void,
602}
603
604impl NcclMem {
605    /// Allocate `size` bytes through NCCL — these are GPU-direct-
606    /// friendly (pre-registered with the transport). Use with
607    /// [`Communicator::register`] for zero-copy collectives.
608    pub fn new(size: usize) -> Result<Self> {
609        let n = nccl()?;
610        let cu = n.nccl_mem_alloc()?;
611        let mut p: *mut core::ffi::c_void = core::ptr::null_mut();
612        check(unsafe { cu(&mut p, size) })?;
613        Ok(Self { ptr: p })
614    }
615
616    #[inline]
617    pub fn as_raw(&self) -> *mut core::ffi::c_void {
618        self.ptr
619    }
620}
621
622impl Drop for NcclMem {
623    fn drop(&mut self) {
624        if let Ok(n) = nccl() {
625            if let Ok(cu) = n.nccl_mem_free() {
626                let _ = unsafe { cu(self.ptr) };
627            }
628        }
629    }
630}