Skip to main content

oxicuda_memory/
peer_copy.rs

1//! Peer-to-peer (P2P) memory copy operations for multi-GPU workloads.
2//!
3//! This module provides functions to check, enable, and disable peer access
4//! between CUDA devices, as well as copy data between device buffers on
5//! different GPUs.
6//!
7//! Peer access enables direct GPU-to-GPU memory transfers over PCIe or
8//! NVLink without staging through host memory, significantly improving
9//! transfer bandwidth in multi-GPU configurations.
10//!
11//! # Example
12//!
13//! ```rust,no_run
14//! use oxicuda_driver::device::Device;
15//! use oxicuda_memory::peer_copy;
16//!
17//! oxicuda_driver::init()?;
18//! let dev0 = Device::get(0)?;
19//! let dev1 = Device::get(1)?;
20//!
21//! if peer_copy::can_access_peer(&dev0, &dev1)? {
22//!     peer_copy::enable_peer_access(&dev0, &dev1)?;
23//!     // Now D2D copies between dev0 and dev1 can go over NVLink/PCIe
24//!     // peer_copy::copy_peer(&mut dst_buf, &dev1, &src_buf, &dev0)?;
25//! }
26//! # Ok::<(), oxicuda_driver::error::CudaError>(())
27//! ```
28
29use std::ffi::c_int;
30
31use oxicuda_driver::device::Device;
32use oxicuda_driver::error::{CudaError, CudaResult};
33use oxicuda_driver::loader::try_driver;
34use oxicuda_driver::primary_context::PrimaryContext;
35use oxicuda_driver::stream::Stream;
36
37use crate::device_buffer::DeviceBuffer;
38
39/// Checks whether `device` can directly access memory on `peer`.
40///
41/// Returns `true` if peer access is supported between the two devices
42/// (e.g., over NVLink or PCIe).  Returns `false` if the devices are the
43/// same or if the hardware topology does not support peer access.
44///
45/// # Errors
46///
47/// Returns a CUDA driver error if the query fails.
48pub fn can_access_peer(device: &Device, peer: &Device) -> CudaResult<bool> {
49    let api = try_driver()?;
50    let mut can_access: c_int = 0;
51    oxicuda_driver::error::check(unsafe {
52        (api.cu_device_can_access_peer)(&mut can_access, device.raw(), peer.raw())
53    })?;
54    Ok(can_access != 0)
55}
56
57/// Enables peer access from `device`'s primary context to `peer`'s primary context.
58///
59/// After calling this function, kernels and copy operations running on `device`
60/// can directly read from and write to memory allocated on `peer`.
61///
62/// # Errors
63///
64/// * [`CudaError::PeerAccessAlreadyEnabled`] if peer access is already enabled.
65/// * [`CudaError::PeerAccessUnsupported`] if the hardware topology does not
66///   support direct peer access between these devices.
67pub fn enable_peer_access(device: &Device, peer: &Device) -> CudaResult<()> {
68    let api = try_driver()?;
69
70    // Retain both primary contexts.  The peer context handle is needed by
71    // cuCtxEnablePeerAccess; the device context is set as current so that the
72    // enable operation applies to it.
73    let dev_ctx = PrimaryContext::retain(device)?;
74    let peer_ctx = PrimaryContext::retain(peer)?;
75
76    // Make the device context current on this thread.
77    oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(dev_ctx.raw()) })?;
78
79    // Enable access from the current (device) context to the peer context.
80    let rc =
81        oxicuda_driver::error::check(unsafe { (api.cu_ctx_enable_peer_access)(peer_ctx.raw(), 0) });
82
83    // Release retained contexts regardless of outcome.
84    let _ = peer_ctx.release();
85    let _ = dev_ctx.release();
86
87    rc
88}
89
90/// Disables peer access from `device`'s primary context to `peer`'s primary context.
91///
92/// # Errors
93///
94/// * [`CudaError::PeerAccessNotEnabled`] if peer access was not previously enabled.
95pub fn disable_peer_access(device: &Device, peer: &Device) -> CudaResult<()> {
96    let api = try_driver()?;
97
98    let dev_ctx = PrimaryContext::retain(device)?;
99    let peer_ctx = PrimaryContext::retain(peer)?;
100
101    oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(dev_ctx.raw()) })?;
102
103    let rc =
104        oxicuda_driver::error::check(unsafe { (api.cu_ctx_disable_peer_access)(peer_ctx.raw()) });
105
106    let _ = peer_ctx.release();
107    let _ = dev_ctx.release();
108
109    rc
110}
111
112/// Copies data between device buffers on different GPUs (synchronous).
113///
114/// Both buffers must have the same length.  Peer access should be enabled
115/// between the source and destination devices before calling this function.
116///
117/// # Errors
118///
119/// * [`CudaError::InvalidValue`] if buffer lengths do not match.
120/// * [`CudaError::PeerAccessNotEnabled`] if peer access has not been enabled.
121pub fn copy_peer<T: Copy>(
122    dst: &mut DeviceBuffer<T>,
123    dst_device: &Device,
124    src: &DeviceBuffer<T>,
125    src_device: &Device,
126) -> CudaResult<()> {
127    if dst.len() != src.len() {
128        return Err(CudaError::InvalidValue);
129    }
130    let api = try_driver()?;
131    let byte_size = src.byte_size();
132
133    let dst_ctx = PrimaryContext::retain(dst_device)?;
134    let src_ctx = PrimaryContext::retain(src_device)?;
135
136    let rc = oxicuda_driver::error::check(unsafe {
137        (api.cu_memcpy_peer)(
138            dst.as_device_ptr(),
139            dst_ctx.raw(),
140            src.as_device_ptr(),
141            src_ctx.raw(),
142            byte_size,
143        )
144    });
145
146    let _ = src_ctx.release();
147    let _ = dst_ctx.release();
148
149    rc
150}
151
152/// Copies data between device buffers on different GPUs (asynchronous).
153///
154/// The copy is enqueued on `stream` and may not be complete when this
155/// function returns.  Both buffers must have the same length.
156///
157/// # Errors
158///
159/// * [`CudaError::InvalidValue`] if buffer lengths do not match.
160pub fn copy_peer_async<T: Copy>(
161    dst: &mut DeviceBuffer<T>,
162    dst_device: &Device,
163    src: &DeviceBuffer<T>,
164    src_device: &Device,
165    stream: &Stream,
166) -> CudaResult<()> {
167    if dst.len() != src.len() {
168        return Err(CudaError::InvalidValue);
169    }
170    let api = try_driver()?;
171    let byte_size = src.byte_size();
172
173    let dst_ctx = PrimaryContext::retain(dst_device)?;
174    let src_ctx = PrimaryContext::retain(src_device)?;
175
176    let rc = oxicuda_driver::error::check(unsafe {
177        (api.cu_memcpy_peer_async)(
178            dst.as_device_ptr(),
179            dst_ctx.raw(),
180            src.as_device_ptr(),
181            src_ctx.raw(),
182            byte_size,
183            stream.raw(),
184        )
185    });
186
187    let _ = src_ctx.release();
188    let _ = dst_ctx.release();
189
190    rc
191}
192
193// ---------------------------------------------------------------------------
194// Tests
195// ---------------------------------------------------------------------------
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn function_signatures_compile() {
203        let _f1: fn(&Device, &Device) -> CudaResult<bool> = can_access_peer;
204        let _f2: fn(&Device, &Device) -> CudaResult<()> = enable_peer_access;
205        let _f3: fn(&Device, &Device) -> CudaResult<()> = disable_peer_access;
206        let _f4: fn(
207            &mut DeviceBuffer<f32>,
208            &Device,
209            &DeviceBuffer<f32>,
210            &Device,
211        ) -> CudaResult<()> = copy_peer;
212    }
213
214    #[test]
215    fn copy_peer_length_mismatch_returns_invalid_value() {
216        // Just confirm copy_peer_async is callable — signature test only.
217        type PeerAsyncFn = fn(
218            &mut DeviceBuffer<f32>,
219            &Device,
220            &DeviceBuffer<f32>,
221            &Device,
222            &Stream,
223        ) -> CudaResult<()>;
224        let _f: PeerAsyncFn = copy_peer_async;
225    }
226
227    #[cfg(feature = "gpu-tests")]
228    #[test]
229    fn can_access_peer_single_gpu() {
230        oxicuda_driver::init().ok();
231        let count = oxicuda_driver::device::Device::count().unwrap_or(0);
232        if count >= 1 {
233            let dev0 = Device::get(0).expect("device 0");
234            if count == 1 {
235                // Single GPU: can_access_peer with itself returns false or an error.
236                let _ = can_access_peer(&dev0, &dev0);
237            } else {
238                let dev1 = Device::get(1).expect("device 1");
239                let _ = can_access_peer(&dev0, &dev1);
240            }
241        }
242    }
243}