1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::ffi::device::Device;
5use crate::ffi::memory::host::HostBuffer;
6use crate::ffi::ptr::DevicePtr;
7use crate::ffi::result;
8use crate::ffi::stream::Stream;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12pub struct DeviceBuffer2D<T: Copy> {
16 pub width: usize,
17 pub height: usize,
18 pub num_channels: usize,
19 pub pitch: usize,
20 internal: DevicePtr,
21 device: DeviceId,
22 _phantom: std::marker::PhantomData<T>,
23}
24
25unsafe impl<T: Copy> Send for DeviceBuffer2D<T> {}
31
32unsafe impl<T: Copy> Sync for DeviceBuffer2D<T> {}
38
39impl<T: Copy> DeviceBuffer2D<T> {
40 pub fn new(width: usize, height: usize, num_channels: usize) -> Self {
41 let device = Device::get_or_panic();
42 let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
43 let ptr_ptr = std::ptr::addr_of_mut!(ptr);
44 let mut pitch = 0_usize;
45 let pitch_ptr = std::ptr::addr_of_mut!(pitch);
46 let line_size = width * num_channels * std::mem::size_of::<T>();
47 let ret = cpp!(unsafe [
48 ptr_ptr as "void**",
49 pitch_ptr as "std::size_t*",
50 line_size as "std::size_t",
51 height as "std::size_t"
52 ] -> i32 as "std::int32_t" {
53 return cudaMallocPitch(
54 ptr_ptr,
55 pitch_ptr,
56 line_size,
57 height
58 );
59 });
60 match result!(ret, DevicePtr::from_addr(ptr)) {
61 Ok(internal) => Self {
62 width,
63 height,
64 num_channels,
65 pitch,
66 internal,
67 device,
68 _phantom: Default::default(),
69 },
70 Err(err) => {
71 panic!("failed to allocate device memory: {err}");
72 }
73 }
74 }
75
76 #[cfg(feature = "ndarray")]
77 pub fn from_array(array: &ndarray::ArrayView3<T>, stream: &Stream) -> Result<Self> {
78 let host_buffer = HostBuffer::from_array(array);
79 let (height, width, num_channels) = array.dim();
80 let mut this = Self::new(width, height, num_channels);
81 unsafe {
83 this.copy_from_async(&host_buffer, stream)?;
84 }
85 stream.synchronize()?;
86 Ok(this)
87 }
88
89 pub unsafe fn copy_from_async(&mut self, other: &HostBuffer<T>, stream: &Stream) -> Result<()> {
96 assert_eq!(self.num_elements(), other.num_elements);
97 let ptr_from = other.as_internal().as_ptr();
98 let ptr_to = self.as_mut_internal().as_mut_ptr();
99 let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
100 let height = self.height;
101 let pitch = self.pitch;
102 let stream_ptr = stream.as_internal().as_ptr();
103 let ret = cpp!(unsafe [
104 ptr_from as "void*",
105 ptr_to as "void*",
106 pitch as "std::size_t",
107 line_size as "std::size_t",
108 height as "std::size_t",
109 stream_ptr as "const void*"
110 ] -> i32 as "std::int32_t" {
111 return cudaMemcpy2DAsync(
112 ptr_to,
113 pitch,
114 ptr_from,
115 line_size,
116 line_size,
117 height,
118 cudaMemcpyHostToDevice,
119 (cudaStream_t) stream_ptr
120 );
121 });
122 result!(ret)
123 }
124
125 pub unsafe fn copy_to_async(&self, other: &mut HostBuffer<T>, stream: &Stream) -> Result<()> {
132 assert_eq!(self.num_elements(), other.num_elements);
133 let ptr_from = self.as_internal().as_ptr();
134 let ptr_to = other.as_mut_internal().as_mut_ptr();
135 let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
136 let height = self.height;
137 let pitch = self.pitch;
138 let stream_ptr = stream.as_internal().as_ptr();
139 let ret = cpp!(unsafe [
140 ptr_from as "void*",
141 ptr_to as "void*",
142 pitch as "std::size_t",
143 line_size as "std::size_t",
144 height as "std::size_t",
145 stream_ptr as "const void*"
146 ] -> i32 as "std::int32_t" {
147 return cudaMemcpy2DAsync(
148 ptr_to,
149 line_size,
150 ptr_from,
151 pitch,
152 line_size,
153 height,
154 cudaMemcpyDeviceToHost,
155 (cudaStream_t) stream_ptr
156 );
157 });
158 result!(ret)
159 }
160
161 pub fn fill_with_byte(&mut self, value: u8, stream: &Stream) -> Result<()> {
163 let ptr = self.as_internal().as_ptr();
164 let value = value as std::ffi::c_int;
165 let line_size = self.width * self.num_channels * std::mem::size_of::<T>();
166 let height = self.height;
167 let pitch = self.pitch;
168 let stream_ptr = stream.as_internal().as_ptr();
169 let ret = cpp!(unsafe [
170 ptr as "void*",
171 value as "int",
172 pitch as "std::size_t",
173 line_size as "std::size_t",
174 height as "std::size_t",
175 stream_ptr as "const void*"
176 ] -> i32 as "std::int32_t" {
177 return cudaMemset2DAsync(
178 ptr,
179 pitch,
180 value,
181 line_size,
182 height,
183 (cudaStream_t) stream_ptr
184 );
185 });
186 result!(ret)
187 }
188
189 #[inline(always)]
190 pub fn num_elements(&self) -> usize {
191 self.width * self.height * self.num_channels
192 }
193
194 #[inline(always)]
196 pub fn as_internal(&self) -> &DevicePtr {
197 &self.internal
198 }
199
200 #[inline(always)]
202 pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
203 &mut self.internal
204 }
205
206 pub unsafe fn free(&mut self) {
216 if self.internal.is_null() {
217 return;
218 }
219
220 Device::set_or_panic(self.device);
221
222 let mut internal = unsafe { self.internal.take() };
224 let ptr = internal.as_mut_ptr();
225 let _ret = cpp!(unsafe [
226 ptr as "void*"
227 ] -> i32 as "std::int32_t" {
228 return cudaFree(ptr);
229 });
230 }
231}
232
233impl<T: Copy> Drop for DeviceBuffer2D<T> {
234 #[inline]
235 fn drop(&mut self) {
236 unsafe {
238 self.free();
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_new() {
249 let buffer = DeviceBuffer2D::<u32>::new(120, 80, 3);
250 assert_eq!(buffer.width, 120);
251 assert_eq!(buffer.height, 80);
252 assert_eq!(buffer.num_channels, 3);
253 assert_eq!(buffer.num_elements(), 120 * 80 * 3);
254 assert!(buffer.pitch >= 360);
255 }
256
257 #[test]
258 fn test_copy() {
259 let stream = Stream::new().unwrap();
260 let all_ones = vec![1_u32; 150];
261 let host_buffer_all_ones = HostBuffer::from_slice(all_ones.as_slice());
262
263 let mut device_buffer = DeviceBuffer2D::<u32>::new(10, 5, 3);
264 unsafe {
265 device_buffer
266 .copy_from_async(&host_buffer_all_ones, &stream)
267 .unwrap();
268 }
269
270 let mut host_buffer = HostBuffer::<u32>::new(150);
271 unsafe {
272 device_buffer
273 .copy_to_async(&mut host_buffer, &stream)
274 .unwrap();
275 }
276
277 let mut another_device_buffer = DeviceBuffer2D::<u32>::new(10, 5, 3);
278 unsafe {
279 another_device_buffer
280 .copy_from_async(&host_buffer, &stream)
281 .unwrap();
282 }
283
284 let mut return_host_buffer = HostBuffer::<u32>::new(150);
285 unsafe {
286 another_device_buffer
287 .copy_to_async(&mut return_host_buffer, &stream)
288 .unwrap();
289 }
290
291 stream.synchronize().unwrap();
292
293 assert_eq!(return_host_buffer.num_elements, 150);
294 let return_data = return_host_buffer.to_vec();
295 assert_eq!(return_data.len(), 150);
296 assert!(return_data.into_iter().all(|v| v == 1_u32));
297 }
298
299 #[test]
300 fn test_copy_2d() {
301 let stream = Stream::new().unwrap();
302 let image: [u8; 12] = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4];
303 let host_buffer = HostBuffer::from_slice(&image);
304 let mut device_buffer = DeviceBuffer2D::<u8>::new(2, 2, 3);
305 unsafe {
306 device_buffer
307 .copy_from_async(&host_buffer, &stream)
308 .unwrap();
309 }
310 let mut return_host_buffer = HostBuffer::<u8>::new(12);
311 unsafe {
312 device_buffer
313 .copy_to_async(&mut return_host_buffer, &stream)
314 .unwrap();
315 }
316 stream.synchronize().unwrap();
317 assert_eq!(
318 &return_host_buffer.to_vec(),
319 &[1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]
320 );
321 }
322
323 #[test]
324 fn test_fill_with_byte() {
325 let stream = Stream::new().unwrap();
326 let mut device_buffer = DeviceBuffer2D::<u8>::new(2, 2, 3);
327 let mut host_buffer = HostBuffer::<u8>::new(2 * 2 * 3);
328 device_buffer.fill_with_byte(0xab, &stream).unwrap();
329 unsafe {
330 device_buffer
331 .copy_to_async(&mut host_buffer, &stream)
332 .unwrap();
333 }
334 stream.synchronize().unwrap();
335 assert_eq!(
336 host_buffer.to_vec(),
337 &[0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab]
338 );
339 }
340
341 #[test]
342 #[should_panic]
343 fn test_it_panics_when_copying_invalid_size() {
344 let stream = Stream::new().unwrap();
345 let device_buffer = DeviceBuffer2D::<u32>::new(5, 5, 3);
346 let mut host_buffer = HostBuffer::<u32>::new(80);
347 let _ = unsafe { device_buffer.copy_to_async(&mut host_buffer, &stream) };
348 }
349}