1#![warn(missing_debug_implementations)]
14
15use baracuda_driver::{DeviceBuffer, Stream};
16use baracuda_nccl_sys::{
17 nccl, ncclComm_t, ncclDataType_t, ncclRedOp_t, ncclResult_t, ncclUniqueId,
18};
19use baracuda_types::DeviceRepr;
20
21pub type Error = baracuda_core::Error<ncclResult_t>;
23pub type Result<T, E = Error> = core::result::Result<T, E>;
25
26#[inline]
27fn check(status: ncclResult_t) -> Result<()> {
28 Error::check(status)
29}
30
31#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
33pub enum RedOp {
34 #[default]
35 Sum,
36 Prod,
37 Max,
38 Min,
39 Avg,
41 Custom(i32),
44}
45
46impl RedOp {
47 fn raw(self) -> ncclRedOp_t {
48 match self {
49 RedOp::Sum => ncclRedOp_t::Sum,
50 RedOp::Prod => ncclRedOp_t::Prod,
51 RedOp::Max => ncclRedOp_t::Max,
52 RedOp::Min => ncclRedOp_t::Min,
53 RedOp::Avg => ncclRedOp_t::Avg,
54 RedOp::Custom(id) => ncclRedOp_t(id),
55 }
56 }
57}
58
59#[derive(Copy, Clone, Debug, Eq, PartialEq)]
61pub enum ScalarResidence {
62 Host = 0,
64 Device = 1,
66}
67
68pub trait NcclScalar: DeviceRepr + sealed::Sealed {
71 #[doc(hidden)]
72 fn raw() -> ncclDataType_t;
73}
74
75macro_rules! impl_nccl_scalar {
76 ($ty:ty, $variant:ident) => {
77 impl NcclScalar for $ty {
78 fn raw() -> ncclDataType_t {
79 ncclDataType_t::$variant
80 }
81 }
82 impl sealed::Sealed for $ty {}
83 };
84}
85
86impl_nccl_scalar!(i8, Int8);
87impl_nccl_scalar!(u8, Uint8);
88impl_nccl_scalar!(i32, Int32);
89impl_nccl_scalar!(u32, Uint32);
90impl_nccl_scalar!(i64, Int64);
91impl_nccl_scalar!(u64, Uint64);
92impl_nccl_scalar!(f32, Float32);
93impl_nccl_scalar!(f64, Float64);
94
95#[cfg(feature = "half-crate")]
99impl_nccl_scalar!(half::f16, Float16);
100#[cfg(feature = "half-crate")]
101impl_nccl_scalar!(half::bf16, BFloat16);
102
103mod sealed {
104 pub trait Sealed {}
108}
109
110#[cfg(all(test, feature = "half-crate"))]
111mod half_scalar_tests {
112 use super::*;
113
114 #[test]
115 fn half_types_are_nccl_scalars() {
116 fn require_scalar<T: NcclScalar>() -> ncclDataType_t {
117 T::raw()
118 }
119 assert_eq!(
120 require_scalar::<half::f16>(),
121 ncclDataType_t::Float16,
122 "half::f16 must map to ncclFloat16"
123 );
124 assert_eq!(
125 require_scalar::<half::bf16>(),
126 ncclDataType_t::BFloat16,
127 "half::bf16 must map to ncclBfloat16"
128 );
129 }
130}
131
132#[derive(Copy, Clone, Debug)]
137pub struct UniqueId(ncclUniqueId);
138
139impl UniqueId {
140 pub fn new() -> Result<Self> {
142 let n = nccl()?;
143 let cu = n.nccl_get_unique_id()?;
144 let mut id = ncclUniqueId::default();
145 check(unsafe { cu(&mut id) })?;
146 Ok(Self(id))
147 }
148
149 pub fn as_bytes(&self) -> [u8; 128] {
151 let mut out = [0u8; 128];
152 for (o, b) in out.iter_mut().zip(&self.0.internal) {
153 *o = *b as u8;
154 }
155 out
156 }
157
158 pub fn from_bytes(bytes: [u8; 128]) -> Self {
160 let mut id = ncclUniqueId::default();
161 for (i, b) in id.internal.iter_mut().zip(&bytes) {
162 *i = *b as i8;
163 }
164 Self(id)
165 }
166}
167
168pub struct Communicator {
170 handle: ncclComm_t,
171}
172
173unsafe impl Send for Communicator {}
174
175impl core::fmt::Debug for Communicator {
176 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
177 f.debug_struct("nccl::Communicator")
178 .field("handle", &self.handle)
179 .finish()
180 }
181}
182
183impl Communicator {
184 pub fn init_all(devices: &[i32]) -> Result<Vec<Self>> {
189 let n = nccl()?;
190 let cu = n.nccl_comm_init_all()?;
191 let ndev = devices.len() as core::ffi::c_int;
192 let mut comms = vec![core::ptr::null_mut::<core::ffi::c_void>(); devices.len()];
193 check(unsafe { cu(comms.as_mut_ptr(), ndev, devices.as_ptr()) })?;
194 Ok(comms.into_iter().map(|handle| Self { handle }).collect())
195 }
196
197 pub fn init_rank(nranks: i32, id: UniqueId, rank: i32) -> Result<Self> {
199 let n = nccl()?;
200 let cu = n.nccl_comm_init_rank()?;
201 let mut handle: ncclComm_t = core::ptr::null_mut();
202 check(unsafe { cu(&mut handle, nranks, id.0, rank) })?;
203 Ok(Self { handle })
204 }
205
206 pub unsafe fn init_rank_config(
218 nranks: i32,
219 id: UniqueId,
220 rank: i32,
221 config: *mut core::ffi::c_void,
222 ) -> Result<Self> {
223 let n = nccl()?;
224 let cu = n.nccl_comm_init_rank_config()?;
225 let mut handle: ncclComm_t = core::ptr::null_mut();
226 check(cu(&mut handle, nranks, id.0, rank, config))?;
227 Ok(Self { handle })
228 }
229
230 pub fn nranks(&self) -> Result<i32> {
232 let n = nccl()?;
233 let cu = n.nccl_comm_count()?;
234 let mut c: core::ffi::c_int = 0;
235 check(unsafe { cu(self.handle, &mut c) })?;
236 Ok(c)
237 }
238
239 pub fn rank(&self) -> Result<i32> {
241 let n = nccl()?;
242 let cu = n.nccl_comm_user_rank()?;
243 let mut r: core::ffi::c_int = 0;
244 check(unsafe { cu(self.handle, &mut r) })?;
245 Ok(r)
246 }
247
248 #[inline]
250 pub fn as_raw(&self) -> ncclComm_t {
251 self.handle
252 }
253}
254
255impl Drop for Communicator {
256 fn drop(&mut self) {
257 if let Ok(n) = nccl() {
258 if let Ok(cu) = n.nccl_comm_destroy() {
259 let _ = unsafe { cu(self.handle) };
260 }
261 }
262 }
263}
264
265#[allow(clippy::too_many_arguments)]
268pub fn all_reduce<T: NcclScalar>(
269 send: &DeviceBuffer<T>,
270 recv: &mut DeviceBuffer<T>,
271 count: usize,
272 op: RedOp,
273 comm: &Communicator,
274 stream: &Stream,
275) -> Result<()> {
276 assert!(send.len() >= count && recv.len() >= count);
277 let n = nccl()?;
278 let cu = n.nccl_all_reduce()?;
279 check(unsafe {
280 cu(
281 send.as_raw().0 as *const core::ffi::c_void,
282 recv.as_raw().0 as *mut core::ffi::c_void,
283 count,
284 T::raw(),
285 op.raw(),
286 comm.handle,
287 stream.as_raw() as _,
288 )
289 })
290}
291
292pub fn broadcast<T: NcclScalar>(
294 send: &DeviceBuffer<T>,
295 recv: &mut DeviceBuffer<T>,
296 count: usize,
297 root: i32,
298 comm: &Communicator,
299 stream: &Stream,
300) -> Result<()> {
301 let n = nccl()?;
302 let cu = n.nccl_broadcast()?;
303 check(unsafe {
304 cu(
305 send.as_raw().0 as *const core::ffi::c_void,
306 recv.as_raw().0 as *mut core::ffi::c_void,
307 count,
308 T::raw(),
309 root,
310 comm.handle,
311 stream.as_raw() as _,
312 )
313 })
314}
315
316pub fn group_start() -> Result<()> {
319 let n = nccl()?;
320 let cu = n.nccl_group_start()?;
321 check(unsafe { cu() })
322}
323
324pub fn group_end() -> Result<()> {
326 let n = nccl()?;
327 let cu = n.nccl_group_end()?;
328 check(unsafe { cu() })
329}
330
331pub fn version() -> Result<i32> {
333 let n = nccl()?;
334 let cu = n.nccl_get_version()?;
335 let mut v: core::ffi::c_int = 0;
336 check(unsafe { cu(&mut v) })?;
337 Ok(v)
338}
339
340pub fn error_string(status: ncclResult_t) -> Result<&'static str> {
342 let n = nccl()?;
343 let cu = n.nccl_get_error_string()?;
344 let p = unsafe { cu(status) };
345 if p.is_null() {
346 return Ok("unknown");
347 }
348 Ok(unsafe { core::ffi::CStr::from_ptr(p) }
349 .to_str()
350 .unwrap_or("unknown"))
351}
352
353impl Communicator {
356 pub fn reduce<T: NcclScalar>(
358 &self,
359 sendbuf: &DeviceBuffer<T>,
360 recvbuf: &mut DeviceBuffer<T>,
361 count: usize,
362 op: RedOp,
363 root: i32,
364 stream: &Stream,
365 ) -> Result<()> {
366 let n = nccl()?;
367 let cu = n.nccl_reduce()?;
368 check(unsafe {
369 cu(
370 sendbuf.as_raw().0 as *const core::ffi::c_void,
371 recvbuf.as_raw().0 as *mut core::ffi::c_void,
372 count,
373 T::raw(),
374 op.raw(),
375 root,
376 self.handle,
377 stream.as_raw(),
378 )
379 })
380 }
381
382 pub fn all_gather<T: NcclScalar>(
384 &self,
385 sendbuf: &DeviceBuffer<T>,
386 recvbuf: &mut DeviceBuffer<T>,
387 sendcount: usize,
388 stream: &Stream,
389 ) -> Result<()> {
390 let n = nccl()?;
391 let cu = n.nccl_all_gather()?;
392 check(unsafe {
393 cu(
394 sendbuf.as_raw().0 as *const core::ffi::c_void,
395 recvbuf.as_raw().0 as *mut core::ffi::c_void,
396 sendcount,
397 T::raw(),
398 self.handle,
399 stream.as_raw(),
400 )
401 })
402 }
403
404 pub fn reduce_scatter<T: NcclScalar>(
407 &self,
408 sendbuf: &DeviceBuffer<T>,
409 recvbuf: &mut DeviceBuffer<T>,
410 recvcount: usize,
411 op: RedOp,
412 stream: &Stream,
413 ) -> Result<()> {
414 let n = nccl()?;
415 let cu = n.nccl_reduce_scatter()?;
416 check(unsafe {
417 cu(
418 sendbuf.as_raw().0 as *const core::ffi::c_void,
419 recvbuf.as_raw().0 as *mut core::ffi::c_void,
420 recvcount,
421 T::raw(),
422 op.raw(),
423 self.handle,
424 stream.as_raw(),
425 )
426 })
427 }
428
429 pub fn send<T: NcclScalar>(
432 &self,
433 sendbuf: &DeviceBuffer<T>,
434 count: usize,
435 peer: i32,
436 stream: &Stream,
437 ) -> Result<()> {
438 let n = nccl()?;
439 let cu = n.nccl_send()?;
440 check(unsafe {
441 cu(
442 sendbuf.as_raw().0 as *const core::ffi::c_void,
443 count,
444 T::raw(),
445 peer,
446 self.handle,
447 stream.as_raw(),
448 )
449 })
450 }
451
452 pub fn recv<T: NcclScalar>(
454 &self,
455 recvbuf: &mut DeviceBuffer<T>,
456 count: usize,
457 peer: i32,
458 stream: &Stream,
459 ) -> Result<()> {
460 let n = nccl()?;
461 let cu = n.nccl_recv()?;
462 check(unsafe {
463 cu(
464 recvbuf.as_raw().0 as *mut core::ffi::c_void,
465 count,
466 T::raw(),
467 peer,
468 self.handle,
469 stream.as_raw(),
470 )
471 })
472 }
473
474 pub fn abort(&self) -> Result<()> {
477 let n = nccl()?;
478 let cu = n.nccl_comm_abort()?;
479 check(unsafe { cu(self.handle) })
480 }
481
482 pub fn finalize(&self) -> Result<()> {
485 let n = nccl()?;
486 let cu = n.nccl_comm_finalize()?;
487 check(unsafe { cu(self.handle) })
488 }
489
490 pub fn get_async_error(&self) -> Result<ncclResult_t> {
493 let n = nccl()?;
494 let cu = n.nccl_comm_get_async_error()?;
495 let mut s = ncclResult_t::Success;
496 check(unsafe { cu(self.handle, &mut s) })?;
497 Ok(s)
498 }
499
500 pub fn cuda_device(&self) -> Result<i32> {
502 let n = nccl()?;
503 let cu = n.nccl_comm_cu_device()?;
504 let mut d: core::ffi::c_int = 0;
505 check(unsafe { cu(self.handle, &mut d) })?;
506 Ok(d)
507 }
508
509 pub fn split(&self, color: i32, key: i32) -> Result<Communicator> {
513 let n = nccl()?;
514 let cu = n.nccl_comm_split()?;
515 let mut new_comm: ncclComm_t = core::ptr::null_mut();
516 check(unsafe { cu(self.handle, color, key, &mut new_comm, core::ptr::null_mut()) })?;
517 Ok(Communicator { handle: new_comm })
518 }
519
520 pub unsafe fn register(
527 &self,
528 dev_ptr: *mut core::ffi::c_void,
529 size: usize,
530 ) -> Result<*mut core::ffi::c_void> {
531 let n = nccl()?;
532 let cu = n.nccl_comm_register()?;
533 let mut handle: *mut core::ffi::c_void = core::ptr::null_mut();
534 check(cu(self.handle, dev_ptr, size, &mut handle))?;
535 Ok(handle)
536 }
537
538 pub unsafe fn deregister(&self, handle: *mut core::ffi::c_void) -> Result<()> {
544 let n = nccl()?;
545 let cu = n.nccl_comm_deregister()?;
546 check(cu(self.handle, handle))
547 }
548
549 pub unsafe fn create_pre_mul_sum<T: NcclScalar>(
562 &self,
563 scalar: *mut core::ffi::c_void,
564 residence: ScalarResidence,
565 ) -> Result<RedOp> {
566 let n = nccl()?;
567 let cu = n.nccl_red_op_create_pre_mul_sum()?;
568 let mut op = ncclRedOp_t(0);
569 check(cu(&mut op, scalar, T::raw(), residence as i32, self.handle))?;
570 Ok(RedOp::Custom(op.0))
571 }
572
573 pub fn destroy_red_op(&self, op: RedOp) -> Result<()> {
577 let n = nccl()?;
578 let cu = n.nccl_red_op_destroy()?;
579 check(unsafe { cu(op.raw(), self.handle) })
580 }
581
582 pub fn last_error(&self) -> Result<&'static str> {
586 let n = nccl()?;
587 let cu = n.nccl_get_last_error()?;
588 let p = unsafe { cu(self.handle) };
589 if p.is_null() {
590 return Ok("unknown");
591 }
592 Ok(unsafe { core::ffi::CStr::from_ptr(p) }
593 .to_str()
594 .unwrap_or("unknown"))
595 }
596}
597
598#[derive(Debug)]
600pub struct NcclMem {
601 ptr: *mut core::ffi::c_void,
602}
603
604impl NcclMem {
605 pub fn new(size: usize) -> Result<Self> {
609 let n = nccl()?;
610 let cu = n.nccl_mem_alloc()?;
611 let mut p: *mut core::ffi::c_void = core::ptr::null_mut();
612 check(unsafe { cu(&mut p, size) })?;
613 Ok(Self { ptr: p })
614 }
615
616 #[inline]
617 pub fn as_raw(&self) -> *mut core::ffi::c_void {
618 self.ptr
619 }
620}
621
622impl Drop for NcclMem {
623 fn drop(&mut self) {
624 if let Ok(n) = nccl() {
625 if let Ok(cu) = n.nccl_mem_free() {
626 let _ = unsafe { cu(self.ptr) };
627 }
628 }
629 }
630}