Skip to main content

baracuda_driver/
multicast.rs

1//! Multicast objects (CUDA 12.0+, NVSwitch systems only).
2//!
3//! A multicast object aliases one virtual-memory range across a group of
4//! peer devices so that a write from any one of them fans out to all of
5//! them. Requires NVSwitch fabric (HGX, DGX H100) — on a single-GPU
6//! system most of these calls will return `CUDA_ERROR_NOT_SUPPORTED`.
7//!
8//! Typical flow:
9//!
10//! 1. [`multicast_granularity`] → round your allocation size to a valid
11//!    multicast chunk.
12//! 2. [`MulticastObject::new`] → create the object.
13//! 3. Call [`MulticastObject::add_device`] for each participating device.
14//! 4. Bind VMM allocations with [`MulticastObject::bind_addr`] (or
15//!    [`bind_mem`](MulticastObject::bind_mem)).
16//! 5. Kernels on any device write into the multicast range and every
17//!    bound allocation receives the update.
18
19use std::sync::Arc;
20
21use baracuda_cuda_sys::types::{
22    CUmemAllocationHandleType, CUmulticastGranularity_flags, CUmulticastObjectProp,
23};
24use baracuda_cuda_sys::{driver, CUdeviceptr, CUmemGenericAllocationHandle};
25
26use crate::device::Device;
27use crate::error::{check, Result};
28
29/// Query the minimum or recommended multicast granularity (bytes) for the
30/// given number of peer devices and `size_bytes`.
31pub fn multicast_granularity(
32    num_devices: u32,
33    size_bytes: usize,
34    recommended: bool,
35) -> Result<usize> {
36    let d = driver()?;
37    let cu = d.cu_multicast_get_granularity()?;
38    let prop = CUmulticastObjectProp {
39        num_devices,
40        size: size_bytes,
41        handle_types: CUmemAllocationHandleType::NONE as u64,
42        flags: 0,
43    };
44    let mut g: usize = 0;
45    let option = if recommended {
46        CUmulticastGranularity_flags::RECOMMENDED
47    } else {
48        CUmulticastGranularity_flags::MINIMUM
49    };
50    check(unsafe { cu(&mut g, &prop, option) })?;
51    Ok(g)
52}
53
54/// A multicast object. Drop releases the underlying
55/// `CUmemGenericAllocationHandle` via `cuMemRelease`.
56#[derive(Clone)]
57pub struct MulticastObject {
58    inner: Arc<MulticastObjectInner>,
59}
60
61struct MulticastObjectInner {
62    handle: CUmemGenericAllocationHandle,
63}
64
65unsafe impl Send for MulticastObjectInner {}
66unsafe impl Sync for MulticastObjectInner {}
67
68impl core::fmt::Debug for MulticastObjectInner {
69    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
70        f.debug_struct("MulticastObject")
71            .field("handle", &self.handle)
72            .finish_non_exhaustive()
73    }
74}
75
76impl core::fmt::Debug for MulticastObject {
77    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78        self.inner.fmt(f)
79    }
80}
81
82impl MulticastObject {
83    pub fn new(num_devices: u32, size_bytes: usize) -> Result<Self> {
84        let d = driver()?;
85        let cu = d.cu_multicast_create()?;
86        let prop = CUmulticastObjectProp {
87            num_devices,
88            size: size_bytes,
89            handle_types: CUmemAllocationHandleType::NONE as u64,
90            flags: 0,
91        };
92        let mut handle: CUmemGenericAllocationHandle = 0;
93        check(unsafe { cu(&mut handle, &prop) })?;
94        Ok(Self {
95            inner: Arc::new(MulticastObjectInner { handle }),
96        })
97    }
98
99    pub fn add_device(&self, device: &Device) -> Result<()> {
100        let d = driver()?;
101        let cu = d.cu_multicast_add_device()?;
102        check(unsafe { cu(self.inner.handle, device.as_raw()) })
103    }
104
105    /// Bind a VMM allocation handle at `mc_offset` → `(mem_handle, mem_offset, size)`.
106    pub fn bind_mem(
107        &self,
108        mc_offset: usize,
109        mem_handle: CUmemGenericAllocationHandle,
110        mem_offset: usize,
111        size: usize,
112    ) -> Result<()> {
113        let d = driver()?;
114        let cu = d.cu_multicast_bind_mem()?;
115        check(unsafe {
116            cu(
117                self.inner.handle,
118                mc_offset,
119                mem_handle,
120                mem_offset,
121                size,
122                0,
123            )
124        })
125    }
126
127    /// Bind an already-mapped device pointer into the multicast object.
128    pub fn bind_addr(&self, mc_offset: usize, ptr: CUdeviceptr, size: usize) -> Result<()> {
129        let d = driver()?;
130        let cu = d.cu_multicast_bind_addr()?;
131        check(unsafe { cu(self.inner.handle, mc_offset, ptr, size, 0) })
132    }
133
134    pub fn unbind(&self, device: &Device, mc_offset: usize, size: usize) -> Result<()> {
135        let d = driver()?;
136        let cu = d.cu_multicast_unbind()?;
137        check(unsafe { cu(self.inner.handle, device.as_raw(), mc_offset, size) })
138    }
139
140    #[inline]
141    pub fn as_raw(&self) -> CUmemGenericAllocationHandle {
142        self.inner.handle
143    }
144}
145
146impl Drop for MulticastObjectInner {
147    fn drop(&mut self) {
148        if self.handle == 0 {
149            return;
150        }
151        if let Ok(d) = driver() {
152            if let Ok(cu) = d.cu_mem_release() {
153                let _ = unsafe { cu(self.handle) };
154            }
155        }
156    }
157}