use std::ffi::c_int;
use oxicuda_driver::loader::try_driver;
use crate::error::{CudaRtError, CudaRtResult};
use crate::memory::DevicePtr;
use crate::stream::CudaStream;
pub fn device_can_access_peer(device: u32, peer_device: u32) -> CudaRtResult<bool> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut can_access: c_int = 0;
let rc = unsafe {
(api.cu_device_can_access_peer)(&raw mut can_access, device as c_int, peer_device as c_int)
};
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
}
Ok(can_access != 0)
}
pub fn device_enable_peer_access(peer_device: u32, flags: u32) -> CudaRtResult<()> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut peer_ctx = oxicuda_driver::ffi::CUcontext::default();
let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut peer_ctx, peer_device as c_int) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
}
let rc2 = unsafe { (api.cu_ctx_enable_peer_access)(peer_ctx, flags) };
if rc2 != 0 {
unsafe { (api.cu_device_primary_ctx_release_v2)(peer_device as c_int) };
return Err(CudaRtError::from_code(rc2).unwrap_or(CudaRtError::PeerAccessUnsupported));
}
Ok(())
}
pub fn device_disable_peer_access(peer_device: u32) -> CudaRtResult<()> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut peer_ctx = oxicuda_driver::ffi::CUcontext::default();
let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut peer_ctx, peer_device as c_int) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
}
let rc2 = unsafe { (api.cu_ctx_disable_peer_access)(peer_ctx) };
if rc2 != 0 {
unsafe { (api.cu_device_primary_ctx_release_v2)(peer_device as c_int) };
return Err(CudaRtError::from_code(rc2).unwrap_or(CudaRtError::PeerAccessNotEnabled));
}
Ok(())
}
pub fn memcpy_peer(
dst: DevicePtr,
dst_device: u32,
src: DevicePtr,
src_device: u32,
count: usize,
) -> CudaRtResult<()> {
if count == 0 {
return Ok(());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut dst_ctx = oxicuda_driver::ffi::CUcontext::default();
let mut src_ctx = oxicuda_driver::ffi::CUcontext::default();
unsafe { (api.cu_device_primary_ctx_retain)(&raw mut dst_ctx, dst_device as c_int) };
unsafe { (api.cu_device_primary_ctx_retain)(&raw mut src_ctx, src_device as c_int) };
let rc = unsafe { (api.cu_memcpy_peer)(dst.0, dst_ctx, src.0, src_ctx, count) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
}
Ok(())
}
pub fn memcpy_peer_async(
dst: DevicePtr,
dst_device: u32,
src: DevicePtr,
src_device: u32,
count: usize,
stream: CudaStream,
) -> CudaRtResult<()> {
if count == 0 {
return Ok(());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut dst_ctx = oxicuda_driver::ffi::CUcontext::default();
let mut src_ctx = oxicuda_driver::ffi::CUcontext::default();
unsafe { (api.cu_device_primary_ctx_retain)(&raw mut dst_ctx, dst_device as c_int) };
unsafe { (api.cu_device_primary_ctx_retain)(&raw mut src_ctx, src_device as c_int) };
let rc =
unsafe { (api.cu_memcpy_peer_async)(dst.0, dst_ctx, src.0, src_ctx, count, stream.raw()) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn peer_access_self_check() {
match device_can_access_peer(0, 0) {
Ok(v) => {
let _ = v;
}
Err(CudaRtError::DriverNotAvailable)
| Err(CudaRtError::NoGpu)
| Err(CudaRtError::InitializationError)
| Err(CudaRtError::InvalidDevice) => {}
Err(e) => panic!("unexpected: {e}"),
}
}
}