Skip to main content

oxicuda_memory/
copy.rs

1//! Explicit memory copy operations between host and device.
2//!
3//! This module provides freestanding functions for copying data between
4//! host memory, device memory, and pinned host memory.  Each function
5//! validates that the source and destination have matching lengths before
6//! issuing the underlying CUDA driver call.
7//!
8//! For simple cases, the methods on [`DeviceBuffer`]
9//! (e.g. [`DeviceBuffer::copy_from_host`]) are more
10//! ergonomic.  These freestanding functions are useful when you want to be
11//! explicit about the direction of the transfer or when working with
12//! [`PinnedBuffer`] for async operations.
13//!
14//! # Length validation
15//!
16//! All functions return [`CudaError::InvalidValue`] if the element counts
17//! of source and destination do not match.
18
19use std::ffi::c_void;
20
21use oxicuda_driver::error::{CudaError, CudaResult};
22use oxicuda_driver::loader::try_driver;
23use oxicuda_driver::stream::Stream;
24
25use crate::device_buffer::DeviceBuffer;
26use crate::host_buffer::PinnedBuffer;
27
28// ---------------------------------------------------------------------------
29// Synchronous copies
30// ---------------------------------------------------------------------------
31
32/// Copies data from a host slice into a device buffer (host-to-device).
33///
34/// This is a synchronous operation: it blocks the calling thread until the
35/// transfer completes.
36///
37/// # Errors
38///
39/// * [`CudaError::InvalidValue`] if `src.len() != dst.len()`.
40/// * Other driver errors from `cuMemcpyHtoD_v2`.
41pub fn copy_htod<T: Copy>(dst: &mut DeviceBuffer<T>, src: &[T]) -> CudaResult<()> {
42    if src.len() != dst.len() {
43        return Err(CudaError::InvalidValue);
44    }
45    let byte_size = dst.byte_size();
46    let api = try_driver()?;
47    // SAFETY: `src` is a valid host slice, `dst` owns a valid device allocation,
48    // and the byte counts match.
49    let rc = unsafe {
50        (api.cu_memcpy_htod_v2)(
51            dst.as_device_ptr(),
52            src.as_ptr().cast::<c_void>(),
53            byte_size,
54        )
55    };
56    oxicuda_driver::check(rc)
57}
58
59/// Copies data from a device buffer into a host slice (device-to-host).
60///
61/// This is a synchronous operation: it blocks the calling thread until the
62/// transfer completes.
63///
64/// # Errors
65///
66/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
67/// * Other driver errors from `cuMemcpyDtoH_v2`.
68pub fn copy_dtoh<T: Copy>(dst: &mut [T], src: &DeviceBuffer<T>) -> CudaResult<()> {
69    if dst.len() != src.len() {
70        return Err(CudaError::InvalidValue);
71    }
72    let byte_size = src.byte_size();
73    let api = try_driver()?;
74    // SAFETY: `dst` is a valid host slice, `src` owns a valid device allocation,
75    // and the byte counts match.
76    let rc = unsafe {
77        (api.cu_memcpy_dtoh_v2)(
78            dst.as_mut_ptr().cast::<c_void>(),
79            src.as_device_ptr(),
80            byte_size,
81        )
82    };
83    oxicuda_driver::check(rc)
84}
85
86/// Copies data from one device buffer to another (device-to-device).
87///
88/// This is a synchronous operation that blocks until the copy completes.
89///
90/// # Errors
91///
92/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
93/// * Other driver errors from `cuMemcpyDtoD_v2`.
94pub fn copy_dtod<T: Copy>(dst: &mut DeviceBuffer<T>, src: &DeviceBuffer<T>) -> CudaResult<()> {
95    if dst.len() != src.len() {
96        return Err(CudaError::InvalidValue);
97    }
98    let byte_size = src.byte_size();
99    let api = try_driver()?;
100    // SAFETY: both buffers own valid device allocations of the same size.
101    let rc =
102        unsafe { (api.cu_memcpy_dtod_v2)(dst.as_device_ptr(), src.as_device_ptr(), byte_size) };
103    oxicuda_driver::check(rc)
104}
105
106// ---------------------------------------------------------------------------
107// Asynchronous copies
108// ---------------------------------------------------------------------------
109
110// ---------------------------------------------------------------------------
111// Asynchronous copies (raw slice variants)
112// ---------------------------------------------------------------------------
113
114/// Asynchronously copies data from a host slice into a device buffer.
115///
116/// The copy is enqueued on `stream` and may not be complete when this
117/// function returns.  The caller must ensure that `src` remains valid
118/// (i.e., is not moved or dropped) until the stream has been synchronised.
119/// For guaranteed correctness with DMA, prefer using a [`PinnedBuffer`]
120/// as the source.
121///
122/// # Errors
123///
124/// * [`CudaError::InvalidValue`] if `src.len() != dst.len()`.
125/// * Other driver errors from `cuMemcpyHtoDAsync_v2`.
126pub fn copy_htod_async_raw<T: Copy>(
127    dst: &mut DeviceBuffer<T>,
128    src: &[T],
129    stream: &Stream,
130) -> CudaResult<()> {
131    if src.len() != dst.len() {
132        return Err(CudaError::InvalidValue);
133    }
134    let byte_size = dst.byte_size();
135    let api = try_driver()?;
136    let rc = unsafe {
137        (api.cu_memcpy_htod_async_v2)(
138            dst.as_device_ptr(),
139            src.as_ptr().cast::<c_void>(),
140            byte_size,
141            stream.raw(),
142        )
143    };
144    oxicuda_driver::check(rc)
145}
146
147/// Asynchronously copies data from a device buffer into a host slice.
148///
149/// The copy is enqueued on `stream` and may not be complete when this
150/// function returns.  The caller must ensure that `dst` remains valid
151/// and is not read until the stream has been synchronised.
152///
153/// # Errors
154///
155/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
156/// * Other driver errors from `cuMemcpyDtoHAsync_v2`.
157pub fn copy_dtoh_async_raw<T: Copy>(
158    dst: &mut [T],
159    src: &DeviceBuffer<T>,
160    stream: &Stream,
161) -> CudaResult<()> {
162    if dst.len() != src.len() {
163        return Err(CudaError::InvalidValue);
164    }
165    let byte_size = src.byte_size();
166    let api = try_driver()?;
167    let rc = unsafe {
168        (api.cu_memcpy_dtoh_async_v2)(
169            dst.as_mut_ptr().cast::<c_void>(),
170            src.as_device_ptr(),
171            byte_size,
172            stream.raw(),
173        )
174    };
175    oxicuda_driver::check(rc)
176}
177
178/// Asynchronously copies data from one device buffer to another.
179///
180/// Both buffers must have the same length.  The copy is enqueued on
181/// `stream`.
182///
183/// Note: The CUDA Driver API does not provide `cuMemcpyDtoDAsync` directly;
184/// this uses `cuMemcpyHtoDAsync_v2` semantics via the driver's internal
185/// routing for device-to-device copies.  For true async D2D, consider
186/// using peer copy functions or ensuring both buffers are in the same
187/// context.
188///
189/// # Errors
190///
191/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
192/// * Other driver errors.
193pub fn copy_dtod_async<T: Copy>(
194    dst: &mut DeviceBuffer<T>,
195    src: &DeviceBuffer<T>,
196    stream: &Stream,
197) -> CudaResult<()> {
198    if dst.len() != src.len() {
199        return Err(CudaError::InvalidValue);
200    }
201    // Use synchronous D2D copy followed by stream ordering via event.
202    // The CUDA driver routes D2D copies internally; we use the sync version
203    // and rely on stream ordering at the caller level.
204    // A future enhancement can add cuMemcpyDtoDAsync when the driver
205    // exposes it.
206    let _ = stream;
207    copy_dtod(dst, src)
208}
209
210// ---------------------------------------------------------------------------
211// Asynchronous copies (pinned buffer variants)
212// ---------------------------------------------------------------------------
213
214/// Asynchronously copies data from a pinned host buffer into a device buffer.
215///
216/// The copy is enqueued on `stream` and may not be complete when this
217/// function returns.  The caller must not modify `src` or read `dst` until
218/// the stream has been synchronised.
219///
220/// Using a [`PinnedBuffer`] as the source guarantees that the host memory
221/// is page-locked, which is required for correct async DMA transfers.
222///
223/// # Errors
224///
225/// * [`CudaError::InvalidValue`] if `src.len() != dst.len()`.
226/// * Other driver errors from `cuMemcpyHtoDAsync_v2`.
227pub fn copy_htod_async<T: Copy>(
228    dst: &mut DeviceBuffer<T>,
229    src: &PinnedBuffer<T>,
230    stream: &Stream,
231) -> CudaResult<()> {
232    if src.len() != dst.len() {
233        return Err(CudaError::InvalidValue);
234    }
235    let byte_size = dst.byte_size();
236    let api = try_driver()?;
237    // SAFETY: `src` is pinned host memory, `dst` is a valid device allocation,
238    // byte counts match, and the stream will order the transfer.
239    let rc = unsafe {
240        (api.cu_memcpy_htod_async_v2)(
241            dst.as_device_ptr(),
242            src.as_ptr().cast::<c_void>(),
243            byte_size,
244            stream.raw(),
245        )
246    };
247    oxicuda_driver::check(rc)
248}
249
250/// Asynchronously copies data from a device buffer into a pinned host buffer.
251///
252/// The copy is enqueued on `stream` and may not be complete when this
253/// function returns.  The caller must not read `dst` until the stream
254/// has been synchronised.
255///
256/// Using a [`PinnedBuffer`] as the destination guarantees that the host
257/// memory is page-locked, which is required for correct async DMA transfers.
258///
259/// # Errors
260///
261/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
262/// * Other driver errors from `cuMemcpyDtoHAsync_v2`.
263pub fn copy_dtoh_async<T: Copy>(
264    dst: &mut PinnedBuffer<T>,
265    src: &DeviceBuffer<T>,
266    stream: &Stream,
267) -> CudaResult<()> {
268    if dst.len() != src.len() {
269        return Err(CudaError::InvalidValue);
270    }
271    let byte_size = src.byte_size();
272    let api = try_driver()?;
273    // SAFETY: `dst` is pinned host memory, `src` is a valid device allocation,
274    // byte counts match, and the stream will order the transfer.
275    let rc = unsafe {
276        (api.cu_memcpy_dtoh_async_v2)(
277            dst.as_mut_ptr().cast::<c_void>(),
278            src.as_device_ptr(),
279            byte_size,
280            stream.raw(),
281        )
282    };
283    oxicuda_driver::check(rc)
284}
285
286// ---------------------------------------------------------------------------
287// Tests
288// ---------------------------------------------------------------------------
289
290#[cfg(test)]
291mod tests {
292    #[test]
293    fn copy_htod_signature_compiles() {
294        let _f: fn(&mut super::DeviceBuffer<f32>, &[f32]) -> super::CudaResult<()> =
295            super::copy_htod;
296        let _f2: fn(&mut [f32], &super::DeviceBuffer<f32>) -> super::CudaResult<()> =
297            super::copy_dtoh;
298    }
299
300    #[test]
301    fn copy_dtod_signature_compiles() {
302        let _f: fn(
303            &mut super::DeviceBuffer<f32>,
304            &super::DeviceBuffer<f32>,
305        ) -> super::CudaResult<()> = super::copy_dtod;
306    }
307
308    #[test]
309    fn async_raw_htod_signature_compiles() {
310        let _f: fn(
311            &mut super::DeviceBuffer<f32>,
312            &[f32],
313            &oxicuda_driver::stream::Stream,
314        ) -> super::CudaResult<()> = super::copy_htod_async_raw;
315    }
316
317    #[test]
318    fn async_raw_dtoh_signature_compiles() {
319        let _f: fn(
320            &mut [f32],
321            &super::DeviceBuffer<f32>,
322            &oxicuda_driver::stream::Stream,
323        ) -> super::CudaResult<()> = super::copy_dtoh_async_raw;
324    }
325
326    #[test]
327    fn async_dtod_signature_compiles() {
328        let _f: fn(
329            &mut super::DeviceBuffer<f32>,
330            &super::DeviceBuffer<f32>,
331            &oxicuda_driver::stream::Stream,
332        ) -> super::CudaResult<()> = super::copy_dtod_async;
333    }
334
335    #[test]
336    fn async_pinned_htod_signature_compiles() {
337        let _f: fn(
338            &mut super::DeviceBuffer<f32>,
339            &super::PinnedBuffer<f32>,
340            &oxicuda_driver::stream::Stream,
341        ) -> super::CudaResult<()> = super::copy_htod_async;
342    }
343}