1use std::ffi::c_int;
10
11use oxicuda_driver::loader::try_driver;
12
13use crate::error::{CudaRtError, CudaRtResult};
14use crate::memory::DevicePtr;
15use crate::stream::CudaStream;
16
17pub fn device_can_access_peer(device: u32, peer_device: u32) -> CudaRtResult<bool> {
27 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
28 let mut can_access: c_int = 0;
29 let rc = unsafe {
31 (api.cu_device_can_access_peer)(&raw mut can_access, device as c_int, peer_device as c_int)
32 };
33 if rc != 0 {
34 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
35 }
36 Ok(can_access != 0)
37}
38
39pub fn device_enable_peer_access(peer_device: u32, flags: u32) -> CudaRtResult<()> {
49 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
50 let mut peer_ctx = oxicuda_driver::ffi::CUcontext::default();
51 let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut peer_ctx, peer_device as c_int) };
54 if rc != 0 {
55 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
56 }
57 let rc2 = unsafe { (api.cu_ctx_enable_peer_access)(peer_ctx, flags) };
60 if rc2 != 0 {
61 unsafe { (api.cu_device_primary_ctx_release_v2)(peer_device as c_int) };
64 return Err(CudaRtError::from_code(rc2).unwrap_or(CudaRtError::PeerAccessUnsupported));
65 }
66 Ok(())
67}
68
69pub fn device_disable_peer_access(peer_device: u32) -> CudaRtResult<()> {
77 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
78 let mut peer_ctx = oxicuda_driver::ffi::CUcontext::default();
79 let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut peer_ctx, peer_device as c_int) };
81 if rc != 0 {
82 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
83 }
84 let rc2 = unsafe { (api.cu_ctx_disable_peer_access)(peer_ctx) };
86 if rc2 != 0 {
87 unsafe { (api.cu_device_primary_ctx_release_v2)(peer_device as c_int) };
89 return Err(CudaRtError::from_code(rc2).unwrap_or(CudaRtError::PeerAccessNotEnabled));
90 }
91 Ok(())
92}
93
94pub fn memcpy_peer(
102 dst: DevicePtr,
103 dst_device: u32,
104 src: DevicePtr,
105 src_device: u32,
106 count: usize,
107) -> CudaRtResult<()> {
108 if count == 0 {
109 return Ok(());
110 }
111 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
112 let mut dst_ctx = oxicuda_driver::ffi::CUcontext::default();
113 let mut src_ctx = oxicuda_driver::ffi::CUcontext::default();
114 unsafe { (api.cu_device_primary_ctx_retain)(&raw mut dst_ctx, dst_device as c_int) };
116 unsafe { (api.cu_device_primary_ctx_retain)(&raw mut src_ctx, src_device as c_int) };
117 let rc = unsafe { (api.cu_memcpy_peer)(dst.0, dst_ctx, src.0, src_ctx, count) };
119 if rc != 0 {
120 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
121 }
122 Ok(())
123}
124
125pub fn memcpy_peer_async(
133 dst: DevicePtr,
134 dst_device: u32,
135 src: DevicePtr,
136 src_device: u32,
137 count: usize,
138 stream: CudaStream,
139) -> CudaRtResult<()> {
140 if count == 0 {
141 return Ok(());
142 }
143 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
144 let mut dst_ctx = oxicuda_driver::ffi::CUcontext::default();
145 let mut src_ctx = oxicuda_driver::ffi::CUcontext::default();
146 unsafe { (api.cu_device_primary_ctx_retain)(&raw mut dst_ctx, dst_device as c_int) };
148 unsafe { (api.cu_device_primary_ctx_retain)(&raw mut src_ctx, src_device as c_int) };
149 let rc =
151 unsafe { (api.cu_memcpy_peer_async)(dst.0, dst_ctx, src.0, src_ctx, count, stream.raw()) };
152 if rc != 0 {
153 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
154 }
155 Ok(())
156}
157
158#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn peer_access_self_check() {
166 match device_can_access_peer(0, 0) {
169 Ok(v) => {
170 let _ = v;
172 }
173 Err(CudaRtError::DriverNotAvailable)
175 | Err(CudaRtError::NoGpu)
176 | Err(CudaRtError::InitializationError)
177 | Err(CudaRtError::InvalidDevice) => {}
178 Err(e) => panic!("unexpected: {e}"),
179 }
180 }
181}