use core::ffi::c_void;
use baracuda_cuda_sys::runtime::runtime;
use baracuda_cuda_sys::runtime::types::cudaMemGenericAllocationHandle_t;
use baracuda_types::{supports, Feature};
use crate::device::Device;
use crate::error::{check, Error, Result};
#[repr(C)]
#[derive(Copy, Clone, Debug, Default)]
#[allow(non_camel_case_types)]
pub struct MulticastProp {
pub num_devices: core::ffi::c_uint,
pub size: usize,
pub handle_types: core::ffi::c_int,
pub flags: u64,
}
fn require_multicast() -> Result<()> {
let installed = crate::init::driver_version()?;
if supports(installed, Feature::MulticastObjects) {
Ok(())
} else {
Err(Error::FeatureNotSupported {
api: "cudaMulticast*",
since: Feature::MulticastObjects.required_version(),
})
}
}
#[derive(Debug)]
pub struct MulticastObject {
handle: cudaMemGenericAllocationHandle_t,
}
impl MulticastObject {
pub fn new(prop: &MulticastProp) -> Result<Self> {
require_multicast()?;
let r = runtime()?;
let cu = r.cuda_multicast_create()?;
let mut h: cudaMemGenericAllocationHandle_t = 0;
check(unsafe { cu(&mut h, prop as *const MulticastProp as *const c_void) })?;
Ok(Self { handle: h })
}
#[inline]
pub fn as_raw(&self) -> cudaMemGenericAllocationHandle_t {
self.handle
}
pub fn add_device(&self, device: &Device) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_multicast_add_device()?;
check(unsafe { cu(self.handle, device.ordinal()) })
}
pub unsafe fn bind_mem(
&self,
mc_offset: usize,
mem_handle: cudaMemGenericAllocationHandle_t,
mem_offset: usize,
size: usize,
flags: u64,
) -> Result<()> { unsafe {
let r = runtime()?;
let cu = r.cuda_multicast_bind_mem()?;
check(cu(
self.handle,
mc_offset,
mem_handle,
mem_offset,
size,
flags,
))
}}
pub unsafe fn bind_addr(
&self,
mc_offset: usize,
mem_ptr: *mut c_void,
size: usize,
flags: u64,
) -> Result<()> { unsafe {
let r = runtime()?;
let cu = r.cuda_multicast_bind_addr()?;
check(cu(self.handle, mc_offset, mem_ptr, size, flags))
}}
pub fn unbind(&self, device: &Device, mc_offset: usize, size: usize) -> Result<()> {
let r = runtime()?;
let cu = r.cuda_multicast_unbind()?;
check(unsafe { cu(self.handle, device.ordinal(), mc_offset, size) })
}
}
impl Drop for MulticastObject {
fn drop(&mut self) {
if let Ok(r) = runtime() {
if let Ok(cu) = r.cuda_mem_release() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
pub fn multicast_granularity(prop: &MulticastProp, option: i32) -> Result<usize> {
require_multicast()?;
let r = runtime()?;
let cu = r.cuda_multicast_get_granularity()?;
let mut g: usize = 0;
check(unsafe {
cu(
&mut g,
prop as *const MulticastProp as *const c_void,
option,
)
})?;
Ok(g)
}