use std::sync::Arc;
use baracuda_cuda_sys::types::{
CUmemAllocationHandleType, CUmulticastGranularity_flags, CUmulticastObjectProp,
};
use baracuda_cuda_sys::{driver, CUdeviceptr, CUmemGenericAllocationHandle};
use crate::device::Device;
use crate::error::{check, Result};
pub fn multicast_granularity(
num_devices: u32,
size_bytes: usize,
recommended: bool,
) -> Result<usize> {
let d = driver()?;
let cu = d.cu_multicast_get_granularity()?;
let prop = CUmulticastObjectProp {
num_devices,
size: size_bytes,
handle_types: CUmemAllocationHandleType::NONE as u64,
flags: 0,
};
let mut g: usize = 0;
let option = if recommended {
CUmulticastGranularity_flags::RECOMMENDED
} else {
CUmulticastGranularity_flags::MINIMUM
};
check(unsafe { cu(&mut g, &prop, option) })?;
Ok(g)
}
#[derive(Clone)]
pub struct MulticastObject {
inner: Arc<MulticastObjectInner>,
}
struct MulticastObjectInner {
handle: CUmemGenericAllocationHandle,
}
unsafe impl Send for MulticastObjectInner {}
unsafe impl Sync for MulticastObjectInner {}
impl core::fmt::Debug for MulticastObjectInner {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("MulticastObject")
.field("handle", &self.handle)
.finish_non_exhaustive()
}
}
impl core::fmt::Debug for MulticastObject {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl MulticastObject {
pub fn new(num_devices: u32, size_bytes: usize) -> Result<Self> {
let d = driver()?;
let cu = d.cu_multicast_create()?;
let prop = CUmulticastObjectProp {
num_devices,
size: size_bytes,
handle_types: CUmemAllocationHandleType::NONE as u64,
flags: 0,
};
let mut handle: CUmemGenericAllocationHandle = 0;
check(unsafe { cu(&mut handle, &prop) })?;
Ok(Self {
inner: Arc::new(MulticastObjectInner { handle }),
})
}
pub fn add_device(&self, device: &Device) -> Result<()> {
let d = driver()?;
let cu = d.cu_multicast_add_device()?;
check(unsafe { cu(self.inner.handle, device.as_raw()) })
}
pub fn bind_mem(
&self,
mc_offset: usize,
mem_handle: CUmemGenericAllocationHandle,
mem_offset: usize,
size: usize,
) -> Result<()> {
let d = driver()?;
let cu = d.cu_multicast_bind_mem()?;
check(unsafe {
cu(
self.inner.handle,
mc_offset,
mem_handle,
mem_offset,
size,
0,
)
})
}
pub fn bind_addr(&self, mc_offset: usize, ptr: CUdeviceptr, size: usize) -> Result<()> {
let d = driver()?;
let cu = d.cu_multicast_bind_addr()?;
check(unsafe { cu(self.inner.handle, mc_offset, ptr, size, 0) })
}
pub fn unbind(&self, device: &Device, mc_offset: usize, size: usize) -> Result<()> {
let d = driver()?;
let cu = d.cu_multicast_unbind()?;
check(unsafe { cu(self.inner.handle, device.as_raw(), mc_offset, size) })
}
#[inline]
pub fn as_raw(&self) -> CUmemGenericAllocationHandle {
self.inner.handle
}
}
impl Drop for MulticastObjectInner {
fn drop(&mut self) {
if self.handle == 0 {
return;
}
if let Ok(d) = driver() {
if let Ok(cu) = d.cu_mem_release() {
let _ = unsafe { cu(self.handle) };
}
}
}
}