async_cuda_core/ffi/memory/
host.rs1use cpp::cpp;
2
3use crate::ffi::memory::device::DeviceBuffer;
4use crate::ffi::ptr::DevicePtr;
5use crate::ffi::result;
6use crate::ffi::stream::Stream;
7
8type Result<T> = std::result::Result<T, crate::error::Error>;
9
10pub struct HostBuffer<T: Copy> {
14 pub num_elements: usize,
15 internal: DevicePtr,
16 _phantom: std::marker::PhantomData<T>,
17}
18
19unsafe impl<T: Copy> Send for HostBuffer<T> {}
25
26unsafe impl<T: Copy> Sync for HostBuffer<T> {}
32
33impl<T: Copy> HostBuffer<T> {
34 pub fn new(num_elements: usize) -> Self {
35 let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
36 let ptr_ptr = std::ptr::addr_of_mut!(ptr);
37 let size = num_elements * std::mem::size_of::<T>();
38 let ret = cpp!(unsafe [
39 ptr_ptr as "void**",
40 size as "std::size_t"
41 ] -> i32 as "std::int32_t" {
42 return cudaMallocHost(ptr_ptr, size);
43 });
44 match result!(ret, ptr.into()) {
45 Ok(internal) => Self {
46 internal,
47 num_elements,
48 _phantom: Default::default(),
49 },
50 Err(err) => {
51 panic!("failed to allocate host memory: {err}");
52 }
53 }
54 }
55
56 pub fn from_slice(slice: &[T]) -> Self {
57 let mut this = Self::new(slice.len());
58 this.copy_from_slice(slice);
59 this
60 }
61
62 #[cfg(feature = "ndarray")]
63 pub fn from_array<D: ndarray::Dimension>(array: &ndarray::ArrayView<T, D>) -> Self {
64 let mut this = Self::new(array.len());
65 this.copy_from_array(array);
66 this
67 }
68
69 #[inline]
76 pub unsafe fn copy_from_async(
77 &mut self,
78 other: &DeviceBuffer<T>,
79 stream: &Stream,
80 ) -> Result<()> {
81 other.copy_to_async(self, stream)
82 }
83
84 #[inline]
91 pub unsafe fn copy_to_async(&self, other: &mut DeviceBuffer<T>, stream: &Stream) -> Result<()> {
92 other.copy_from_async(self, stream)
93 }
94
95 pub fn copy_from_slice(&mut self, slice: &[T]) {
96 let target = unsafe {
99 std::slice::from_raw_parts_mut(self.internal.as_mut_ptr() as *mut T, self.num_elements)
100 };
101 target.copy_from_slice(slice);
102 }
103
104 #[cfg(feature = "ndarray")]
105 pub fn copy_from_array<D: ndarray::Dimension>(&mut self, array: &ndarray::ArrayView<T, D>) {
106 assert!(
107 array.is_standard_layout(),
108 "array must be in standard layout"
109 );
110 let target = unsafe {
113 std::slice::from_raw_parts_mut(self.internal.as_mut_ptr() as *mut T, self.num_elements)
114 };
115 target.copy_from_slice(array.as_slice().unwrap());
116 }
117
118 #[inline]
119 pub fn to_vec(&self) -> Vec<T> {
120 let source = unsafe {
123 std::slice::from_raw_parts(self.internal.as_ptr() as *const T, self.num_elements)
124 };
125 source.to_vec()
126 }
127
128 #[cfg(feature = "ndarray")]
129 pub fn to_array_with_shape<D: ndarray::Dimension>(
130 &self,
131 shape: impl Into<ndarray::StrideShape<D>>,
132 ) -> ndarray::Array<T, D> {
133 let shape = shape.into();
134 assert_eq!(
135 self.num_elements,
136 shape.size(),
137 "provided shape does not match number of elements in buffer"
138 );
139 ndarray::Array::from_shape_vec(shape, self.to_vec()).unwrap()
140 }
141
142 #[inline(always)]
144 pub fn as_internal(&self) -> &DevicePtr {
145 &self.internal
146 }
147
148 #[inline(always)]
150 pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
151 &mut self.internal
152 }
153}
154
155impl<T: Copy> Drop for HostBuffer<T> {
156 fn drop(&mut self) {
157 if self.internal.is_null() {
158 return;
159 }
160
161 let mut internal = unsafe { self.internal.take() };
163 let ptr = internal.as_mut_ptr();
164 let _ret = cpp!(unsafe [
165 ptr as "void*"
166 ] -> i32 as "std::int32_t" {
167 return cudaFreeHost(ptr);
168 });
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_new() {
178 let buffer = HostBuffer::<u32>::new(100);
179 assert_eq!(buffer.num_elements, 100);
180 assert_eq!(buffer.to_vec().len(), 100);
181 }
182
183 #[test]
184 fn test_from_slice() {
185 let all_ones = vec![1_u32; 200];
186 let buffer = HostBuffer::from_slice(all_ones.as_slice());
187 assert_eq!(buffer.num_elements, 200);
188 let data = buffer.to_vec();
189 assert_eq!(data.len(), 200);
190 assert!(data.into_iter().all(|v| v == 1_u32));
191 }
192
193 #[test]
194 fn test_copy() {
195 let stream = Stream::new().unwrap();
196 let all_ones = vec![1_u32; 100];
197 let host_buffer = HostBuffer::from_slice(all_ones.as_slice());
198
199 let mut device_buffer = DeviceBuffer::<u32>::new(100, &stream);
200 unsafe {
201 host_buffer
202 .copy_to_async(&mut device_buffer, &stream)
203 .unwrap();
204 }
205
206 let mut return_host_buffer = HostBuffer::<u32>::new(100);
207 unsafe {
208 return_host_buffer
209 .copy_from_async(&device_buffer, &stream)
210 .unwrap();
211 }
212
213 stream.synchronize().unwrap();
214
215 assert_eq!(return_host_buffer.num_elements, 100);
216 let return_data = return_host_buffer.to_vec();
217 assert_eq!(return_data.len(), 100);
218 assert!(return_data.into_iter().all(|v| v == 1_u32));
219 }
220
221 #[test]
222 #[should_panic]
223 fn test_it_panics_when_copying_invalid_size() {
224 let stream = Stream::new().unwrap();
225 let host_buffer = HostBuffer::<u32>::new(100);
226 let mut device_buffer = DeviceBuffer::<u32>::new(101, &Stream::null());
227 let _ = unsafe { host_buffer.copy_to_async(&mut device_buffer, &stream) };
228 }
229}