Skip to main content

oxicuda_runtime/
memory.rs

1//! Device and host memory management.
2//!
3//! Implements the CUDA Runtime memory API:
4//! - `cudaMalloc` / `cudaFree`
5//! - `cudaMallocHost` / `cudaFreeHost` (pinned host memory)
6//! - `cudaMallocManaged` (unified memory)
7//! - `cudaMallocPitch` (pitched 2-D allocation)
8//! - `cudaMemcpy` / `cudaMemcpyAsync`
9//! - `cudaMemset` / `cudaMemsetAsync`
10//! - `cudaMemGetInfo`
11//!
12//! All memory addresses returned for device allocations are represented as
13//! [`DevicePtr`], a newtype around `u64` that matches the driver API's
14//! `CUdeviceptr`.
15
16use std::ffi::c_void;
17
18use oxicuda_driver::loader::try_driver;
19
20use crate::error::{CudaRtError, CudaRtResult};
21use crate::stream::CudaStream;
22
23// ─── DevicePtr ───────────────────────────────────────────────────────────────
24
25/// Opaque CUDA device-memory address (mirrors `CUdeviceptr`).
26///
27/// This is a plain `u64` wrapped in a newtype to prevent accidental
28/// dereferencing from host code.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct DevicePtr(pub u64);
31
32impl DevicePtr {
33    /// The null (zero) device pointer.
34    pub const NULL: Self = Self(0);
35
36    /// Returns `true` if this is the null pointer.
37    #[must_use]
38    pub fn is_null(self) -> bool {
39        self.0 == 0
40    }
41
42    /// Offset this pointer by `offset` bytes, returning a new `DevicePtr`.
43    #[must_use]
44    pub fn offset(self, offset: isize) -> Self {
45        Self((self.0 as i64 + offset as i64) as u64)
46    }
47}
48
49// ─── MemcpyKind ──────────────────────────────────────────────────────────────
50
51/// Direction of a `cudaMemcpy` transfer.
52///
53/// Mirrors `cudaMemcpyKind`.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub enum MemcpyKind {
56    /// Host → Host.
57    HostToHost = 0,
58    /// Host → Device.
59    HostToDevice = 1,
60    /// Device → Host.
61    DeviceToHost = 2,
62    /// Device → Device.
63    DeviceToDevice = 3,
64    /// Direction inferred from pointer attributes (unified addressing).
65    Default = 4,
66}
67
68// ─── MemAttachFlags ──────────────────────────────────────────────────────────
69
70/// Flags for `cudaMallocManaged`.
71///
72/// Mirrors `cudaMemAttachFlags`.
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
74pub enum MemAttachFlags {
75    /// Memory accessible by all CUDA devices and host.
76    Global = 1,
77    /// Memory only accessible by the host and a single CUDA device.
78    Host = 2,
79    /// Memory only accessible by single stream (deprecated in CUDA 12).
80    Single = 4,
81}
82
83// ─── Allocation ──────────────────────────────────────────────────────────────
84
85/// Allocate `size` bytes of device memory.
86///
87/// Mirrors `cudaMalloc`.
88///
89/// # Errors
90///
91/// - [`CudaRtError::DriverNotAvailable`] — driver not loaded.
92/// - [`CudaRtError::MemoryAllocation`] — out of device memory.
93pub fn malloc(size: usize) -> CudaRtResult<DevicePtr> {
94    if size == 0 {
95        return Ok(DevicePtr::NULL);
96    }
97    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
98    let mut ptr: u64 = 0;
99    // SAFETY: FFI; ptr is a valid stack-allocated u64.
100    let rc = unsafe { (api.cu_mem_alloc_v2)(&raw mut ptr, size) };
101    if rc != 0 {
102        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::MemoryAllocation));
103    }
104    Ok(DevicePtr(ptr))
105}
106
107/// Free device memory previously allocated with [`malloc`].
108///
109/// Mirrors `cudaFree`.
110///
111/// # Errors
112///
113/// Propagates driver errors.  Passing [`DevicePtr::NULL`] is a no-op.
114pub fn free(ptr: DevicePtr) -> CudaRtResult<()> {
115    if ptr.is_null() {
116        return Ok(());
117    }
118    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
119    // SAFETY: FFI; ptr was returned by cu_mem_alloc_v2.
120    let rc = unsafe { (api.cu_mem_free_v2)(ptr.0) };
121    if rc != 0 {
122        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevicePointer));
123    }
124    Ok(())
125}
126
127/// Allocate `size` bytes of pinned (page-locked) host memory.
128///
129/// Mirrors `cudaMallocHost`.
130///
131/// Returns a raw host pointer that must be freed with [`free_host`].
132///
133/// # Errors
134///
135/// - [`CudaRtError::MemoryAllocation`] — out of host memory.
136pub fn malloc_host(size: usize) -> CudaRtResult<*mut c_void> {
137    if size == 0 {
138        return Ok(std::ptr::null_mut());
139    }
140    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
141    let mut ptr: *mut c_void = std::ptr::null_mut();
142    // SAFETY: FFI; ptr is a valid stack-allocated pointer.
143    let rc = unsafe { (api.cu_mem_alloc_host_v2)(&raw mut ptr, size) };
144    if rc != 0 {
145        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::MemoryAllocation));
146    }
147    Ok(ptr)
148}
149
150/// Free page-locked host memory previously allocated with [`malloc_host`].
151///
152/// Mirrors `cudaFreeHost`.
153///
154/// # Errors
155///
156/// Propagates driver errors.
157///
158/// # Safety
159///
160/// `ptr` must have been returned by [`malloc_host`] and must not have been
161/// freed already.
162pub unsafe fn free_host(ptr: *mut c_void) -> CudaRtResult<()> {
163    if ptr.is_null() {
164        return Ok(());
165    }
166    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
167    // SAFETY: FFI; ptr was returned by cu_mem_alloc_host_v2.
168    let rc = unsafe { (api.cu_mem_free_host)(ptr) };
169    if rc != 0 {
170        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidHostPointer));
171    }
172    Ok(())
173}
174
175/// Allocate unified managed memory accessible from both CPU and GPU.
176///
177/// Mirrors `cudaMallocManaged`.
178///
179/// # Errors
180///
181/// - [`CudaRtError::NotSupported`] — device does not support managed memory.
182/// - [`CudaRtError::MemoryAllocation`] — out of memory.
183pub fn malloc_managed(size: usize, flags: MemAttachFlags) -> CudaRtResult<DevicePtr> {
184    if size == 0 {
185        return Ok(DevicePtr::NULL);
186    }
187    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
188    let mut ptr: u64 = 0;
189    // SAFETY: FFI; ptr is valid and flags maps to CU_MEM_ATTACH_* values.
190    let rc = unsafe { (api.cu_mem_alloc_managed)(&raw mut ptr, size, flags as u32) };
191    if rc != 0 {
192        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::MemoryAllocation));
193    }
194    Ok(DevicePtr(ptr))
195}
196
197/// Allocate pitched device memory for 2-D arrays.
198///
199/// Mirrors `cudaMallocPitch`.
200///
201/// Returns `(device_ptr, pitch_bytes)`.  `pitch_bytes` is ≥ `width_bytes`
202/// and aligned to the hardware's texture alignment.
203///
204/// # Errors
205///
206/// Propagates driver errors.
207pub fn malloc_pitch(width_bytes: usize, height: usize) -> CudaRtResult<(DevicePtr, usize)> {
208    if width_bytes == 0 || height == 0 {
209        return Ok((DevicePtr::NULL, 0));
210    }
211    // Compute the pitch: round width_bytes up to 512-byte alignment, which
212    // matches the driver's cuMemAllocPitch behaviour for most hardware.
213    let align: usize = 512;
214    let pitch = width_bytes.div_ceil(align) * align;
215    let size = pitch * height;
216    let ptr = malloc(size)?;
217    Ok((ptr, pitch))
218}
219
220// ─── Memcpy ──────────────────────────────────────────────────────────────────
221
222/// Synchronously copy `count` bytes between memory regions.
223///
224/// Mirrors `cudaMemcpy`.
225///
226/// # Safety
227///
228/// `src` and `dst` must point to valid memory of the appropriate kind
229/// (host or device) and must not overlap.
230///
231/// # Errors
232///
233/// - [`CudaRtError::InvalidMemcpyDirection`] for unsupported `kind`.
234/// - Driver errors for invalid pointers or counts.
235pub unsafe fn memcpy(
236    dst: *mut c_void,
237    src: *const c_void,
238    count: usize,
239    kind: MemcpyKind,
240) -> CudaRtResult<()> {
241    if count == 0 {
242        return Ok(());
243    }
244    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
245    let rc = match kind {
246        MemcpyKind::HostToHost => {
247            // Pure host copy — no driver involvement.
248            // SAFETY: Caller ensures src/dst are valid and non-overlapping.
249            unsafe { std::ptr::copy_nonoverlapping(src as *const u8, dst as *mut u8, count) };
250            0u32
251        }
252        MemcpyKind::HostToDevice => {
253            let dst_ptr = dst as u64;
254            // SAFETY: FFI; src/dst valid per caller contract.
255            unsafe { (api.cu_memcpy_htod_v2)(dst_ptr, src, count) }
256        }
257        MemcpyKind::DeviceToHost => {
258            let src_ptr = src as u64;
259            // SAFETY: FFI; src/dst valid per caller contract.
260            unsafe { (api.cu_memcpy_dtoh_v2)(dst, src_ptr, count) }
261        }
262        MemcpyKind::DeviceToDevice => {
263            let dst_ptr = dst as u64;
264            let src_ptr = src as u64;
265            // SAFETY: FFI; src/dst valid per caller contract.
266            unsafe { (api.cu_memcpy_dtod_v2)(dst_ptr, src_ptr, count) }
267        }
268        MemcpyKind::Default => {
269            // Fall back to H2D (common case; real implementation would use
270            // cuPointerGetAttribute to determine actual memory type).
271            let dst_ptr = dst as u64;
272            // SAFETY: FFI.
273            unsafe { (api.cu_memcpy_htod_v2)(dst_ptr, src, count) }
274        }
275    };
276    if rc != 0 {
277        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
278    }
279    Ok(())
280}
281
282/// Asynchronously copy `count` bytes on `stream`.
283///
284/// Mirrors `cudaMemcpyAsync`.
285///
286/// # Safety
287///
288/// Same requirements as [`memcpy`] plus `stream` must be valid.
289///
290/// # Errors
291///
292/// Propagates driver errors.
293pub unsafe fn memcpy_async(
294    dst: *mut c_void,
295    src: *const c_void,
296    count: usize,
297    kind: MemcpyKind,
298    stream: &CudaStream,
299) -> CudaRtResult<()> {
300    if count == 0 {
301        return Ok(());
302    }
303    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
304    let rc = match kind {
305        MemcpyKind::HostToHost => {
306            // SAFETY: host-to-host can be dispatched synchronously.
307            unsafe { std::ptr::copy_nonoverlapping(src as *const u8, dst as *mut u8, count) };
308            0u32
309        }
310        MemcpyKind::HostToDevice | MemcpyKind::Default => {
311            let dst_ptr = dst as u64;
312            // SAFETY: FFI; caller guarantees validity.
313            unsafe { (api.cu_memcpy_htod_async_v2)(dst_ptr, src, count, stream.raw()) }
314        }
315        MemcpyKind::DeviceToHost => {
316            let src_ptr = src as u64;
317            // SAFETY: FFI.
318            unsafe { (api.cu_memcpy_dtoh_async_v2)(dst, src_ptr, count, stream.raw()) }
319        }
320        MemcpyKind::DeviceToDevice => {
321            // Fall back to synchronous D2D (driver lacks async D2D helper in v1).
322            let dst_ptr = dst as u64;
323            let src_ptr = src as u64;
324            // SAFETY: FFI.
325            unsafe { (api.cu_memcpy_dtod_v2)(dst_ptr, src_ptr, count) }
326        }
327    };
328    if rc != 0 {
329        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
330    }
331    Ok(())
332}
333
334// ─── Typed helpers ────────────────────────────────────────────────────────────
335
336/// Copy a slice of host data to a device allocation.
337///
338/// # Errors
339///
340/// Propagates driver errors.
341pub fn memcpy_h2d<T: Copy>(dst: DevicePtr, src: &[T]) -> CudaRtResult<()> {
342    let bytes = std::mem::size_of_val(src);
343    // SAFETY: src is a valid slice; dst is a device allocation.
344    unsafe {
345        memcpy(
346            dst.0 as *mut c_void,
347            src.as_ptr() as *const c_void,
348            bytes,
349            MemcpyKind::HostToDevice,
350        )
351    }
352}
353
354/// Copy device memory to a host slice.
355///
356/// # Errors
357///
358/// Propagates driver errors.
359pub fn memcpy_d2h<T: Copy>(dst: &mut [T], src: DevicePtr) -> CudaRtResult<()> {
360    let bytes = std::mem::size_of_val(dst);
361    // SAFETY: dst is a valid mutable slice; src is a device allocation.
362    unsafe {
363        memcpy(
364            dst.as_mut_ptr() as *mut c_void,
365            src.0 as *const c_void,
366            bytes,
367            MemcpyKind::DeviceToHost,
368        )
369    }
370}
371
372/// Copy between two device allocations.
373///
374/// # Errors
375///
376/// Propagates driver errors.
377pub fn memcpy_d2d(dst: DevicePtr, src: DevicePtr, bytes: usize) -> CudaRtResult<()> {
378    // SAFETY: both ptrs are device allocations.
379    unsafe {
380        memcpy(
381            dst.0 as *mut c_void,
382            src.0 as *const c_void,
383            bytes,
384            MemcpyKind::DeviceToDevice,
385        )
386    }
387}
388
389// ─── Memset ──────────────────────────────────────────────────────────────────
390
391/// Set `count` bytes of device memory starting at `ptr` to `value`.
392///
393/// Mirrors `cudaMemset`.
394///
395/// # Errors
396///
397/// Propagates driver errors.
398pub fn memset(ptr: DevicePtr, value: u8, count: usize) -> CudaRtResult<()> {
399    if count == 0 || ptr.is_null() {
400        return Ok(());
401    }
402    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
403    // SAFETY: FFI; ptr is a valid device allocation.
404    let rc = unsafe { (api.cu_memset_d8_v2)(ptr.0, value, count) };
405    if rc != 0 {
406        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevicePointer));
407    }
408    Ok(())
409}
410
411/// Set device memory to 32-bit value pattern.
412///
413/// `count` is the number of 32-bit words (not bytes) to set.
414/// Mirrors `cudaMemset` for 4-byte granularity.
415///
416/// # Errors
417///
418/// Propagates driver errors.
419pub fn memset32(ptr: DevicePtr, value: u32, count: usize) -> CudaRtResult<()> {
420    if count == 0 || ptr.is_null() {
421        return Ok(());
422    }
423    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
424    // SAFETY: FFI; ptr is a valid device allocation.
425    let rc = unsafe { (api.cu_memset_d32_v2)(ptr.0, value, count) };
426    if rc != 0 {
427        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevicePointer));
428    }
429    Ok(())
430}
431
432// ─── MemGetInfo ──────────────────────────────────────────────────────────────
433
434/// Returns `(free_bytes, total_bytes)` for the current device's global memory.
435///
436/// Mirrors `cudaMemGetInfo`.
437///
438/// # Errors
439///
440/// Propagates driver errors.
441pub fn mem_get_info() -> CudaRtResult<(usize, usize)> {
442    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
443    let mut free: usize = 0;
444    let mut total: usize = 0;
445    // SAFETY: FFI; both pointers are valid stack-allocated usizes.
446    let rc = unsafe { (api.cu_mem_get_info_v2)(&raw mut free, &raw mut total) };
447    if rc != 0 {
448        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::Unknown));
449    }
450    Ok((free, total))
451}
452
453// ─── Tests ───────────────────────────────────────────────────────────────────
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn malloc_zero_returns_null() {
461        // zero-byte allocation must return NULL without calling the driver.
462        // This is valid even without a GPU.
463        let result = malloc(0);
464        assert!(matches!(result, Ok(DevicePtr(0))));
465    }
466
467    #[test]
468    fn free_null_is_noop() {
469        // freeing a null pointer must not panic or call the driver.
470        let result = free(DevicePtr::NULL);
471        assert!(result.is_ok() || result.is_err()); // either is acceptable w/o GPU
472    }
473
474    #[test]
475    fn device_ptr_offset() {
476        let p = DevicePtr(1000);
477        assert_eq!(p.offset(8), DevicePtr(1008));
478        assert_eq!(p.offset(-8), DevicePtr(992));
479    }
480
481    #[test]
482    fn device_ptr_is_null() {
483        assert!(DevicePtr::NULL.is_null());
484        assert!(!DevicePtr(1).is_null());
485    }
486
487    #[test]
488    fn malloc_pitch_returns_aligned_pitch() {
489        // Without a GPU, malloc_pitch falls through to malloc which may fail,
490        // but the pitch computation is pure arithmetic.
491        let (_, pitch) = malloc_pitch(100, 32).unwrap_or((DevicePtr::NULL, 512));
492        // Pitch must be a multiple of 512.
493        assert_eq!(pitch % 512, 0);
494        assert!(pitch >= 100);
495    }
496
497    #[test]
498    fn memcpy_kind_values() {
499        assert_eq!(MemcpyKind::HostToHost as u32, 0);
500        assert_eq!(MemcpyKind::HostToDevice as u32, 1);
501        assert_eq!(MemcpyKind::DeviceToHost as u32, 2);
502        assert_eq!(MemcpyKind::DeviceToDevice as u32, 3);
503        assert_eq!(MemcpyKind::Default as u32, 4);
504    }
505}