async_cuda_core/ffi/memory/
host.rs

1use cpp::cpp;
2
3use crate::ffi::memory::device::DeviceBuffer;
4use crate::ffi::ptr::DevicePtr;
5use crate::ffi::result;
6use crate::ffi::stream::Stream;
7
8type Result<T> = std::result::Result<T, crate::error::Error>;
9
10/// Synchronous implementation of [`crate::HostBuffer`].
11///
12/// Refer to [`crate::HostBuffer`] for documentation.
13pub struct HostBuffer<T: Copy> {
14    pub num_elements: usize,
15    internal: DevicePtr,
16    _phantom: std::marker::PhantomData<T>,
17}
18
19/// Implements [`Send`] for [`HostBuffer`].
20///
21/// # Safety
22///
23/// This property is inherited from the CUDA API, which is thread-safe.
24unsafe impl<T: Copy> Send for HostBuffer<T> {}
25
26/// Implements [`Sync`] for [`HostBuffer`].
27///
28/// # Safety
29///
30/// This property is inherited from the CUDA API, which is thread-safe.
31unsafe impl<T: Copy> Sync for HostBuffer<T> {}
32
33impl<T: Copy> HostBuffer<T> {
34    pub fn new(num_elements: usize) -> Self {
35        let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
36        let ptr_ptr = std::ptr::addr_of_mut!(ptr);
37        let size = num_elements * std::mem::size_of::<T>();
38        let ret = cpp!(unsafe [
39            ptr_ptr as "void**",
40            size as "std::size_t"
41        ] -> i32 as "std::int32_t" {
42            return cudaMallocHost(ptr_ptr, size);
43        });
44        match result!(ret, ptr.into()) {
45            Ok(internal) => Self {
46                internal,
47                num_elements,
48                _phantom: Default::default(),
49            },
50            Err(err) => {
51                panic!("failed to allocate host memory: {err}");
52            }
53        }
54    }
55
56    pub fn from_slice(slice: &[T]) -> Self {
57        let mut this = Self::new(slice.len());
58        this.copy_from_slice(slice);
59        this
60    }
61
62    #[cfg(feature = "ndarray")]
63    pub fn from_array<D: ndarray::Dimension>(array: &ndarray::ArrayView<T, D>) -> Self {
64        let mut this = Self::new(array.len());
65        this.copy_from_array(array);
66        this
67    }
68
69    /// Copy from device buffer.
70    ///
71    /// # Safety
72    ///
73    /// This function is marked unsafe because it does not synchronize and the operation might not
74    /// have completed when it returns.
75    #[inline]
76    pub unsafe fn copy_from_async(
77        &mut self,
78        other: &DeviceBuffer<T>,
79        stream: &Stream,
80    ) -> Result<()> {
81        other.copy_to_async(self, stream)
82    }
83
84    /// Copy to device buffer.
85    ///
86    /// # Safety
87    ///
88    /// This function is marked unsafe because it does not synchronize and the operation might not
89    /// have completed when it returns.
90    #[inline]
91    pub unsafe fn copy_to_async(&self, other: &mut DeviceBuffer<T>, stream: &Stream) -> Result<()> {
92        other.copy_from_async(self, stream)
93    }
94
95    pub fn copy_from_slice(&mut self, slice: &[T]) {
96        // SAFETY: This is safe because we only instantiate the slice temporarily whilst having
97        // exclusive mutable access to it to copy the data into it.
98        let target = unsafe {
99            std::slice::from_raw_parts_mut(self.internal.as_mut_ptr() as *mut T, self.num_elements)
100        };
101        target.copy_from_slice(slice);
102    }
103
104    #[cfg(feature = "ndarray")]
105    pub fn copy_from_array<D: ndarray::Dimension>(&mut self, array: &ndarray::ArrayView<T, D>) {
106        assert!(
107            array.is_standard_layout(),
108            "array must be in standard layout"
109        );
110        // SAFETY: This is safe because we only instantiate the slice temporarily whilst having
111        // exclusive mutable access to it to copy the data into it.
112        let target = unsafe {
113            std::slice::from_raw_parts_mut(self.internal.as_mut_ptr() as *mut T, self.num_elements)
114        };
115        target.copy_from_slice(array.as_slice().unwrap());
116    }
117
118    #[inline]
119    pub fn to_vec(&self) -> Vec<T> {
120        // SAFETY: This is safe because we only instantiate the slice temporarily to copy the data
121        // to a safe Rust [`Vec`].
122        let source = unsafe {
123            std::slice::from_raw_parts(self.internal.as_ptr() as *const T, self.num_elements)
124        };
125        source.to_vec()
126    }
127
128    #[cfg(feature = "ndarray")]
129    pub fn to_array_with_shape<D: ndarray::Dimension>(
130        &self,
131        shape: impl Into<ndarray::StrideShape<D>>,
132    ) -> ndarray::Array<T, D> {
133        let shape = shape.into();
134        assert_eq!(
135            self.num_elements,
136            shape.size(),
137            "provided shape does not match number of elements in buffer"
138        );
139        ndarray::Array::from_shape_vec(shape, self.to_vec()).unwrap()
140    }
141
142    /// Get readonly reference to internal [`DevicePtr`].
143    #[inline(always)]
144    pub fn as_internal(&self) -> &DevicePtr {
145        &self.internal
146    }
147
148    /// Get readonly reference to internal [`DevicePtr`].
149    #[inline(always)]
150    pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
151        &mut self.internal
152    }
153}
154
155impl<T: Copy> Drop for HostBuffer<T> {
156    fn drop(&mut self) {
157        if self.internal.is_null() {
158            return;
159        }
160
161        // SAFETY: Safe because we won't use the pointer after this.
162        let mut internal = unsafe { self.internal.take() };
163        let ptr = internal.as_mut_ptr();
164        let _ret = cpp!(unsafe [
165            ptr as "void*"
166        ] -> i32 as "std::int32_t" {
167            return cudaFreeHost(ptr);
168        });
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_new() {
178        let buffer = HostBuffer::<u32>::new(100);
179        assert_eq!(buffer.num_elements, 100);
180        assert_eq!(buffer.to_vec().len(), 100);
181    }
182
183    #[test]
184    fn test_from_slice() {
185        let all_ones = vec![1_u32; 200];
186        let buffer = HostBuffer::from_slice(all_ones.as_slice());
187        assert_eq!(buffer.num_elements, 200);
188        let data = buffer.to_vec();
189        assert_eq!(data.len(), 200);
190        assert!(data.into_iter().all(|v| v == 1_u32));
191    }
192
193    #[test]
194    fn test_copy() {
195        let stream = Stream::new().unwrap();
196        let all_ones = vec![1_u32; 100];
197        let host_buffer = HostBuffer::from_slice(all_ones.as_slice());
198
199        let mut device_buffer = DeviceBuffer::<u32>::new(100, &stream);
200        unsafe {
201            host_buffer
202                .copy_to_async(&mut device_buffer, &stream)
203                .unwrap();
204        }
205
206        let mut return_host_buffer = HostBuffer::<u32>::new(100);
207        unsafe {
208            return_host_buffer
209                .copy_from_async(&device_buffer, &stream)
210                .unwrap();
211        }
212
213        stream.synchronize().unwrap();
214
215        assert_eq!(return_host_buffer.num_elements, 100);
216        let return_data = return_host_buffer.to_vec();
217        assert_eq!(return_data.len(), 100);
218        assert!(return_data.into_iter().all(|v| v == 1_u32));
219    }
220
221    #[test]
222    #[should_panic]
223    fn test_it_panics_when_copying_invalid_size() {
224        let stream = Stream::new().unwrap();
225        let host_buffer = HostBuffer::<u32>::new(100);
226        let mut device_buffer = DeviceBuffer::<u32>::new(101, &Stream::null());
227        let _ = unsafe { host_buffer.copy_to_async(&mut device_buffer, &stream) };
228    }
229}