1use std::{
2 ops::{Deref, DerefMut},
3 pin::Pin,
4 rc::Rc,
5};
6
7use crate::*;
8
9#[derive(Clone)]
11pub struct DevicePtr<'a> {
12 pub(crate) handle: Rc<Handle<'a>>,
13 pub(crate) inner: u64,
14 pub(crate) len: u64,
15}
16
17impl<'a> DevicePtr<'a> {
18 pub fn as_raw(&self) -> u64 {
19 self.inner
20 }
21
22 pub unsafe fn from_raw_parts(handle: Rc<Handle<'a>>, ptr: u64, len: u64) -> Self {
23 Self {
24 handle,
25 inner: ptr,
26 len,
27 }
28 }
29
30 pub fn copy_to<'b>(&self, target: &DevicePtr<'b>) -> CudaResult<()> {
32 if self.len > target.len {
33 panic!("overflow in DevicePtr::copy_to");
34 } else if self.len < target.len {
35 panic!("underflow in DevicePtr::copy_to");
36 }
37
38 if std::ptr::eq(self.handle.context, target.handle.context) {
39 cuda_error(unsafe { sys::cuMemcpy(target.inner, self.inner, self.len as sys::size_t) })
40 } else {
41 cuda_error(unsafe {
42 sys::cuMemcpyPeer(
43 target.inner,
44 target.handle.context.inner,
45 self.inner,
46 self.handle.context.inner,
47 self.len as sys::size_t,
48 )
49 })
50 }
51 }
52
53 pub fn copy_to_stream<'b, 'c: 'b + 'a>(
55 &self,
56 target: &DevicePtr<'b>,
57 stream: &mut Stream<'c>,
58 ) -> CudaResult<()>
59 where
60 'a: 'b,
61 {
62 if self.len > target.len {
63 panic!("overflow in DevicePtr::copy_to");
64 } else if self.len < target.len {
65 panic!("underflow in DevicePtr::copy_to");
66 }
67
68 if std::ptr::eq(self.handle.context, target.handle.context) {
69 cuda_error(unsafe {
70 sys::cuMemcpyAsync(
71 target.inner,
72 self.inner,
73 self.len as sys::size_t,
74 stream.inner,
75 )
76 })
77 } else {
78 cuda_error(unsafe {
79 sys::cuMemcpyPeerAsync(
80 target.inner,
81 target.handle.context.inner,
82 self.inner,
83 self.handle.context.inner,
84 self.len as sys::size_t,
85 stream.inner,
86 )
87 })
88 }
89 }
90
91 pub fn copy_from<'b>(&self, source: &DevicePtr<'b>) -> CudaResult<()> {
102 source.copy_to(self)
103 }
104
105 pub fn copy_from_stream<'b: 'a, 'c: 'a + 'b>(
107 &self,
108 source: &DevicePtr<'b>,
109 stream: &mut Stream<'c>,
110 ) -> CudaResult<()> {
111 source.copy_to_stream(self, stream)
112 }
113
114 pub fn subslice(&self, from: u64, to: u64) -> Self {
116 if from > self.len || from > to || to > self.len {
117 panic!("overflow in DevicePtr::subslice");
118 }
119 Self {
120 handle: self.handle.clone(),
121 inner: self.inner + from,
122 len: to - from,
123 }
124 }
125
126 pub fn len(&self) -> u64 {
128 self.len
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.len == 0
134 }
135
136 pub fn load(&self) -> CudaResult<Vec<u8>> {
138 let mut buf = Vec::with_capacity(self.len as usize);
139 cuda_error(unsafe {
140 sys::cuMemcpyDtoH_v2(
141 buf.as_mut_ptr() as *mut _,
142 self.inner,
143 self.len as sys::size_t,
144 )
145 })?;
146 unsafe { buf.set_len(self.len as usize) };
147 Ok(buf)
148 }
149
150 pub unsafe fn load_stream(&self, stream: &mut Stream<'a>) -> CudaResult<Vec<u8>> {
154 let mut buf = Vec::with_capacity(self.len as usize);
155 cuda_error(sys::cuMemcpyDtoHAsync_v2(
156 buf.as_mut_ptr() as *mut _,
157 self.inner,
158 self.len as sys::size_t,
159 stream.inner,
160 ))?;
161 buf.set_len(self.len as usize);
162 Ok(buf)
163 }
164
165 pub fn store(&self, data: &[u8]) -> CudaResult<()> {
167 if data.len() > self.len as usize {
168 panic!("overflow in DevicePtr::store");
169 } else if data.len() < self.len as usize {
170 panic!("underflow in DevicePtr::store");
171 }
172 cuda_error(unsafe {
173 sys::cuMemcpyHtoD_v2(
174 self.inner,
175 data.as_ptr() as *const _,
176 self.len as sys::size_t,
177 )
178 })?;
179 Ok(())
180 }
181
182 pub fn store_stream<'b>(&self, data: &'b [u8], stream: &'b mut Stream<'a>) -> CudaResult<()> {
185 if data.len() > self.len as usize {
186 panic!("overflow in DevicePtr::store");
187 } else if data.len() < self.len as usize {
188 panic!("underflow in DevicePtr::store");
189 }
190 cuda_error(unsafe {
191 sys::cuMemcpyHtoDAsync_v2(
192 self.inner,
193 data.as_ptr() as *const _,
194 self.len as sys::size_t,
195 stream.inner,
196 )
197 })?;
198 Ok(())
199 }
200
201 pub fn store_stream_buf(&self, data: Vec<u8>, stream: &mut Stream<'a>) -> CudaResult<()> {
204 if data.len() > self.len as usize {
205 panic!("overflow in DevicePtr::store");
206 } else if data.len() < self.len as usize {
207 panic!("underflow in DevicePtr::store");
208 }
209 let data: Pin<Box<[u8]>> = data.into_boxed_slice().into();
210 stream.pending_stores.push(data);
211 cuda_error(unsafe {
212 sys::cuMemcpyHtoDAsync_v2(
213 self.inner,
214 stream.pending_stores.last().unwrap().as_ptr() as *const _,
215 self.len as sys::size_t,
216 stream.inner,
217 )
218 })?;
219 Ok(())
220 }
221
222 pub fn memset_d8(&self, data: u8) -> CudaResult<()> {
224 cuda_error(unsafe { sys::cuMemsetD8_v2(self.inner, data, self.len as sys::size_t) })
225 }
226
227 pub fn memset_d8_stream(&self, data: u8, stream: &mut Stream<'a>) -> CudaResult<()> {
229 cuda_error(unsafe {
230 sys::cuMemsetD8Async(self.inner, data, self.len as sys::size_t, stream.inner)
231 })
232 }
233
234 pub fn memset_d16(&self, data: u16) -> CudaResult<()> {
237 if self.len % 2 != 0 {
238 panic!("alignment failure in DevicePtr::memset_d16");
239 }
240 cuda_error(unsafe { sys::cuMemsetD16_v2(self.inner, data, self.len as sys::size_t / 2) })
241 }
242
243 pub fn memset_d16_stream(&self, data: u16, stream: &mut Stream<'a>) -> CudaResult<()> {
246 if self.len % 2 != 0 {
247 panic!("alignment failure in DevicePtr::memset_d16_stream");
248 }
249 cuda_error(unsafe {
250 sys::cuMemsetD16Async(self.inner, data, self.len as sys::size_t / 2, stream.inner)
251 })
252 }
253
254 pub fn memset_d32(&self, data: u32) -> CudaResult<()> {
257 if self.len % 4 != 0 {
258 panic!("alignment failure in DevicePtr::memset_d32");
259 }
260 cuda_error(unsafe { sys::cuMemsetD32_v2(self.inner, data, self.len as sys::size_t / 4) })
261 }
262
263 pub fn memset_d32_stream(&self, data: u32, stream: &mut Stream<'a>) -> CudaResult<()> {
266 if self.len % 4 != 0 {
267 panic!("alignment failure in DevicePtr::memset_d32_stream");
268 }
269 cuda_error(unsafe {
270 sys::cuMemsetD32Async(self.inner, data, self.len as sys::size_t / 4, stream.inner)
271 })
272 }
273
274 pub fn handle(&self) -> &Rc<Handle<'a>> {
276 &self.handle
277 }
278}
279
280pub struct DeviceBox<'a> {
282 pub(crate) inner: DevicePtr<'a>,
283}
284
285impl<'a> DeviceBox<'a> {
286 pub fn alloc(handle: &Rc<Handle<'a>>, size: u64) -> CudaResult<Self> {
288 let mut out = 0u64;
289 cuda_error(unsafe { sys::cuMemAlloc_v2(&mut out as *mut u64, size as sys::size_t) })?;
290 Ok(DeviceBox {
291 inner: DevicePtr {
292 handle: handle.clone(),
293 inner: out,
294 len: size,
295 },
296 })
297 }
298
299 pub fn new(handle: &Rc<Handle<'a>>, input: &[u8]) -> CudaResult<Self> {
301 let buf = Self::alloc(handle, input.len() as u64)?;
302 buf.store(input)?;
303 Ok(buf)
304 }
305
306 pub fn new_stream<'b>(
310 handle: &Rc<Handle<'a>>,
311 input: &'b [u8],
312 stream: &'b mut Stream<'a>,
313 ) -> CudaResult<Self> {
314 let buf = Self::alloc(handle, input.len() as u64)?;
315 buf.store_stream(input, stream)?;
316 Ok(buf)
317 }
318
319 pub fn new_stream_buf(
323 handle: &Rc<Handle<'a>>,
324 input: Vec<u8>,
325 stream: &mut Stream<'a>,
326 ) -> CudaResult<Self> {
327 let buf = Self::alloc(handle, input.len() as u64)?;
328 buf.store_stream_buf(input, stream)?;
329 Ok(buf)
330 }
331
332 pub fn new_ffi<T>(handle: &Rc<Handle<'a>>, input: &[T]) -> CudaResult<Self> {
335 let raw = unsafe {
336 std::slice::from_raw_parts(
337 input.as_ptr() as *const u8,
338 input.len() * std::mem::size_of::<T>(),
339 )
340 };
341 let buf = Self::alloc(handle, raw.len() as u64)?;
342 buf.store(raw)?;
343 Ok(buf)
344 }
345
346 pub fn new_ffi_stream<'b, T>(
351 handle: &Rc<Handle<'a>>,
352 input: &'b [T],
353 stream: &'b mut Stream<'a>,
354 ) -> CudaResult<Self> {
355 let raw = unsafe {
356 std::slice::from_raw_parts(
357 input.as_ptr() as *const u8,
358 input.len() * std::mem::size_of::<T>(),
359 )
360 };
361 let buf = Self::alloc(handle, raw.len() as u64)?;
362 buf.store_stream(raw, stream)?;
363 Ok(buf)
364 }
365
366 pub fn new_ffi_stream_buf<'b, T>(
371 handle: &Rc<Handle<'a>>,
372 mut input: Vec<T>,
373 stream: &'b mut Stream<'a>,
374 ) -> CudaResult<Self> {
375 let raw = unsafe {
376 Vec::from_raw_parts(
377 input.as_mut_ptr() as *mut u8,
378 input.len() * std::mem::size_of::<T>(),
379 input.capacity() * std::mem::size_of::<T>(),
380 )
381 };
382 std::mem::forget(input);
383 let buf = Self::alloc(handle, raw.len() as u64)?;
384 buf.store_stream_buf(raw, stream)?;
385 Ok(buf)
386 }
387
388 pub fn leak(self) {
390 std::mem::forget(self);
391 }
392
393 pub unsafe fn from_raw(raw: DevicePtr<'a>) -> Self {
395 Self { inner: raw }
396 }
397}
398
399impl<'a> Drop for DeviceBox<'a> {
400 fn drop(&mut self) {
401 if let Err(e) = cuda_error(unsafe { sys::cuMemFree_v2(self.inner.inner) }) {
402 eprintln!("CUDA: failed freeing device buffer: {:?}", e);
403 }
404 }
405}
406
407impl<'a> AsRef<DevicePtr<'a>> for DeviceBox<'a> {
408 fn as_ref(&self) -> &DevicePtr<'a> {
409 &self.inner
410 }
411}
412
413impl<'a> Deref for DeviceBox<'a> {
414 type Target = DevicePtr<'a>;
415
416 fn deref(&self) -> &Self::Target {
417 &self.inner
418 }
419}
420
421impl<'a> DerefMut for DeviceBox<'a> {
422 fn deref_mut(&mut self) -> &mut Self::Target {
423 &mut self.inner
424 }
425}