async_cuda_core/ffi/memory/
device.rs1use cpp::cpp;
2
3use crate::ffi::memory::host::HostBuffer;
4use crate::ffi::ptr::DevicePtr;
5use crate::ffi::result;
6use crate::ffi::stream::Stream;
7
8type Result<T> = std::result::Result<T, crate::error::Error>;
9
10pub struct DeviceBuffer<T: Copy> {
14 pub num_elements: usize,
15 internal: DevicePtr,
16 _phantom: std::marker::PhantomData<T>,
17}
18
19unsafe impl<T: Copy> Send for DeviceBuffer<T> {}
25
26unsafe impl<T: Copy> Sync for DeviceBuffer<T> {}
32
33impl<T: Copy> DeviceBuffer<T> {
34 pub fn new(num_elements: usize, stream: &Stream) -> Self {
35 let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
36 let ptr_ptr = std::ptr::addr_of_mut!(ptr);
37 let size = num_elements * std::mem::size_of::<T>();
38 let stream_ptr = stream.as_internal().as_ptr();
39 let ret = cpp!(unsafe [
40 ptr_ptr as "void**",
41 size as "std::size_t",
42 stream_ptr as "const void*"
43 ] -> i32 as "std::int32_t" {
44 return cudaMallocAsync(ptr_ptr, size, (cudaStream_t) stream_ptr);
45 });
46 match result!(ret, ptr.into()) {
47 Ok(internal) => Self {
48 internal,
49 num_elements,
50 _phantom: Default::default(),
51 },
52 Err(err) => {
53 panic!("failed to allocate device memory: {err}");
54 }
55 }
56 }
57
58 pub fn from_slice(slice: &[T], stream: &Stream) -> Result<Self> {
59 let host_buffer = HostBuffer::from_slice(slice);
60 let mut this = Self::new(slice.len(), stream);
61 unsafe {
63 this.copy_from_async(&host_buffer, stream)?;
64 }
65 stream.synchronize()?;
66 Ok(this)
67 }
68
69 #[cfg(feature = "ndarray")]
70 pub fn from_array<D: ndarray::Dimension>(
71 array: &ndarray::ArrayView<T, D>,
72 stream: &Stream,
73 ) -> Result<Self> {
74 let host_buffer = HostBuffer::from_array(array);
75 let mut this = Self::new(array.len(), stream);
76 unsafe {
78 this.copy_from_async(&host_buffer, stream)?;
79 }
80 stream.synchronize()?;
81 Ok(this)
82 }
83
84 pub unsafe fn copy_from_async(&mut self, other: &HostBuffer<T>, stream: &Stream) -> Result<()> {
91 assert_eq!(self.num_elements, other.num_elements);
92 let ptr_to = self.as_mut_internal().as_mut_ptr();
93 let ptr_from = other.as_internal().as_ptr();
94 let stream_ptr = stream.as_internal().as_ptr();
95 let size = self.num_elements * std::mem::size_of::<T>();
96 let ret = cpp!(unsafe [
97 ptr_from as "void*",
98 ptr_to as "void*",
99 size as "std::size_t",
100 stream_ptr as "const void*"
101 ] -> i32 as "std::int32_t" {
102 return cudaMemcpyAsync(
103 ptr_to,
104 ptr_from,
105 size,
106 cudaMemcpyHostToDevice,
107 (cudaStream_t) stream_ptr
108 );
109 });
110 result!(ret)
111 }
112
113 pub unsafe fn copy_to_async(&self, other: &mut HostBuffer<T>, stream: &Stream) -> Result<()> {
120 assert_eq!(self.num_elements, other.num_elements);
121 let ptr_from = self.as_internal().as_ptr();
122 let ptr_to = other.as_mut_internal().as_mut_ptr();
123 let size = self.num_elements * std::mem::size_of::<T>();
124 let stream_ptr = stream.as_internal().as_ptr();
125 let ret = cpp!(unsafe [
126 ptr_from as "void*",
127 ptr_to as "void*",
128 size as "std::size_t",
129 stream_ptr as "const void*"
130 ] -> i32 as "std::int32_t" {
131 return cudaMemcpyAsync(
132 ptr_to,
133 ptr_from,
134 size,
135 cudaMemcpyDeviceToHost,
136 (cudaStream_t) stream_ptr
137 );
138 });
139 result!(ret)
140 }
141
142 pub fn fill_with_byte(&mut self, value: u8, stream: &Stream) -> Result<()> {
144 let ptr = self.as_internal().as_ptr();
145 let value = value as std::ffi::c_int;
146 let size = self.num_elements * std::mem::size_of::<T>();
147 let stream_ptr = stream.as_internal().as_ptr();
148 let ret = cpp!(unsafe [
149 ptr as "void*",
150 value as "int",
151 size as "std::size_t",
152 stream_ptr as "const void*"
153 ] -> i32 as "std::int32_t" {
154 return cudaMemsetAsync(
155 ptr,
156 value,
157 size,
158 (cudaStream_t) stream_ptr
159 );
160 });
161 result!(ret)
162 }
163
164 #[inline(always)]
166 pub fn as_internal(&self) -> &DevicePtr {
167 &self.internal
168 }
169
170 #[inline(always)]
172 pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
173 &mut self.internal
174 }
175}
176
177impl<T: Copy> Drop for DeviceBuffer<T> {
178 fn drop(&mut self) {
179 if self.internal.is_null() {
180 return;
181 }
182
183 let mut internal = unsafe { self.internal.take() };
185 let ptr = internal.as_mut_ptr();
186 let _ret = cpp!(unsafe [
187 ptr as "void*"
188 ] -> i32 as "std::int32_t" {
189 return cudaFree(ptr);
190 });
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_new() {
200 let buffer = DeviceBuffer::<u32>::new(100, &Stream::null());
201 assert_eq!(buffer.num_elements, 100);
202 }
203
204 #[test]
205 fn test_copy() {
206 let stream = Stream::new().unwrap();
207 let all_ones = vec![1_u32; 100];
208 let host_buffer_all_ones = HostBuffer::from_slice(all_ones.as_slice());
209
210 let mut device_buffer = DeviceBuffer::<u32>::new(100, &stream);
211 unsafe {
212 device_buffer
213 .copy_from_async(&host_buffer_all_ones, &stream)
214 .unwrap();
215 }
216
217 let mut host_buffer = HostBuffer::<u32>::new(100);
218 unsafe {
219 device_buffer
220 .copy_to_async(&mut host_buffer, &stream)
221 .unwrap();
222 }
223
224 let mut another_device_buffer = DeviceBuffer::<u32>::new(100, &stream);
225 unsafe {
226 another_device_buffer
227 .copy_from_async(&host_buffer, &stream)
228 .unwrap();
229 }
230
231 let mut return_host_buffer = HostBuffer::<u32>::new(100);
232 unsafe {
233 another_device_buffer
234 .copy_to_async(&mut return_host_buffer, &stream)
235 .unwrap();
236 }
237
238 stream.synchronize().unwrap();
239
240 assert_eq!(return_host_buffer.num_elements, 100);
241 let return_data = return_host_buffer.to_vec();
242 assert_eq!(return_data.len(), 100);
243 assert!(return_data.into_iter().all(|v| v == 1_u32));
244 }
245
246 #[test]
247 fn test_fill_with_byte() {
248 let stream = Stream::new().unwrap();
249 let mut device_buffer = DeviceBuffer::<u8>::new(4, &stream);
250 let mut host_buffer = HostBuffer::<u8>::new(4);
251 device_buffer.fill_with_byte(0xab, &stream).unwrap();
252 unsafe {
253 device_buffer
254 .copy_to_async(&mut host_buffer, &stream)
255 .unwrap();
256 }
257 stream.synchronize().unwrap();
258 assert_eq!(host_buffer.to_vec(), &[0xab, 0xab, 0xab, 0xab]);
259 }
260
261 #[test]
262 #[should_panic]
263 fn test_it_panics_when_copying_invalid_size() {
264 let stream = Stream::new().unwrap();
265 let device_buffer = DeviceBuffer::<u32>::new(101, &stream);
266 let mut host_buffer = HostBuffer::<u32>::new(100);
267 let _ = unsafe { device_buffer.copy_to_async(&mut host_buffer, &stream) };
268 }
269}