async_cuda/ffi/
ptr.rs

1/// Represents a device-local pointer. Pointers qualify as device-local if they refer to memory that
2/// lives on the device, and not on the host.
3///
4/// # Safety
5///
6/// ## Null
7///
8/// Creating a null pointer is always unsafe, because any CUDA operations on null pointers can cause
9/// undefined behavior.
10///
11/// Use the `unsafe` function `Ptr::null` to create a null pointer in cases where usage is safe.
12pub struct DevicePtr {
13    addr: *mut std::ffi::c_void,
14}
15
16impl DevicePtr {
17    /// Create from device address.
18    ///
19    /// # Arguments
20    ///
21    /// * `addr` - Address of pointer.
22    #[inline]
23    pub fn from_addr(addr: *mut std::ffi::c_void) -> Self {
24        if !addr.is_null() {
25            DevicePtr { addr }
26        } else {
27            panic!("unexpected null pointer");
28        }
29    }
30
31    /// Create null pointer.
32    ///
33    /// # Safety
34    ///
35    /// This is unsafe because operating on a `null` pointer in CUDA code can cause crashes. In some
36    /// cases it is allowed though, for example, a `null` pointer can designate the default stream
37    /// in stream-related operations.
38    #[inline]
39    pub unsafe fn null() -> Self {
40        DevicePtr {
41            addr: std::ptr::null_mut(),
42        }
43    }
44
45    /// Whether or not the device pointer is a null pointer.
46    #[inline]
47    pub fn is_null(&self) -> bool {
48        self.addr.is_null()
49    }
50
51    /// Get the readonly pointer value.
52    #[inline(always)]
53    pub fn as_ptr(&self) -> *const std::ffi::c_void {
54        self.addr as *const std::ffi::c_void
55    }
56
57    /// Get the mutable pointer value.
58    #[inline(always)]
59    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
60        self.addr
61    }
62
63    /// Take the pointer from this wrapper and replace it with a null pointer.
64    ///
65    /// # Safety
66    ///
67    /// This operation is unsafe because it creates a null pointer.
68    ///
69    /// # Usage
70    ///
71    /// This function can be used inside [`Drop`] if it known that the pointer object will not be
72    /// used for the remainder of the function scope, and the object is to be dropped.
73    ///
74    /// # Example
75    ///
76    /// ```ignore
77    /// # use async_cuda::ffi::DevicePtr;
78    /// pub struct Object {
79    ///     internal: DevicePtr,
80    /// }
81    ///
82    /// impl Drop for Object {
83    ///     fn drop(&mut self) {
84    ///         // SAFETY: This is safe because `self` and `self.internal`
85    ///         // are not used beyond this unsafe block.
86    ///         let ptr = unsafe {
87    ///             self.internal.take();
88    ///         };
89    ///         // Propertly deallocate the pointer here and do *NOT* use
90    ///         // use `self` for anything!
91    ///     }
92    /// }
93    /// ```
94    #[inline]
95    pub unsafe fn take(&mut self) -> DevicePtr {
96        DevicePtr {
97            // sets `self.addr` to NULL, puts addr in new device ptr
98            addr: std::mem::replace(&mut self.addr, std::ptr::null_mut()),
99        }
100    }
101}
102
103impl std::fmt::Display for DevicePtr {
104    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
105        write!(f, "{:?}", self.addr)
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn test_it_holds_on() {
115        let fake = 0xffffffff as *mut std::ffi::c_void;
116        let ptr = DevicePtr::from_addr(fake);
117        assert_eq!(ptr.as_ptr(), 0xffffffff as *const std::ffi::c_void);
118    }
119
120    #[test]
121    #[should_panic]
122    fn test_it_panics_when_null() {
123        let _ = DevicePtr::from_addr(std::ptr::null_mut());
124    }
125
126    #[test]
127    fn test_null() {
128        let ptr = unsafe { DevicePtr::null() };
129        assert!(ptr.is_null());
130        assert_eq!(ptr.as_ptr(), std::ptr::null_mut());
131    }
132
133    #[test]
134    fn test_take() {
135        let fake = 0xffffffff as *mut std::ffi::c_void;
136        let mut ptr = DevicePtr::from_addr(fake);
137        assert_eq!(
138            unsafe { ptr.take().as_ptr() },
139            0xffffffff as *const std::ffi::c_void,
140        );
141    }
142}