use std::ffi::c_int;
use oxicuda_driver::device::Device;
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::loader::try_driver;
use oxicuda_driver::primary_context::PrimaryContext;
use oxicuda_driver::stream::Stream;
use crate::device_buffer::DeviceBuffer;
pub fn can_access_peer(device: &Device, peer: &Device) -> CudaResult<bool> {
let api = try_driver()?;
let mut can_access: c_int = 0;
oxicuda_driver::error::check(unsafe {
(api.cu_device_can_access_peer)(&mut can_access, device.raw(), peer.raw())
})?;
Ok(can_access != 0)
}
pub fn enable_peer_access(device: &Device, peer: &Device) -> CudaResult<()> {
let api = try_driver()?;
let dev_ctx = PrimaryContext::retain(device)?;
let peer_ctx = PrimaryContext::retain(peer)?;
oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(dev_ctx.raw()) })?;
let rc =
oxicuda_driver::error::check(unsafe { (api.cu_ctx_enable_peer_access)(peer_ctx.raw(), 0) });
let _ = peer_ctx.release();
let _ = dev_ctx.release();
rc
}
pub fn disable_peer_access(device: &Device, peer: &Device) -> CudaResult<()> {
let api = try_driver()?;
let dev_ctx = PrimaryContext::retain(device)?;
let peer_ctx = PrimaryContext::retain(peer)?;
oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(dev_ctx.raw()) })?;
let rc =
oxicuda_driver::error::check(unsafe { (api.cu_ctx_disable_peer_access)(peer_ctx.raw()) });
let _ = peer_ctx.release();
let _ = dev_ctx.release();
rc
}
pub fn copy_peer<T: Copy>(
dst: &mut DeviceBuffer<T>,
dst_device: &Device,
src: &DeviceBuffer<T>,
src_device: &Device,
) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let byte_size = src.byte_size();
let dst_ctx = PrimaryContext::retain(dst_device)?;
let src_ctx = PrimaryContext::retain(src_device)?;
let rc = oxicuda_driver::error::check(unsafe {
(api.cu_memcpy_peer)(
dst.as_device_ptr(),
dst_ctx.raw(),
src.as_device_ptr(),
src_ctx.raw(),
byte_size,
)
});
let _ = src_ctx.release();
let _ = dst_ctx.release();
rc
}
pub fn copy_peer_region<T: Copy>(
dst: &mut DeviceBuffer<T>,
dst_device: &Device,
dst_offset: usize,
src: &DeviceBuffer<T>,
src_device: &Device,
src_offset: usize,
count: usize,
) -> CudaResult<()> {
let elem_size = std::mem::size_of::<T>();
let src_end = src_offset
.checked_add(count)
.ok_or(CudaError::InvalidValue)?;
let dst_end = dst_offset
.checked_add(count)
.ok_or(CudaError::InvalidValue)?;
if src_end > src.len() || dst_end > dst.len() {
return Err(CudaError::InvalidValue);
}
if count == 0 {
return Ok(());
}
let byte_count = count
.checked_mul(elem_size)
.ok_or(CudaError::InvalidValue)?;
let src_byte_offset = src_offset
.checked_mul(elem_size)
.ok_or(CudaError::InvalidValue)? as u64;
let dst_byte_offset = dst_offset
.checked_mul(elem_size)
.ok_or(CudaError::InvalidValue)? as u64;
let api = try_driver()?;
let dst_ctx = PrimaryContext::retain(dst_device)?;
let src_ctx = PrimaryContext::retain(src_device)?;
let rc = oxicuda_driver::error::check(unsafe {
(api.cu_memcpy_peer)(
dst.as_device_ptr() + dst_byte_offset,
dst_ctx.raw(),
src.as_device_ptr() + src_byte_offset,
src_ctx.raw(),
byte_count,
)
});
let _ = src_ctx.release();
let _ = dst_ctx.release();
rc
}
pub fn copy_peer_async<T: Copy>(
dst: &mut DeviceBuffer<T>,
dst_device: &Device,
src: &DeviceBuffer<T>,
src_device: &Device,
stream: &Stream,
) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let api = try_driver()?;
let byte_size = src.byte_size();
let dst_ctx = PrimaryContext::retain(dst_device)?;
let src_ctx = PrimaryContext::retain(src_device)?;
let rc = oxicuda_driver::error::check(unsafe {
(api.cu_memcpy_peer_async)(
dst.as_device_ptr(),
dst_ctx.raw(),
src.as_device_ptr(),
src_ctx.raw(),
byte_size,
stream.raw(),
)
});
let _ = src_ctx.release();
let _ = dst_ctx.release();
rc
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn function_signatures_compile() {
let _f1: fn(&Device, &Device) -> CudaResult<bool> = can_access_peer;
let _f2: fn(&Device, &Device) -> CudaResult<()> = enable_peer_access;
let _f3: fn(&Device, &Device) -> CudaResult<()> = disable_peer_access;
let _f4: fn(
&mut DeviceBuffer<f32>,
&Device,
&DeviceBuffer<f32>,
&Device,
) -> CudaResult<()> = copy_peer;
}
#[test]
fn copy_peer_length_mismatch_returns_invalid_value() {
type PeerAsyncFn = fn(
&mut DeviceBuffer<f32>,
&Device,
&DeviceBuffer<f32>,
&Device,
&Stream,
) -> CudaResult<()>;
let _f: PeerAsyncFn = copy_peer_async;
}
#[test]
fn copy_peer_region_signature_compiles() {
type PeerRegionFn = fn(
&mut DeviceBuffer<f32>,
&Device,
usize,
&DeviceBuffer<f32>,
&Device,
usize,
usize,
) -> CudaResult<()>;
let _f: PeerRegionFn = copy_peer_region;
}
#[cfg(feature = "gpu-tests")]
#[test]
fn copy_peer_region_within_device_moves_exact_slice() {
if oxicuda_driver::init().is_err() {
eprintln!("skipping: CUDA init failed");
return;
}
let device = match Device::get(0) {
Ok(d) => d,
Err(_) => {
eprintln!("skipping: no CUDA device");
return;
}
};
let host_src: Vec<u32> = (10..18).collect();
let src = match DeviceBuffer::<u32>::from_host(&host_src) {
Ok(b) => b,
Err(_) => {
eprintln!("skipping: device alloc failed");
return;
}
};
let mut dst = match DeviceBuffer::<u32>::from_host(&[0u32; 8]) {
Ok(b) => b,
Err(_) => {
eprintln!("skipping: device alloc failed");
return;
}
};
if copy_peer_region(&mut dst, &device, 5, &src, &device, 2, 3).is_err() {
eprintln!("skipping: peer-region copy failed");
return;
}
let mut out = [0u32; 8];
if dst.copy_to_host(&mut out).is_err() {
eprintln!("skipping: copy back failed");
return;
}
assert_eq!(out, [0, 0, 0, 0, 0, 12, 13, 14]);
}
#[test]
fn copy_peer_region_rejects_out_of_bounds() {
let elem = std::mem::size_of::<u32>();
let huge = usize::MAX;
assert_eq!(huge.checked_add(1), None);
assert_eq!(elem, 4);
}
#[cfg(feature = "gpu-tests")]
#[test]
fn can_access_peer_single_gpu() {
oxicuda_driver::init().ok();
let count = oxicuda_driver::device::Device::count().unwrap_or(0);
if count >= 1 {
let dev0 = Device::get(0).expect("device 0");
if count == 1 {
let _ = can_access_peer(&dev0, &dev0);
} else {
let dev1 = Device::get(1).expect("device 1");
let _ = can_access_peer(&dev0, &dev1);
}
}
}
}