Skip to main content

oxicuda_runtime/
peer.rs

1//! Peer-to-peer device access.
2//!
3//! Implements:
4//! - `cudaDeviceCanAccessPeer`
5//! - `cudaDeviceEnablePeerAccess`
6//! - `cudaDeviceDisablePeerAccess`
7//! - `cudaMemcpyPeer` / `cudaMemcpyPeerAsync`
8
9use std::ffi::c_int;
10
11use oxicuda_driver::loader::try_driver;
12
13use crate::error::{CudaRtError, CudaRtResult};
14use crate::memory::DevicePtr;
15use crate::stream::CudaStream;
16
17/// Check whether `device` can directly access the memory of `peer_device`.
18///
19/// Mirrors `cudaDeviceCanAccessPeer`.
20///
21/// Returns `Ok(true)` if peer access is supported.
22///
23/// # Errors
24///
25/// Propagates driver errors.
26pub fn device_can_access_peer(device: u32, peer_device: u32) -> CudaRtResult<bool> {
27    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
28    let mut can_access: c_int = 0;
29    // SAFETY: FFI; both ordinals are checked against count by caller if needed.
30    let rc = unsafe {
31        (api.cu_device_can_access_peer)(&raw mut can_access, device as c_int, peer_device as c_int)
32    };
33    if rc != 0 {
34        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
35    }
36    Ok(can_access != 0)
37}
38
39/// Enable peer access from the current context to the context owning `peer_device`.
40///
41/// Mirrors `cudaDeviceEnablePeerAccess`.
42///
43/// # Errors
44///
45/// - [`CudaRtError::PeerAccessUnsupported`] — link does not support peer access.
46/// - [`CudaRtError::PeerAccessAlreadyEnabled`] — already enabled.
47/// - Other driver errors.
48pub fn device_enable_peer_access(peer_device: u32, flags: u32) -> CudaRtResult<()> {
49    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
50    let mut peer_ctx = oxicuda_driver::ffi::CUcontext::default();
51    // Retain the primary context of the peer device.
52    // SAFETY: FFI.
53    let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut peer_ctx, peer_device as c_int) };
54    if rc != 0 {
55        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
56    }
57    // Enable peer access to that context.
58    // SAFETY: FFI.
59    let rc2 = unsafe { (api.cu_ctx_enable_peer_access)(peer_ctx, flags) };
60    if rc2 != 0 {
61        // Release the retained context regardless.
62        // SAFETY: FFI.
63        unsafe { (api.cu_device_primary_ctx_release_v2)(peer_device as c_int) };
64        return Err(CudaRtError::from_code(rc2).unwrap_or(CudaRtError::PeerAccessUnsupported));
65    }
66    Ok(())
67}
68
69/// Disable peer access from the current context to `peer_device`.
70///
71/// Mirrors `cudaDeviceDisablePeerAccess`.
72///
73/// # Errors
74///
75/// Propagates driver errors.
76pub fn device_disable_peer_access(peer_device: u32) -> CudaRtResult<()> {
77    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
78    let mut peer_ctx = oxicuda_driver::ffi::CUcontext::default();
79    // SAFETY: FFI.
80    let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut peer_ctx, peer_device as c_int) };
81    if rc != 0 {
82        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
83    }
84    // SAFETY: FFI.
85    let rc2 = unsafe { (api.cu_ctx_disable_peer_access)(peer_ctx) };
86    if rc2 != 0 {
87        // SAFETY: FFI.
88        unsafe { (api.cu_device_primary_ctx_release_v2)(peer_device as c_int) };
89        return Err(CudaRtError::from_code(rc2).unwrap_or(CudaRtError::PeerAccessNotEnabled));
90    }
91    Ok(())
92}
93
94/// Copy `count` bytes from `src` on `src_device` to `dst` on `dst_device`.
95///
96/// Mirrors `cudaMemcpyPeer`.
97///
98/// # Errors
99///
100/// Propagates driver errors.
101pub fn memcpy_peer(
102    dst: DevicePtr,
103    dst_device: u32,
104    src: DevicePtr,
105    src_device: u32,
106    count: usize,
107) -> CudaRtResult<()> {
108    if count == 0 {
109        return Ok(());
110    }
111    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
112    let mut dst_ctx = oxicuda_driver::ffi::CUcontext::default();
113    let mut src_ctx = oxicuda_driver::ffi::CUcontext::default();
114    // SAFETY: FFI.
115    unsafe { (api.cu_device_primary_ctx_retain)(&raw mut dst_ctx, dst_device as c_int) };
116    unsafe { (api.cu_device_primary_ctx_retain)(&raw mut src_ctx, src_device as c_int) };
117    // SAFETY: FFI; pointers are valid device allocations on the specified devices.
118    let rc = unsafe { (api.cu_memcpy_peer)(dst.0, dst_ctx, src.0, src_ctx, count) };
119    if rc != 0 {
120        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
121    }
122    Ok(())
123}
124
125/// Asynchronously copy across devices on `stream`.
126///
127/// Mirrors `cudaMemcpyPeerAsync`.
128///
129/// # Errors
130///
131/// Propagates driver errors.
132pub fn memcpy_peer_async(
133    dst: DevicePtr,
134    dst_device: u32,
135    src: DevicePtr,
136    src_device: u32,
137    count: usize,
138    stream: CudaStream,
139) -> CudaRtResult<()> {
140    if count == 0 {
141        return Ok(());
142    }
143    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
144    let mut dst_ctx = oxicuda_driver::ffi::CUcontext::default();
145    let mut src_ctx = oxicuda_driver::ffi::CUcontext::default();
146    // SAFETY: FFI.
147    unsafe { (api.cu_device_primary_ctx_retain)(&raw mut dst_ctx, dst_device as c_int) };
148    unsafe { (api.cu_device_primary_ctx_retain)(&raw mut src_ctx, src_device as c_int) };
149    // SAFETY: FFI.
150    let rc =
151        unsafe { (api.cu_memcpy_peer_async)(dst.0, dst_ctx, src.0, src_ctx, count, stream.raw()) };
152    if rc != 0 {
153        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
154    }
155    Ok(())
156}
157
158// ─── Tests ───────────────────────────────────────────────────────────────────
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn peer_access_self_check() {
166        // Without GPU, driver returns DriverNotAvailable.
167        // With GPU, peer access with itself should return false or succeed.
168        match device_can_access_peer(0, 0) {
169            Ok(v) => {
170                // Self-access should typically be false for P2P (same device).
171                let _ = v;
172            }
173            // Driver absent or not initialised — both are expected without a GPU.
174            Err(CudaRtError::DriverNotAvailable)
175            | Err(CudaRtError::NoGpu)
176            | Err(CudaRtError::InitializationError)
177            | Err(CudaRtError::InvalidDevice) => {}
178            Err(e) => panic!("unexpected: {e}"),
179        }
180    }
181}