async_cuda_core/ffi/memory/
device.rs

1use cpp::cpp;
2
3use crate::ffi::memory::host::HostBuffer;
4use crate::ffi::ptr::DevicePtr;
5use crate::ffi::result;
6use crate::ffi::stream::Stream;
7
8type Result<T> = std::result::Result<T, crate::error::Error>;
9
10/// Synchronous implementation of [`crate::DeviceBuffer`].
11///
12/// Refer to [`crate::DeviceBuffer`] for documentation.
13pub struct DeviceBuffer<T: Copy> {
14    pub num_elements: usize,
15    internal: DevicePtr,
16    _phantom: std::marker::PhantomData<T>,
17}
18
19/// Implements [`Send`] for [`DeviceBuffer`].
20///
21/// # Safety
22///
23/// This property is inherited from the CUDA API, which is thread-safe.
24unsafe impl<T: Copy> Send for DeviceBuffer<T> {}
25
26/// Implements [`Sync`] for [`DeviceBuffer`].
27///
28/// # Safety
29///
30/// This property is inherited from the CUDA API, which is thread-safe.
31unsafe impl<T: Copy> Sync for DeviceBuffer<T> {}
32
33impl<T: Copy> DeviceBuffer<T> {
34    pub fn new(num_elements: usize, stream: &Stream) -> Self {
35        let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
36        let ptr_ptr = std::ptr::addr_of_mut!(ptr);
37        let size = num_elements * std::mem::size_of::<T>();
38        let stream_ptr = stream.as_internal().as_ptr();
39        let ret = cpp!(unsafe [
40            ptr_ptr as "void**",
41            size as "std::size_t",
42            stream_ptr as "const void*"
43        ] -> i32 as "std::int32_t" {
44            return cudaMallocAsync(ptr_ptr, size, (cudaStream_t) stream_ptr);
45        });
46        match result!(ret, ptr.into()) {
47            Ok(internal) => Self {
48                internal,
49                num_elements,
50                _phantom: Default::default(),
51            },
52            Err(err) => {
53                panic!("failed to allocate device memory: {err}");
54            }
55        }
56    }
57
58    pub fn from_slice(slice: &[T], stream: &Stream) -> Result<Self> {
59        let host_buffer = HostBuffer::from_slice(slice);
60        let mut this = Self::new(slice.len(), stream);
61        // SAFETY: Safe because the stream is synchronized after this.
62        unsafe {
63            this.copy_from_async(&host_buffer, stream)?;
64        }
65        stream.synchronize()?;
66        Ok(this)
67    }
68
69    #[cfg(feature = "ndarray")]
70    pub fn from_array<D: ndarray::Dimension>(
71        array: &ndarray::ArrayView<T, D>,
72        stream: &Stream,
73    ) -> Result<Self> {
74        let host_buffer = HostBuffer::from_array(array);
75        let mut this = Self::new(array.len(), stream);
76        // SAFETY: Safe because the stream is synchronized after this.
77        unsafe {
78            this.copy_from_async(&host_buffer, stream)?;
79        }
80        stream.synchronize()?;
81        Ok(this)
82    }
83
84    /// Copy from host buffer.
85    ///
86    /// # Safety
87    ///
88    /// This function is marked unsafe because it does not synchronize and the operation might not
89    /// have completed when it returns.
90    pub unsafe fn copy_from_async(&mut self, other: &HostBuffer<T>, stream: &Stream) -> Result<()> {
91        assert_eq!(self.num_elements, other.num_elements);
92        let ptr_to = self.as_mut_internal().as_mut_ptr();
93        let ptr_from = other.as_internal().as_ptr();
94        let stream_ptr = stream.as_internal().as_ptr();
95        let size = self.num_elements * std::mem::size_of::<T>();
96        let ret = cpp!(unsafe [
97            ptr_from as "void*",
98            ptr_to as "void*",
99            size as "std::size_t",
100            stream_ptr as "const void*"
101        ] -> i32 as "std::int32_t" {
102            return cudaMemcpyAsync(
103                ptr_to,
104                ptr_from,
105                size,
106                cudaMemcpyHostToDevice,
107                (cudaStream_t) stream_ptr
108            );
109        });
110        result!(ret)
111    }
112
113    /// Copy to host buffer.
114    ///
115    /// # Safety
116    ///
117    /// This function is marked unsafe because it does not synchronize and the operation might not
118    /// have completed when it returns.
119    pub unsafe fn copy_to_async(&self, other: &mut HostBuffer<T>, stream: &Stream) -> Result<()> {
120        assert_eq!(self.num_elements, other.num_elements);
121        let ptr_from = self.as_internal().as_ptr();
122        let ptr_to = other.as_mut_internal().as_mut_ptr();
123        let size = self.num_elements * std::mem::size_of::<T>();
124        let stream_ptr = stream.as_internal().as_ptr();
125        let ret = cpp!(unsafe [
126            ptr_from as "void*",
127            ptr_to as "void*",
128            size as "std::size_t",
129            stream_ptr as "const void*"
130        ] -> i32 as "std::int32_t" {
131            return cudaMemcpyAsync(
132                ptr_to,
133                ptr_from,
134                size,
135                cudaMemcpyDeviceToHost,
136                (cudaStream_t) stream_ptr
137            );
138        });
139        result!(ret)
140    }
141
142    /// Fill buffer with byte value.
143    pub fn fill_with_byte(&mut self, value: u8, stream: &Stream) -> Result<()> {
144        let ptr = self.as_internal().as_ptr();
145        let value = value as std::ffi::c_int;
146        let size = self.num_elements * std::mem::size_of::<T>();
147        let stream_ptr = stream.as_internal().as_ptr();
148        let ret = cpp!(unsafe [
149            ptr as "void*",
150            value as "int",
151            size as "std::size_t",
152            stream_ptr as "const void*"
153        ] -> i32 as "std::int32_t" {
154            return cudaMemsetAsync(
155                ptr,
156                value,
157                size,
158                (cudaStream_t) stream_ptr
159            );
160        });
161        result!(ret)
162    }
163
164    /// Get readonly reference to internal [`DevicePtr`].
165    #[inline(always)]
166    pub fn as_internal(&self) -> &DevicePtr {
167        &self.internal
168    }
169
170    /// Get readonly reference to internal [`DevicePtr`].
171    #[inline(always)]
172    pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
173        &mut self.internal
174    }
175}
176
177impl<T: Copy> Drop for DeviceBuffer<T> {
178    fn drop(&mut self) {
179        if self.internal.is_null() {
180            return;
181        }
182
183        // SAFETY: Safe because we won't use pointer after this.
184        let mut internal = unsafe { self.internal.take() };
185        let ptr = internal.as_mut_ptr();
186        let _ret = cpp!(unsafe [
187            ptr as "void*"
188        ] -> i32 as "std::int32_t" {
189            return cudaFree(ptr);
190        });
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_new() {
200        let buffer = DeviceBuffer::<u32>::new(100, &Stream::null());
201        assert_eq!(buffer.num_elements, 100);
202    }
203
204    #[test]
205    fn test_copy() {
206        let stream = Stream::new().unwrap();
207        let all_ones = vec![1_u32; 100];
208        let host_buffer_all_ones = HostBuffer::from_slice(all_ones.as_slice());
209
210        let mut device_buffer = DeviceBuffer::<u32>::new(100, &stream);
211        unsafe {
212            device_buffer
213                .copy_from_async(&host_buffer_all_ones, &stream)
214                .unwrap();
215        }
216
217        let mut host_buffer = HostBuffer::<u32>::new(100);
218        unsafe {
219            device_buffer
220                .copy_to_async(&mut host_buffer, &stream)
221                .unwrap();
222        }
223
224        let mut another_device_buffer = DeviceBuffer::<u32>::new(100, &stream);
225        unsafe {
226            another_device_buffer
227                .copy_from_async(&host_buffer, &stream)
228                .unwrap();
229        }
230
231        let mut return_host_buffer = HostBuffer::<u32>::new(100);
232        unsafe {
233            another_device_buffer
234                .copy_to_async(&mut return_host_buffer, &stream)
235                .unwrap();
236        }
237
238        stream.synchronize().unwrap();
239
240        assert_eq!(return_host_buffer.num_elements, 100);
241        let return_data = return_host_buffer.to_vec();
242        assert_eq!(return_data.len(), 100);
243        assert!(return_data.into_iter().all(|v| v == 1_u32));
244    }
245
246    #[test]
247    fn test_fill_with_byte() {
248        let stream = Stream::new().unwrap();
249        let mut device_buffer = DeviceBuffer::<u8>::new(4, &stream);
250        let mut host_buffer = HostBuffer::<u8>::new(4);
251        device_buffer.fill_with_byte(0xab, &stream).unwrap();
252        unsafe {
253            device_buffer
254                .copy_to_async(&mut host_buffer, &stream)
255                .unwrap();
256        }
257        stream.synchronize().unwrap();
258        assert_eq!(host_buffer.to_vec(), &[0xab, 0xab, 0xab, 0xab]);
259    }
260
261    #[test]
262    #[should_panic]
263    fn test_it_panics_when_copying_invalid_size() {
264        let stream = Stream::new().unwrap();
265        let device_buffer = DeviceBuffer::<u32>::new(101, &stream);
266        let mut host_buffer = HostBuffer::<u32>::new(100);
267        let _ = unsafe { device_buffer.copy_to_async(&mut host_buffer, &stream) };
268    }
269}