async_cuda/ffi/memory/
device2d.rs

1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::ffi::device::Device;
5use crate::ffi::memory::host::HostBuffer;
6use crate::ffi::ptr::DevicePtr;
7use crate::ffi::result;
8use crate::ffi::stream::Stream;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12/// Synchronous implementation of [`crate::DeviceBuffer2D`].
13///
14/// Refer to [`crate::DeviceBuffer2D`] for documentation.
15pub struct DeviceBuffer2D<T: Copy> {
16    pub width: usize,
17    pub height: usize,
18    pub num_channels: usize,
19    pub pitch: usize,
20    internal: DevicePtr,
21    device: DeviceId,
22    _phantom: std::marker::PhantomData<T>,
23}
24
25/// Implements [`Send`] for [`DeviceBuffer2D`].
26///
27/// # Safety
28///
29/// This property is inherited from the CUDA API, which is thread-safe.
30unsafe impl<T: Copy> Send for DeviceBuffer2D<T> {}
31
32/// Implements [`Sync`] for [`DeviceBuffer2D`].
33///
34/// # Safety
35///
36/// This property is inherited from the CUDA API, which is thread-safe.
37unsafe impl<T: Copy> Sync for DeviceBuffer2D<T> {}
38
39impl<T: Copy> DeviceBuffer2D<T> {
40    pub fn new(width: usize, height: usize, num_channels: usize) -> Self {
41        let device = Device::get_or_panic();
42        let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
43        let ptr_ptr = std::ptr::addr_of_mut!(ptr);
44        let mut pitch = 0_usize;
45        let pitch_ptr = std::ptr::addr_of_mut!(pitch);
46        let line_size = width * num_channels * std::mem::size_of::<T>();
47        let ret = cpp!(unsafe [
48            ptr_ptr as "void**",
49            pitch_ptr as "std::size_t*",
50            line_size as "std::size_t",
51            height as "std::size_t"
52        ] -> i32 as "std::int32_t" {
53            return cudaMallocPitch(
54                ptr_ptr,
55                pitch_ptr,
56                line_size,
57                height
58            );
59        });
60        match result!(ret, DevicePtr::from_addr(ptr)) {
61            Ok(internal) => Self {
62                width,
63                height,
64                num_channels,
65                pitch,
66                internal,
67                device,
68                _phantom: Default::default(),
69            },
70            Err(err) => {
71                panic!("failed to allocate device memory: {err}");
72            }
73        }
74    }
75
76    #[cfg(feature = "ndarray")]
77    pub fn from_array(array: &ndarray::ArrayView3<T>, stream: &Stream) -> Result<Self> {
78        let host_buffer = HostBuffer::from_array(array);
79        let (height, width, num_channels) = array.dim();
80        let mut this = Self::new(width, height, num_channels);
81        // SAFETY: Safe because the stream is synchronized after this.
82        unsafe {
83            this.copy_from_async(&host_buffer, stream)?;
84        }
85        stream.synchronize()?;
86        Ok(this)
87    }
88
89    /// Copy from host buffer.
90    ///
91    /// # Safety
92    ///
93    /// This function is marked unsafe because it does not synchronize and the operation might not
94    /// have completed when it returns.
95    pub unsafe fn copy_from_async(&mut self, other: &HostBuffer<T>, stream: &Stream) -> Result<()> {
96        assert_eq!(self.num_elements(), other.num_elements);
97        let ptr_from = other.as_internal().as_ptr();
98        let ptr_to = self.as_mut_internal().as_mut_ptr();
99        let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
100        let height = self.height;
101        let pitch = self.pitch;
102        let stream_ptr = stream.as_internal().as_ptr();
103        let ret = cpp!(unsafe [
104            ptr_from as "void*",
105            ptr_to as "void*",
106            pitch as "std::size_t",
107            line_size as "std::size_t",
108            height as "std::size_t",
109            stream_ptr as "const void*"
110        ] -> i32 as "std::int32_t" {
111            return cudaMemcpy2DAsync(
112                ptr_to,
113                pitch,
114                ptr_from,
115                line_size,
116                line_size,
117                height,
118                cudaMemcpyHostToDevice,
119                (cudaStream_t) stream_ptr
120            );
121        });
122        result!(ret)
123    }
124
125    /// Copy to host buffer.
126    ///
127    /// # Safety
128    ///
129    /// This function is marked unsafe because it does not synchronize and the operation might not
130    /// have completed when it returns.
131    pub unsafe fn copy_to_async(&self, other: &mut HostBuffer<T>, stream: &Stream) -> Result<()> {
132        assert_eq!(self.num_elements(), other.num_elements);
133        let ptr_from = self.as_internal().as_ptr();
134        let ptr_to = other.as_mut_internal().as_mut_ptr();
135        let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
136        let height = self.height;
137        let pitch = self.pitch;
138        let stream_ptr = stream.as_internal().as_ptr();
139        let ret = cpp!(unsafe [
140            ptr_from as "void*",
141            ptr_to as "void*",
142            pitch as "std::size_t",
143            line_size as "std::size_t",
144            height as "std::size_t",
145            stream_ptr as "const void*"
146        ] -> i32 as "std::int32_t" {
147            return cudaMemcpy2DAsync(
148                ptr_to,
149                line_size,
150                ptr_from,
151                pitch,
152                line_size,
153                height,
154                cudaMemcpyDeviceToHost,
155                (cudaStream_t) stream_ptr
156            );
157        });
158        result!(ret)
159    }
160
161    /// Fill buffer with byte value.
162    pub fn fill_with_byte(&mut self, value: u8, stream: &Stream) -> Result<()> {
163        let ptr = self.as_internal().as_ptr();
164        let value = value as std::ffi::c_int;
165        let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
166        let height = self.height;
167        let pitch = self.pitch;
168        let stream_ptr = stream.as_internal().as_ptr();
169        let ret = cpp!(unsafe [
170            ptr as "void*",
171            value as "int",
172            pitch as "std::size_t",
173            line_size as "std::size_t",
174            height as "std::size_t",
175            stream_ptr as "const void*"
176        ] -> i32 as "std::int32_t" {
177            return cudaMemset2DAsync(
178                ptr,
179                pitch,
180                value,
181                line_size,
182                height,
183                (cudaStream_t) stream_ptr
184            );
185        });
186        result!(ret)
187    }
188
189    #[inline(always)]
190    pub fn num_elements(&self) -> usize {
191        self.width * self.height * self.num_channels
192    }
193
194    /// Get readonly reference to internal [`DevicePtr`].
195    #[inline(always)]
196    pub fn as_internal(&self) -> &DevicePtr {
197        &self.internal
198    }
199
200    /// Get mutable reference to internal [`DevicePtr`].
201    #[inline(always)]
202    pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
203        &mut self.internal
204    }
205
206    /// Release the buffer memory.
207    ///
208    /// # Panics
209    ///
210    /// This function panics if binding to the corresponding device fails.
211    ///
212    /// # Safety
213    ///
214    /// The buffer may not be used after this function is called, except for being dropped.
215    pub unsafe fn free(&mut self) {
216        if self.internal.is_null() {
217            return;
218        }
219
220        Device::set_or_panic(self.device);
221
222        // SAFETY: Safe because we won't use pointer after this.
223        let mut internal = unsafe { self.internal.take() };
224        let ptr = internal.as_mut_ptr();
225        let _ret = cpp!(unsafe [
226            ptr as "void*"
227        ] -> i32 as "std::int32_t" {
228            return cudaFree(ptr);
229        });
230    }
231}
232
233impl<T: Copy> Drop for DeviceBuffer2D<T> {
234    #[inline]
235    fn drop(&mut self) {
236        // SAFETY: This is safe since the buffer cannot be used after this.
237        unsafe {
238            self.free();
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_new() {
249        let buffer = DeviceBuffer2D::<u32>::new(120, 80, 3);
250        assert_eq!(buffer.width, 120);
251        assert_eq!(buffer.height, 80);
252        assert_eq!(buffer.num_channels, 3);
253        assert_eq!(buffer.num_elements(), 120 * 80 * 3);
254        assert!(buffer.pitch >= 360);
255    }
256
257    #[test]
258    fn test_copy() {
259        let stream = Stream::new().unwrap();
260        let all_ones = vec![1_u32; 150];
261        let host_buffer_all_ones = HostBuffer::from_slice(all_ones.as_slice());
262
263        let mut device_buffer = DeviceBuffer2D::<u32>::new(10, 5, 3);
264        unsafe {
265            device_buffer
266                .copy_from_async(&host_buffer_all_ones, &stream)
267                .unwrap();
268        }
269
270        let mut host_buffer = HostBuffer::<u32>::new(150);
271        unsafe {
272            device_buffer
273                .copy_to_async(&mut host_buffer, &stream)
274                .unwrap();
275        }
276
277        let mut another_device_buffer = DeviceBuffer2D::<u32>::new(10, 5, 3);
278        unsafe {
279            another_device_buffer
280                .copy_from_async(&host_buffer, &stream)
281                .unwrap();
282        }
283
284        let mut return_host_buffer = HostBuffer::<u32>::new(150);
285        unsafe {
286            another_device_buffer
287                .copy_to_async(&mut return_host_buffer, &stream)
288                .unwrap();
289        }
290
291        stream.synchronize().unwrap();
292
293        assert_eq!(return_host_buffer.num_elements, 150);
294        let return_data = return_host_buffer.to_vec();
295        assert_eq!(return_data.len(), 150);
296        assert!(return_data.into_iter().all(|v| v == 1_u32));
297    }
298
299    #[test]
300    fn test_copy_2d() {
301        let stream = Stream::new().unwrap();
302        let image: [u8; 12] = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4];
303        let host_buffer = HostBuffer::from_slice(&image);
304        let mut device_buffer = DeviceBuffer2D::<u8>::new(2, 2, 3);
305        unsafe {
306            device_buffer
307                .copy_from_async(&host_buffer, &stream)
308                .unwrap();
309        }
310        let mut return_host_buffer = HostBuffer::<u8>::new(12);
311        unsafe {
312            device_buffer
313                .copy_to_async(&mut return_host_buffer, &stream)
314                .unwrap();
315        }
316        stream.synchronize().unwrap();
317        assert_eq!(
318            &return_host_buffer.to_vec(),
319            &[1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]
320        );
321    }
322
323    #[test]
324    fn test_fill_with_byte() {
325        let stream = Stream::new().unwrap();
326        let mut device_buffer = DeviceBuffer2D::<u8>::new(2, 2, 3);
327        let mut host_buffer = HostBuffer::<u8>::new(2 * 2 * 3);
328        device_buffer.fill_with_byte(0xab, &stream).unwrap();
329        unsafe {
330            device_buffer
331                .copy_to_async(&mut host_buffer, &stream)
332                .unwrap();
333        }
334        stream.synchronize().unwrap();
335        assert_eq!(
336            host_buffer.to_vec(),
337            &[0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab]
338        );
339    }
340
341    #[test]
342    #[should_panic]
343    fn test_it_panics_when_copying_invalid_size() {
344        let stream = Stream::new().unwrap();
345        let device_buffer = DeviceBuffer2D::<u32>::new(5, 5, 3);
346        let mut host_buffer = HostBuffer::<u32>::new(80);
347        let _ = unsafe { device_buffer.copy_to_async(&mut host_buffer, &stream) };
348    }
349}