async_cuda_core/ffi/memory/
device2d.rs

1use cpp::cpp;
2
3use crate::ffi::memory::host::HostBuffer;
4use crate::ffi::ptr::DevicePtr;
5use crate::ffi::result;
6use crate::ffi::stream::Stream;
7
8type Result<T> = std::result::Result<T, crate::error::Error>;
9
10/// Synchronous implementation of [`crate::DeviceBuffer2D`].
11///
12/// Refer to [`crate::DeviceBuffer2D`] for documentation.
13pub struct DeviceBuffer2D<T: Copy> {
14    pub width: usize,
15    pub height: usize,
16    pub num_channels: usize,
17    pub pitch: usize,
18    internal: DevicePtr,
19    _phantom: std::marker::PhantomData<T>,
20}
21
22/// Implements [`Send`] for [`DeviceBuffer2D`].
23///
24/// # Safety
25///
26/// This property is inherited from the CUDA API, which is thread-safe.
27unsafe impl<T: Copy> Send for DeviceBuffer2D<T> {}
28
29/// Implements [`Sync`] for [`DeviceBuffer2D`].
30///
31/// # Safety
32///
33/// This property is inherited from the CUDA API, which is thread-safe.
34unsafe impl<T: Copy> Sync for DeviceBuffer2D<T> {}
35
36impl<T: Copy> DeviceBuffer2D<T> {
37    pub fn new(width: usize, height: usize, num_channels: usize) -> Self {
38        let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
39        let ptr_ptr = std::ptr::addr_of_mut!(ptr);
40        let mut pitch = 0_usize;
41        let pitch_ptr = std::ptr::addr_of_mut!(pitch);
42        let line_size = width * num_channels * std::mem::size_of::<T>();
43        let ret = cpp!(unsafe [
44            ptr_ptr as "void**",
45            pitch_ptr as "std::size_t*",
46            line_size as "std::size_t",
47            height as "std::size_t"
48        ] -> i32 as "std::int32_t" {
49            return cudaMallocPitch(
50                ptr_ptr,
51                pitch_ptr,
52                line_size,
53                height
54            );
55        });
56        match result!(ret, ptr.into()) {
57            Ok(internal) => Self {
58                width,
59                height,
60                num_channels,
61                pitch,
62                internal,
63                _phantom: Default::default(),
64            },
65            Err(err) => {
66                panic!("failed to allocate device memory: {err}");
67            }
68        }
69    }
70
71    #[cfg(feature = "ndarray")]
72    pub fn from_array(array: &ndarray::ArrayView3<T>, stream: &Stream) -> Result<Self> {
73        let host_buffer = HostBuffer::from_array(array);
74        let (height, width, num_channels) = array.dim();
75        let mut this = Self::new(width, height, num_channels);
76        // SAFETY: Safe because the stream is synchronized after this.
77        unsafe {
78            this.copy_from_async(&host_buffer, stream)?;
79        }
80        stream.synchronize()?;
81        Ok(this)
82    }
83
84    /// Copy from host buffer.
85    ///
86    /// # Safety
87    ///
88    /// This function is marked unsafe because it does not synchronize and the operation might not
89    /// have completed when it returns.
90    pub unsafe fn copy_from_async(&mut self, other: &HostBuffer<T>, stream: &Stream) -> Result<()> {
91        assert_eq!(self.num_elements(), other.num_elements);
92        let ptr_from = other.as_internal().as_ptr();
93        let ptr_to = self.as_mut_internal().as_mut_ptr();
94        let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
95        let height = self.height;
96        let pitch = self.pitch;
97        let stream_ptr = stream.as_internal().as_ptr();
98        let ret = cpp!(unsafe [
99            ptr_from as "void*",
100            ptr_to as "void*",
101            pitch as "std::size_t",
102            line_size as "std::size_t",
103            height as "std::size_t",
104            stream_ptr as "const void*"
105        ] -> i32 as "std::int32_t" {
106            return cudaMemcpy2DAsync(
107                ptr_to,
108                pitch,
109                ptr_from,
110                line_size,
111                line_size,
112                height,
113                cudaMemcpyHostToDevice,
114                (cudaStream_t) stream_ptr
115            );
116        });
117        result!(ret)
118    }
119
120    /// Copy to host buffer.
121    ///
122    /// # Safety
123    ///
124    /// This function is marked unsafe because it does not synchronize and the operation might not
125    /// have completed when it returns.
126    pub unsafe fn copy_to_async(&self, other: &mut HostBuffer<T>, stream: &Stream) -> Result<()> {
127        assert_eq!(self.num_elements(), other.num_elements);
128        let ptr_from = self.as_internal().as_ptr();
129        let ptr_to = other.as_mut_internal().as_mut_ptr();
130        let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
131        let height = self.height;
132        let pitch = self.pitch;
133        let stream_ptr = stream.as_internal().as_ptr();
134        let ret = cpp!(unsafe [
135            ptr_from as "void*",
136            ptr_to as "void*",
137            pitch as "std::size_t",
138            line_size as "std::size_t",
139            height as "std::size_t",
140            stream_ptr as "const void*"
141        ] -> i32 as "std::int32_t" {
142            return cudaMemcpy2DAsync(
143                ptr_to,
144                line_size,
145                ptr_from,
146                pitch,
147                line_size,
148                height,
149                cudaMemcpyDeviceToHost,
150                (cudaStream_t) stream_ptr
151            );
152        });
153        result!(ret)
154    }
155
156    /// Fill buffer with byte value.
157    pub fn fill_with_byte(&mut self, value: u8, stream: &Stream) -> Result<()> {
158        let ptr = self.as_internal().as_ptr();
159        let value = value as std::ffi::c_int;
160        let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
161        let height = self.height;
162        let pitch = self.pitch;
163        let stream_ptr = stream.as_internal().as_ptr();
164        let ret = cpp!(unsafe [
165            ptr as "void*",
166            value as "int",
167            pitch as "std::size_t",
168            line_size as "std::size_t",
169            height as "std::size_t",
170            stream_ptr as "const void*"
171        ] -> i32 as "std::int32_t" {
172            return cudaMemset2DAsync(
173                ptr,
174                pitch,
175                value,
176                line_size,
177                height,
178                (cudaStream_t) stream_ptr
179            );
180        });
181        result!(ret)
182    }
183
184    #[inline(always)]
185    pub fn num_elements(&self) -> usize {
186        self.width * self.height * self.num_channels
187    }
188
189    /// Get readonly reference to internal [`DevicePtr`].
190    #[inline(always)]
191    pub fn as_internal(&self) -> &DevicePtr {
192        &self.internal
193    }
194
195    /// Get readonly reference to internal [`DevicePtr`].
196    #[inline(always)]
197    pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
198        &mut self.internal
199    }
200}
201
202impl<T: Copy> Drop for DeviceBuffer2D<T> {
203    fn drop(&mut self) {
204        if self.internal.is_null() {
205            return;
206        }
207
208        // SAFETY: Safe because we won't use pointer after this.
209        let mut internal = unsafe { self.internal.take() };
210        let ptr = internal.as_mut_ptr();
211        let _ret = cpp!(unsafe [
212            ptr as "void*"
213        ] -> i32 as "std::int32_t" {
214            return cudaFree(ptr);
215        });
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_new() {
225        let buffer = DeviceBuffer2D::<u32>::new(120, 80, 3);
226        assert_eq!(buffer.width, 120);
227        assert_eq!(buffer.height, 80);
228        assert_eq!(buffer.num_channels, 3);
229        assert_eq!(buffer.num_elements(), 120 * 80 * 3);
230        assert!(buffer.pitch >= 360);
231    }
232
233    #[test]
234    fn test_copy() {
235        let stream = Stream::new().unwrap();
236        let all_ones = vec![1_u32; 150];
237        let host_buffer_all_ones = HostBuffer::from_slice(all_ones.as_slice());
238
239        let mut device_buffer = DeviceBuffer2D::<u32>::new(10, 5, 3);
240        unsafe {
241            device_buffer
242                .copy_from_async(&host_buffer_all_ones, &stream)
243                .unwrap();
244        }
245
246        let mut host_buffer = HostBuffer::<u32>::new(150);
247        unsafe {
248            device_buffer
249                .copy_to_async(&mut host_buffer, &stream)
250                .unwrap();
251        }
252
253        let mut another_device_buffer = DeviceBuffer2D::<u32>::new(10, 5, 3);
254        unsafe {
255            another_device_buffer
256                .copy_from_async(&host_buffer, &stream)
257                .unwrap();
258        }
259
260        let mut return_host_buffer = HostBuffer::<u32>::new(150);
261        unsafe {
262            another_device_buffer
263                .copy_to_async(&mut return_host_buffer, &stream)
264                .unwrap();
265        }
266
267        stream.synchronize().unwrap();
268
269        assert_eq!(return_host_buffer.num_elements, 150);
270        let return_data = return_host_buffer.to_vec();
271        assert_eq!(return_data.len(), 150);
272        assert!(return_data.into_iter().all(|v| v == 1_u32));
273    }
274
275    #[test]
276    fn test_copy_2d() {
277        let stream = Stream::new().unwrap();
278        let image: [u8; 12] = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4];
279        let host_buffer = HostBuffer::from_slice(&image);
280        let mut device_buffer = DeviceBuffer2D::<u8>::new(2, 2, 3);
281        unsafe {
282            device_buffer
283                .copy_from_async(&host_buffer, &stream)
284                .unwrap();
285        }
286        let mut return_host_buffer = HostBuffer::<u8>::new(12);
287        unsafe {
288            device_buffer
289                .copy_to_async(&mut return_host_buffer, &stream)
290                .unwrap();
291        }
292        stream.synchronize().unwrap();
293        assert_eq!(
294            &return_host_buffer.to_vec(),
295            &[1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]
296        );
297    }
298
299    #[test]
300    fn test_fill_with_byte() {
301        let stream = Stream::new().unwrap();
302        let mut device_buffer = DeviceBuffer2D::<u8>::new(2, 2, 3);
303        let mut host_buffer = HostBuffer::<u8>::new(2 * 2 * 3);
304        device_buffer.fill_with_byte(0xab, &stream).unwrap();
305        unsafe {
306            device_buffer
307                .copy_to_async(&mut host_buffer, &stream)
308                .unwrap();
309        }
310        stream.synchronize().unwrap();
311        assert_eq!(
312            host_buffer.to_vec(),
313            &[0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab]
314        );
315    }
316
317    #[test]
318    #[should_panic]
319    fn test_it_panics_when_copying_invalid_size() {
320        let stream = Stream::new().unwrap();
321        let device_buffer = DeviceBuffer2D::<u32>::new(5, 5, 3);
322        let mut host_buffer = HostBuffer::<u32>::new(80);
323        let _ = unsafe { device_buffer.copy_to_async(&mut host_buffer, &stream) };
324    }
325}