async_cuda/ffi/memory/
host.rs

1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::ffi::device::Device;
5use crate::ffi::memory::device::DeviceBuffer;
6use crate::ffi::ptr::DevicePtr;
7use crate::ffi::result;
8use crate::ffi::stream::Stream;
9
10type Result<T> = std::result::Result<T, crate::error::Error>;
11
12/// Synchronous implementation of [`crate::HostBuffer`].
13///
14/// Refer to [`crate::HostBuffer`] for documentation.
15pub struct HostBuffer<T: Copy> {
16    pub num_elements: usize,
17    internal: DevicePtr,
18    device: DeviceId,
19    _phantom: std::marker::PhantomData<T>,
20}
21
22/// Implements [`Send`] for [`HostBuffer`].
23///
24/// # Safety
25///
26/// This property is inherited from the CUDA API, which is thread-safe.
27unsafe impl<T: Copy> Send for HostBuffer<T> {}
28
29/// Implements [`Sync`] for [`HostBuffer`].
30///
31/// # Safety
32///
33/// This property is inherited from the CUDA API, which is thread-safe.
34unsafe impl<T: Copy> Sync for HostBuffer<T> {}
35
36impl<T: Copy> HostBuffer<T> {
37    pub fn new(num_elements: usize) -> Self {
38        let device = Device::get_or_panic();
39        let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
40        let ptr_ptr = std::ptr::addr_of_mut!(ptr);
41        let size = num_elements * std::mem::size_of::<T>();
42        let ret = cpp!(unsafe [
43            ptr_ptr as "void**",
44            size as "std::size_t"
45        ] -> i32 as "std::int32_t" {
46            return cudaMallocHost(ptr_ptr, size);
47        });
48        match result!(ret, DevicePtr::from_addr(ptr)) {
49            Ok(internal) => Self {
50                internal,
51                device,
52                num_elements,
53                _phantom: Default::default(),
54            },
55            Err(err) => {
56                panic!("failed to allocate host memory: {err}");
57            }
58        }
59    }
60
61    pub fn from_slice(slice: &[T]) -> Self {
62        let mut this = Self::new(slice.len());
63        this.copy_from_slice(slice);
64        this
65    }
66
67    #[cfg(feature = "ndarray")]
68    pub fn from_array<D: ndarray::Dimension>(array: &ndarray::ArrayView<T, D>) -> Self {
69        let mut this = Self::new(array.len());
70        this.copy_from_array(array);
71        this
72    }
73
74    /// Copy from device buffer.
75    ///
76    /// # Safety
77    ///
78    /// This function is marked unsafe because it does not synchronize and the operation might not
79    /// have completed when it returns.
80    #[inline]
81    pub unsafe fn copy_from_async(
82        &mut self,
83        other: &DeviceBuffer<T>,
84        stream: &Stream,
85    ) -> Result<()> {
86        other.copy_to_async(self, stream)
87    }
88
89    /// Copy to device buffer.
90    ///
91    /// # Safety
92    ///
93    /// This function is marked unsafe because it does not synchronize and the operation might not
94    /// have completed when it returns.
95    #[inline]
96    pub unsafe fn copy_to_async(&self, other: &mut DeviceBuffer<T>, stream: &Stream) -> Result<()> {
97        other.copy_from_async(self, stream)
98    }
99
100    pub fn copy_from_slice(&mut self, slice: &[T]) {
101        // SAFETY: This is safe because we only instantiate the slice temporarily whilst having
102        // exclusive mutable access to it to copy the data into it.
103        let target = unsafe {
104            std::slice::from_raw_parts_mut(self.internal.as_mut_ptr() as *mut T, self.num_elements)
105        };
106        target.copy_from_slice(slice);
107    }
108
109    #[cfg(feature = "ndarray")]
110    pub fn copy_from_array<D: ndarray::Dimension>(&mut self, array: &ndarray::ArrayView<T, D>) {
111        assert!(
112            array.is_standard_layout(),
113            "array must be in standard layout"
114        );
115        // SAFETY: This is safe because we only instantiate the slice temporarily whilst having
116        // exclusive mutable access to it to copy the data into it.
117        let target = unsafe {
118            std::slice::from_raw_parts_mut(self.internal.as_mut_ptr() as *mut T, self.num_elements)
119        };
120        target.copy_from_slice(array.as_slice().unwrap());
121    }
122
123    #[inline]
124    pub fn to_vec(&self) -> Vec<T> {
125        // SAFETY: This is safe because we only instantiate the slice temporarily to copy the data
126        // to a safe Rust [`Vec`].
127        let source = unsafe {
128            std::slice::from_raw_parts(self.internal.as_ptr() as *const T, self.num_elements)
129        };
130        source.to_vec()
131    }
132
133    #[cfg(feature = "ndarray")]
134    pub fn to_array_with_shape<D: ndarray::Dimension>(
135        &self,
136        shape: impl Into<ndarray::StrideShape<D>>,
137    ) -> ndarray::Array<T, D> {
138        let shape = shape.into();
139        assert_eq!(
140            self.num_elements,
141            shape.size(),
142            "provided shape does not match number of elements in buffer"
143        );
144        ndarray::Array::from_shape_vec(shape, self.to_vec()).unwrap()
145    }
146
147    /// Get readonly reference to internal [`DevicePtr`].
148    #[inline(always)]
149    pub fn as_internal(&self) -> &DevicePtr {
150        &self.internal
151    }
152
153    /// Get mutable reference to internal [`DevicePtr`].
154    #[inline(always)]
155    pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
156        &mut self.internal
157    }
158
159    /// Release the buffer memory.
160    ///
161    /// # Panics
162    ///
163    /// This function panics if binding to the corresponding device fails.
164    ///
165    /// # Safety
166    ///
167    /// The buffer may not be used after this function is called, except for being dropped.
168    pub unsafe fn free(&mut self) {
169        if self.internal.is_null() {
170            return;
171        }
172
173        Device::set_or_panic(self.device);
174
175        // SAFETY: Safe because we won't use the pointer after this.
176        let mut internal = unsafe { self.internal.take() };
177        let ptr = internal.as_mut_ptr();
178        let _ret = cpp!(unsafe [
179            ptr as "void*"
180        ] -> i32 as "std::int32_t" {
181            return cudaFreeHost(ptr);
182        });
183    }
184}
185
186impl<T: Copy> Drop for HostBuffer<T> {
187    #[inline]
188    fn drop(&mut self) {
189        // SAFETY: This is safe since the buffer cannot be used after this.
190        unsafe {
191            self.free();
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_new() {
202        let buffer = HostBuffer::<u32>::new(100);
203        assert_eq!(buffer.num_elements, 100);
204        assert_eq!(buffer.to_vec().len(), 100);
205    }
206
207    #[test]
208    fn test_from_slice() {
209        let all_ones = vec![1_u32; 200];
210        let buffer = HostBuffer::from_slice(all_ones.as_slice());
211        assert_eq!(buffer.num_elements, 200);
212        let data = buffer.to_vec();
213        assert_eq!(data.len(), 200);
214        assert!(data.into_iter().all(|v| v == 1_u32));
215    }
216
217    #[test]
218    fn test_copy() {
219        let stream = Stream::new().unwrap();
220        let all_ones = vec![1_u32; 100];
221        let host_buffer = HostBuffer::from_slice(all_ones.as_slice());
222
223        let mut device_buffer = DeviceBuffer::<u32>::new(100, &stream);
224        unsafe {
225            host_buffer
226                .copy_to_async(&mut device_buffer, &stream)
227                .unwrap();
228        }
229
230        let mut return_host_buffer = HostBuffer::<u32>::new(100);
231        unsafe {
232            return_host_buffer
233                .copy_from_async(&device_buffer, &stream)
234                .unwrap();
235        }
236
237        stream.synchronize().unwrap();
238
239        assert_eq!(return_host_buffer.num_elements, 100);
240        let return_data = return_host_buffer.to_vec();
241        assert_eq!(return_data.len(), 100);
242        assert!(return_data.into_iter().all(|v| v == 1_u32));
243    }
244
245    #[test]
246    #[should_panic]
247    fn test_it_panics_when_copying_invalid_size() {
248        let stream = Stream::new().unwrap();
249        let host_buffer = HostBuffer::<u32>::new(100);
250        let mut device_buffer = DeviceBuffer::<u32>::new(101, &Stream::null());
251        let _ = unsafe { host_buffer.copy_to_async(&mut device_buffer, &stream) };
252    }
253}