cuda_oxide/
mem.rs

1use std::{
2    ops::{Deref, DerefMut},
3    pin::Pin,
4    rc::Rc,
5};
6
7use crate::*;
8
9/// A slice into the device memory.
10#[derive(Clone)]
11pub struct DevicePtr<'a> {
12    pub(crate) handle: Rc<Handle<'a>>,
13    pub(crate) inner: u64,
14    pub(crate) len: u64,
15}
16
17impl<'a> DevicePtr<'a> {
18    pub fn as_raw(&self) -> u64 {
19        self.inner
20    }
21
22    pub unsafe fn from_raw_parts(handle: Rc<Handle<'a>>, ptr: u64, len: u64) -> Self {
23        Self {
24            handle,
25            inner: ptr,
26            len,
27        }
28    }
29
30    /// Synchronously copies data from `self` to `target`. Panics if length is not equal.
31    pub fn copy_to<'b>(&self, target: &DevicePtr<'b>) -> CudaResult<()> {
32        if self.len > target.len {
33            panic!("overflow in DevicePtr::copy_to");
34        } else if self.len < target.len {
35            panic!("underflow in DevicePtr::copy_to");
36        }
37
38        if std::ptr::eq(self.handle.context, target.handle.context) {
39            cuda_error(unsafe { sys::cuMemcpy(target.inner, self.inner, self.len as sys::size_t) })
40        } else {
41            cuda_error(unsafe {
42                sys::cuMemcpyPeer(
43                    target.inner,
44                    target.handle.context.inner,
45                    self.inner,
46                    self.handle.context.inner,
47                    self.len as sys::size_t,
48                )
49            })
50        }
51    }
52
53    /// Asynchronously copies data from `self` to `target`. Panics if length is not equal.
54    pub fn copy_to_stream<'b, 'c: 'b + 'a>(
55        &self,
56        target: &DevicePtr<'b>,
57        stream: &mut Stream<'c>,
58    ) -> CudaResult<()>
59    where
60        'a: 'b,
61    {
62        if self.len > target.len {
63            panic!("overflow in DevicePtr::copy_to");
64        } else if self.len < target.len {
65            panic!("underflow in DevicePtr::copy_to");
66        }
67
68        if std::ptr::eq(self.handle.context, target.handle.context) {
69            cuda_error(unsafe {
70                sys::cuMemcpyAsync(
71                    target.inner,
72                    self.inner,
73                    self.len as sys::size_t,
74                    stream.inner,
75                )
76            })
77        } else {
78            cuda_error(unsafe {
79                sys::cuMemcpyPeerAsync(
80                    target.inner,
81                    target.handle.context.inner,
82                    self.inner,
83                    self.handle.context.inner,
84                    self.len as sys::size_t,
85                    stream.inner,
86                )
87            })
88        }
89    }
90
91    // pub fn copy_to_async<'b>(&self, target: &DevicePtr<'b>) -> CudaResult<CudaFuture<'a, ()>>
92    // where
93    //     'a: 'b,
94    // {
95    //     let mut stream = self.handle.get_async_stream()?;
96    //     unsafe { self.copy_to_stream(target, &mut stream) }?;
97    //     Ok(CudaFuture::new(self.handle.clone(), stream))
98    // }
99
100    /// Synchronously copies data from `source` to `self`. Panics if length is not equal.
101    pub fn copy_from<'b>(&self, source: &DevicePtr<'b>) -> CudaResult<()> {
102        source.copy_to(self)
103    }
104
105    /// Asynchronously copies data from `source` to `self`. Panics if length is not equal.
106    pub fn copy_from_stream<'b: 'a, 'c: 'a + 'b>(
107        &self,
108        source: &DevicePtr<'b>,
109        stream: &mut Stream<'c>,
110    ) -> CudaResult<()> {
111        source.copy_to_stream(self, stream)
112    }
113
114    /// Gets a subslice of this slice from `[from:to]`
115    pub fn subslice(&self, from: u64, to: u64) -> Self {
116        if from > self.len || from > to || to > self.len {
117            panic!("overflow in DevicePtr::subslice");
118        }
119        Self {
120            handle: self.handle.clone(),
121            inner: self.inner + from,
122            len: to - from,
123        }
124    }
125
126    /// Gets the length of this slice
127    pub fn len(&self) -> u64 {
128        self.len
129    }
130
131    /// Check if the slice's length is 0
132    pub fn is_empty(&self) -> bool {
133        self.len == 0
134    }
135
136    /// Synchronously loads the data from this slice into a local buffer
137    pub fn load(&self) -> CudaResult<Vec<u8>> {
138        let mut buf = Vec::with_capacity(self.len as usize);
139        cuda_error(unsafe {
140            sys::cuMemcpyDtoH_v2(
141                buf.as_mut_ptr() as *mut _,
142                self.inner,
143                self.len as sys::size_t,
144            )
145        })?;
146        unsafe { buf.set_len(self.len as usize) };
147        Ok(buf)
148    }
149
150    /// Asynchronously loads the data from this slice into a local buffer.
151    /// The contents of the buffer are undefined until `stream.sync` is called.
152    /// The output must not be dropped until the stream is synced.
153    pub unsafe fn load_stream(&self, stream: &mut Stream<'a>) -> CudaResult<Vec<u8>> {
154        let mut buf = Vec::with_capacity(self.len as usize);
155        cuda_error(sys::cuMemcpyDtoHAsync_v2(
156            buf.as_mut_ptr() as *mut _,
157            self.inner,
158            self.len as sys::size_t,
159            stream.inner,
160        ))?;
161        buf.set_len(self.len as usize);
162        Ok(buf)
163    }
164
165    /// Synchronously stores host data from `data` to `self`.
166    pub fn store(&self, data: &[u8]) -> CudaResult<()> {
167        if data.len() > self.len as usize {
168            panic!("overflow in DevicePtr::store");
169        } else if data.len() < self.len as usize {
170            panic!("underflow in DevicePtr::store");
171        }
172        cuda_error(unsafe {
173            sys::cuMemcpyHtoD_v2(
174                self.inner,
175                data.as_ptr() as *const _,
176                self.len as sys::size_t,
177            )
178        })?;
179        Ok(())
180    }
181
182    /// Asynchronously stores host data from `data` to `self`.
183    /// The `data` must not be dropped or mutated until `stream.sync` is called.
184    pub fn store_stream<'b>(&self, data: &'b [u8], stream: &'b mut Stream<'a>) -> CudaResult<()> {
185        if data.len() > self.len as usize {
186            panic!("overflow in DevicePtr::store");
187        } else if data.len() < self.len as usize {
188            panic!("underflow in DevicePtr::store");
189        }
190        cuda_error(unsafe {
191            sys::cuMemcpyHtoDAsync_v2(
192                self.inner,
193                data.as_ptr() as *const _,
194                self.len as sys::size_t,
195                stream.inner,
196            )
197        })?;
198        Ok(())
199    }
200
201    /// Asynchronously stores host data from `data` to `self`.
202    /// `data` will be dropped once the [`Stream`] is synced or dropped.
203    pub fn store_stream_buf(&self, data: Vec<u8>, stream: &mut Stream<'a>) -> CudaResult<()> {
204        if data.len() > self.len as usize {
205            panic!("overflow in DevicePtr::store");
206        } else if data.len() < self.len as usize {
207            panic!("underflow in DevicePtr::store");
208        }
209        let data: Pin<Box<[u8]>> = data.into_boxed_slice().into();
210        stream.pending_stores.push(data);
211        cuda_error(unsafe {
212            sys::cuMemcpyHtoDAsync_v2(
213                self.inner,
214                stream.pending_stores.last().unwrap().as_ptr() as *const _,
215                self.len as sys::size_t,
216                stream.inner,
217            )
218        })?;
219        Ok(())
220    }
221
222    /// Synchronously set the contents of `self` to `data` repeated to fill length
223    pub fn memset_d8(&self, data: u8) -> CudaResult<()> {
224        cuda_error(unsafe { sys::cuMemsetD8_v2(self.inner, data, self.len as sys::size_t) })
225    }
226
227    /// Asynchronously set the contents of `self` to `data` repeated to fill length
228    pub fn memset_d8_stream(&self, data: u8, stream: &mut Stream<'a>) -> CudaResult<()> {
229        cuda_error(unsafe {
230            sys::cuMemsetD8Async(self.inner, data, self.len as sys::size_t, stream.inner)
231        })
232    }
233
234    /// Synchronously set the contents of `self` to `data` repeated to fill length.
235    /// Panics if [`Self::len`] is not a multiple of 2.
236    pub fn memset_d16(&self, data: u16) -> CudaResult<()> {
237        if self.len % 2 != 0 {
238            panic!("alignment failure in DevicePtr::memset_d16");
239        }
240        cuda_error(unsafe { sys::cuMemsetD16_v2(self.inner, data, self.len as sys::size_t / 2) })
241    }
242
243    /// Asynchronously set the contents of `self` to `data` repeated to fill length.
244    /// Panics if [`Self::len`] is not a multiple of 2.
245    pub fn memset_d16_stream(&self, data: u16, stream: &mut Stream<'a>) -> CudaResult<()> {
246        if self.len % 2 != 0 {
247            panic!("alignment failure in DevicePtr::memset_d16_stream");
248        }
249        cuda_error(unsafe {
250            sys::cuMemsetD16Async(self.inner, data, self.len as sys::size_t / 2, stream.inner)
251        })
252    }
253
254    /// Synchronously set the contents of `self` to `data` repeated to fill length.
255    /// Panics if [`Self::len`] is not a multiple of 4.
256    pub fn memset_d32(&self, data: u32) -> CudaResult<()> {
257        if self.len % 4 != 0 {
258            panic!("alignment failure in DevicePtr::memset_d32");
259        }
260        cuda_error(unsafe { sys::cuMemsetD32_v2(self.inner, data, self.len as sys::size_t / 4) })
261    }
262
263    /// Asynchronously set the contents of `self` to `data` repeated to fill length.
264    /// Panics if [`Self::len`] is not a multiple of 4.
265    pub fn memset_d32_stream(&self, data: u32, stream: &mut Stream<'a>) -> CudaResult<()> {
266        if self.len % 4 != 0 {
267            panic!("alignment failure in DevicePtr::memset_d32_stream");
268        }
269        cuda_error(unsafe {
270            sys::cuMemsetD32Async(self.inner, data, self.len as sys::size_t / 4, stream.inner)
271        })
272    }
273
274    /// Gets a reference to the owning handle
275    pub fn handle(&self) -> &Rc<Handle<'a>> {
276        &self.handle
277    }
278}
279
280/// An owned device-allocated buffer
281pub struct DeviceBox<'a> {
282    pub(crate) inner: DevicePtr<'a>,
283}
284
285impl<'a> DeviceBox<'a> {
286    /// Allocate an uninitialized buffer of size `size` on the device
287    pub fn alloc(handle: &Rc<Handle<'a>>, size: u64) -> CudaResult<Self> {
288        let mut out = 0u64;
289        cuda_error(unsafe { sys::cuMemAlloc_v2(&mut out as *mut u64, size as sys::size_t) })?;
290        Ok(DeviceBox {
291            inner: DevicePtr {
292                handle: handle.clone(),
293                inner: out,
294                len: size,
295            },
296        })
297    }
298
299    /// Allocate a new initialized buffer on the device matching the size and content of `input`.
300    pub fn new(handle: &Rc<Handle<'a>>, input: &[u8]) -> CudaResult<Self> {
301        let buf = Self::alloc(handle, input.len() as u64)?;
302        buf.store(input)?;
303        Ok(buf)
304    }
305
306    /// Allocates a new uninitialized buffer on the device, then asynchronously fills it with `input`.
307    /// `input` must not be dropped or mutated until `stream.sync` is called.
308    /// Does not allocate the memory asynchronously.
309    pub fn new_stream<'b>(
310        handle: &Rc<Handle<'a>>,
311        input: &'b [u8],
312        stream: &'b mut Stream<'a>,
313    ) -> CudaResult<Self> {
314        let buf = Self::alloc(handle, input.len() as u64)?;
315        buf.store_stream(input, stream)?;
316        Ok(buf)
317    }
318
319    /// Allocates a new uninitialized buffer on the device, then synchronously fills it with `input`.
320    /// `input` will be dropped when the stream is synced or dropped.
321    /// Does not allocate the memory asynchronously.
322    pub fn new_stream_buf(
323        handle: &Rc<Handle<'a>>,
324        input: Vec<u8>,
325        stream: &mut Stream<'a>,
326    ) -> CudaResult<Self> {
327        let buf = Self::alloc(handle, input.len() as u64)?;
328        buf.store_stream_buf(input, stream)?;
329        Ok(buf)
330    }
331
332    /// Allocates a new initialized buffer on the device matching the size and content of `input`.
333    /// Note that memory is directly copied, so [`T`] must be [`Sized`] should not contain any pointers, references, unsized types, or other non-FFI safe types.
334    pub fn new_ffi<T>(handle: &Rc<Handle<'a>>, input: &[T]) -> CudaResult<Self> {
335        let raw = unsafe {
336            std::slice::from_raw_parts(
337                input.as_ptr() as *const u8,
338                input.len() * std::mem::size_of::<T>(),
339            )
340        };
341        let buf = Self::alloc(handle, raw.len() as u64)?;
342        buf.store(raw)?;
343        Ok(buf)
344    }
345
346    /// Allocates a new uninitialized buffer on the device, then synchronously fills it with `input`.
347    /// Note that memory is directly copied, so [`T`] must be [`Sized`] *should* not contain any pointers, references, unsized types, or other non-FFI safe types.
348    /// `input` must not be dropped or mutated until `stream.sync` is called.
349    /// Does not allocate the memory asynchronously.
350    pub fn new_ffi_stream<'b, T>(
351        handle: &Rc<Handle<'a>>,
352        input: &'b [T],
353        stream: &'b mut Stream<'a>,
354    ) -> CudaResult<Self> {
355        let raw = unsafe {
356            std::slice::from_raw_parts(
357                input.as_ptr() as *const u8,
358                input.len() * std::mem::size_of::<T>(),
359            )
360        };
361        let buf = Self::alloc(handle, raw.len() as u64)?;
362        buf.store_stream(raw, stream)?;
363        Ok(buf)
364    }
365
366    /// Allocates a new uninitialized buffer on the device, then synchronously fills it with `input`.
367    /// Note that memory is directly copied, so [`T`] must be [`Sized`] *should* not contain any pointers, references, unsized types, or other non-FFI safe types.
368    /// `input` will be dropped when the stream is synced or dropped.
369    /// Does not allocate the memory asynchronously.
370    pub fn new_ffi_stream_buf<'b, T>(
371        handle: &Rc<Handle<'a>>,
372        mut input: Vec<T>,
373        stream: &'b mut Stream<'a>,
374    ) -> CudaResult<Self> {
375        let raw = unsafe {
376            Vec::from_raw_parts(
377                input.as_mut_ptr() as *mut u8,
378                input.len() * std::mem::size_of::<T>(),
379                input.capacity() * std::mem::size_of::<T>(),
380            )
381        };
382        std::mem::forget(input);
383        let buf = Self::alloc(handle, raw.len() as u64)?;
384        buf.store_stream_buf(raw, stream)?;
385        Ok(buf)
386    }
387
388    /// Leaks the DeviceBox, similar to [`Box::leak`].
389    pub fn leak(self) {
390        std::mem::forget(self);
391    }
392
393    /// Constructs a [`DeviceBox`] from a device pointer.
394    pub unsafe fn from_raw(raw: DevicePtr<'a>) -> Self {
395        Self { inner: raw }
396    }
397}
398
399impl<'a> Drop for DeviceBox<'a> {
400    fn drop(&mut self) {
401        if let Err(e) = cuda_error(unsafe { sys::cuMemFree_v2(self.inner.inner) }) {
402            eprintln!("CUDA: failed freeing device buffer: {:?}", e);
403        }
404    }
405}
406
407impl<'a> AsRef<DevicePtr<'a>> for DeviceBox<'a> {
408    fn as_ref(&self) -> &DevicePtr<'a> {
409        &self.inner
410    }
411}
412
413impl<'a> Deref for DeviceBox<'a> {
414    type Target = DevicePtr<'a>;
415
416    fn deref(&self) -> &Self::Target {
417        &self.inner
418    }
419}
420
421impl<'a> DerefMut for DeviceBox<'a> {
422    fn deref_mut(&mut self) -> &mut Self::Target {
423        &mut self.inner
424    }
425}