Skip to main content

oxicuda_memory/
copy.rs

1//! Explicit memory copy operations between host and device.
2//!
3//! This module provides freestanding functions for copying data between
4//! host memory, device memory, and pinned host memory.  Each function
5//! validates that the source and destination have matching lengths before
6//! issuing the underlying CUDA driver call.
7//!
8//! For simple cases, the methods on [`DeviceBuffer`]
9//! (e.g. [`DeviceBuffer::copy_from_host`]) are more
10//! ergonomic.  These freestanding functions are useful when you want to be
11//! explicit about the direction of the transfer or when working with
12//! [`PinnedBuffer`] for async operations.
13//!
14//! # Length validation
15//!
16//! All functions return [`CudaError::InvalidValue`] if the element counts
17//! of source and destination do not match.
18
19use std::ffi::c_void;
20
21use oxicuda_driver::error::{CudaError, CudaResult};
22use oxicuda_driver::loader::try_driver;
23use oxicuda_driver::stream::Stream;
24
25use crate::device_buffer::DeviceBuffer;
26use crate::host_buffer::PinnedBuffer;
27
28// ---------------------------------------------------------------------------
29// Synchronous copies
30// ---------------------------------------------------------------------------
31
32/// Copies data from a host slice into a device buffer (host-to-device).
33///
34/// This is a synchronous operation: it blocks the calling thread until the
35/// transfer completes.
36///
37/// # Errors
38///
39/// * [`CudaError::InvalidValue`] if `src.len() != dst.len()`.
40/// * Other driver errors from `cuMemcpyHtoD_v2`.
41pub fn copy_htod<T: Copy>(dst: &mut DeviceBuffer<T>, src: &[T]) -> CudaResult<()> {
42    if src.len() != dst.len() {
43        return Err(CudaError::InvalidValue);
44    }
45    let byte_size = dst.byte_size();
46    let api = try_driver()?;
47    // SAFETY: `src` is a valid host slice, `dst` owns a valid device allocation,
48    // and the byte counts match.
49    let rc = unsafe {
50        (api.cu_memcpy_htod_v2)(
51            dst.as_device_ptr(),
52            src.as_ptr().cast::<c_void>(),
53            byte_size,
54        )
55    };
56    oxicuda_driver::check(rc)
57}
58
59/// Copies data from a device buffer into a host slice (device-to-host).
60///
61/// This is a synchronous operation: it blocks the calling thread until the
62/// transfer completes.
63///
64/// # Errors
65///
66/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
67/// * Other driver errors from `cuMemcpyDtoH_v2`.
68pub fn copy_dtoh<T: Copy>(dst: &mut [T], src: &DeviceBuffer<T>) -> CudaResult<()> {
69    if dst.len() != src.len() {
70        return Err(CudaError::InvalidValue);
71    }
72    let byte_size = src.byte_size();
73    let api = try_driver()?;
74    // SAFETY: `dst` is a valid host slice, `src` owns a valid device allocation,
75    // and the byte counts match.
76    let rc = unsafe {
77        (api.cu_memcpy_dtoh_v2)(
78            dst.as_mut_ptr().cast::<c_void>(),
79            src.as_device_ptr(),
80            byte_size,
81        )
82    };
83    oxicuda_driver::check(rc)
84}
85
86/// Copies data from one device buffer to another (device-to-device).
87///
88/// This is a synchronous operation that blocks until the copy completes.
89///
90/// # Errors
91///
92/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
93/// * Other driver errors from `cuMemcpyDtoD_v2`.
94pub fn copy_dtod<T: Copy>(dst: &mut DeviceBuffer<T>, src: &DeviceBuffer<T>) -> CudaResult<()> {
95    if dst.len() != src.len() {
96        return Err(CudaError::InvalidValue);
97    }
98    let byte_size = src.byte_size();
99    let api = try_driver()?;
100    // SAFETY: both buffers own valid device allocations of the same size.
101    let rc =
102        unsafe { (api.cu_memcpy_dtod_v2)(dst.as_device_ptr(), src.as_device_ptr(), byte_size) };
103    oxicuda_driver::check(rc)
104}
105
106// ---------------------------------------------------------------------------
107// Asynchronous copies
108// ---------------------------------------------------------------------------
109
110// ---------------------------------------------------------------------------
111// Asynchronous copies (raw slice variants)
112// ---------------------------------------------------------------------------
113
114/// Asynchronously copies data from a host slice into a device buffer.
115///
116/// The copy is enqueued on `stream` and may not be complete when this
117/// function returns.  The caller must ensure that `src` remains valid
118/// (i.e., is not moved or dropped) until the stream has been synchronised.
119/// For guaranteed correctness with DMA, prefer using a [`PinnedBuffer`]
120/// as the source.
121///
122/// # Errors
123///
124/// * [`CudaError::InvalidValue`] if `src.len() != dst.len()`.
125/// * Other driver errors from `cuMemcpyHtoDAsync_v2`.
126pub fn copy_htod_async_raw<T: Copy>(
127    dst: &mut DeviceBuffer<T>,
128    src: &[T],
129    stream: &Stream,
130) -> CudaResult<()> {
131    if src.len() != dst.len() {
132        return Err(CudaError::InvalidValue);
133    }
134    let byte_size = dst.byte_size();
135    let api = try_driver()?;
136    let rc = unsafe {
137        (api.cu_memcpy_htod_async_v2)(
138            dst.as_device_ptr(),
139            src.as_ptr().cast::<c_void>(),
140            byte_size,
141            stream.raw(),
142        )
143    };
144    oxicuda_driver::check(rc)
145}
146
147/// Asynchronously copies data from a device buffer into a host slice.
148///
149/// The copy is enqueued on `stream` and may not be complete when this
150/// function returns.  The caller must ensure that `dst` remains valid
151/// and is not read until the stream has been synchronised.
152///
153/// # Errors
154///
155/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
156/// * Other driver errors from `cuMemcpyDtoHAsync_v2`.
157pub fn copy_dtoh_async_raw<T: Copy>(
158    dst: &mut [T],
159    src: &DeviceBuffer<T>,
160    stream: &Stream,
161) -> CudaResult<()> {
162    if dst.len() != src.len() {
163        return Err(CudaError::InvalidValue);
164    }
165    let byte_size = src.byte_size();
166    let api = try_driver()?;
167    let rc = unsafe {
168        (api.cu_memcpy_dtoh_async_v2)(
169            dst.as_mut_ptr().cast::<c_void>(),
170            src.as_device_ptr(),
171            byte_size,
172            stream.raw(),
173        )
174    };
175    oxicuda_driver::check(rc)
176}
177
178/// Asynchronously copies data from one device buffer to another.
179///
180/// Both buffers must have the same length.  The copy is enqueued on
181/// `stream`.
182///
183/// Note: The CUDA Driver API does not provide `cuMemcpyDtoDAsync` directly;
184/// this uses `cuMemcpyHtoDAsync_v2` semantics via the driver's internal
185/// routing for device-to-device copies.  For true async D2D, consider
186/// using peer copy functions or ensuring both buffers are in the same
187/// context.
188///
189/// # Errors
190///
191/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
192/// * Other driver errors.
193pub fn copy_dtod_async<T: Copy>(
194    dst: &mut DeviceBuffer<T>,
195    src: &DeviceBuffer<T>,
196    stream: &Stream,
197) -> CudaResult<()> {
198    if dst.len() != src.len() {
199        return Err(CudaError::InvalidValue);
200    }
201    // Use synchronous D2D copy followed by stream ordering via event.
202    // The CUDA driver routes D2D copies internally; we use the sync version
203    // and rely on stream ordering at the caller level.
204    // A future enhancement can add cuMemcpyDtoDAsync when the driver
205    // exposes it.
206    let _ = stream;
207    copy_dtod(dst, src)
208}
209
210// ---------------------------------------------------------------------------
211// Asynchronous copies (pinned buffer variants)
212// ---------------------------------------------------------------------------
213
214/// Asynchronously copies data from a pinned host buffer into a device buffer.
215///
216/// The copy is enqueued on `stream` and may not be complete when this
217/// function returns.  The caller must not modify `src` or read `dst` until
218/// the stream has been synchronised.
219///
220/// Using a [`PinnedBuffer`] as the source guarantees that the host memory
221/// is page-locked, which is required for correct async DMA transfers.
222///
223/// # Errors
224///
225/// * [`CudaError::InvalidValue`] if `src.len() != dst.len()`.
226/// * Other driver errors from `cuMemcpyHtoDAsync_v2`.
227pub fn copy_htod_async<T: Copy>(
228    dst: &mut DeviceBuffer<T>,
229    src: &PinnedBuffer<T>,
230    stream: &Stream,
231) -> CudaResult<()> {
232    if src.len() != dst.len() {
233        return Err(CudaError::InvalidValue);
234    }
235    let byte_size = dst.byte_size();
236    let api = try_driver()?;
237    // SAFETY: `src` is pinned host memory, `dst` is a valid device allocation,
238    // byte counts match, and the stream will order the transfer.
239    let rc = unsafe {
240        (api.cu_memcpy_htod_async_v2)(
241            dst.as_device_ptr(),
242            src.as_ptr().cast::<c_void>(),
243            byte_size,
244            stream.raw(),
245        )
246    };
247    oxicuda_driver::check(rc)
248}
249
250/// Asynchronously copies data from a device buffer into a pinned host buffer.
251///
252/// The copy is enqueued on `stream` and may not be complete when this
253/// function returns.  The caller must not read `dst` until the stream
254/// has been synchronised.
255///
256/// Using a [`PinnedBuffer`] as the destination guarantees that the host
257/// memory is page-locked, which is required for correct async DMA transfers.
258///
259/// # Errors
260///
261/// * [`CudaError::InvalidValue`] if `dst.len() != src.len()`.
262/// * Other driver errors from `cuMemcpyDtoHAsync_v2`.
263pub fn copy_dtoh_async<T: Copy>(
264    dst: &mut PinnedBuffer<T>,
265    src: &DeviceBuffer<T>,
266    stream: &Stream,
267) -> CudaResult<()> {
268    if dst.len() != src.len() {
269        return Err(CudaError::InvalidValue);
270    }
271    let byte_size = src.byte_size();
272    let api = try_driver()?;
273    // SAFETY: `dst` is pinned host memory, `src` is a valid device allocation,
274    // byte counts match, and the stream will order the transfer.
275    let rc = unsafe {
276        (api.cu_memcpy_dtoh_async_v2)(
277            dst.as_mut_ptr().cast::<c_void>(),
278            src.as_device_ptr(),
279            byte_size,
280            stream.raw(),
281        )
282    };
283    oxicuda_driver::check(rc)
284}
285
286// ---------------------------------------------------------------------------
287// Asynchronous sub-region copies (pinned buffer staging)
288// ---------------------------------------------------------------------------
289
290/// Asynchronously copies a contiguous sub-region of a device buffer into a
291/// pinned host buffer.
292///
293/// Exactly `count` elements starting at element index `src_offset` within
294/// `src` are copied into `dst[0..count]`.  The pinned buffer must be large
295/// enough to receive `count` elements.
296///
297/// This is the device→host leg of a host-staged inter-device transfer: the
298/// caller stages a slab slice into pinned memory here, then pushes it onto a
299/// different device with [`copy_htod_region_async`].
300///
301/// The copy is enqueued on `stream`; the caller must synchronise the stream
302/// before reading `dst`.
303///
304/// # Errors
305///
306/// * [`CudaError::InvalidValue`] if `src_offset + count` exceeds `src.len()`,
307///   if `count` exceeds `dst.len()`, or on offset overflow.
308/// * Other driver errors from `cuMemcpyDtoHAsync_v2`.
309pub fn copy_dtoh_region_async<T: Copy>(
310    dst: &mut PinnedBuffer<T>,
311    src: &DeviceBuffer<T>,
312    src_offset: usize,
313    count: usize,
314    stream: &Stream,
315) -> CudaResult<()> {
316    let elem_size = std::mem::size_of::<T>();
317    let src_end = src_offset
318        .checked_add(count)
319        .ok_or(CudaError::InvalidValue)?;
320    if src_end > src.len() || count > dst.len() {
321        return Err(CudaError::InvalidValue);
322    }
323    if count == 0 {
324        return Ok(());
325    }
326    let byte_count = count
327        .checked_mul(elem_size)
328        .ok_or(CudaError::InvalidValue)?;
329    let src_byte_offset = src_offset
330        .checked_mul(elem_size)
331        .ok_or(CudaError::InvalidValue)? as u64;
332    let api = try_driver()?;
333    // SAFETY: `dst` is pinned host memory with room for `count` elements,
334    // the source sub-range lies within `src`, and byte counts match.
335    let rc = unsafe {
336        (api.cu_memcpy_dtoh_async_v2)(
337            dst.as_mut_ptr().cast::<c_void>(),
338            src.as_device_ptr() + src_byte_offset,
339            byte_count,
340            stream.raw(),
341        )
342    };
343    oxicuda_driver::check(rc)
344}
345
346/// Asynchronously copies from a pinned host buffer into a contiguous
347/// sub-region of a device buffer.
348///
349/// The first `count` elements of `src` are written into `dst` starting at
350/// element index `dst_offset`.
351///
352/// This is the host→device leg of a host-staged inter-device transfer; see
353/// [`copy_dtoh_region_async`] for the device→host leg.
354///
355/// The copy is enqueued on `stream`; the caller must synchronise the stream
356/// before reusing `src`.
357///
358/// # Errors
359///
360/// * [`CudaError::InvalidValue`] if `dst_offset + count` exceeds `dst.len()`,
361///   if `count` exceeds `src.len()`, or on offset overflow.
362/// * Other driver errors from `cuMemcpyHtoDAsync_v2`.
363pub fn copy_htod_region_async<T: Copy>(
364    dst: &mut DeviceBuffer<T>,
365    dst_offset: usize,
366    src: &PinnedBuffer<T>,
367    count: usize,
368    stream: &Stream,
369) -> CudaResult<()> {
370    let elem_size = std::mem::size_of::<T>();
371    let dst_end = dst_offset
372        .checked_add(count)
373        .ok_or(CudaError::InvalidValue)?;
374    if dst_end > dst.len() || count > src.len() {
375        return Err(CudaError::InvalidValue);
376    }
377    if count == 0 {
378        return Ok(());
379    }
380    let byte_count = count
381        .checked_mul(elem_size)
382        .ok_or(CudaError::InvalidValue)?;
383    let dst_byte_offset = dst_offset
384        .checked_mul(elem_size)
385        .ok_or(CudaError::InvalidValue)? as u64;
386    let api = try_driver()?;
387    // SAFETY: `src` is pinned host memory holding at least `count` elements,
388    // the destination sub-range lies within `dst`, and byte counts match.
389    let rc = unsafe {
390        (api.cu_memcpy_htod_async_v2)(
391            dst.as_device_ptr() + dst_byte_offset,
392            src.as_ptr().cast::<c_void>(),
393            byte_count,
394            stream.raw(),
395        )
396    };
397    oxicuda_driver::check(rc)
398}
399
400// ---------------------------------------------------------------------------
401// Tests
402// ---------------------------------------------------------------------------
403
404#[cfg(test)]
405mod tests {
406    #[test]
407    fn copy_htod_signature_compiles() {
408        let _f: fn(&mut super::DeviceBuffer<f32>, &[f32]) -> super::CudaResult<()> =
409            super::copy_htod;
410        let _f2: fn(&mut [f32], &super::DeviceBuffer<f32>) -> super::CudaResult<()> =
411            super::copy_dtoh;
412    }
413
414    #[test]
415    fn copy_dtod_signature_compiles() {
416        let _f: fn(
417            &mut super::DeviceBuffer<f32>,
418            &super::DeviceBuffer<f32>,
419        ) -> super::CudaResult<()> = super::copy_dtod;
420    }
421
422    #[test]
423    fn async_raw_htod_signature_compiles() {
424        let _f: fn(
425            &mut super::DeviceBuffer<f32>,
426            &[f32],
427            &oxicuda_driver::stream::Stream,
428        ) -> super::CudaResult<()> = super::copy_htod_async_raw;
429    }
430
431    #[test]
432    fn async_raw_dtoh_signature_compiles() {
433        let _f: fn(
434            &mut [f32],
435            &super::DeviceBuffer<f32>,
436            &oxicuda_driver::stream::Stream,
437        ) -> super::CudaResult<()> = super::copy_dtoh_async_raw;
438    }
439
440    #[test]
441    fn async_dtod_signature_compiles() {
442        let _f: fn(
443            &mut super::DeviceBuffer<f32>,
444            &super::DeviceBuffer<f32>,
445            &oxicuda_driver::stream::Stream,
446        ) -> super::CudaResult<()> = super::copy_dtod_async;
447    }
448
449    #[test]
450    fn async_pinned_htod_signature_compiles() {
451        let _f: fn(
452            &mut super::DeviceBuffer<f32>,
453            &super::PinnedBuffer<f32>,
454            &oxicuda_driver::stream::Stream,
455        ) -> super::CudaResult<()> = super::copy_htod_async;
456    }
457
458    #[test]
459    fn region_dtoh_signature_compiles() {
460        type RegionDtohFn = fn(
461            &mut super::PinnedBuffer<f32>,
462            &super::DeviceBuffer<f32>,
463            usize,
464            usize,
465            &oxicuda_driver::stream::Stream,
466        ) -> super::CudaResult<()>;
467        let _f: RegionDtohFn = super::copy_dtoh_region_async;
468    }
469
470    #[test]
471    fn region_htod_signature_compiles() {
472        type RegionHtodFn = fn(
473            &mut super::DeviceBuffer<f32>,
474            usize,
475            &super::PinnedBuffer<f32>,
476            usize,
477            &oxicuda_driver::stream::Stream,
478        ) -> super::CudaResult<()>;
479        let _f: RegionHtodFn = super::copy_htod_region_async;
480    }
481}