baracuda_nvshmem/lib.rs
1//! Safe Rust wrappers for the **NVIDIA NVSHMEM host API**.
2//!
3//! NVSHMEM is the OpenSHMEM symmetric-heap model on GPUs: every PE
4//! (processing element — typically one GPU) allocates from a *symmetric
5//! heap* at a shared virtual address, and any PE can read/write another PE's
6//! heap directly via one-sided `put` / `get`. This is the fine-grained,
7//! one-sided complement to [`baracuda-nccl`]'s collectives — the two coexist
8//! and a single program may use both.
9//!
10//! ## What this crate covers (Tier 1)
11//!
12//! - [`Context`] — process-wide NVSHMEM lifetime (init / finalize) plus
13//! cached `my_pe` / `n_pes`, and the barrier / quiet / fence ordering
14//! primitives.
15//! - [`Team`] — a subset of PEs created via strided split.
16//! - [`SymmetricBuffer`] — a typed allocation on the symmetric heap.
17//! - Host-initiated RMA — blocking and stream-ordered [`Context::put`] /
18//! [`Context::get`].
19//!
20//! ## What it does *not* cover
21//!
22//! The **device-side** API — the `__device__` `nvshmem_int_p` /
23//! `nvshmem_putmem_nbi` calls issued from *inside* a CUDA kernel — requires
24//! linking `libnvshmem_device.a` into the consumer's kernel binary and is
25//! out of scope (it cannot be a lazily-loaded host symbol). A consumer that
26//! needs device-side NVSHMEM writes its own `.cu` that includes the NVSHMEM
27//! headers and links the device archive.
28//!
29//! ## Availability
30//!
31//! NVSHMEM is a Linux library requiring compute capability sm_70+ (every
32//! baracuda-supported GPU qualifies). On hosts without the NVSHMEM runtime
33//! installed, [`Context::init`] returns `LoaderError::LibraryNotFound`, so
34//! callers can fall back to single-process execution.
35//!
36//! [`baracuda-nccl`]: https://docs.rs/baracuda-nccl
37
38#![warn(missing_debug_implementations)]
39
40use core::ffi::{c_int, c_void};
41
42use baracuda_driver::Stream;
43use baracuda_nvshmem_sys::{nvshmem, nvshmemResult_t, nvshmem_team_t, nvshmemx_uniqueid_t};
44use baracuda_types::DeviceRepr;
45
46/// Error type for NVSHMEM operations.
47pub type Error = baracuda_core::Error<nvshmemResult_t>;
48/// Result alias.
49pub type Result<T, E = Error> = core::result::Result<T, E>;
50
51#[inline]
52fn check(status: nvshmemResult_t) -> Result<()> {
53 Error::check(status)
54}
55
56#[inline]
57fn stream_raw(stream: &Stream) -> baracuda_cuda_sys::runtime::cudaStream_t {
58 // `CUstream` (driver) and `cudaStream_t` (runtime) are the same opaque
59 // handle at the ABI level — the runtime stream wraps the driver one.
60 stream.as_raw() as _
61}
62
63// ---- Context --------------------------------------------------------------
64
65/// The process-wide NVSHMEM runtime, from this PE's point of view.
66///
67/// NVSHMEM is a **process singleton**: `nvshmem_init` / `nvshmem_finalize`
68/// must be called exactly once each per process. Construct a single
69/// `Context` near program start and drop it (or call [`Context::finalize`])
70/// at shutdown. `my_pe` / `n_pes` are read once at init and cached, so the
71/// hot accessors are infallible.
72pub struct Context {
73 my_pe: i32,
74 n_pes: i32,
75 finalized: bool,
76}
77
78impl core::fmt::Debug for Context {
79 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
80 f.debug_struct("nvshmem::Context")
81 .field("my_pe", &self.my_pe)
82 .field("n_pes", &self.n_pes)
83 .finish()
84 }
85}
86
87impl Context {
88 /// Initialize NVSHMEM using the environment-selected bootstrap
89 /// (`NVSHMEM_BOOTSTRAP` — PMI / MPI / PMIx). This is the common launcher
90 /// path (`nvshmrun` / `mpirun`). Call exactly once per process.
91 pub fn init() -> Result<Self> {
92 let n = nvshmem()?;
93 let init = n.nvshmem_init()?;
94 unsafe { init() };
95 Self::from_initialized()
96 }
97
98 /// Initialize NVSHMEM through `nvshmemx_init_attr`, passing a caller-built
99 /// attributes struct (e.g. one populated from a
100 /// [`nvshmemx_uniqueid_t`](baracuda_nvshmem_sys::nvshmemx_uniqueid_t) via
101 /// the raw `nvshmemx_set_attr_uniqueid_args`). Pass `flags` and a pointer
102 /// to a valid `nvshmemx_init_attr_t` (or null for defaults).
103 ///
104 /// The attribute struct's layout is NVSHMEM-version-specific and so is
105 /// **not** modeled as a typed Rust struct — build it through the raw
106 /// [`baracuda-nvshmem-sys`] helpers.
107 ///
108 /// # Safety
109 ///
110 /// `attr` must be null or a properly-initialized `nvshmemx_init_attr_t`
111 /// for the installed NVSHMEM version, and must outlive the call.
112 ///
113 /// [`baracuda-nvshmem-sys`]: baracuda_nvshmem_sys
114 pub unsafe fn init_with_attr(flags: u32, attr: *mut c_void) -> Result<Self> {
115 let n = nvshmem()?;
116 let init = n.nvshmemx_init_attr()?;
117 check(unsafe { init(flags, attr) })?;
118 Self::from_initialized()
119 }
120
121 fn from_initialized() -> Result<Self> {
122 let n = nvshmem()?;
123 let my_pe = unsafe { (n.nvshmem_my_pe()?)() };
124 let n_pes = unsafe { (n.nvshmem_n_pes()?)() };
125 Ok(Self {
126 my_pe,
127 n_pes,
128 finalized: false,
129 })
130 }
131
132 /// This PE's global index (0..`n_pes`). Cached at init.
133 #[inline]
134 pub fn my_pe(&self) -> i32 {
135 self.my_pe
136 }
137
138 /// Total number of PEs in the program. Cached at init.
139 #[inline]
140 pub fn n_pes(&self) -> i32 {
141 self.n_pes
142 }
143
144 /// NVSHMEM version as `(major, minor)`.
145 pub fn version(&self) -> Result<(i32, i32)> {
146 let n = nvshmem()?;
147 let cu = n.nvshmem_info_get_version()?;
148 let mut major: c_int = 0;
149 let mut minor: c_int = 0;
150 unsafe { cu(&mut major, &mut minor) };
151 Ok((major, minor))
152 }
153
154 /// Allocate `len` elements of `T` on the symmetric heap. The returned
155 /// buffer occupies the **same virtual address on every PE**, so its
156 /// pointer doubles as a remote address in [`Context::put`] /
157 /// [`Context::get`].
158 pub fn malloc<T: DeviceRepr>(&self, len: usize) -> Result<SymmetricBuffer<T>> {
159 SymmetricBuffer::new(len)
160 }
161
162 /// The team of all PEs ([`Team::WORLD`]).
163 #[inline]
164 pub fn world(&self) -> Team {
165 Team::WORLD
166 }
167
168 // -- ordering / synchronization --
169
170 /// Global barrier: every PE arrives **and** all RMA issued before the
171 /// call has completed remotely.
172 pub fn barrier_all(&self) -> Result<()> {
173 let n = nvshmem()?;
174 unsafe { (n.nvshmem_barrier_all()?)() };
175 Ok(())
176 }
177
178 /// Stream-ordered [`Self::barrier_all`].
179 pub fn barrier_all_on_stream(&self, stream: &Stream) -> Result<()> {
180 let n = nvshmem()?;
181 let cu = n.nvshmemx_barrier_all_on_stream()?;
182 unsafe { cu(stream_raw(stream)) };
183 Ok(())
184 }
185
186 /// Lighter barrier — PE arrival only, without the RMA remote-completion
187 /// guarantee of [`Self::barrier_all`].
188 pub fn sync_all(&self) -> Result<()> {
189 let n = nvshmem()?;
190 unsafe { (n.nvshmem_sync_all()?)() };
191 Ok(())
192 }
193
194 /// Block until all RMA issued by this PE has completed remotely.
195 pub fn quiet(&self) -> Result<()> {
196 let n = nvshmem()?;
197 unsafe { (n.nvshmem_quiet()?)() };
198 Ok(())
199 }
200
201 /// Order (but do not wait for completion of) outstanding RMA from this PE.
202 pub fn fence(&self) -> Result<()> {
203 let n = nvshmem()?;
204 unsafe { (n.nvshmem_fence()?)() };
205 Ok(())
206 }
207
208 // -- host-initiated RMA --
209
210 /// Blocking host put: copy `count` elements from the local `src` buffer
211 /// into PE `pe`'s copy of the symmetric `dest` buffer. Returns after the
212 /// data has left the local PE.
213 pub fn put<T: DeviceRepr>(
214 &self,
215 dest: &SymmetricBuffer<T>,
216 src: &SymmetricBuffer<T>,
217 count: usize,
218 pe: i32,
219 ) -> Result<()> {
220 assert!(count <= dest.len() && count <= src.len(), "put out of range");
221 let n = nvshmem()?;
222 let cu = n.nvshmem_putmem()?;
223 unsafe {
224 cu(
225 dest.ptr,
226 src.ptr as *const c_void,
227 count * core::mem::size_of::<T>(),
228 pe,
229 )
230 };
231 Ok(())
232 }
233
234 /// Blocking host get: copy `count` elements from PE `pe`'s copy of the
235 /// symmetric `src` buffer into the local `dest` buffer.
236 pub fn get<T: DeviceRepr>(
237 &self,
238 dest: &SymmetricBuffer<T>,
239 src: &SymmetricBuffer<T>,
240 count: usize,
241 pe: i32,
242 ) -> Result<()> {
243 assert!(count <= dest.len() && count <= src.len(), "get out of range");
244 let n = nvshmem()?;
245 let cu = n.nvshmem_getmem()?;
246 unsafe {
247 cu(
248 dest.ptr,
249 src.ptr as *const c_void,
250 count * core::mem::size_of::<T>(),
251 pe,
252 )
253 };
254 Ok(())
255 }
256
257 /// Stream-ordered [`Self::put`]. Completes in `stream` order; pair with
258 /// [`Self::quiet`] (or a [`Self::barrier_all`]) for remote completion.
259 pub fn put_on_stream<T: DeviceRepr>(
260 &self,
261 dest: &SymmetricBuffer<T>,
262 src: &SymmetricBuffer<T>,
263 count: usize,
264 pe: i32,
265 stream: &Stream,
266 ) -> Result<()> {
267 assert!(count <= dest.len() && count <= src.len(), "put out of range");
268 let n = nvshmem()?;
269 let cu = n.nvshmemx_putmem_on_stream()?;
270 unsafe {
271 cu(
272 dest.ptr,
273 src.ptr as *const c_void,
274 count * core::mem::size_of::<T>(),
275 pe,
276 stream_raw(stream),
277 )
278 };
279 Ok(())
280 }
281
282 /// Stream-ordered [`Self::get`].
283 pub fn get_on_stream<T: DeviceRepr>(
284 &self,
285 dest: &SymmetricBuffer<T>,
286 src: &SymmetricBuffer<T>,
287 count: usize,
288 pe: i32,
289 stream: &Stream,
290 ) -> Result<()> {
291 assert!(count <= dest.len() && count <= src.len(), "get out of range");
292 let n = nvshmem()?;
293 let cu = n.nvshmemx_getmem_on_stream()?;
294 unsafe {
295 cu(
296 dest.ptr,
297 src.ptr as *const c_void,
298 count * core::mem::size_of::<T>(),
299 pe,
300 stream_raw(stream),
301 )
302 };
303 Ok(())
304 }
305
306 /// Explicitly finalize NVSHMEM. Idempotent — also run on [`Drop`]. After
307 /// this no further NVSHMEM calls are valid.
308 pub fn finalize(&mut self) -> Result<()> {
309 if self.finalized {
310 return Ok(());
311 }
312 let n = nvshmem()?;
313 unsafe { (n.nvshmem_finalize()?)() };
314 self.finalized = true;
315 Ok(())
316 }
317}
318
319impl Drop for Context {
320 fn drop(&mut self) {
321 if self.finalized {
322 return;
323 }
324 if let Ok(n) = nvshmem() {
325 if let Ok(cu) = n.nvshmem_finalize() {
326 unsafe { cu() };
327 }
328 }
329 }
330}
331
332// ---- Team -----------------------------------------------------------------
333
334/// A team — a named subset of PEs. Teams created via
335/// [`Team::split_strided`] must be released with [`Team::destroy`]; the
336/// predefined [`Team::WORLD`] / [`Team::SHARED`] must **not** be destroyed.
337#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
338pub struct Team(nvshmem_team_t);
339
340impl Team {
341 /// The team of every PE in the program.
342 pub const WORLD: Self = Self(nvshmem_team_t::WORLD);
343 /// The team of PEs sharing a compute node.
344 pub const SHARED: Self = Self(nvshmem_team_t::SHARED);
345
346 /// Create a sub-team of `size` PEs from this team, starting at PE `start`
347 /// (in this team's index space) and taking every `stride`-th PE. Returns
348 /// `None` on the PEs that are **not** members of the new team.
349 ///
350 /// Defaults are used for the team config (`config = null`,
351 /// `config_mask = 0`).
352 pub fn split_strided(
353 &self,
354 start: i32,
355 stride: i32,
356 size: i32,
357 ) -> Result<Option<Team>> {
358 let n = nvshmem()?;
359 let cu = n.nvshmem_team_split_strided()?;
360 let mut new_team = nvshmem_team_t::INVALID;
361 check(unsafe {
362 cu(
363 self.0,
364 start,
365 stride,
366 size,
367 core::ptr::null(),
368 0,
369 &mut new_team,
370 )
371 })?;
372 if new_team == nvshmem_team_t::INVALID {
373 Ok(None)
374 } else {
375 Ok(Some(Team(new_team)))
376 }
377 }
378
379 /// This PE's index *within this team* (0..`n_pes`), or `-1` if this PE is
380 /// not a member.
381 pub fn my_pe(&self) -> Result<i32> {
382 let n = nvshmem()?;
383 let cu = n.nvshmem_team_my_pe()?;
384 Ok(unsafe { cu(self.0) })
385 }
386
387 /// Number of PEs in this team.
388 pub fn n_pes(&self) -> Result<i32> {
389 let n = nvshmem()?;
390 let cu = n.nvshmem_team_n_pes()?;
391 Ok(unsafe { cu(self.0) })
392 }
393
394 /// Translate `src_pe` (an index in this team) into its index in
395 /// `dest_team`. Returns `-1` if `src_pe` is not in `dest_team`.
396 pub fn translate_pe(&self, src_pe: i32, dest_team: Team) -> Result<i32> {
397 let n = nvshmem()?;
398 let cu = n.nvshmem_team_translate_pe()?;
399 Ok(unsafe { cu(self.0, src_pe, dest_team.0) })
400 }
401
402 /// Destroy a team created via [`Self::split_strided`]. Destroying a
403 /// predefined team ([`Self::WORLD`] / [`Self::SHARED`]) is a programmer
404 /// error and is rejected here.
405 pub fn destroy(self) -> Result<()> {
406 if self == Team::WORLD || self == Team::SHARED {
407 // Predefined teams are owned by the runtime — don't free them.
408 return Ok(());
409 }
410 let n = nvshmem()?;
411 let cu = n.nvshmem_team_destroy()?;
412 unsafe { cu(self.0) };
413 Ok(())
414 }
415
416 /// The raw team handle.
417 #[inline]
418 pub fn as_raw(&self) -> nvshmem_team_t {
419 self.0
420 }
421}
422
423// ---- SymmetricBuffer ------------------------------------------------------
424
425/// A typed allocation on the NVSHMEM symmetric heap. The same virtual address
426/// is valid on every PE, so the pointer can be used as a remote address in
427/// [`Context::put`] / [`Context::get`]. Freed on [`Drop`] via `nvshmem_free`.
428///
429/// `SymmetricBuffer` is **not** `Send`/`Sync`: NVSHMEM is bound to the PE's
430/// owning thread/process and the buffer must be freed on the same.
431pub struct SymmetricBuffer<T: DeviceRepr> {
432 ptr: *mut c_void,
433 len: usize,
434 _marker: core::marker::PhantomData<T>,
435}
436
437impl<T: DeviceRepr> core::fmt::Debug for SymmetricBuffer<T> {
438 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
439 f.debug_struct("SymmetricBuffer")
440 .field("ptr", &self.ptr)
441 .field("len", &self.len)
442 .finish()
443 }
444}
445
446impl<T: DeviceRepr> SymmetricBuffer<T> {
447 /// Allocate `len` elements on the symmetric heap. This is a **collective**
448 /// call — every PE must call it with the same `len` (NVSHMEM contract).
449 pub fn new(len: usize) -> Result<Self> {
450 let n = nvshmem()?;
451 let cu = n.nvshmem_malloc()?;
452 let bytes = len.checked_mul(core::mem::size_of::<T>()).expect("size overflow");
453 let ptr = unsafe { cu(bytes) };
454 if ptr.is_null() && bytes != 0 {
455 // nvshmem_malloc aborts internally on real OOM; a null with a
456 // non-zero request still warrants an error rather than a silent
457 // dangling buffer.
458 return Err(Error::Status {
459 status: nvshmemResult_t(1),
460 });
461 }
462 Ok(Self {
463 ptr,
464 len,
465 _marker: core::marker::PhantomData,
466 })
467 }
468
469 /// Element count.
470 #[inline]
471 pub fn len(&self) -> usize {
472 self.len
473 }
474
475 /// Whether the buffer is empty.
476 #[inline]
477 pub fn is_empty(&self) -> bool {
478 self.len == 0
479 }
480
481 /// The symmetric device pointer (same VA on every PE).
482 #[inline]
483 pub fn as_ptr(&self) -> *const T {
484 self.ptr as *const T
485 }
486
487 /// The symmetric device pointer, mutable.
488 #[inline]
489 pub fn as_mut_ptr(&self) -> *mut T {
490 self.ptr as *mut T
491 }
492}
493
494impl<T: DeviceRepr> Drop for SymmetricBuffer<T> {
495 fn drop(&mut self) {
496 if self.ptr.is_null() {
497 return;
498 }
499 if let Ok(n) = nvshmem() {
500 if let Ok(cu) = n.nvshmem_free() {
501 unsafe { cu(self.ptr) };
502 }
503 }
504 }
505}
506
507// ---- UniqueId -------------------------------------------------------------
508
509/// A 128-byte identifier for the unique-id bootstrap (the NVSHMEM analogue of
510/// NCCL's `UniqueId`). One PE calls [`UniqueId::new`] and distributes the
511/// bytes to every other PE; each then feeds it to the raw
512/// `nvshmemx_set_attr_uniqueid_args` + [`Context::init_with_attr`] path.
513///
514/// Wiring the id into init requires the version-specific
515/// `nvshmemx_init_attr_t` struct, which this safe layer deliberately does not
516/// model — use the raw [`baracuda-nvshmem-sys`] helpers for that step.
517///
518/// [`baracuda-nvshmem-sys`]: baracuda_nvshmem_sys
519#[derive(Copy, Clone, Debug)]
520pub struct UniqueId(nvshmemx_uniqueid_t);
521
522impl UniqueId {
523 /// Generate a fresh unique id on this PE.
524 pub fn new() -> Result<Self> {
525 let n = nvshmem()?;
526 let cu = n.nvshmemx_get_uniqueid()?;
527 let mut id = nvshmemx_uniqueid_t::default();
528 check(unsafe { cu(&mut id) })?;
529 Ok(Self(id))
530 }
531
532 /// Raw representation. Transmit verbatim to the other PEs.
533 pub fn as_raw(&self) -> nvshmemx_uniqueid_t {
534 self.0
535 }
536
537 /// Rebuild from a raw id received from another PE.
538 pub fn from_raw(id: nvshmemx_uniqueid_t) -> Self {
539 Self(id)
540 }
541}
542
543/// Convenience: NVSHMEM library version as `(major, minor)` without holding a
544/// [`Context`]. Useful for capability probes. Errors if NVSHMEM is not
545/// installed.
546pub fn version() -> Result<(i32, i32)> {
547 let n = nvshmem()?;
548 let cu = n.nvshmem_info_get_version()?;
549 let mut major: c_int = 0;
550 let mut minor: c_int = 0;
551 unsafe { cu(&mut major, &mut minor) };
552 Ok((major, minor))
553}