Skip to main content

baracuda_nvshmem/
lib.rs

1//! Safe Rust wrappers for the **NVIDIA NVSHMEM host API**.
2//!
3//! NVSHMEM is the OpenSHMEM symmetric-heap model on GPUs: every PE
4//! (processing element — typically one GPU) allocates from a *symmetric
5//! heap* at a shared virtual address, and any PE can read/write another PE's
6//! heap directly via one-sided `put` / `get`. This is the fine-grained,
7//! one-sided complement to [`baracuda-nccl`]'s collectives — the two coexist
8//! and a single program may use both.
9//!
10//! ## What this crate covers (Tier 1)
11//!
12//! - [`Context`] — process-wide NVSHMEM lifetime (init / finalize) plus
13//!   cached `my_pe` / `n_pes`, and the barrier / quiet / fence ordering
14//!   primitives.
15//! - [`Team`] — a subset of PEs created via strided split.
16//! - [`SymmetricBuffer`] — a typed allocation on the symmetric heap.
17//! - Host-initiated RMA — blocking and stream-ordered [`Context::put`] /
18//!   [`Context::get`].
19//!
20//! ## What it does *not* cover
21//!
22//! The **device-side** API — the `__device__` `nvshmem_int_p` /
23//! `nvshmem_putmem_nbi` calls issued from *inside* a CUDA kernel — requires
24//! linking `libnvshmem_device.a` into the consumer's kernel binary and is
25//! out of scope (it cannot be a lazily-loaded host symbol). A consumer that
26//! needs device-side NVSHMEM writes its own `.cu` that includes the NVSHMEM
27//! headers and links the device archive.
28//!
29//! ## Availability
30//!
31//! NVSHMEM is a Linux library requiring compute capability sm_70+ (every
32//! baracuda-supported GPU qualifies). On hosts without the NVSHMEM runtime
33//! installed, [`Context::init`] returns `LoaderError::LibraryNotFound`, so
34//! callers can fall back to single-process execution.
35//!
36//! [`baracuda-nccl`]: https://docs.rs/baracuda-nccl
37
38#![warn(missing_debug_implementations)]
39
40use core::ffi::{c_int, c_void};
41
42use baracuda_driver::Stream;
43use baracuda_nvshmem_sys::{nvshmem, nvshmemResult_t, nvshmem_team_t, nvshmemx_uniqueid_t};
44use baracuda_types::DeviceRepr;
45
46/// Error type for NVSHMEM operations.
47pub type Error = baracuda_core::Error<nvshmemResult_t>;
48/// Result alias.
49pub type Result<T, E = Error> = core::result::Result<T, E>;
50
51#[inline]
52fn check(status: nvshmemResult_t) -> Result<()> {
53    Error::check(status)
54}
55
56#[inline]
57fn stream_raw(stream: &Stream) -> baracuda_cuda_sys::runtime::cudaStream_t {
58    // `CUstream` (driver) and `cudaStream_t` (runtime) are the same opaque
59    // handle at the ABI level — the runtime stream wraps the driver one.
60    stream.as_raw() as _
61}
62
63// ---- Context --------------------------------------------------------------
64
65/// The process-wide NVSHMEM runtime, from this PE's point of view.
66///
67/// NVSHMEM is a **process singleton**: `nvshmem_init` / `nvshmem_finalize`
68/// must be called exactly once each per process. Construct a single
69/// `Context` near program start and drop it (or call [`Context::finalize`])
70/// at shutdown. `my_pe` / `n_pes` are read once at init and cached, so the
71/// hot accessors are infallible.
72pub struct Context {
73    my_pe: i32,
74    n_pes: i32,
75    finalized: bool,
76}
77
78impl core::fmt::Debug for Context {
79    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
80        f.debug_struct("nvshmem::Context")
81            .field("my_pe", &self.my_pe)
82            .field("n_pes", &self.n_pes)
83            .finish()
84    }
85}
86
87impl Context {
88    /// Initialize NVSHMEM using the environment-selected bootstrap
89    /// (`NVSHMEM_BOOTSTRAP` — PMI / MPI / PMIx). This is the common launcher
90    /// path (`nvshmrun` / `mpirun`). Call exactly once per process.
91    pub fn init() -> Result<Self> {
92        let n = nvshmem()?;
93        let init = n.nvshmem_init()?;
94        unsafe { init() };
95        Self::from_initialized()
96    }
97
98    /// Initialize NVSHMEM through `nvshmemx_init_attr`, passing a caller-built
99    /// attributes struct (e.g. one populated from a
100    /// [`nvshmemx_uniqueid_t`](baracuda_nvshmem_sys::nvshmemx_uniqueid_t) via
101    /// the raw `nvshmemx_set_attr_uniqueid_args`). Pass `flags` and a pointer
102    /// to a valid `nvshmemx_init_attr_t` (or null for defaults).
103    ///
104    /// The attribute struct's layout is NVSHMEM-version-specific and so is
105    /// **not** modeled as a typed Rust struct — build it through the raw
106    /// [`baracuda-nvshmem-sys`] helpers.
107    ///
108    /// # Safety
109    ///
110    /// `attr` must be null or a properly-initialized `nvshmemx_init_attr_t`
111    /// for the installed NVSHMEM version, and must outlive the call.
112    ///
113    /// [`baracuda-nvshmem-sys`]: baracuda_nvshmem_sys
114    pub unsafe fn init_with_attr(flags: u32, attr: *mut c_void) -> Result<Self> {
115        let n = nvshmem()?;
116        let init = n.nvshmemx_init_attr()?;
117        check(unsafe { init(flags, attr) })?;
118        Self::from_initialized()
119    }
120
121    fn from_initialized() -> Result<Self> {
122        let n = nvshmem()?;
123        let my_pe = unsafe { (n.nvshmem_my_pe()?)() };
124        let n_pes = unsafe { (n.nvshmem_n_pes()?)() };
125        Ok(Self {
126            my_pe,
127            n_pes,
128            finalized: false,
129        })
130    }
131
132    /// This PE's global index (0..`n_pes`). Cached at init.
133    #[inline]
134    pub fn my_pe(&self) -> i32 {
135        self.my_pe
136    }
137
138    /// Total number of PEs in the program. Cached at init.
139    #[inline]
140    pub fn n_pes(&self) -> i32 {
141        self.n_pes
142    }
143
144    /// NVSHMEM version as `(major, minor)`.
145    pub fn version(&self) -> Result<(i32, i32)> {
146        let n = nvshmem()?;
147        let cu = n.nvshmem_info_get_version()?;
148        let mut major: c_int = 0;
149        let mut minor: c_int = 0;
150        unsafe { cu(&mut major, &mut minor) };
151        Ok((major, minor))
152    }
153
154    /// Allocate `len` elements of `T` on the symmetric heap. The returned
155    /// buffer occupies the **same virtual address on every PE**, so its
156    /// pointer doubles as a remote address in [`Context::put`] /
157    /// [`Context::get`].
158    pub fn malloc<T: DeviceRepr>(&self, len: usize) -> Result<SymmetricBuffer<T>> {
159        SymmetricBuffer::new(len)
160    }
161
162    /// The team of all PEs ([`Team::WORLD`]).
163    #[inline]
164    pub fn world(&self) -> Team {
165        Team::WORLD
166    }
167
168    // -- ordering / synchronization --
169
170    /// Global barrier: every PE arrives **and** all RMA issued before the
171    /// call has completed remotely.
172    pub fn barrier_all(&self) -> Result<()> {
173        let n = nvshmem()?;
174        unsafe { (n.nvshmem_barrier_all()?)() };
175        Ok(())
176    }
177
178    /// Stream-ordered [`Self::barrier_all`].
179    pub fn barrier_all_on_stream(&self, stream: &Stream) -> Result<()> {
180        let n = nvshmem()?;
181        let cu = n.nvshmemx_barrier_all_on_stream()?;
182        unsafe { cu(stream_raw(stream)) };
183        Ok(())
184    }
185
186    /// Lighter barrier — PE arrival only, without the RMA remote-completion
187    /// guarantee of [`Self::barrier_all`].
188    pub fn sync_all(&self) -> Result<()> {
189        let n = nvshmem()?;
190        unsafe { (n.nvshmem_sync_all()?)() };
191        Ok(())
192    }
193
194    /// Block until all RMA issued by this PE has completed remotely.
195    pub fn quiet(&self) -> Result<()> {
196        let n = nvshmem()?;
197        unsafe { (n.nvshmem_quiet()?)() };
198        Ok(())
199    }
200
201    /// Order (but do not wait for completion of) outstanding RMA from this PE.
202    pub fn fence(&self) -> Result<()> {
203        let n = nvshmem()?;
204        unsafe { (n.nvshmem_fence()?)() };
205        Ok(())
206    }
207
208    // -- host-initiated RMA --
209
210    /// Blocking host put: copy `count` elements from the local `src` buffer
211    /// into PE `pe`'s copy of the symmetric `dest` buffer. Returns after the
212    /// data has left the local PE.
213    pub fn put<T: DeviceRepr>(
214        &self,
215        dest: &SymmetricBuffer<T>,
216        src: &SymmetricBuffer<T>,
217        count: usize,
218        pe: i32,
219    ) -> Result<()> {
220        assert!(count <= dest.len() && count <= src.len(), "put out of range");
221        let n = nvshmem()?;
222        let cu = n.nvshmem_putmem()?;
223        unsafe {
224            cu(
225                dest.ptr,
226                src.ptr as *const c_void,
227                count * core::mem::size_of::<T>(),
228                pe,
229            )
230        };
231        Ok(())
232    }
233
234    /// Blocking host get: copy `count` elements from PE `pe`'s copy of the
235    /// symmetric `src` buffer into the local `dest` buffer.
236    pub fn get<T: DeviceRepr>(
237        &self,
238        dest: &SymmetricBuffer<T>,
239        src: &SymmetricBuffer<T>,
240        count: usize,
241        pe: i32,
242    ) -> Result<()> {
243        assert!(count <= dest.len() && count <= src.len(), "get out of range");
244        let n = nvshmem()?;
245        let cu = n.nvshmem_getmem()?;
246        unsafe {
247            cu(
248                dest.ptr,
249                src.ptr as *const c_void,
250                count * core::mem::size_of::<T>(),
251                pe,
252            )
253        };
254        Ok(())
255    }
256
257    /// Stream-ordered [`Self::put`]. Completes in `stream` order; pair with
258    /// [`Self::quiet`] (or a [`Self::barrier_all`]) for remote completion.
259    pub fn put_on_stream<T: DeviceRepr>(
260        &self,
261        dest: &SymmetricBuffer<T>,
262        src: &SymmetricBuffer<T>,
263        count: usize,
264        pe: i32,
265        stream: &Stream,
266    ) -> Result<()> {
267        assert!(count <= dest.len() && count <= src.len(), "put out of range");
268        let n = nvshmem()?;
269        let cu = n.nvshmemx_putmem_on_stream()?;
270        unsafe {
271            cu(
272                dest.ptr,
273                src.ptr as *const c_void,
274                count * core::mem::size_of::<T>(),
275                pe,
276                stream_raw(stream),
277            )
278        };
279        Ok(())
280    }
281
282    /// Stream-ordered [`Self::get`].
283    pub fn get_on_stream<T: DeviceRepr>(
284        &self,
285        dest: &SymmetricBuffer<T>,
286        src: &SymmetricBuffer<T>,
287        count: usize,
288        pe: i32,
289        stream: &Stream,
290    ) -> Result<()> {
291        assert!(count <= dest.len() && count <= src.len(), "get out of range");
292        let n = nvshmem()?;
293        let cu = n.nvshmemx_getmem_on_stream()?;
294        unsafe {
295            cu(
296                dest.ptr,
297                src.ptr as *const c_void,
298                count * core::mem::size_of::<T>(),
299                pe,
300                stream_raw(stream),
301            )
302        };
303        Ok(())
304    }
305
306    /// Explicitly finalize NVSHMEM. Idempotent — also run on [`Drop`]. After
307    /// this no further NVSHMEM calls are valid.
308    pub fn finalize(&mut self) -> Result<()> {
309        if self.finalized {
310            return Ok(());
311        }
312        let n = nvshmem()?;
313        unsafe { (n.nvshmem_finalize()?)() };
314        self.finalized = true;
315        Ok(())
316    }
317}
318
319impl Drop for Context {
320    fn drop(&mut self) {
321        if self.finalized {
322            return;
323        }
324        if let Ok(n) = nvshmem() {
325            if let Ok(cu) = n.nvshmem_finalize() {
326                unsafe { cu() };
327            }
328        }
329    }
330}
331
332// ---- Team -----------------------------------------------------------------
333
334/// A team — a named subset of PEs. Teams created via
335/// [`Team::split_strided`] must be released with [`Team::destroy`]; the
336/// predefined [`Team::WORLD`] / [`Team::SHARED`] must **not** be destroyed.
337#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
338pub struct Team(nvshmem_team_t);
339
340impl Team {
341    /// The team of every PE in the program.
342    pub const WORLD: Self = Self(nvshmem_team_t::WORLD);
343    /// The team of PEs sharing a compute node.
344    pub const SHARED: Self = Self(nvshmem_team_t::SHARED);
345
346    /// Create a sub-team of `size` PEs from this team, starting at PE `start`
347    /// (in this team's index space) and taking every `stride`-th PE. Returns
348    /// `None` on the PEs that are **not** members of the new team.
349    ///
350    /// Defaults are used for the team config (`config = null`,
351    /// `config_mask = 0`).
352    pub fn split_strided(
353        &self,
354        start: i32,
355        stride: i32,
356        size: i32,
357    ) -> Result<Option<Team>> {
358        let n = nvshmem()?;
359        let cu = n.nvshmem_team_split_strided()?;
360        let mut new_team = nvshmem_team_t::INVALID;
361        check(unsafe {
362            cu(
363                self.0,
364                start,
365                stride,
366                size,
367                core::ptr::null(),
368                0,
369                &mut new_team,
370            )
371        })?;
372        if new_team == nvshmem_team_t::INVALID {
373            Ok(None)
374        } else {
375            Ok(Some(Team(new_team)))
376        }
377    }
378
379    /// This PE's index *within this team* (0..`n_pes`), or `-1` if this PE is
380    /// not a member.
381    pub fn my_pe(&self) -> Result<i32> {
382        let n = nvshmem()?;
383        let cu = n.nvshmem_team_my_pe()?;
384        Ok(unsafe { cu(self.0) })
385    }
386
387    /// Number of PEs in this team.
388    pub fn n_pes(&self) -> Result<i32> {
389        let n = nvshmem()?;
390        let cu = n.nvshmem_team_n_pes()?;
391        Ok(unsafe { cu(self.0) })
392    }
393
394    /// Translate `src_pe` (an index in this team) into its index in
395    /// `dest_team`. Returns `-1` if `src_pe` is not in `dest_team`.
396    pub fn translate_pe(&self, src_pe: i32, dest_team: Team) -> Result<i32> {
397        let n = nvshmem()?;
398        let cu = n.nvshmem_team_translate_pe()?;
399        Ok(unsafe { cu(self.0, src_pe, dest_team.0) })
400    }
401
402    /// Destroy a team created via [`Self::split_strided`]. Destroying a
403    /// predefined team ([`Self::WORLD`] / [`Self::SHARED`]) is a programmer
404    /// error and is rejected here.
405    pub fn destroy(self) -> Result<()> {
406        if self == Team::WORLD || self == Team::SHARED {
407            // Predefined teams are owned by the runtime — don't free them.
408            return Ok(());
409        }
410        let n = nvshmem()?;
411        let cu = n.nvshmem_team_destroy()?;
412        unsafe { cu(self.0) };
413        Ok(())
414    }
415
416    /// The raw team handle.
417    #[inline]
418    pub fn as_raw(&self) -> nvshmem_team_t {
419        self.0
420    }
421}
422
423// ---- SymmetricBuffer ------------------------------------------------------
424
425/// A typed allocation on the NVSHMEM symmetric heap. The same virtual address
426/// is valid on every PE, so the pointer can be used as a remote address in
427/// [`Context::put`] / [`Context::get`]. Freed on [`Drop`] via `nvshmem_free`.
428///
429/// `SymmetricBuffer` is **not** `Send`/`Sync`: NVSHMEM is bound to the PE's
430/// owning thread/process and the buffer must be freed on the same.
431pub struct SymmetricBuffer<T: DeviceRepr> {
432    ptr: *mut c_void,
433    len: usize,
434    _marker: core::marker::PhantomData<T>,
435}
436
437impl<T: DeviceRepr> core::fmt::Debug for SymmetricBuffer<T> {
438    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
439        f.debug_struct("SymmetricBuffer")
440            .field("ptr", &self.ptr)
441            .field("len", &self.len)
442            .finish()
443    }
444}
445
446impl<T: DeviceRepr> SymmetricBuffer<T> {
447    /// Allocate `len` elements on the symmetric heap. This is a **collective**
448    /// call — every PE must call it with the same `len` (NVSHMEM contract).
449    pub fn new(len: usize) -> Result<Self> {
450        let n = nvshmem()?;
451        let cu = n.nvshmem_malloc()?;
452        let bytes = len.checked_mul(core::mem::size_of::<T>()).expect("size overflow");
453        let ptr = unsafe { cu(bytes) };
454        if ptr.is_null() && bytes != 0 {
455            // nvshmem_malloc aborts internally on real OOM; a null with a
456            // non-zero request still warrants an error rather than a silent
457            // dangling buffer.
458            return Err(Error::Status {
459                status: nvshmemResult_t(1),
460            });
461        }
462        Ok(Self {
463            ptr,
464            len,
465            _marker: core::marker::PhantomData,
466        })
467    }
468
469    /// Element count.
470    #[inline]
471    pub fn len(&self) -> usize {
472        self.len
473    }
474
475    /// Whether the buffer is empty.
476    #[inline]
477    pub fn is_empty(&self) -> bool {
478        self.len == 0
479    }
480
481    /// The symmetric device pointer (same VA on every PE).
482    #[inline]
483    pub fn as_ptr(&self) -> *const T {
484        self.ptr as *const T
485    }
486
487    /// The symmetric device pointer, mutable.
488    #[inline]
489    pub fn as_mut_ptr(&self) -> *mut T {
490        self.ptr as *mut T
491    }
492}
493
494impl<T: DeviceRepr> Drop for SymmetricBuffer<T> {
495    fn drop(&mut self) {
496        if self.ptr.is_null() {
497            return;
498        }
499        if let Ok(n) = nvshmem() {
500            if let Ok(cu) = n.nvshmem_free() {
501                unsafe { cu(self.ptr) };
502            }
503        }
504    }
505}
506
507// ---- UniqueId -------------------------------------------------------------
508
509/// A 128-byte identifier for the unique-id bootstrap (the NVSHMEM analogue of
510/// NCCL's `UniqueId`). One PE calls [`UniqueId::new`] and distributes the
511/// bytes to every other PE; each then feeds it to the raw
512/// `nvshmemx_set_attr_uniqueid_args` + [`Context::init_with_attr`] path.
513///
514/// Wiring the id into init requires the version-specific
515/// `nvshmemx_init_attr_t` struct, which this safe layer deliberately does not
516/// model — use the raw [`baracuda-nvshmem-sys`] helpers for that step.
517///
518/// [`baracuda-nvshmem-sys`]: baracuda_nvshmem_sys
519#[derive(Copy, Clone, Debug)]
520pub struct UniqueId(nvshmemx_uniqueid_t);
521
522impl UniqueId {
523    /// Generate a fresh unique id on this PE.
524    pub fn new() -> Result<Self> {
525        let n = nvshmem()?;
526        let cu = n.nvshmemx_get_uniqueid()?;
527        let mut id = nvshmemx_uniqueid_t::default();
528        check(unsafe { cu(&mut id) })?;
529        Ok(Self(id))
530    }
531
532    /// Raw representation. Transmit verbatim to the other PEs.
533    pub fn as_raw(&self) -> nvshmemx_uniqueid_t {
534        self.0
535    }
536
537    /// Rebuild from a raw id received from another PE.
538    pub fn from_raw(id: nvshmemx_uniqueid_t) -> Self {
539        Self(id)
540    }
541}
542
543/// Convenience: NVSHMEM library version as `(major, minor)` without holding a
544/// [`Context`]. Useful for capability probes. Errors if NVSHMEM is not
545/// installed.
546pub fn version() -> Result<(i32, i32)> {
547    let n = nvshmem()?;
548    let cu = n.nvshmem_info_get_version()?;
549    let mut major: c_int = 0;
550    let mut minor: c_int = 0;
551    unsafe { cu(&mut major, &mut minor) };
552    Ok((major, minor))
553}