oxicuda_memory/
peer_copy.rs1use std::ffi::c_int;
30
31use oxicuda_driver::device::Device;
32use oxicuda_driver::error::{CudaError, CudaResult};
33use oxicuda_driver::loader::try_driver;
34use oxicuda_driver::primary_context::PrimaryContext;
35use oxicuda_driver::stream::Stream;
36
37use crate::device_buffer::DeviceBuffer;
38
39pub fn can_access_peer(device: &Device, peer: &Device) -> CudaResult<bool> {
49 let api = try_driver()?;
50 let mut can_access: c_int = 0;
51 oxicuda_driver::error::check(unsafe {
52 (api.cu_device_can_access_peer)(&mut can_access, device.raw(), peer.raw())
53 })?;
54 Ok(can_access != 0)
55}
56
57pub fn enable_peer_access(device: &Device, peer: &Device) -> CudaResult<()> {
68 let api = try_driver()?;
69
70 let dev_ctx = PrimaryContext::retain(device)?;
74 let peer_ctx = PrimaryContext::retain(peer)?;
75
76 oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(dev_ctx.raw()) })?;
78
79 let rc =
81 oxicuda_driver::error::check(unsafe { (api.cu_ctx_enable_peer_access)(peer_ctx.raw(), 0) });
82
83 let _ = peer_ctx.release();
85 let _ = dev_ctx.release();
86
87 rc
88}
89
90pub fn disable_peer_access(device: &Device, peer: &Device) -> CudaResult<()> {
96 let api = try_driver()?;
97
98 let dev_ctx = PrimaryContext::retain(device)?;
99 let peer_ctx = PrimaryContext::retain(peer)?;
100
101 oxicuda_driver::error::check(unsafe { (api.cu_ctx_set_current)(dev_ctx.raw()) })?;
102
103 let rc =
104 oxicuda_driver::error::check(unsafe { (api.cu_ctx_disable_peer_access)(peer_ctx.raw()) });
105
106 let _ = peer_ctx.release();
107 let _ = dev_ctx.release();
108
109 rc
110}
111
112pub fn copy_peer<T: Copy>(
122 dst: &mut DeviceBuffer<T>,
123 dst_device: &Device,
124 src: &DeviceBuffer<T>,
125 src_device: &Device,
126) -> CudaResult<()> {
127 if dst.len() != src.len() {
128 return Err(CudaError::InvalidValue);
129 }
130 let api = try_driver()?;
131 let byte_size = src.byte_size();
132
133 let dst_ctx = PrimaryContext::retain(dst_device)?;
134 let src_ctx = PrimaryContext::retain(src_device)?;
135
136 let rc = oxicuda_driver::error::check(unsafe {
137 (api.cu_memcpy_peer)(
138 dst.as_device_ptr(),
139 dst_ctx.raw(),
140 src.as_device_ptr(),
141 src_ctx.raw(),
142 byte_size,
143 )
144 });
145
146 let _ = src_ctx.release();
147 let _ = dst_ctx.release();
148
149 rc
150}
151
152pub fn copy_peer_async<T: Copy>(
161 dst: &mut DeviceBuffer<T>,
162 dst_device: &Device,
163 src: &DeviceBuffer<T>,
164 src_device: &Device,
165 stream: &Stream,
166) -> CudaResult<()> {
167 if dst.len() != src.len() {
168 return Err(CudaError::InvalidValue);
169 }
170 let api = try_driver()?;
171 let byte_size = src.byte_size();
172
173 let dst_ctx = PrimaryContext::retain(dst_device)?;
174 let src_ctx = PrimaryContext::retain(src_device)?;
175
176 let rc = oxicuda_driver::error::check(unsafe {
177 (api.cu_memcpy_peer_async)(
178 dst.as_device_ptr(),
179 dst_ctx.raw(),
180 src.as_device_ptr(),
181 src_ctx.raw(),
182 byte_size,
183 stream.raw(),
184 )
185 });
186
187 let _ = src_ctx.release();
188 let _ = dst_ctx.release();
189
190 rc
191}
192
193#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn function_signatures_compile() {
203 let _f1: fn(&Device, &Device) -> CudaResult<bool> = can_access_peer;
204 let _f2: fn(&Device, &Device) -> CudaResult<()> = enable_peer_access;
205 let _f3: fn(&Device, &Device) -> CudaResult<()> = disable_peer_access;
206 let _f4: fn(
207 &mut DeviceBuffer<f32>,
208 &Device,
209 &DeviceBuffer<f32>,
210 &Device,
211 ) -> CudaResult<()> = copy_peer;
212 }
213
214 #[test]
215 fn copy_peer_length_mismatch_returns_invalid_value() {
216 type PeerAsyncFn = fn(
218 &mut DeviceBuffer<f32>,
219 &Device,
220 &DeviceBuffer<f32>,
221 &Device,
222 &Stream,
223 ) -> CudaResult<()>;
224 let _f: PeerAsyncFn = copy_peer_async;
225 }
226
227 #[cfg(feature = "gpu-tests")]
228 #[test]
229 fn can_access_peer_single_gpu() {
230 oxicuda_driver::init().ok();
231 let count = oxicuda_driver::device::Device::count().unwrap_or(0);
232 if count >= 1 {
233 let dev0 = Device::get(0).expect("device 0");
234 if count == 1 {
235 let _ = can_access_peer(&dev0, &dev0);
237 } else {
238 let dev1 = Device::get(1).expect("device 1");
239 let _ = can_access_peer(&dev0, &dev1);
240 }
241 }
242 }
243}