#![warn(missing_debug_implementations)]
use core::ffi::{c_int, c_void};
use baracuda_driver::Stream;
use baracuda_nvshmem_sys::{nvshmem, nvshmemResult_t, nvshmem_team_t, nvshmemx_uniqueid_t};
use baracuda_types::DeviceRepr;
pub type Error = baracuda_core::Error<nvshmemResult_t>;
pub type Result<T, E = Error> = core::result::Result<T, E>;
#[inline]
fn check(status: nvshmemResult_t) -> Result<()> {
Error::check(status)
}
#[inline]
fn stream_raw(stream: &Stream) -> baracuda_cuda_sys::runtime::cudaStream_t {
stream.as_raw() as _
}
pub struct Context {
my_pe: i32,
n_pes: i32,
finalized: bool,
}
impl core::fmt::Debug for Context {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("nvshmem::Context")
.field("my_pe", &self.my_pe)
.field("n_pes", &self.n_pes)
.finish()
}
}
impl Context {
pub fn init() -> Result<Self> {
let n = nvshmem()?;
let init = n.nvshmem_init()?;
unsafe { init() };
Self::from_initialized()
}
pub unsafe fn init_with_attr(flags: u32, attr: *mut c_void) -> Result<Self> {
let n = nvshmem()?;
let init = n.nvshmemx_init_attr()?;
check(unsafe { init(flags, attr) })?;
Self::from_initialized()
}
fn from_initialized() -> Result<Self> {
let n = nvshmem()?;
let my_pe = unsafe { (n.nvshmem_my_pe()?)() };
let n_pes = unsafe { (n.nvshmem_n_pes()?)() };
Ok(Self {
my_pe,
n_pes,
finalized: false,
})
}
#[inline]
pub fn my_pe(&self) -> i32 {
self.my_pe
}
#[inline]
pub fn n_pes(&self) -> i32 {
self.n_pes
}
pub fn version(&self) -> Result<(i32, i32)> {
let n = nvshmem()?;
let cu = n.nvshmem_info_get_version()?;
let mut major: c_int = 0;
let mut minor: c_int = 0;
unsafe { cu(&mut major, &mut minor) };
Ok((major, minor))
}
pub fn malloc<T: DeviceRepr>(&self, len: usize) -> Result<SymmetricBuffer<T>> {
SymmetricBuffer::new(len)
}
#[inline]
pub fn world(&self) -> Team {
Team::WORLD
}
pub fn barrier_all(&self) -> Result<()> {
let n = nvshmem()?;
unsafe { (n.nvshmem_barrier_all()?)() };
Ok(())
}
pub fn barrier_all_on_stream(&self, stream: &Stream) -> Result<()> {
let n = nvshmem()?;
let cu = n.nvshmemx_barrier_all_on_stream()?;
unsafe { cu(stream_raw(stream)) };
Ok(())
}
pub fn sync_all(&self) -> Result<()> {
let n = nvshmem()?;
unsafe { (n.nvshmem_sync_all()?)() };
Ok(())
}
pub fn quiet(&self) -> Result<()> {
let n = nvshmem()?;
unsafe { (n.nvshmem_quiet()?)() };
Ok(())
}
pub fn fence(&self) -> Result<()> {
let n = nvshmem()?;
unsafe { (n.nvshmem_fence()?)() };
Ok(())
}
pub fn put<T: DeviceRepr>(
&self,
dest: &SymmetricBuffer<T>,
src: &SymmetricBuffer<T>,
count: usize,
pe: i32,
) -> Result<()> {
assert!(count <= dest.len() && count <= src.len(), "put out of range");
let n = nvshmem()?;
let cu = n.nvshmem_putmem()?;
unsafe {
cu(
dest.ptr,
src.ptr as *const c_void,
count * core::mem::size_of::<T>(),
pe,
)
};
Ok(())
}
pub fn get<T: DeviceRepr>(
&self,
dest: &SymmetricBuffer<T>,
src: &SymmetricBuffer<T>,
count: usize,
pe: i32,
) -> Result<()> {
assert!(count <= dest.len() && count <= src.len(), "get out of range");
let n = nvshmem()?;
let cu = n.nvshmem_getmem()?;
unsafe {
cu(
dest.ptr,
src.ptr as *const c_void,
count * core::mem::size_of::<T>(),
pe,
)
};
Ok(())
}
pub fn put_on_stream<T: DeviceRepr>(
&self,
dest: &SymmetricBuffer<T>,
src: &SymmetricBuffer<T>,
count: usize,
pe: i32,
stream: &Stream,
) -> Result<()> {
assert!(count <= dest.len() && count <= src.len(), "put out of range");
let n = nvshmem()?;
let cu = n.nvshmemx_putmem_on_stream()?;
unsafe {
cu(
dest.ptr,
src.ptr as *const c_void,
count * core::mem::size_of::<T>(),
pe,
stream_raw(stream),
)
};
Ok(())
}
pub fn get_on_stream<T: DeviceRepr>(
&self,
dest: &SymmetricBuffer<T>,
src: &SymmetricBuffer<T>,
count: usize,
pe: i32,
stream: &Stream,
) -> Result<()> {
assert!(count <= dest.len() && count <= src.len(), "get out of range");
let n = nvshmem()?;
let cu = n.nvshmemx_getmem_on_stream()?;
unsafe {
cu(
dest.ptr,
src.ptr as *const c_void,
count * core::mem::size_of::<T>(),
pe,
stream_raw(stream),
)
};
Ok(())
}
pub fn finalize(&mut self) -> Result<()> {
if self.finalized {
return Ok(());
}
let n = nvshmem()?;
unsafe { (n.nvshmem_finalize()?)() };
self.finalized = true;
Ok(())
}
}
impl Drop for Context {
fn drop(&mut self) {
if self.finalized {
return;
}
if let Ok(n) = nvshmem() {
if let Ok(cu) = n.nvshmem_finalize() {
unsafe { cu() };
}
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct Team(nvshmem_team_t);
impl Team {
pub const WORLD: Self = Self(nvshmem_team_t::WORLD);
pub const SHARED: Self = Self(nvshmem_team_t::SHARED);
pub fn split_strided(
&self,
start: i32,
stride: i32,
size: i32,
) -> Result<Option<Team>> {
let n = nvshmem()?;
let cu = n.nvshmem_team_split_strided()?;
let mut new_team = nvshmem_team_t::INVALID;
check(unsafe {
cu(
self.0,
start,
stride,
size,
core::ptr::null(),
0,
&mut new_team,
)
})?;
if new_team == nvshmem_team_t::INVALID {
Ok(None)
} else {
Ok(Some(Team(new_team)))
}
}
pub fn my_pe(&self) -> Result<i32> {
let n = nvshmem()?;
let cu = n.nvshmem_team_my_pe()?;
Ok(unsafe { cu(self.0) })
}
pub fn n_pes(&self) -> Result<i32> {
let n = nvshmem()?;
let cu = n.nvshmem_team_n_pes()?;
Ok(unsafe { cu(self.0) })
}
pub fn translate_pe(&self, src_pe: i32, dest_team: Team) -> Result<i32> {
let n = nvshmem()?;
let cu = n.nvshmem_team_translate_pe()?;
Ok(unsafe { cu(self.0, src_pe, dest_team.0) })
}
pub fn destroy(self) -> Result<()> {
if self == Team::WORLD || self == Team::SHARED {
return Ok(());
}
let n = nvshmem()?;
let cu = n.nvshmem_team_destroy()?;
unsafe { cu(self.0) };
Ok(())
}
#[inline]
pub fn as_raw(&self) -> nvshmem_team_t {
self.0
}
}
pub struct SymmetricBuffer<T: DeviceRepr> {
ptr: *mut c_void,
len: usize,
_marker: core::marker::PhantomData<T>,
}
impl<T: DeviceRepr> core::fmt::Debug for SymmetricBuffer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("SymmetricBuffer")
.field("ptr", &self.ptr)
.field("len", &self.len)
.finish()
}
}
impl<T: DeviceRepr> SymmetricBuffer<T> {
pub fn new(len: usize) -> Result<Self> {
let n = nvshmem()?;
let cu = n.nvshmem_malloc()?;
let bytes = len.checked_mul(core::mem::size_of::<T>()).expect("size overflow");
let ptr = unsafe { cu(bytes) };
if ptr.is_null() && bytes != 0 {
return Err(Error::Status {
status: nvshmemResult_t(1),
});
}
Ok(Self {
ptr,
len,
_marker: core::marker::PhantomData,
})
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr as *const T
}
#[inline]
pub fn as_mut_ptr(&self) -> *mut T {
self.ptr as *mut T
}
}
impl<T: DeviceRepr> Drop for SymmetricBuffer<T> {
fn drop(&mut self) {
if self.ptr.is_null() {
return;
}
if let Ok(n) = nvshmem() {
if let Ok(cu) = n.nvshmem_free() {
unsafe { cu(self.ptr) };
}
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct UniqueId(nvshmemx_uniqueid_t);
impl UniqueId {
pub fn new() -> Result<Self> {
let n = nvshmem()?;
let cu = n.nvshmemx_get_uniqueid()?;
let mut id = nvshmemx_uniqueid_t::default();
check(unsafe { cu(&mut id) })?;
Ok(Self(id))
}
pub fn as_raw(&self) -> nvshmemx_uniqueid_t {
self.0
}
pub fn from_raw(id: nvshmemx_uniqueid_t) -> Self {
Self(id)
}
}
pub fn version() -> Result<(i32, i32)> {
let n = nvshmem()?;
let cu = n.nvshmem_info_get_version()?;
let mut major: c_int = 0;
let mut minor: c_int = 0;
unsafe { cu(&mut major, &mut minor) };
Ok((major, minor))
}