async_cuda/ffi/memory/
device.rs1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::ffi::device::Device;
5use crate::ffi::memory::host::HostBuffer;
6use crate::ffi::ptr::DevicePtr;
7use crate::ffi::result;
8use crate::ffi::stream::Stream;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12pub struct DeviceBuffer<T: Copy> {
16 pub num_elements: usize,
17 internal: DevicePtr,
18 device: DeviceId,
19 _phantom: std::marker::PhantomData<T>,
20}
21
22unsafe impl<T: Copy> Send for DeviceBuffer<T> {}
28
29unsafe impl<T: Copy> Sync for DeviceBuffer<T> {}
35
36impl<T: Copy> DeviceBuffer<T> {
37 pub fn new(num_elements: usize, stream: &Stream) -> Self {
38 let device = Device::get_or_panic();
39 let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
40 let ptr_ptr = std::ptr::addr_of_mut!(ptr);
41 let size = num_elements * std::mem::size_of::<T>();
42 let stream_ptr = stream.as_internal().as_ptr();
43 let ret = cpp!(unsafe [
44 ptr_ptr as "void**",
45 size as "std::size_t",
46 stream_ptr as "const void*"
47 ] -> i32 as "std::int32_t" {
48 return cudaMallocAsync(ptr_ptr, size, (cudaStream_t) stream_ptr);
49 });
50 match result!(ret, DevicePtr::from_addr(ptr)) {
51 Ok(internal) => Self {
52 internal,
53 device,
54 num_elements,
55 _phantom: Default::default(),
56 },
57 Err(err) => {
58 panic!("failed to allocate device memory: {err}");
59 }
60 }
61 }
62
63 pub fn from_slice(slice: &[T], stream: &Stream) -> Result<Self> {
64 let host_buffer = HostBuffer::from_slice(slice);
65 let mut this = Self::new(slice.len(), stream);
66 unsafe {
68 this.copy_from_async(&host_buffer, stream)?;
69 }
70 stream.synchronize()?;
71 Ok(this)
72 }
73
74 #[cfg(feature = "ndarray")]
75 pub fn from_array<D: ndarray::Dimension>(
76 array: &ndarray::ArrayView<T, D>,
77 stream: &Stream,
78 ) -> Result<Self> {
79 let host_buffer = HostBuffer::from_array(array);
80 let mut this = Self::new(array.len(), stream);
81 unsafe {
83 this.copy_from_async(&host_buffer, stream)?;
84 }
85 stream.synchronize()?;
86 Ok(this)
87 }
88
89 pub unsafe fn copy_from_async(&mut self, other: &HostBuffer<T>, stream: &Stream) -> Result<()> {
96 assert_eq!(self.num_elements, other.num_elements);
97 let ptr_to = self.as_mut_internal().as_mut_ptr();
98 let ptr_from = other.as_internal().as_ptr();
99 let stream_ptr = stream.as_internal().as_ptr();
100 let size = self.num_elements * std::mem::size_of::<T>();
101 let ret = cpp!(unsafe [
102 ptr_from as "void*",
103 ptr_to as "void*",
104 size as "std::size_t",
105 stream_ptr as "const void*"
106 ] -> i32 as "std::int32_t" {
107 return cudaMemcpyAsync(
108 ptr_to,
109 ptr_from,
110 size,
111 cudaMemcpyHostToDevice,
112 (cudaStream_t) stream_ptr
113 );
114 });
115 result!(ret)
116 }
117
118 pub unsafe fn copy_to_async(&self, other: &mut HostBuffer<T>, stream: &Stream) -> Result<()> {
125 assert_eq!(self.num_elements, other.num_elements);
126 let ptr_from = self.as_internal().as_ptr();
127 let ptr_to = other.as_mut_internal().as_mut_ptr();
128 let size = self.num_elements * std::mem::size_of::<T>();
129 let stream_ptr = stream.as_internal().as_ptr();
130 let ret = cpp!(unsafe [
131 ptr_from as "void*",
132 ptr_to as "void*",
133 size as "std::size_t",
134 stream_ptr as "const void*"
135 ] -> i32 as "std::int32_t" {
136 return cudaMemcpyAsync(
137 ptr_to,
138 ptr_from,
139 size,
140 cudaMemcpyDeviceToHost,
141 (cudaStream_t) stream_ptr
142 );
143 });
144 result!(ret)
145 }
146
147 pub fn fill_with_byte(&mut self, value: u8, stream: &Stream) -> Result<()> {
149 let ptr = self.as_internal().as_ptr();
150 let value = value as std::ffi::c_int;
151 let size = self.num_elements * std::mem::size_of::<T>();
152 let stream_ptr = stream.as_internal().as_ptr();
153 let ret = cpp!(unsafe [
154 ptr as "void*",
155 value as "int",
156 size as "std::size_t",
157 stream_ptr as "const void*"
158 ] -> i32 as "std::int32_t" {
159 return cudaMemsetAsync(
160 ptr,
161 value,
162 size,
163 (cudaStream_t) stream_ptr
164 );
165 });
166 result!(ret)
167 }
168
169 #[inline(always)]
171 pub fn as_internal(&self) -> &DevicePtr {
172 &self.internal
173 }
174
175 #[inline(always)]
177 pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
178 &mut self.internal
179 }
180
181 pub unsafe fn free(&mut self) {
191 if self.internal.is_null() {
192 return;
193 }
194
195 Device::set_or_panic(self.device);
196
197 let mut internal = unsafe { self.internal.take() };
199 let ptr = internal.as_mut_ptr();
200 let _ret = cpp!(unsafe [
201 ptr as "void*"
202 ] -> i32 as "std::int32_t" {
203 return cudaFree(ptr);
204 });
205 }
206}
207
208impl<T: Copy> Drop for DeviceBuffer<T> {
209 #[inline]
210 fn drop(&mut self) {
211 unsafe {
213 self.free();
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_new() {
224 let buffer = DeviceBuffer::<u32>::new(100, &Stream::null());
225 assert_eq!(buffer.num_elements, 100);
226 }
227
228 #[test]
229 fn test_copy() {
230 let stream = Stream::new().unwrap();
231 let all_ones = vec![1_u32; 100];
232 let host_buffer_all_ones = HostBuffer::from_slice(all_ones.as_slice());
233
234 let mut device_buffer = DeviceBuffer::<u32>::new(100, &stream);
235 unsafe {
236 device_buffer
237 .copy_from_async(&host_buffer_all_ones, &stream)
238 .unwrap();
239 }
240
241 let mut host_buffer = HostBuffer::<u32>::new(100);
242 unsafe {
243 device_buffer
244 .copy_to_async(&mut host_buffer, &stream)
245 .unwrap();
246 }
247
248 let mut another_device_buffer = DeviceBuffer::<u32>::new(100, &stream);
249 unsafe {
250 another_device_buffer
251 .copy_from_async(&host_buffer, &stream)
252 .unwrap();
253 }
254
255 let mut return_host_buffer = HostBuffer::<u32>::new(100);
256 unsafe {
257 another_device_buffer
258 .copy_to_async(&mut return_host_buffer, &stream)
259 .unwrap();
260 }
261
262 stream.synchronize().unwrap();
263
264 assert_eq!(return_host_buffer.num_elements, 100);
265 let return_data = return_host_buffer.to_vec();
266 assert_eq!(return_data.len(), 100);
267 assert!(return_data.into_iter().all(|v| v == 1_u32));
268 }
269
270 #[test]
271 fn test_fill_with_byte() {
272 let stream = Stream::new().unwrap();
273 let mut device_buffer = DeviceBuffer::<u8>::new(4, &stream);
274 let mut host_buffer = HostBuffer::<u8>::new(4);
275 device_buffer.fill_with_byte(0xab, &stream).unwrap();
276 unsafe {
277 device_buffer
278 .copy_to_async(&mut host_buffer, &stream)
279 .unwrap();
280 }
281 stream.synchronize().unwrap();
282 assert_eq!(host_buffer.to_vec(), &[0xab, 0xab, 0xab, 0xab]);
283 }
284
285 #[test]
286 #[should_panic]
287 fn test_it_panics_when_copying_invalid_size() {
288 let stream = Stream::new().unwrap();
289 let device_buffer = DeviceBuffer::<u32>::new(101, &stream);
290 let mut host_buffer = HostBuffer::<u32>::new(100);
291 let _ = unsafe { device_buffer.copy_to_async(&mut host_buffer, &stream) };
292 }
293}