async_cuda/ffi/memory/
device.rs

1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::ffi::device::Device;
5use crate::ffi::memory::host::HostBuffer;
6use crate::ffi::ptr::DevicePtr;
7use crate::ffi::result;
8use crate::ffi::stream::Stream;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12/// Synchronous implementation of [`crate::DeviceBuffer`].
13///
14/// Refer to [`crate::DeviceBuffer`] for documentation.
15pub struct DeviceBuffer<T: Copy> {
16    pub num_elements: usize,
17    internal: DevicePtr,
18    device: DeviceId,
19    _phantom: std::marker::PhantomData<T>,
20}
21
22/// Implements [`Send`] for [`DeviceBuffer`].
23///
24/// # Safety
25///
26/// This property is inherited from the CUDA API, which is thread-safe.
27unsafe impl<T: Copy> Send for DeviceBuffer<T> {}
28
29/// Implements [`Sync`] for [`DeviceBuffer`].
30///
31/// # Safety
32///
33/// This property is inherited from the CUDA API, which is thread-safe.
34unsafe impl<T: Copy> Sync for DeviceBuffer<T> {}
35
36impl<T: Copy> DeviceBuffer<T> {
37    pub fn new(num_elements: usize, stream: &Stream) -> Self {
38        let device = Device::get_or_panic();
39        let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
40        let ptr_ptr = std::ptr::addr_of_mut!(ptr);
41        let size = num_elements * std::mem::size_of::<T>();
42        let stream_ptr = stream.as_internal().as_ptr();
43        let ret = cpp!(unsafe [
44            ptr_ptr as "void**",
45            size as "std::size_t",
46            stream_ptr as "const void*"
47        ] -> i32 as "std::int32_t" {
48            return cudaMallocAsync(ptr_ptr, size, (cudaStream_t) stream_ptr);
49        });
50        match result!(ret, DevicePtr::from_addr(ptr)) {
51            Ok(internal) => Self {
52                internal,
53                device,
54                num_elements,
55                _phantom: Default::default(),
56            },
57            Err(err) => {
58                panic!("failed to allocate device memory: {err}");
59            }
60        }
61    }
62
63    pub fn from_slice(slice: &[T], stream: &Stream) -> Result<Self> {
64        let host_buffer = HostBuffer::from_slice(slice);
65        let mut this = Self::new(slice.len(), stream);
66        // SAFETY: Safe because the stream is synchronized after this.
67        unsafe {
68            this.copy_from_async(&host_buffer, stream)?;
69        }
70        stream.synchronize()?;
71        Ok(this)
72    }
73
74    #[cfg(feature = "ndarray")]
75    pub fn from_array<D: ndarray::Dimension>(
76        array: &ndarray::ArrayView<T, D>,
77        stream: &Stream,
78    ) -> Result<Self> {
79        let host_buffer = HostBuffer::from_array(array);
80        let mut this = Self::new(array.len(), stream);
81        // SAFETY: Safe because the stream is synchronized after this.
82        unsafe {
83            this.copy_from_async(&host_buffer, stream)?;
84        }
85        stream.synchronize()?;
86        Ok(this)
87    }
88
89    /// Copy from host buffer.
90    ///
91    /// # Safety
92    ///
93    /// This function is marked unsafe because it does not synchronize and the operation might not
94    /// have completed when it returns.
95    pub unsafe fn copy_from_async(&mut self, other: &HostBuffer<T>, stream: &Stream) -> Result<()> {
96        assert_eq!(self.num_elements, other.num_elements);
97        let ptr_to = self.as_mut_internal().as_mut_ptr();
98        let ptr_from = other.as_internal().as_ptr();
99        let stream_ptr = stream.as_internal().as_ptr();
100        let size = self.num_elements * std::mem::size_of::<T>();
101        let ret = cpp!(unsafe [
102            ptr_from as "void*",
103            ptr_to as "void*",
104            size as "std::size_t",
105            stream_ptr as "const void*"
106        ] -> i32 as "std::int32_t" {
107            return cudaMemcpyAsync(
108                ptr_to,
109                ptr_from,
110                size,
111                cudaMemcpyHostToDevice,
112                (cudaStream_t) stream_ptr
113            );
114        });
115        result!(ret)
116    }
117
118    /// Copy to host buffer.
119    ///
120    /// # Safety
121    ///
122    /// This function is marked unsafe because it does not synchronize and the operation might not
123    /// have completed when it returns.
124    pub unsafe fn copy_to_async(&self, other: &mut HostBuffer<T>, stream: &Stream) -> Result<()> {
125        assert_eq!(self.num_elements, other.num_elements);
126        let ptr_from = self.as_internal().as_ptr();
127        let ptr_to = other.as_mut_internal().as_mut_ptr();
128        let size = self.num_elements * std::mem::size_of::<T>();
129        let stream_ptr = stream.as_internal().as_ptr();
130        let ret = cpp!(unsafe [
131            ptr_from as "void*",
132            ptr_to as "void*",
133            size as "std::size_t",
134            stream_ptr as "const void*"
135        ] -> i32 as "std::int32_t" {
136            return cudaMemcpyAsync(
137                ptr_to,
138                ptr_from,
139                size,
140                cudaMemcpyDeviceToHost,
141                (cudaStream_t) stream_ptr
142            );
143        });
144        result!(ret)
145    }
146
147    /// Fill buffer with byte value.
148    pub fn fill_with_byte(&mut self, value: u8, stream: &Stream) -> Result<()> {
149        let ptr = self.as_internal().as_ptr();
150        let value = value as std::ffi::c_int;
151        let size = self.num_elements * std::mem::size_of::<T>();
152        let stream_ptr = stream.as_internal().as_ptr();
153        let ret = cpp!(unsafe [
154            ptr as "void*",
155            value as "int",
156            size as "std::size_t",
157            stream_ptr as "const void*"
158        ] -> i32 as "std::int32_t" {
159            return cudaMemsetAsync(
160                ptr,
161                value,
162                size,
163                (cudaStream_t) stream_ptr
164            );
165        });
166        result!(ret)
167    }
168
169    /// Get readonly reference to internal [`DevicePtr`].
170    #[inline(always)]
171    pub fn as_internal(&self) -> &DevicePtr {
172        &self.internal
173    }
174
175    /// Get mutable reference to internal [`DevicePtr`].
176    #[inline(always)]
177    pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
178        &mut self.internal
179    }
180
181    /// Release the buffer memory.
182    ///
183    /// # Panics
184    ///
185    /// This function panics if binding to the corresponding device fails.
186    ///
187    /// # Safety
188    ///
189    /// The buffer may not be used after this function is called, except for being dropped.
190    pub unsafe fn free(&mut self) {
191        if self.internal.is_null() {
192            return;
193        }
194
195        Device::set_or_panic(self.device);
196
197        // SAFETY: Safe because we won't use pointer after this.
198        let mut internal = unsafe { self.internal.take() };
199        let ptr = internal.as_mut_ptr();
200        let _ret = cpp!(unsafe [
201            ptr as "void*"
202        ] -> i32 as "std::int32_t" {
203            return cudaFree(ptr);
204        });
205    }
206}
207
208impl<T: Copy> Drop for DeviceBuffer<T> {
209    #[inline]
210    fn drop(&mut self) {
211        // SAFETY: This is safe since the buffer cannot be used after this.
212        unsafe {
213            self.free();
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_new() {
224        let buffer = DeviceBuffer::<u32>::new(100, &Stream::null());
225        assert_eq!(buffer.num_elements, 100);
226    }
227
228    #[test]
229    fn test_copy() {
230        let stream = Stream::new().unwrap();
231        let all_ones = vec![1_u32; 100];
232        let host_buffer_all_ones = HostBuffer::from_slice(all_ones.as_slice());
233
234        let mut device_buffer = DeviceBuffer::<u32>::new(100, &stream);
235        unsafe {
236            device_buffer
237                .copy_from_async(&host_buffer_all_ones, &stream)
238                .unwrap();
239        }
240
241        let mut host_buffer = HostBuffer::<u32>::new(100);
242        unsafe {
243            device_buffer
244                .copy_to_async(&mut host_buffer, &stream)
245                .unwrap();
246        }
247
248        let mut another_device_buffer = DeviceBuffer::<u32>::new(100, &stream);
249        unsafe {
250            another_device_buffer
251                .copy_from_async(&host_buffer, &stream)
252                .unwrap();
253        }
254
255        let mut return_host_buffer = HostBuffer::<u32>::new(100);
256        unsafe {
257            another_device_buffer
258                .copy_to_async(&mut return_host_buffer, &stream)
259                .unwrap();
260        }
261
262        stream.synchronize().unwrap();
263
264        assert_eq!(return_host_buffer.num_elements, 100);
265        let return_data = return_host_buffer.to_vec();
266        assert_eq!(return_data.len(), 100);
267        assert!(return_data.into_iter().all(|v| v == 1_u32));
268    }
269
270    #[test]
271    fn test_fill_with_byte() {
272        let stream = Stream::new().unwrap();
273        let mut device_buffer = DeviceBuffer::<u8>::new(4, &stream);
274        let mut host_buffer = HostBuffer::<u8>::new(4);
275        device_buffer.fill_with_byte(0xab, &stream).unwrap();
276        unsafe {
277            device_buffer
278                .copy_to_async(&mut host_buffer, &stream)
279                .unwrap();
280        }
281        stream.synchronize().unwrap();
282        assert_eq!(host_buffer.to_vec(), &[0xab, 0xab, 0xab, 0xab]);
283    }
284
285    #[test]
286    #[should_panic]
287    fn test_it_panics_when_copying_invalid_size() {
288        let stream = Stream::new().unwrap();
289        let device_buffer = DeviceBuffer::<u32>::new(101, &stream);
290        let mut host_buffer = HostBuffer::<u32>::new(100);
291        let _ = unsafe { device_buffer.copy_to_async(&mut host_buffer, &stream) };
292    }
293}