Skip to main content

baracuda_runtime/
multicast.rs

1//! Multicast objects (CUDA 12.0+).
2//!
3//! A multicast object is a single VMM handle bound to multiple devices;
4//! writes through the handle are implicitly replicated across them.
5//! Useful for NVLink-connected GPUs (A100/H100) on all-reduce-like
6//! workloads when you don't want to go through NCCL.
7//!
8//! Not supported on older drivers — returns
9//! [`crate::Error::FeatureNotSupported`].
10
11use core::ffi::c_void;
12
13use baracuda_cuda_sys::runtime::runtime;
14use baracuda_cuda_sys::runtime::types::cudaMemGenericAllocationHandle_t;
15use baracuda_types::{supports, Feature};
16
17use crate::device::Device;
18use crate::error::{check, Error, Result};
19
20/// Properties passed to [`MulticastObject::new`]. Layout matches
21/// `cudaMulticastObjectProp`.
22#[repr(C)]
23#[derive(Copy, Clone, Debug, Default)]
24#[allow(non_camel_case_types)]
25pub struct MulticastProp {
26    pub num_devices: core::ffi::c_uint,
27    pub size: usize,
28    pub handle_types: core::ffi::c_int,
29    pub flags: u64,
30}
31
32fn require_multicast() -> Result<()> {
33    let installed = crate::init::driver_version()?;
34    if supports(installed, Feature::MulticastObjects) {
35        Ok(())
36    } else {
37        Err(Error::FeatureNotSupported {
38            api: "cudaMulticast*",
39            since: Feature::MulticastObjects.required_version(),
40        })
41    }
42}
43
44/// A multicast object. Drop releases it via `cudaMemRelease`.
45#[derive(Debug)]
46pub struct MulticastObject {
47    handle: cudaMemGenericAllocationHandle_t,
48}
49
50impl MulticastObject {
51    /// Create a multicast object with the given props.
52    pub fn new(prop: &MulticastProp) -> Result<Self> {
53        require_multicast()?;
54        let r = runtime()?;
55        let cu = r.cuda_multicast_create()?;
56        let mut h: cudaMemGenericAllocationHandle_t = 0;
57        check(unsafe { cu(&mut h, prop as *const MulticastProp as *const c_void) })?;
58        Ok(Self { handle: h })
59    }
60
61    #[inline]
62    pub fn as_raw(&self) -> cudaMemGenericAllocationHandle_t {
63        self.handle
64    }
65
66    /// Add a participating device to this object.
67    pub fn add_device(&self, device: &Device) -> Result<()> {
68        let r = runtime()?;
69        let cu = r.cuda_multicast_add_device()?;
70        check(unsafe { cu(self.handle, device.ordinal()) })
71    }
72
73    /// Bind a physical-memory handle (from [`crate::vmm::MemHandle`]) at
74    /// `mc_offset` within this object, `size` bytes.
75    ///
76    /// # Safety
77    ///
78    /// `mem_handle` must be a live VMM allocation on a device that was
79    /// already added via [`Self::add_device`].
80    pub unsafe fn bind_mem(
81        &self,
82        mc_offset: usize,
83        mem_handle: cudaMemGenericAllocationHandle_t,
84        mem_offset: usize,
85        size: usize,
86        flags: u64,
87    ) -> Result<()> {
88        let r = runtime()?;
89        let cu = r.cuda_multicast_bind_mem()?;
90        check(cu(
91            self.handle,
92            mc_offset,
93            mem_handle,
94            mem_offset,
95            size,
96            flags,
97        ))
98    }
99
100    /// Bind a device address (instead of a handle).
101    ///
102    /// # Safety
103    ///
104    /// `mem_ptr` must be a mapped VMM address on a registered device.
105    pub unsafe fn bind_addr(
106        &self,
107        mc_offset: usize,
108        mem_ptr: *mut c_void,
109        size: usize,
110        flags: u64,
111    ) -> Result<()> {
112        let r = runtime()?;
113        let cu = r.cuda_multicast_bind_addr()?;
114        check(cu(self.handle, mc_offset, mem_ptr, size, flags))
115    }
116
117    /// Unbind the region `[mc_offset, mc_offset + size)` from `device`.
118    pub fn unbind(&self, device: &Device, mc_offset: usize, size: usize) -> Result<()> {
119        let r = runtime()?;
120        let cu = r.cuda_multicast_unbind()?;
121        check(unsafe { cu(self.handle, device.ordinal(), mc_offset, size) })
122    }
123}
124
125impl Drop for MulticastObject {
126    fn drop(&mut self) {
127        if let Ok(r) = runtime() {
128            if let Ok(cu) = r.cuda_mem_release() {
129                let _ = unsafe { cu(self.handle) };
130            }
131        }
132    }
133}
134
135/// Report the granularity (min alignment / min size) for multicast
136/// objects with the given props. `option`: 0 = minimum, 1 = recommended.
137pub fn multicast_granularity(prop: &MulticastProp, option: i32) -> Result<usize> {
138    require_multicast()?;
139    let r = runtime()?;
140    let cu = r.cuda_multicast_get_granularity()?;
141    let mut g: usize = 0;
142    check(unsafe {
143        cu(
144            &mut g,
145            prop as *const MulticastProp as *const c_void,
146            option,
147        )
148    })?;
149    Ok(g)
150}