1#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
13#![warn(missing_debug_implementations)]
14
15use core::ffi::{c_int, c_void};
16use std::sync::OnceLock;
17
18use baracuda_core::{Library, LoaderError};
19use baracuda_cuda_sys::runtime::cudaStream_t;
20use baracuda_types::CudaStatus;
21
22pub type ncclComm_t = *mut c_void;
24
25#[repr(C)]
27#[derive(Copy, Clone)]
28pub struct ncclUniqueId {
29 pub internal: [i8; 128],
31}
32
33impl core::fmt::Debug for ncclUniqueId {
34 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
35 f.debug_struct("ncclUniqueId").finish_non_exhaustive()
36 }
37}
38
39impl Default for ncclUniqueId {
40 fn default() -> Self {
41 Self { internal: [0; 128] }
42 }
43}
44
45#[repr(i32)]
47#[derive(Copy, Clone, Debug, Eq, PartialEq)]
48pub enum ncclDataType_t {
49 Int8 = 0,
51 Uint8 = 1,
53 Int32 = 2,
55 Uint32 = 3,
57 Int64 = 4,
59 Uint64 = 5,
61 Float16 = 6,
63 Float32 = 7,
65 Float64 = 8,
67 BFloat16 = 9,
69}
70
71#[repr(transparent)]
75#[derive(Copy, Clone, Debug, Eq, PartialEq)]
76#[allow(non_camel_case_types)]
77pub struct ncclRedOp_t(pub i32);
78
79#[allow(non_upper_case_globals)]
80impl ncclRedOp_t {
81 pub const Sum: Self = Self(0);
83 pub const Prod: Self = Self(1);
85 pub const Max: Self = Self(2);
87 pub const Min: Self = Self(3);
89 pub const Avg: Self = Self(4);
91}
92
93#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
97#[repr(transparent)]
98pub struct ncclResult_t(pub i32);
99
100impl ncclResult_t {
101 pub const Success: Self = Self(0);
103 pub const UnhandledCudaError: Self = Self(1);
105 pub const SystemError: Self = Self(2);
107 pub const InternalError: Self = Self(3);
109 pub const InvalidArgument: Self = Self(4);
111 pub const InvalidUsage: Self = Self(5);
113 pub const RemoteError: Self = Self(6);
115 pub const InProgress: Self = Self(7);
117
118 pub const fn is_success(self) -> bool {
120 self.0 == 0
121 }
122}
123
124impl CudaStatus for ncclResult_t {
125 fn code(self) -> i32 {
126 self.0
127 }
128 fn name(self) -> &'static str {
129 match self.0 {
130 0 => "ncclSuccess",
131 1 => "ncclUnhandledCudaError",
132 2 => "ncclSystemError",
133 3 => "ncclInternalError",
134 4 => "ncclInvalidArgument",
135 5 => "ncclInvalidUsage",
136 6 => "ncclRemoteError",
137 7 => "ncclInProgress",
138 _ => "ncclUnrecognizedResult",
139 }
140 }
141 fn description(self) -> &'static str {
142 match self.0 {
143 0 => "success",
144 1 => "unhandled CUDA error",
145 2 => "system error",
146 3 => "internal NCCL error",
147 4 => "invalid argument",
148 5 => "invalid usage",
149 6 => "remote error (another rank failed)",
150 7 => "operation in progress (non-blocking comm)",
151 _ => "unrecognized NCCL status code",
152 }
153 }
154 fn is_success(self) -> bool {
155 ncclResult_t::is_success(self)
156 }
157 fn library(self) -> &'static str {
158 "nccl"
159 }
160}
161
162pub type PFN_ncclGetVersion = unsafe extern "C" fn(version: *mut c_int) -> ncclResult_t;
166pub type PFN_ncclGetUniqueId = unsafe extern "C" fn(id: *mut ncclUniqueId) -> ncclResult_t;
168pub type PFN_ncclCommInitRank = unsafe extern "C" fn(
170 comm: *mut ncclComm_t,
171 nranks: c_int,
172 comm_id: ncclUniqueId,
173 rank: c_int,
174) -> ncclResult_t;
175pub type PFN_ncclCommInitAll = unsafe extern "C" fn(
177 comms: *mut ncclComm_t,
178 ndev: c_int,
179 dev_list: *const c_int,
180) -> ncclResult_t;
181pub type PFN_ncclCommDestroy = unsafe extern "C" fn(comm: ncclComm_t) -> ncclResult_t;
183pub type PFN_ncclCommCount =
185 unsafe extern "C" fn(comm: ncclComm_t, count: *mut c_int) -> ncclResult_t;
186pub type PFN_ncclCommUserRank =
188 unsafe extern "C" fn(comm: ncclComm_t, rank: *mut c_int) -> ncclResult_t;
189
190pub type PFN_ncclAllReduce = unsafe extern "C" fn(
192 sendbuff: *const c_void,
193 recvbuff: *mut c_void,
194 count: usize,
195 datatype: ncclDataType_t,
196 op: ncclRedOp_t,
197 comm: ncclComm_t,
198 stream: cudaStream_t,
199) -> ncclResult_t;
200
201pub type PFN_ncclBroadcast = unsafe extern "C" fn(
203 sendbuff: *const c_void,
204 recvbuff: *mut c_void,
205 count: usize,
206 datatype: ncclDataType_t,
207 root: c_int,
208 comm: ncclComm_t,
209 stream: cudaStream_t,
210) -> ncclResult_t;
211
212pub type PFN_ncclGroupStart = unsafe extern "C" fn() -> ncclResult_t;
214pub type PFN_ncclGroupEnd = unsafe extern "C" fn() -> ncclResult_t;
216
217pub type PFN_ncclReduce = unsafe extern "C" fn(
221 sendbuff: *const c_void,
222 recvbuff: *mut c_void,
223 count: usize,
224 datatype: ncclDataType_t,
225 op: ncclRedOp_t,
226 root: c_int,
227 comm: ncclComm_t,
228 stream: cudaStream_t,
229) -> ncclResult_t;
230
231pub type PFN_ncclAllGather = unsafe extern "C" fn(
233 sendbuff: *const c_void,
234 recvbuff: *mut c_void,
235 sendcount: usize,
236 datatype: ncclDataType_t,
237 comm: ncclComm_t,
238 stream: cudaStream_t,
239) -> ncclResult_t;
240
241pub type PFN_ncclReduceScatter = unsafe extern "C" fn(
243 sendbuff: *const c_void,
244 recvbuff: *mut c_void,
245 recvcount: usize,
246 datatype: ncclDataType_t,
247 op: ncclRedOp_t,
248 comm: ncclComm_t,
249 stream: cudaStream_t,
250) -> ncclResult_t;
251
252pub type PFN_ncclSend = unsafe extern "C" fn(
254 sendbuff: *const c_void,
255 count: usize,
256 datatype: ncclDataType_t,
257 peer: c_int,
258 comm: ncclComm_t,
259 stream: cudaStream_t,
260) -> ncclResult_t;
261
262pub type PFN_ncclRecv = unsafe extern "C" fn(
264 recvbuff: *mut c_void,
265 count: usize,
266 datatype: ncclDataType_t,
267 peer: c_int,
268 comm: ncclComm_t,
269 stream: cudaStream_t,
270) -> ncclResult_t;
271
272pub type PFN_ncclCommAbort = unsafe extern "C" fn(comm: ncclComm_t) -> ncclResult_t;
276pub type PFN_ncclCommFinalize = unsafe extern "C" fn(comm: ncclComm_t) -> ncclResult_t;
278pub type PFN_ncclCommGetAsyncError =
280 unsafe extern "C" fn(comm: ncclComm_t, async_error: *mut ncclResult_t) -> ncclResult_t;
281pub type PFN_ncclCommCuDevice =
283 unsafe extern "C" fn(comm: ncclComm_t, device: *mut c_int) -> ncclResult_t;
284pub type PFN_ncclCommSplit = unsafe extern "C" fn(
286 comm: ncclComm_t,
287 color: c_int,
288 key: c_int,
289 new_comm: *mut ncclComm_t,
290 config: *mut c_void, ) -> ncclResult_t;
292
293pub type PFN_ncclCommInitRankConfig = unsafe extern "C" fn(
295 comm: *mut ncclComm_t,
296 nranks: c_int,
297 comm_id: ncclUniqueId,
298 rank: c_int,
299 config: *mut c_void, ) -> ncclResult_t;
301
302pub type PFN_ncclMemAlloc =
306 unsafe extern "C" fn(ptr: *mut *mut c_void, size: usize) -> ncclResult_t;
307pub type PFN_ncclMemFree = unsafe extern "C" fn(ptr: *mut c_void) -> ncclResult_t;
309
310pub type PFN_ncclCommRegister = unsafe extern "C" fn(
312 comm: ncclComm_t,
313 buff: *mut c_void,
314 size: usize,
315 handle: *mut *mut c_void,
316) -> ncclResult_t;
317
318pub type PFN_ncclCommDeregister =
320 unsafe extern "C" fn(comm: ncclComm_t, handle: *mut c_void) -> ncclResult_t;
321
322pub type PFN_ncclRedOpCreatePreMulSum = unsafe extern "C" fn(
326 op: *mut ncclRedOp_t,
327 scalar: *mut c_void,
328 datatype: ncclDataType_t,
329 residence: i32, comm: ncclComm_t,
331) -> ncclResult_t;
332
333pub type PFN_ncclRedOpDestroy =
335 unsafe extern "C" fn(op: ncclRedOp_t, comm: ncclComm_t) -> ncclResult_t;
336
337pub type PFN_ncclGetErrorString =
341 unsafe extern "C" fn(result: ncclResult_t) -> *const core::ffi::c_char;
342pub type PFN_ncclGetLastError =
344 unsafe extern "C" fn(comm: ncclComm_t) -> *const core::ffi::c_char;
345
346fn nccl_candidates() -> &'static [&'static str] {
349 #[cfg(target_os = "linux")]
350 {
351 &["libnccl.so.2", "libnccl.so"]
352 }
353 #[cfg(target_os = "windows")]
354 {
355 &["nccl.dll", "libnccl.dll"]
356 }
357 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
358 {
359 &[]
360 }
361}
362
363macro_rules! nccl_fns {
364 ($($name:ident as $sym:literal : $pfn:ty);* $(;)?) => {
365 pub struct Nccl {
367 lib: Library,
368 $($name: OnceLock<$pfn>,)*
369 }
370 impl core::fmt::Debug for Nccl {
371 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
372 f.debug_struct("Nccl").field("lib", &self.lib).finish_non_exhaustive()
373 }
374 }
375 impl Nccl {
376 $(
377 pub fn $name(&self) -> Result<$pfn, LoaderError> {
379 if let Some(&p) = self.$name.get() { return Ok(p); }
380 let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
381 let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
382 let _ = self.$name.set(p);
383 Ok(p)
384 }
385 )*
386 fn empty(lib: Library) -> Self {
387 Self { lib, $($name: OnceLock::new(),)* }
388 }
389 }
390 };
391}
392
393nccl_fns! {
394 nccl_get_version as "ncclGetVersion": PFN_ncclGetVersion;
395 nccl_get_unique_id as "ncclGetUniqueId": PFN_ncclGetUniqueId;
396 nccl_comm_init_rank as "ncclCommInitRank": PFN_ncclCommInitRank;
397 nccl_comm_init_rank_config as "ncclCommInitRankConfig": PFN_ncclCommInitRankConfig;
398 nccl_comm_init_all as "ncclCommInitAll": PFN_ncclCommInitAll;
399 nccl_comm_destroy as "ncclCommDestroy": PFN_ncclCommDestroy;
400 nccl_comm_abort as "ncclCommAbort": PFN_ncclCommAbort;
401 nccl_comm_finalize as "ncclCommFinalize": PFN_ncclCommFinalize;
402 nccl_comm_get_async_error as "ncclCommGetAsyncError": PFN_ncclCommGetAsyncError;
403 nccl_comm_count as "ncclCommCount": PFN_ncclCommCount;
404 nccl_comm_user_rank as "ncclCommUserRank": PFN_ncclCommUserRank;
405 nccl_comm_cu_device as "ncclCommCuDevice": PFN_ncclCommCuDevice;
406 nccl_comm_split as "ncclCommSplit": PFN_ncclCommSplit;
407 nccl_all_reduce as "ncclAllReduce": PFN_ncclAllReduce;
408 nccl_reduce as "ncclReduce": PFN_ncclReduce;
409 nccl_broadcast as "ncclBroadcast": PFN_ncclBroadcast;
410 nccl_all_gather as "ncclAllGather": PFN_ncclAllGather;
411 nccl_reduce_scatter as "ncclReduceScatter": PFN_ncclReduceScatter;
412 nccl_send as "ncclSend": PFN_ncclSend;
413 nccl_recv as "ncclRecv": PFN_ncclRecv;
414 nccl_group_start as "ncclGroupStart": PFN_ncclGroupStart;
415 nccl_group_end as "ncclGroupEnd": PFN_ncclGroupEnd;
416 nccl_mem_alloc as "ncclMemAlloc": PFN_ncclMemAlloc;
417 nccl_mem_free as "ncclMemFree": PFN_ncclMemFree;
418 nccl_comm_register as "ncclCommRegister": PFN_ncclCommRegister;
419 nccl_comm_deregister as "ncclCommDeregister": PFN_ncclCommDeregister;
420 nccl_red_op_create_pre_mul_sum as "ncclRedOpCreatePreMulSum": PFN_ncclRedOpCreatePreMulSum;
421 nccl_red_op_destroy as "ncclRedOpDestroy": PFN_ncclRedOpDestroy;
422 nccl_get_error_string as "ncclGetErrorString": PFN_ncclGetErrorString;
423 nccl_get_last_error as "ncclGetLastError": PFN_ncclGetLastError;
424}
425
426pub fn nccl() -> Result<&'static Nccl, LoaderError> {
428 static NCCL: OnceLock<Nccl> = OnceLock::new();
429 if let Some(n) = NCCL.get() {
430 return Ok(n);
431 }
432 let lib = Library::open("nccl", nccl_candidates())?;
433 let n = Nccl::empty(lib);
434 let _ = NCCL.set(n);
435 Ok(NCCL.get().expect("OnceLock set or lost race"))
436}