1use super::{result, sys};
2use crate::driver::{CudaContext, CudaStream, DevicePtr, DevicePtrMut};
3use std::{mem::MaybeUninit, sync::Arc, vec, vec::Vec};
4
5pub use result::{group_end, group_start};
6
7#[derive(Debug)]
8pub struct Comm {
9 comm: sys::ncclComm_t,
10 stream: Arc<CudaStream>,
11 rank: usize,
12 world_size: usize,
13}
14
15#[derive(Debug, Clone, Copy)]
16pub struct Id {
17 id: sys::ncclUniqueId,
18}
19
20impl Id {
21 pub fn new() -> Result<Self, result::NcclError> {
22 let id = result::get_uniqueid()?;
23 Ok(Self { id })
24 }
25
26 pub fn uninit(internal: [::core::ffi::c_char; 128usize]) -> Self {
27 let id = sys::ncclUniqueId { internal };
28 Self { id }
29 }
30
31 pub fn internal(&self) -> &[::core::ffi::c_char; 128usize] {
32 &self.id.internal
33 }
34}
35
36pub enum ReduceOp {
37 Sum,
38 Prod,
39 Max,
40 Min,
41 Avg,
42}
43
44fn convert_to_nccl_reduce_op(op: &ReduceOp) -> sys::ncclRedOp_t {
45 match op {
46 ReduceOp::Sum => sys::ncclRedOp_t::ncclSum,
47 ReduceOp::Prod => sys::ncclRedOp_t::ncclProd,
48 ReduceOp::Max => sys::ncclRedOp_t::ncclMax,
49 ReduceOp::Min => sys::ncclRedOp_t::ncclMin,
50 ReduceOp::Avg => sys::ncclRedOp_t::ncclAvg,
51 }
52}
53
54impl Drop for Comm {
55 fn drop(&mut self) {
56 unsafe {
58 result::comm_abort(self.comm).expect("Error when aborting Comm.");
59 }
60 }
61}
62
63pub trait NcclType {
64 fn as_nccl_type() -> sys::ncclDataType_t;
65}
66
67macro_rules! define_nccl_type {
68 ($t:ty, $nccl_type:expr) => {
69 impl NcclType for $t {
70 fn as_nccl_type() -> sys::ncclDataType_t {
71 $nccl_type
72 }
73 }
74 };
75}
76
77define_nccl_type!(f32, sys::ncclDataType_t::ncclFloat32);
78define_nccl_type!(f64, sys::ncclDataType_t::ncclFloat64);
79define_nccl_type!(i8, sys::ncclDataType_t::ncclInt8);
80define_nccl_type!(i32, sys::ncclDataType_t::ncclInt32);
81define_nccl_type!(i64, sys::ncclDataType_t::ncclInt64);
82define_nccl_type!(u8, sys::ncclDataType_t::ncclUint8);
83define_nccl_type!(u32, sys::ncclDataType_t::ncclUint32);
84define_nccl_type!(u64, sys::ncclDataType_t::ncclUint64);
85define_nccl_type!(char, sys::ncclDataType_t::ncclUint8);
86#[cfg(feature = "f16")]
87define_nccl_type!(half::f16, sys::ncclDataType_t::ncclFloat16);
88#[cfg(feature = "f16")]
89define_nccl_type!(half::bf16, sys::ncclDataType_t::ncclBfloat16);
90impl Comm {
91 pub fn from_devices(streams: Vec<Arc<CudaStream>>) -> Result<Vec<Self>, result::NcclError> {
113 let n_streams = streams.len();
114 let mut comms = vec![std::ptr::null_mut(); n_streams];
115 let ordinals: Vec<_> = streams
116 .iter()
117 .map(|d| d.context().ordinal() as i32)
118 .collect();
119 unsafe {
120 result::comm_init_all(comms.as_mut_ptr(), n_streams as i32, ordinals.as_ptr())?;
121 }
122
123 let comms: Vec<Self> = comms
124 .into_iter()
125 .zip(streams.iter().cloned())
126 .enumerate()
127 .map(|(rank, (comm, stream))| Self {
128 comm,
129 stream,
130 rank,
131 world_size: n_streams,
132 })
133 .collect();
134
135 Ok(comms)
136 }
137
138 pub fn stream(&self) -> Arc<CudaStream> {
139 self.stream.clone()
140 }
141
142 pub fn context(&self) -> &Arc<CudaContext> {
143 self.stream.context()
144 }
145
146 pub fn ordinal(&self) -> usize {
147 self.stream.ctx.ordinal
148 }
149
150 pub fn rank(&self) -> usize {
151 self.rank
152 }
153
154 pub fn world_size(&self) -> usize {
155 self.world_size
156 }
157
158 pub fn from_rank(
185 stream: Arc<CudaStream>,
186 rank: usize,
187 world_size: usize,
188 id: Id,
189 ) -> Result<Self, result::NcclError> {
190 let mut comm = MaybeUninit::uninit();
191
192 let comm = unsafe {
193 result::comm_init_rank(
194 comm.as_mut_ptr(),
195 world_size
196 .try_into()
197 .expect("World_size cannot be casted to i32"),
198 id.id,
199 rank.try_into().expect("Rank cannot be cast to i32"),
200 )?;
201 comm.assume_init()
202 };
203 Ok(Self {
204 comm,
205 stream,
206 rank,
207 world_size,
208 })
209 }
210}
211
212impl Comm {
213 pub fn send<S: DevicePtr<T>, T: NcclType>(
215 &self,
216 data: &S,
217 peer: i32,
218 ) -> Result<(), result::NcclError> {
219 let (src, _record_src) = data.device_ptr(&self.stream);
220 unsafe {
221 result::send(
222 src as _,
223 data.len(),
224 T::as_nccl_type(),
225 peer,
226 self.comm,
227 self.stream.cu_stream as _,
228 )
229 }?;
230 Ok(())
231 }
232
233 pub fn recv<R: DevicePtrMut<T>, T: NcclType>(
235 &self,
236 buff: &mut R,
237 peer: i32,
238 ) -> Result<result::NcclStatus, result::NcclError> {
239 let count = buff.len();
240 let (dst, _record_dst) = buff.device_ptr_mut(&self.stream);
241 unsafe {
242 result::recv(
243 dst as _,
244 count,
245 T::as_nccl_type(),
246 peer,
247 self.comm,
248 self.stream.cu_stream as _,
249 )
250 }
251 }
252
253 pub fn broadcast<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
261 &self,
262 sendbuff: Option<&S>,
263 recvbuff: &mut R,
264 root: i32,
265 ) -> Result<result::NcclStatus, result::NcclError> {
266 debug_assert!(sendbuff.is_some() || self.rank != root as usize);
267 let count = recvbuff.len();
268 let (src, _record_src) = sendbuff.map(|b| b.device_ptr(&self.stream)).unzip();
269 let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
270 unsafe {
271 result::broadcast(
272 src.map(|ptr| ptr as _).unwrap_or(std::ptr::null()),
273 dst as _,
274 count,
275 T::as_nccl_type(),
276 root,
277 self.comm,
278 self.stream.cu_stream as _,
279 )
280 }
281 }
282
283 pub fn broadcast_in_place<R: DevicePtrMut<T>, T: NcclType>(
286 &self,
287 recvbuff: &mut R,
288 root: i32,
289 ) -> Result<result::NcclStatus, result::NcclError> {
290 let count = recvbuff.len();
291 let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
292 unsafe {
293 result::broadcast(
294 dst as _,
295 dst as _,
296 count,
297 T::as_nccl_type(),
298 root,
299 self.comm,
300 self.stream.cu_stream as _,
301 )
302 }
303 }
304
305 pub fn all_gather<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
307 &self,
308 sendbuff: &S,
309 recvbuff: &mut R,
310 ) -> Result<result::NcclStatus, result::NcclError> {
311 let (src, _record_src) = sendbuff.device_ptr(&self.stream);
312 let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
313 unsafe {
314 result::all_gather(
315 src as _,
316 dst as _,
317 sendbuff.len(),
318 T::as_nccl_type(),
319 self.comm,
320 self.stream.cu_stream as _,
321 )
322 }
323 }
324
325 pub fn all_reduce<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
327 &self,
328 sendbuff: &S,
329 recvbuff: &mut R,
330 reduce_op: &ReduceOp,
331 ) -> Result<result::NcclStatus, result::NcclError> {
332 let (src, _record_src) = sendbuff.device_ptr(&self.stream);
333 let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
334 unsafe {
335 result::all_reduce(
336 src as _,
337 dst as _,
338 sendbuff.len(),
339 T::as_nccl_type(),
340 convert_to_nccl_reduce_op(reduce_op),
341 self.comm,
342 self.stream.cu_stream as _,
343 )
344 }
345 }
346
347 pub fn all_reduce_in_place<R: DevicePtrMut<T>, T: NcclType>(
350 &self,
351 buff: &mut R,
352 reduce_op: &ReduceOp,
353 ) -> Result<result::NcclStatus, result::NcclError> {
354 let count = buff.len();
355 let (dst, _record_dst) = buff.device_ptr_mut(&self.stream);
356 unsafe {
357 result::all_reduce(
358 dst as _,
359 dst as _,
360 count,
361 T::as_nccl_type(),
362 convert_to_nccl_reduce_op(reduce_op),
363 self.comm,
364 self.stream.cu_stream as _,
365 )
366 }
367 }
368
369 pub fn reduce<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
376 &self,
377 sendbuff: &S,
378 recvbuff: Option<&mut R>,
379 reduce_op: &ReduceOp,
380 root: i32,
381 ) -> Result<result::NcclStatus, result::NcclError> {
382 debug_assert!(recvbuff.is_some() || self.rank != root as usize);
383
384 let (src, _record_src) = sendbuff.device_ptr(&self.stream);
385 let (dst, _record_dst) = recvbuff.map(|b| b.device_ptr_mut(&self.stream)).unzip();
386 unsafe {
387 result::reduce(
388 src as _,
389 dst.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()),
390 sendbuff.len(),
391 T::as_nccl_type(),
392 convert_to_nccl_reduce_op(reduce_op),
393 root,
394 self.comm,
395 self.stream.cu_stream as _,
396 )
397 }
398 }
399
400 pub fn reduce_in_place<R: DevicePtrMut<T>, T: NcclType>(
403 &self,
404 recvbuff: &mut R,
405 reduce_op: &ReduceOp,
406 root: i32,
407 ) -> Result<result::NcclStatus, result::NcclError> {
408 let count = recvbuff.len();
409 let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
410 unsafe {
411 result::reduce(
412 dst as _,
413 dst as _,
414 count,
415 T::as_nccl_type(),
416 convert_to_nccl_reduce_op(reduce_op),
417 root,
418 self.comm,
419 self.stream.cu_stream as _,
420 )
421 }
422 }
423
424 pub fn reduce_scatter<S: DevicePtr<T>, R: DevicePtrMut<T>, T: NcclType>(
426 &self,
427 sendbuff: &S,
428 recvbuff: &mut R,
429 reduce_op: &ReduceOp,
430 ) -> Result<result::NcclStatus, result::NcclError> {
431 let count = recvbuff.len();
432 let (src, _record_src) = sendbuff.device_ptr(&self.stream);
433 let (dst, _record_dst) = recvbuff.device_ptr_mut(&self.stream);
434 unsafe {
435 result::reduce_scatter(
436 src as _,
437 dst as _,
438 count,
439 T::as_nccl_type(),
440 convert_to_nccl_reduce_op(reduce_op),
441 self.comm,
442 self.stream.cu_stream as _,
443 )
444 }
445 }
446}
447
448#[macro_export]
449macro_rules! group {
450 ($x:block) => {
451 unsafe {
452 result::group_start().unwrap();
453 }
454 $x
455 unsafe {
456 result::group_end().unwrap();
457 }
458 };
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 #[cfg(feature = "no-std")]
465 use no_std_compat::println;
466
467 #[test]
468 fn test_all_reduce() {
469 let n = 2;
470 let n_devices = CudaContext::device_count().unwrap() as usize;
471 let id = Id::new().unwrap();
472 let threads: Vec<_> = (0..n_devices)
473 .map(|i| {
474 println!("III {i}");
475 std::thread::spawn(move || {
476 println!("Within thread {i}");
477 let ctx = CudaContext::new(i).unwrap();
478 let stream = ctx.default_stream();
479 let comm = Comm::from_rank(stream.clone(), i, n_devices, id).unwrap();
480 let slice = stream.clone_htod(&vec![(i + 1) as f32 * 1.0; n]).unwrap();
481 let mut slice_receive = stream.alloc_zeros::<f32>(n).unwrap();
482 comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum)
483 .unwrap();
484
485 let out = stream.clone_dtoh(&slice_receive).unwrap();
486
487 assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]);
488 })
489 })
490 .collect();
491 for t in threads {
492 t.join().unwrap()
493 }
494 }
495
496 #[test]
497 fn test_all_reduce_views() {
498 let n = 2;
499 let n_devices = CudaContext::device_count().unwrap() as usize;
500 let id = Id::new().unwrap();
501 let threads: Vec<_> = (0..n_devices)
502 .map(|i| {
503 println!("III {i}");
504 std::thread::spawn(move || {
505 println!("Within thread {i}");
506 let ctx = CudaContext::new(i).unwrap();
507 let stream = ctx.default_stream();
508 let comm = Comm::from_rank(stream.clone(), i, n_devices, id).unwrap();
509 let slice = stream.clone_htod(&vec![(i + 1) as f32 * 1.0; n]).unwrap();
510 let mut slice_receive = stream.alloc_zeros::<f32>(n).unwrap();
511 let slice_view = slice.slice(..);
512 let mut slice_receive_view = slice_receive.slice_mut(..);
513
514 comm.all_reduce(&slice_view, &mut slice_receive_view, &ReduceOp::Sum)
515 .unwrap();
516
517 let out = stream.clone_dtoh(&slice_receive).unwrap();
518
519 assert_eq!(out, vec![(n_devices * (n_devices + 1)) as f32 / 2.0; n]);
520 })
521 })
522 .collect();
523 for t in threads {
524 t.join().unwrap()
525 }
526 }
527}