use std::ffi::c_int;
use oxicuda_driver::device::Device;
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::loader::try_driver;
use oxicuda_driver::primary_context::PrimaryContext;
use oxicuda_driver::stream::Stream;
use crate::device_buffer::DeviceBuffer;
pub fn can_access_peer(device: &Device, peer: &Device) -> CudaResult<bool> {
let api = try_driver()?;
let mut can_access: c_int = 0;
oxicuda_driver::error::check(unsafe {
(api.cu_device_can_access_peer)(&mut can_access, device.raw(), peer.raw())
})?;
Ok(can_access != 0)
}
pub fn enable_peer_access(device: &Device, peer: &Device) -> CudaResult<()> {
let api = try_driver()?;
let dev_ctx = PrimaryContext::retain(device)?;
let peer_ctx = PrimaryContext::retain(peer)?;
oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(dev_ctx.raw()) })?;
let rc =
oxicuda_driver::error::check(unsafe { (api.cu_ctx_enable_peer_access)(peer_ctx.raw(), 0) });
let _ = peer_ctx.release();
let _ = dev_ctx.release();
rc
}
pub fn disable_peer_access(device: &Device, peer: &Device) -> CudaResult<()> {
let api = try_driver()?;
let dev_ctx = PrimaryContext::retain(device)?;
let peer_ctx = PrimaryContext::retain(peer)?;
oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(dev_ctx.raw()) })?;
let rc =
oxicuda_driver::error::check(unsafe { (api.cu_ctx_disable_peer_access)(peer_ctx.raw()) });
let _ = peer_ctx.release();
let _ = dev_ctx.release();
rc
}
pub fn copy_peer<T: Copy>(
dst: &mut DeviceBuffer<T>,
dst_device: &Device,
src: &DeviceBuffer<T>,
src_device: &Device,
) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let byte_size = src.byte_size();
let dst_ctx = PrimaryContext::retain(dst_device)?;
let src_ctx = PrimaryContext::retain(src_device)?;
let rc = oxicuda_driver::error::check(unsafe {
(api.cu_memcpy_peer)(
dst.as_device_ptr(),
dst_ctx.raw(),
src.as_device_ptr(),
src_ctx.raw(),
byte_size,
)
});
let _ = src_ctx.release();
let _ = dst_ctx.release();
rc
}
pub fn copy_peer_async<T: Copy>(
dst: &mut DeviceBuffer<T>,
dst_device: &Device,
src: &DeviceBuffer<T>,
src_device: &Device,
stream: &Stream,
) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let byte_size = src.byte_size();
let dst_ctx = PrimaryContext::retain(dst_device)?;
let src_ctx = PrimaryContext::retain(src_device)?;
let rc = oxicuda_driver::error::check(unsafe {
(api.cu_memcpy_peer_async)(
dst.as_device_ptr(),
dst_ctx.raw(),
src.as_device_ptr(),
src_ctx.raw(),
byte_size,
stream.raw(),
)
});
let _ = src_ctx.release();
let _ = dst_ctx.release();
rc
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn function_signatures_compile() {
let _f1: fn(&Device, &Device) -> CudaResult<bool> = can_access_peer;
let _f2: fn(&Device, &Device) -> CudaResult<()> = enable_peer_access;
let _f3: fn(&Device, &Device) -> CudaResult<()> = disable_peer_access;
let _f4: fn(
&mut DeviceBuffer<f32>,
&Device,
&DeviceBuffer<f32>,
&Device,
) -> CudaResult<()> = copy_peer;
}
#[test]
fn copy_peer_length_mismatch_returns_invalid_value() {
type PeerAsyncFn = fn(
&mut DeviceBuffer<f32>,
&Device,
&DeviceBuffer<f32>,
&Device,
&Stream,
) -> CudaResult<()>;
let _f: PeerAsyncFn = copy_peer_async;
}
#[cfg(feature = "gpu-tests")]
#[test]
fn can_access_peer_single_gpu() {
oxicuda_driver::init().ok();
let count = oxicuda_driver::device::Device::count().unwrap_or(0);
if count >= 1 {
let dev0 = Device::get(0).expect("device 0");
if count == 1 {
let _ = can_access_peer(&dev0, &dev0);
} else {
let dev1 = Device::get(1).expect("device 1");
let _ = can_access_peer(&dev0, &dev1);
}
}
}
}