async_cuda/ffi/memory/
host.rs1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::ffi::device::Device;
5use crate::ffi::memory::device::DeviceBuffer;
6use crate::ffi::ptr::DevicePtr;
7use crate::ffi::result;
8use crate::ffi::stream::Stream;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12pub struct HostBuffer<T: Copy> {
16 pub num_elements: usize,
17 internal: DevicePtr,
18 device: DeviceId,
19 _phantom: std::marker::PhantomData<T>,
20}
21
22unsafe impl<T: Copy> Send for HostBuffer<T> {}
28
29unsafe impl<T: Copy> Sync for HostBuffer<T> {}
35
36impl<T: Copy> HostBuffer<T> {
37 pub fn new(num_elements: usize) -> Self {
38 let device = Device::get_or_panic();
39 let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
40 let ptr_ptr = std::ptr::addr_of_mut!(ptr);
41 let size = num_elements * std::mem::size_of::<T>();
42 let ret = cpp!(unsafe [
43 ptr_ptr as "void**",
44 size as "std::size_t"
45 ] -> i32 as "std::int32_t" {
46 return cudaMallocHost(ptr_ptr, size);
47 });
48 match result!(ret, DevicePtr::from_addr(ptr)) {
49 Ok(internal) => Self {
50 internal,
51 device,
52 num_elements,
53 _phantom: Default::default(),
54 },
55 Err(err) => {
56 panic!("failed to allocate host memory: {err}");
57 }
58 }
59 }
60
61 pub fn from_slice(slice: &[T]) -> Self {
62 let mut this = Self::new(slice.len());
63 this.copy_from_slice(slice);
64 this
65 }
66
67 #[cfg(feature = "ndarray")]
68 pub fn from_array<D: ndarray::Dimension>(array: &ndarray::ArrayView<T, D>) -> Self {
69 let mut this = Self::new(array.len());
70 this.copy_from_array(array);
71 this
72 }
73
74 #[inline]
81 pub unsafe fn copy_from_async(
82 &mut self,
83 other: &DeviceBuffer<T>,
84 stream: &Stream,
85 ) -> Result<()> {
86 other.copy_to_async(self, stream)
87 }
88
89 #[inline]
96 pub unsafe fn copy_to_async(&self, other: &mut DeviceBuffer<T>, stream: &Stream) -> Result<()> {
97 other.copy_from_async(self, stream)
98 }
99
100 pub fn copy_from_slice(&mut self, slice: &[T]) {
101 let target = unsafe {
104 std::slice::from_raw_parts_mut(self.internal.as_mut_ptr() as *mut T, self.num_elements)
105 };
106 target.copy_from_slice(slice);
107 }
108
109 #[cfg(feature = "ndarray")]
110 pub fn copy_from_array<D: ndarray::Dimension>(&mut self, array: &ndarray::ArrayView<T, D>) {
111 assert!(
112 array.is_standard_layout(),
113 "array must be in standard layout"
114 );
115 let target = unsafe {
118 std::slice::from_raw_parts_mut(self.internal.as_mut_ptr() as *mut T, self.num_elements)
119 };
120 target.copy_from_slice(array.as_slice().unwrap());
121 }
122
123 #[inline]
124 pub fn to_vec(&self) -> Vec<T> {
125 let source = unsafe {
128 std::slice::from_raw_parts(self.internal.as_ptr() as *const T, self.num_elements)
129 };
130 source.to_vec()
131 }
132
133 #[cfg(feature = "ndarray")]
134 pub fn to_array_with_shape<D: ndarray::Dimension>(
135 &self,
136 shape: impl Into<ndarray::StrideShape<D>>,
137 ) -> ndarray::Array<T, D> {
138 let shape = shape.into();
139 assert_eq!(
140 self.num_elements,
141 shape.size(),
142 "provided shape does not match number of elements in buffer"
143 );
144 ndarray::Array::from_shape_vec(shape, self.to_vec()).unwrap()
145 }
146
147 #[inline(always)]
149 pub fn as_internal(&self) -> &DevicePtr {
150 &self.internal
151 }
152
153 #[inline(always)]
155 pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
156 &mut self.internal
157 }
158
159 pub unsafe fn free(&mut self) {
169 if self.internal.is_null() {
170 return;
171 }
172
173 Device::set_or_panic(self.device);
174
175 let mut internal = unsafe { self.internal.take() };
177 let ptr = internal.as_mut_ptr();
178 let _ret = cpp!(unsafe [
179 ptr as "void*"
180 ] -> i32 as "std::int32_t" {
181 return cudaFreeHost(ptr);
182 });
183 }
184}
185
186impl<T: Copy> Drop for HostBuffer<T> {
187 #[inline]
188 fn drop(&mut self) {
189 unsafe {
191 self.free();
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_new() {
202 let buffer = HostBuffer::<u32>::new(100);
203 assert_eq!(buffer.num_elements, 100);
204 assert_eq!(buffer.to_vec().len(), 100);
205 }
206
207 #[test]
208 fn test_from_slice() {
209 let all_ones = vec![1_u32; 200];
210 let buffer = HostBuffer::from_slice(all_ones.as_slice());
211 assert_eq!(buffer.num_elements, 200);
212 let data = buffer.to_vec();
213 assert_eq!(data.len(), 200);
214 assert!(data.into_iter().all(|v| v == 1_u32));
215 }
216
217 #[test]
218 fn test_copy() {
219 let stream = Stream::new().unwrap();
220 let all_ones = vec![1_u32; 100];
221 let host_buffer = HostBuffer::from_slice(all_ones.as_slice());
222
223 let mut device_buffer = DeviceBuffer::<u32>::new(100, &stream);
224 unsafe {
225 host_buffer
226 .copy_to_async(&mut device_buffer, &stream)
227 .unwrap();
228 }
229
230 let mut return_host_buffer = HostBuffer::<u32>::new(100);
231 unsafe {
232 return_host_buffer
233 .copy_from_async(&device_buffer, &stream)
234 .unwrap();
235 }
236
237 stream.synchronize().unwrap();
238
239 assert_eq!(return_host_buffer.num_elements, 100);
240 let return_data = return_host_buffer.to_vec();
241 assert_eq!(return_data.len(), 100);
242 assert!(return_data.into_iter().all(|v| v == 1_u32));
243 }
244
245 #[test]
246 #[should_panic]
247 fn test_it_panics_when_copying_invalid_size() {
248 let stream = Stream::new().unwrap();
249 let host_buffer = HostBuffer::<u32>::new(100);
250 let mut device_buffer = DeviceBuffer::<u32>::new(101, &Stream::null());
251 let _ = unsafe { host_buffer.copy_to_async(&mut device_buffer, &stream) };
252 }
253}