1use core::ffi::c_void;
8use core::marker::PhantomData;
9use core::mem::size_of;
10use core::ops::Range;
11
12use baracuda_cuda_sys::{driver, CUdeviceptr};
13use baracuda_types::{DeviceRepr, KernelArg};
14
15use crate::context::Context;
16use crate::error::{check, Result};
17use crate::stream::Stream;
18
19pub struct DeviceBuffer<T: DeviceRepr> {
26 ptr: CUdeviceptr,
27 len: usize,
28 context: Context,
29 _marker: PhantomData<T>,
30}
31
32unsafe impl<T: DeviceRepr + Send> Send for DeviceBuffer<T> {}
35
36impl<T: DeviceRepr> core::fmt::Debug for DeviceBuffer<T> {
37 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38 f.debug_struct("DeviceBuffer")
39 .field("ptr", &format_args!("{:#x}", self.ptr.0))
40 .field("len", &self.len)
41 .field("type", &core::any::type_name::<T>())
42 .finish()
43 }
44}
45
46impl<T: DeviceRepr> DeviceBuffer<T> {
47 pub fn new(context: &Context, len: usize) -> Result<Self> {
54 let bytes = len
55 .checked_mul(size_of::<T>())
56 .expect("overflow computing allocation size");
57 if bytes == 0 {
58 return Ok(Self {
59 ptr: CUdeviceptr(0),
60 len,
61 context: context.clone(),
62 _marker: PhantomData,
63 });
64 }
65 context.set_current()?;
66 let d = driver()?;
67 let cu = d.cu_mem_alloc()?;
68 let mut ptr = CUdeviceptr(0);
69 check(unsafe { cu(&mut ptr, bytes) })?;
71 Ok(Self {
72 ptr,
73 len,
74 context: context.clone(),
75 _marker: PhantomData,
76 })
77 }
78
79 pub fn new_async(context: &Context, len: usize, stream: &Stream) -> Result<Self> {
87 let bytes = len
88 .checked_mul(size_of::<T>())
89 .expect("overflow computing allocation size");
90 if bytes == 0 {
91 return Ok(Self {
92 ptr: CUdeviceptr(0),
93 len,
94 context: context.clone(),
95 _marker: PhantomData,
96 });
97 }
98 context.set_current()?;
99 let d = driver()?;
100 let cu = d.cu_mem_alloc_async()?;
101 let mut ptr = CUdeviceptr(0);
102 check(unsafe { cu(&mut ptr, bytes, stream.as_raw()) })?;
104 Ok(Self {
105 ptr,
106 len,
107 context: context.clone(),
108 _marker: PhantomData,
109 })
110 }
111
112 pub fn free_async(mut self, stream: &Stream) -> Result<()> {
118 let ptr = core::mem::replace(&mut self.ptr, CUdeviceptr(0));
119 if ptr.0 == 0 {
120 return Ok(());
121 }
122 let d = driver()?;
123 let cu = d.cu_mem_free_async()?;
124 check(unsafe { cu(ptr, stream.as_raw()) })
125 }
126
127 pub fn zeros(context: &Context, len: usize) -> Result<Self> {
130 let buf = Self::new(context, len)?;
131 let bytes = len * size_of::<T>();
132 if bytes == 0 {
133 return Ok(buf);
134 }
135 let d = driver()?;
136 let cu = d.cu_memset_d8()?;
137 check(unsafe { cu(buf.ptr, 0, bytes) })?;
138 Ok(buf)
139 }
140
141 pub fn zero(&self) -> Result<()> {
146 let bytes = self.len * size_of::<T>();
147 if bytes == 0 {
148 return Ok(());
149 }
150 let d = driver()?;
151 let cu = d.cu_memset_d8()?;
152 check(unsafe { cu(self.ptr, 0, bytes) })
153 }
154
155 pub fn zero_async(&self, stream: &Stream) -> Result<()> {
159 let bytes = self.len * size_of::<T>();
160 if bytes == 0 {
161 return Ok(());
162 }
163 let d = driver()?;
164 let cu = d.cu_memset_d8_async()?;
165 check(unsafe { cu(self.ptr, 0, bytes, stream.as_raw()) })
166 }
167
168 pub fn from_slice(context: &Context, src: &[T]) -> Result<Self> {
171 let buf = Self::new(context, src.len())?;
172 buf.copy_from_host(src)?;
173 Ok(buf)
174 }
175
176 pub fn copy_from_host(&self, src: &[T]) -> Result<()> {
179 assert_eq!(
180 src.len(),
181 self.len,
182 "copy_from_host: source length {} != buffer length {}",
183 src.len(),
184 self.len
185 );
186 let bytes = self.len * size_of::<T>();
187 if bytes == 0 {
188 return Ok(());
189 }
190 let d = driver()?;
191 let cu = d.cu_memcpy_htod()?;
192 check(unsafe { cu(self.ptr, src.as_ptr() as *const c_void, bytes) })
195 }
196
197 pub fn copy_to_host(&self, dst: &mut [T]) -> Result<()> {
200 assert_eq!(
201 dst.len(),
202 self.len,
203 "copy_to_host: destination length {} != buffer length {}",
204 dst.len(),
205 self.len
206 );
207 let bytes = self.len * size_of::<T>();
208 if bytes == 0 {
209 return Ok(());
210 }
211 let d = driver()?;
212 let cu = d.cu_memcpy_dtoh()?;
213 check(unsafe { cu(dst.as_mut_ptr() as *mut c_void, self.ptr, bytes) })
215 }
216
217 pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
219 assert_eq!(src.len(), self.len);
220 let bytes = self.len * size_of::<T>();
221 if bytes == 0 {
222 return Ok(());
223 }
224 let d = driver()?;
225 let cu = d.cu_memcpy_htod_async()?;
226 check(unsafe {
227 cu(
228 self.ptr,
229 src.as_ptr() as *const c_void,
230 bytes,
231 stream.as_raw(),
232 )
233 })
234 }
235
236 pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> Result<()> {
238 assert_eq!(dst.len(), self.len);
239 let bytes = self.len * size_of::<T>();
240 if bytes == 0 {
241 return Ok(());
242 }
243 let d = driver()?;
244 let cu = d.cu_memcpy_dtoh_async()?;
245 check(unsafe {
246 cu(
247 dst.as_mut_ptr() as *mut c_void,
248 self.ptr,
249 bytes,
250 stream.as_raw(),
251 )
252 })
253 }
254
255 pub fn copy_to_device(&self, dst: &DeviceBuffer<T>) -> Result<()> {
258 assert_eq!(dst.len, self.len);
259 let bytes = self.len * size_of::<T>();
260 if bytes == 0 {
261 return Ok(());
262 }
263 let d = driver()?;
264 let cu = d.cu_memcpy_dtod()?;
265 check(unsafe { cu(dst.ptr, self.ptr, bytes) })
266 }
267
268 pub fn copy_to_device_async(&self, dst: &DeviceBuffer<T>, stream: &Stream) -> Result<()> {
270 assert_eq!(dst.len, self.len);
271 let bytes = self.len * size_of::<T>();
272 if bytes == 0 {
273 return Ok(());
274 }
275 let d = driver()?;
276 let cu = d.cu_memcpy_dtod_async()?;
277 check(unsafe { cu(dst.ptr, self.ptr, bytes, stream.as_raw()) })
278 }
279
280 #[inline]
282 pub fn len(&self) -> usize {
283 self.len
284 }
285
286 #[inline]
288 pub fn byte_size(&self) -> usize {
289 self.len * size_of::<T>()
290 }
291
292 #[inline]
294 pub fn is_empty(&self) -> bool {
295 self.len == 0
296 }
297
298 #[inline]
300 pub fn context(&self) -> &Context {
301 &self.context
302 }
303
304 #[inline]
306 pub fn as_raw(&self) -> CUdeviceptr {
307 self.ptr
308 }
309
310 #[inline]
312 pub fn as_slice(&self) -> DeviceSlice<'_, T> {
313 DeviceSlice {
314 ptr: self.ptr,
315 len: self.len,
316 _marker: PhantomData,
317 }
318 }
319
320 #[inline]
322 pub fn as_slice_mut(&mut self) -> DeviceSliceMut<'_, T> {
323 DeviceSliceMut {
324 ptr: self.ptr,
325 len: self.len,
326 _marker: PhantomData,
327 }
328 }
329
330 #[inline]
345 pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
346 assert!(
347 range.start <= range.end && range.end <= self.len,
348 "DeviceBuffer::slice({}..{}) out of bounds for len {}",
349 range.start,
350 range.end,
351 self.len,
352 );
353 DeviceSlice {
354 ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
355 len: range.end - range.start,
356 _marker: PhantomData,
357 }
358 }
359
360 #[inline]
362 pub fn slice_mut(&mut self, range: Range<usize>) -> DeviceSliceMut<'_, T> {
363 assert!(
364 range.start <= range.end && range.end <= self.len,
365 "DeviceBuffer::slice_mut({}..{}) out of bounds for len {}",
366 range.start,
367 range.end,
368 self.len,
369 );
370 DeviceSliceMut {
371 ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
372 len: range.end - range.start,
373 _marker: PhantomData,
374 }
375 }
376}
377
378impl DeviceBuffer<u8> {
379 #[inline]
396 pub fn view_as<U: DeviceRepr>(&self) -> DeviceSlice<'_, U> {
397 let elem = size_of::<U>();
398 if elem == 0 {
399 return DeviceSlice {
400 ptr: self.ptr,
401 len: 0,
402 _marker: PhantomData,
403 };
404 }
405 assert!(
406 self.len % elem == 0,
407 "DeviceBuffer<u8>::view_as: byte length {} not divisible by size_of::<{}>() = {}",
408 self.len,
409 core::any::type_name::<U>(),
410 elem,
411 );
412 DeviceSlice {
413 ptr: self.ptr,
414 len: self.len / elem,
415 _marker: PhantomData,
416 }
417 }
418
419 #[inline]
421 pub fn view_as_mut<U: DeviceRepr>(&mut self) -> DeviceSliceMut<'_, U> {
422 let elem = size_of::<U>();
423 if elem == 0 {
424 return DeviceSliceMut {
425 ptr: self.ptr,
426 len: 0,
427 _marker: PhantomData,
428 };
429 }
430 assert!(
431 self.len % elem == 0,
432 "DeviceBuffer<u8>::view_as_mut: byte length {} not divisible by size_of::<{}>() = {}",
433 self.len,
434 core::any::type_name::<U>(),
435 elem,
436 );
437 DeviceSliceMut {
438 ptr: self.ptr,
439 len: self.len / elem,
440 _marker: PhantomData,
441 }
442 }
443}
444
445#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
449pub enum ManagedAttach {
450 #[default]
452 Global,
453 Host,
455 Single,
457}
458
459impl ManagedAttach {
460 #[inline]
461 fn raw(self) -> u32 {
462 use baracuda_cuda_sys::types::CUmemAttach_flags as F;
463 match self {
464 ManagedAttach::Global => F::GLOBAL,
465 ManagedAttach::Host => F::HOST,
466 ManagedAttach::Single => F::SINGLE,
467 }
468 }
469}
470
471#[derive(Copy, Clone, Debug, Eq, PartialEq)]
473pub enum MemAdvise {
474 SetReadMostly,
475 UnsetReadMostly,
476 SetPreferredLocation,
477 UnsetPreferredLocation,
478 SetAccessedBy,
479 UnsetAccessedBy,
480}
481
482impl MemAdvise {
483 #[inline]
484 fn raw(self) -> i32 {
485 use baracuda_cuda_sys::types::CUmem_advise as A;
486 match self {
487 MemAdvise::SetReadMostly => A::SET_READ_MOSTLY,
488 MemAdvise::UnsetReadMostly => A::UNSET_READ_MOSTLY,
489 MemAdvise::SetPreferredLocation => A::SET_PREFERRED_LOCATION,
490 MemAdvise::UnsetPreferredLocation => A::UNSET_PREFERRED_LOCATION,
491 MemAdvise::SetAccessedBy => A::SET_ACCESSED_BY,
492 MemAdvise::UnsetAccessedBy => A::UNSET_ACCESSED_BY,
493 }
494 }
495}
496
497pub struct ManagedBuffer<T: DeviceRepr> {
507 ptr: CUdeviceptr,
508 len: usize,
509 context: Context,
510 _marker: PhantomData<T>,
511}
512
513unsafe impl<T: DeviceRepr + Send> Send for ManagedBuffer<T> {}
514
515impl<T: DeviceRepr> core::fmt::Debug for ManagedBuffer<T> {
516 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
517 f.debug_struct("ManagedBuffer")
518 .field("ptr", &format_args!("{:#x}", self.ptr.0))
519 .field("len", &self.len)
520 .field("type", &core::any::type_name::<T>())
521 .finish()
522 }
523}
524
525impl<T: DeviceRepr> ManagedBuffer<T> {
526 pub fn new(context: &Context, len: usize) -> Result<Self> {
529 Self::new_with_flags(context, len, ManagedAttach::Global)
530 }
531
532 pub fn new_with_flags(context: &Context, len: usize, attach: ManagedAttach) -> Result<Self> {
534 context.set_current()?;
535 let d = driver()?;
536 let cu = d.cu_mem_alloc_managed()?;
537 let bytes = len
538 .checked_mul(size_of::<T>())
539 .expect("overflow computing allocation size");
540 let mut ptr = CUdeviceptr(0);
541 check(unsafe { cu(&mut ptr, bytes, attach.raw()) })?;
543 Ok(Self {
544 ptr,
545 len,
546 context: context.clone(),
547 _marker: PhantomData,
548 })
549 }
550
551 pub fn advise(&self, advice: MemAdvise, device: &crate::Device) -> Result<()> {
556 let d = driver()?;
557 let cu = d.cu_mem_advise()?;
558 let bytes = self.len * size_of::<T>();
559 check(unsafe { cu(self.ptr, bytes, advice.raw(), device.as_raw()) })
560 }
561
562 pub fn prefetch_async(&self, device: &crate::Device, stream: &Stream) -> Result<()> {
564 let d = driver()?;
565 let cu = d.cu_mem_prefetch_async()?;
566 let bytes = self.len * size_of::<T>();
567 check(unsafe { cu(self.ptr, bytes, device.as_raw(), stream.as_raw()) })
568 }
569
570 pub unsafe fn as_host_slice(&self) -> &[T] { unsafe {
580 core::slice::from_raw_parts(self.ptr.0 as *const T, self.len)
581 }}
582
583 pub unsafe fn as_host_slice_mut(&mut self) -> &mut [T] { unsafe {
589 core::slice::from_raw_parts_mut(self.ptr.0 as *mut T, self.len)
590 }}
591
592 #[inline]
594 pub fn len(&self) -> usize {
595 self.len
596 }
597
598 #[inline]
600 pub fn is_empty(&self) -> bool {
601 self.len == 0
602 }
603
604 #[inline]
606 pub fn as_raw(&self) -> CUdeviceptr {
607 self.ptr
608 }
609
610 #[inline]
612 pub fn context(&self) -> &Context {
613 &self.context
614 }
615}
616
617impl<T: DeviceRepr> Drop for ManagedBuffer<T> {
618 fn drop(&mut self) {
619 if self.ptr.0 == 0 {
620 return;
621 }
622 if let Ok(d) = driver() {
623 if let Ok(cu) = d.cu_mem_free() {
624 let _ = unsafe { cu(self.ptr) };
625 }
626 }
627 }
628}
629
630pub fn mem_get_info() -> Result<(u64, u64)> {
634 let d = driver()?;
635 let cu = d.cu_mem_get_info()?;
636 let mut free: usize = 0;
637 let mut total: usize = 0;
638 check(unsafe { cu(&mut free, &mut total) })?;
639 Ok((free as u64, total as u64))
640}
641
642pub fn memcpy_peer<T: DeviceRepr>(
646 dst: &DeviceBuffer<T>,
647 dst_ctx: &Context,
648 src: &DeviceBuffer<T>,
649 src_ctx: &Context,
650) -> Result<()> {
651 assert_eq!(dst.len(), src.len());
652 let d = driver()?;
653 let cu = d.cu_memcpy_peer()?;
654 let bytes = src.len() * size_of::<T>();
655 check(unsafe {
656 cu(
657 dst.as_raw(),
658 dst_ctx.as_raw(),
659 src.as_raw(),
660 src_ctx.as_raw(),
661 bytes,
662 )
663 })
664}
665
666pub fn memcpy_peer_async<T: DeviceRepr>(
668 dst: &DeviceBuffer<T>,
669 dst_ctx: &Context,
670 src: &DeviceBuffer<T>,
671 src_ctx: &Context,
672 stream: &Stream,
673) -> Result<()> {
674 assert_eq!(dst.len(), src.len());
675 let d = driver()?;
676 let cu = d.cu_memcpy_peer_async()?;
677 let bytes = src.len() * size_of::<T>();
678 check(unsafe {
679 cu(
680 dst.as_raw(),
681 dst_ctx.as_raw(),
682 src.as_raw(),
683 src_ctx.as_raw(),
684 bytes,
685 stream.as_raw(),
686 )
687 })
688}
689
690pub fn memset_u16(dst: CUdeviceptr, value: u16, count: usize) -> Result<()> {
692 let d = driver()?;
693 let cu = d.cu_memset_d16()?;
694 check(unsafe { cu(dst, value, count) })
695}
696
697pub fn memset_u16_async(dst: CUdeviceptr, value: u16, count: usize, stream: &Stream) -> Result<()> {
699 let d = driver()?;
700 let cu = d.cu_memset_d16_async()?;
701 check(unsafe { cu(dst, value, count, stream.as_raw()) })
702}
703
704pub fn memset_u8_async(dst: CUdeviceptr, value: u8, count: usize, stream: &Stream) -> Result<()> {
706 let d = driver()?;
707 let cu = d.cu_memset_d8_async()?;
708 check(unsafe { cu(dst, value, count, stream.as_raw()) })
709}
710
711pub fn memset_u32_async(dst: CUdeviceptr, value: u32, count: usize, stream: &Stream) -> Result<()> {
713 let d = driver()?;
714 let cu = d.cu_memset_d32_async()?;
715 check(unsafe { cu(dst, value, count, stream.as_raw()) })
716}
717
718pub fn memset_u32(dst: CUdeviceptr, value: u32, count: usize) -> Result<()> {
720 let d = driver()?;
721 let cu = d.cu_memset_d32()?;
722 check(unsafe { cu(dst, value, count) })
723}
724
725pub fn memset_2d_u8(
728 dst: CUdeviceptr,
729 pitch: usize,
730 value: u8,
731 width: usize,
732 height: usize,
733) -> Result<()> {
734 let d = driver()?;
735 let cu = d.cu_memset_d2d8()?;
736 check(unsafe { cu(dst, pitch, value, width, height) })
737}
738
739pub fn memset_2d_u16(
741 dst: CUdeviceptr,
742 pitch: usize,
743 value: u16,
744 width: usize,
745 height: usize,
746) -> Result<()> {
747 let d = driver()?;
748 let cu = d.cu_memset_d2d16()?;
749 check(unsafe { cu(dst, pitch, value, width, height) })
750}
751
752pub fn memset_2d_u32(
754 dst: CUdeviceptr,
755 pitch: usize,
756 value: u32,
757 width: usize,
758 height: usize,
759) -> Result<()> {
760 let d = driver()?;
761 let cu = d.cu_memset_d2d32()?;
762 check(unsafe { cu(dst, pitch, value, width, height) })
763}
764
765pub unsafe fn memcpy(dst: CUdeviceptr, src: CUdeviceptr, bytes: usize) -> Result<()> { unsafe {
775 let d = driver()?;
776 let cu = d.cu_memcpy()?;
777 check(cu(dst, src, bytes))
778}}
779
780pub unsafe fn memcpy_async(
787 dst: CUdeviceptr,
788 src: CUdeviceptr,
789 bytes: usize,
790 stream: &Stream,
791) -> Result<()> { unsafe {
792 let d = driver()?;
793 let cu = d.cu_memcpy_async()?;
794 check(cu(dst, src, bytes, stream.as_raw()))
795}}
796
797#[derive(Copy, Clone, Debug, Eq, PartialEq)]
802pub enum PrefetchTarget {
803 Device(i32),
805 Host,
807 HostNuma(i32),
809 HostNumaCurrent,
811}
812
813impl PrefetchTarget {
814 fn as_location(self) -> baracuda_cuda_sys::types::CUmemLocation {
815 use baracuda_cuda_sys::types::CUmemLocationType;
816 let (type_, id) = match self {
817 PrefetchTarget::Device(i) => (CUmemLocationType::DEVICE, i),
818 PrefetchTarget::Host => (CUmemLocationType::HOST, 0),
819 PrefetchTarget::HostNuma(n) => (CUmemLocationType::HOST_NUMA, n),
820 PrefetchTarget::HostNumaCurrent => (CUmemLocationType::HOST_NUMA_CURRENT, 0),
821 };
822 baracuda_cuda_sys::types::CUmemLocation { type_, id }
823 }
824}
825
826pub fn mem_prefetch_v2(
829 dptr: CUdeviceptr,
830 count: usize,
831 target: PrefetchTarget,
832 stream: &Stream,
833) -> Result<()> {
834 let d = driver()?;
835 let cu = d.cu_mem_prefetch_async_v2()?;
836 check(unsafe { cu(dptr, count, target.as_location(), 0, stream.as_raw()) })
837}
838
839pub fn mem_advise_v2(
841 dptr: CUdeviceptr,
842 count: usize,
843 advice: i32,
844 target: PrefetchTarget,
845) -> Result<()> {
846 let d = driver()?;
847 let cu = d.cu_mem_advise_v2()?;
848 check(unsafe { cu(dptr, count, advice, target.as_location()) })
849}
850
851pub fn retain_allocation_handle(
855 addr: CUdeviceptr,
856) -> Result<baracuda_cuda_sys::CUmemGenericAllocationHandle> {
857 let d = driver()?;
858 let cu = d.cu_mem_retain_allocation_handle()?;
859 let mut h: baracuda_cuda_sys::CUmemGenericAllocationHandle = 0;
860 check(unsafe { cu(&mut h, addr.0 as *mut core::ffi::c_void) })?;
861 Ok(h)
862}
863
864pub fn allocation_properties_from_handle(
866 handle: baracuda_cuda_sys::CUmemGenericAllocationHandle,
867) -> Result<baracuda_cuda_sys::types::CUmemAllocationProp> {
868 let d = driver()?;
869 let cu = d.cu_mem_get_allocation_properties_from_handle()?;
870 let mut prop = baracuda_cuda_sys::types::CUmemAllocationProp::default();
871 check(unsafe { cu(&mut prop, handle) })?;
872 Ok(prop)
873}
874
875pub unsafe fn get_handle_for_address_range(
884 handle_out: *mut core::ffi::c_void,
885 dptr: CUdeviceptr,
886 size: usize,
887 handle_type: i32,
888) -> Result<()> { unsafe {
889 let d = driver()?;
890 let cu = d.cu_mem_get_handle_for_address_range()?;
891 check(cu(handle_out, dptr, size, handle_type, 0))
892}}
893
894impl<T: DeviceRepr> Drop for DeviceBuffer<T> {
895 fn drop(&mut self) {
896 if self.ptr.0 == 0 {
897 return;
898 }
899 if let Ok(d) = driver() {
900 if let Ok(cu) = d.cu_mem_free() {
901 let _ = unsafe { cu(self.ptr) };
902 }
903 }
904 }
905}
906
907#[derive(Copy, Clone)]
909pub struct DeviceSlice<'a, T: DeviceRepr> {
910 pub(crate) ptr: CUdeviceptr,
911 pub(crate) len: usize,
912 pub(crate) _marker: PhantomData<&'a T>,
913}
914
915impl<'a, T: DeviceRepr> core::fmt::Debug for DeviceSlice<'a, T> {
916 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
917 f.debug_struct("DeviceSlice")
918 .field("ptr", &format_args!("{:#x}", self.ptr.0))
919 .field("len", &self.len)
920 .finish()
921 }
922}
923
924impl<'a, T: DeviceRepr> DeviceSlice<'a, T> {
925 #[inline]
926 pub fn len(&self) -> usize {
927 self.len
928 }
929 #[inline]
930 pub fn is_empty(&self) -> bool {
931 self.len == 0
932 }
933 #[inline]
934 pub fn as_raw(&self) -> CUdeviceptr {
935 self.ptr
936 }
937
938 #[inline]
962 pub unsafe fn from_raw_parts<'b>(ptr: CUdeviceptr, len: usize) -> DeviceSlice<'b, T> {
963 DeviceSlice {
964 ptr,
965 len,
966 _marker: PhantomData,
967 }
968 }
969
970 #[inline]
972 pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
973 assert!(
974 range.start <= range.end && range.end <= self.len,
975 "DeviceSlice::slice({}..{}) out of bounds for len {}",
976 range.start,
977 range.end,
978 self.len,
979 );
980 DeviceSlice {
981 ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
982 len: range.end - range.start,
983 _marker: PhantomData,
984 }
985 }
986}
987
988pub struct DeviceSliceMut<'a, T: DeviceRepr> {
990 pub(crate) ptr: CUdeviceptr,
991 pub(crate) len: usize,
992 pub(crate) _marker: PhantomData<&'a mut T>,
993}
994
995impl<'a, T: DeviceRepr> core::fmt::Debug for DeviceSliceMut<'a, T> {
996 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
997 f.debug_struct("DeviceSliceMut")
998 .field("ptr", &format_args!("{:#x}", self.ptr.0))
999 .field("len", &self.len)
1000 .finish()
1001 }
1002}
1003
1004impl<'a, T: DeviceRepr> DeviceSliceMut<'a, T> {
1005 #[inline]
1006 pub fn len(&self) -> usize {
1007 self.len
1008 }
1009 #[inline]
1010 pub fn is_empty(&self) -> bool {
1011 self.len == 0
1012 }
1013 #[inline]
1014 pub fn as_raw(&self) -> CUdeviceptr {
1015 self.ptr
1016 }
1017
1018 #[inline]
1032 pub unsafe fn from_raw_parts<'b>(ptr: CUdeviceptr, len: usize) -> DeviceSliceMut<'b, T> {
1033 DeviceSliceMut {
1034 ptr,
1035 len,
1036 _marker: PhantomData,
1037 }
1038 }
1039
1040 #[inline]
1042 pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
1043 assert!(
1044 range.start <= range.end && range.end <= self.len,
1045 "DeviceSliceMut::slice({}..{}) out of bounds for len {}",
1046 range.start,
1047 range.end,
1048 self.len,
1049 );
1050 DeviceSlice {
1051 ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
1052 len: range.end - range.start,
1053 _marker: PhantomData,
1054 }
1055 }
1056
1057 #[inline]
1059 pub fn slice_mut(&mut self, range: Range<usize>) -> DeviceSliceMut<'_, T> {
1060 assert!(
1061 range.start <= range.end && range.end <= self.len,
1062 "DeviceSliceMut::slice_mut({}..{}) out of bounds for len {}",
1063 range.start,
1064 range.end,
1065 self.len,
1066 );
1067 DeviceSliceMut {
1068 ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
1069 len: range.end - range.start,
1070 _marker: PhantomData,
1071 }
1072 }
1073
1074 pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
1079 assert_eq!(src.len(), self.len);
1080 let bytes = self.len * size_of::<T>();
1081 if bytes == 0 {
1082 return Ok(());
1083 }
1084 let d = driver()?;
1085 let cu = d.cu_memcpy_htod_async()?;
1086 check(unsafe {
1087 cu(
1088 self.ptr,
1089 src.as_ptr() as *const c_void,
1090 bytes,
1091 stream.as_raw(),
1092 )
1093 })
1094 }
1095}
1096
1097pub unsafe trait DevicePtr<T: DeviceRepr> {
1126 fn device_ptr(&self) -> CUdeviceptr;
1128
1129 fn len(&self) -> usize;
1131
1132 #[inline]
1134 fn is_empty(&self) -> bool {
1135 self.len() == 0
1136 }
1137
1138 #[inline]
1140 fn byte_size(&self) -> usize {
1141 self.len() * core::mem::size_of::<T>()
1142 }
1143}
1144
1145pub unsafe trait DevicePtrMut<T: DeviceRepr>: DevicePtr<T> {
1160 fn device_ptr_mut(&mut self) -> CUdeviceptr;
1162}
1163
1164unsafe impl<T: DeviceRepr> DevicePtr<T> for DeviceBuffer<T> {
1167 #[inline]
1168 fn device_ptr(&self) -> CUdeviceptr {
1169 self.ptr
1170 }
1171 #[inline]
1172 fn len(&self) -> usize {
1173 self.len
1174 }
1175}
1176
1177unsafe impl<T: DeviceRepr> DevicePtrMut<T> for DeviceBuffer<T> {
1178 #[inline]
1179 fn device_ptr_mut(&mut self) -> CUdeviceptr {
1180 self.ptr
1181 }
1182}
1183
1184unsafe impl<'a, T: DeviceRepr> DevicePtr<T> for DeviceSlice<'a, T> {
1185 #[inline]
1186 fn device_ptr(&self) -> CUdeviceptr {
1187 self.ptr
1188 }
1189 #[inline]
1190 fn len(&self) -> usize {
1191 self.len
1192 }
1193}
1194
1195unsafe impl<'a, T: DeviceRepr> DevicePtr<T> for DeviceSliceMut<'a, T> {
1196 #[inline]
1197 fn device_ptr(&self) -> CUdeviceptr {
1198 self.ptr
1199 }
1200 #[inline]
1201 fn len(&self) -> usize {
1202 self.len
1203 }
1204}
1205
1206unsafe impl<'a, T: DeviceRepr> DevicePtrMut<T> for DeviceSliceMut<'a, T> {
1207 #[inline]
1208 fn device_ptr_mut(&mut self) -> CUdeviceptr {
1209 self.ptr
1210 }
1211}
1212
1213unsafe impl<T: DeviceRepr, P: DevicePtr<T> + ?Sized> DevicePtr<T> for &P {
1215 #[inline]
1216 fn device_ptr(&self) -> CUdeviceptr {
1217 (**self).device_ptr()
1218 }
1219 #[inline]
1220 fn len(&self) -> usize {
1221 (**self).len()
1222 }
1223}
1224
1225unsafe impl<T: DeviceRepr, P: DevicePtr<T> + ?Sized> DevicePtr<T> for &mut P {
1226 #[inline]
1227 fn device_ptr(&self) -> CUdeviceptr {
1228 (**self).device_ptr()
1229 }
1230 #[inline]
1231 fn len(&self) -> usize {
1232 (**self).len()
1233 }
1234}
1235
1236unsafe impl<T: DeviceRepr, P: DevicePtrMut<T> + ?Sized> DevicePtrMut<T> for &mut P {
1237 #[inline]
1238 fn device_ptr_mut(&mut self) -> CUdeviceptr {
1239 (**self).device_ptr_mut()
1240 }
1241}
1242
1243unsafe impl<T: DeviceRepr> KernelArg for &DeviceBuffer<T> {
1260 #[inline]
1261 fn as_kernel_arg_ptr(&self) -> *mut c_void {
1262 &self.ptr as *const CUdeviceptr as *mut c_void
1263 }
1264}
1265
1266unsafe impl<T: DeviceRepr> KernelArg for &mut DeviceBuffer<T> {
1267 #[inline]
1268 fn as_kernel_arg_ptr(&self) -> *mut c_void {
1269 &self.ptr as *const CUdeviceptr as *mut c_void
1270 }
1271}
1272
1273unsafe impl<'a, T: DeviceRepr> KernelArg for &DeviceSlice<'a, T> {
1274 #[inline]
1275 fn as_kernel_arg_ptr(&self) -> *mut c_void {
1276 &self.ptr as *const CUdeviceptr as *mut c_void
1277 }
1278}
1279
1280unsafe impl<'a, T: DeviceRepr> KernelArg for &DeviceSliceMut<'a, T> {
1281 #[inline]
1282 fn as_kernel_arg_ptr(&self) -> *mut c_void {
1283 &self.ptr as *const CUdeviceptr as *mut c_void
1284 }
1285}
1286
1287unsafe impl<'a, T: DeviceRepr> KernelArg for &mut DeviceSliceMut<'a, T> {
1288 #[inline]
1289 fn as_kernel_arg_ptr(&self) -> *mut c_void {
1290 &self.ptr as *const CUdeviceptr as *mut c_void
1291 }
1292}
1293
1294#[cfg(test)]
1295mod slice_tests {
1296 use super::*;
1298
1299 fn fake_slice<T: DeviceRepr>(ptr: u64, len: usize) -> DeviceSlice<'static, T> {
1300 DeviceSlice {
1301 ptr: CUdeviceptr(ptr),
1302 len,
1303 _marker: PhantomData,
1304 }
1305 }
1306
1307 #[test]
1308 fn slice_offsets_ptr_by_element_bytes() {
1309 let s: DeviceSlice<'_, f32> = fake_slice(0x1000, 16);
1310 let sub = s.slice(4..12);
1311 assert_eq!(sub.len(), 8);
1312 assert_eq!(sub.as_raw().0, 0x1000 + 4 * 4); }
1314
1315 #[test]
1316 fn slice_of_slice_stays_correct() {
1317 let s: DeviceSlice<'_, f64> = fake_slice(0x2000, 100);
1318 let mid = s.slice(10..90);
1319 let inner = mid.slice(5..15);
1320 assert_eq!(inner.len(), 10);
1321 assert_eq!(inner.as_raw().0, 0x2000 + 15 * 8);
1323 }
1324
1325 #[test]
1326 #[should_panic(expected = "out of bounds")]
1327 fn slice_end_past_len_panics() {
1328 let s: DeviceSlice<'_, u8> = fake_slice(0, 10);
1329 let _ = s.slice(0..11);
1330 }
1331
1332 #[test]
1333 #[should_panic(expected = "out of bounds")]
1334 #[allow(clippy::reversed_empty_ranges)]
1335 fn slice_inverted_range_panics() {
1338 let s: DeviceSlice<'_, u8> = fake_slice(0, 10);
1339 let _ = s.slice(5..3);
1340 }
1341
1342 #[test]
1345 fn from_raw_parts_preserves_ptr_and_len() {
1346 let s: DeviceSlice<'static, f32> =
1349 unsafe { DeviceSlice::from_raw_parts(CUdeviceptr(0x4000), 32) };
1350 assert_eq!(s.as_raw().0, 0x4000);
1351 assert_eq!(s.len(), 32);
1352 }
1353
1354 #[test]
1355 fn from_raw_parts_mut_preserves_ptr_and_len() {
1356 let s: DeviceSliceMut<'static, u32> =
1357 unsafe { DeviceSliceMut::from_raw_parts(CUdeviceptr(0x8000), 64) };
1358 assert_eq!(s.as_raw().0, 0x8000);
1359 assert_eq!(s.len(), 64);
1360 }
1361}
1362
1363#[cfg(test)]
1364mod kernel_arg_tests {
1365 use super::*;
1371 use core::mem::size_of;
1372
1373 #[test]
1374 fn slice_kernel_arg_points_at_ptr_field() {
1375 let slice: DeviceSlice<'_, f32> = DeviceSlice {
1376 ptr: CUdeviceptr(0xDEAD_BEEF_u64),
1377 len: 42,
1378 _marker: PhantomData,
1379 };
1380 let kernel_arg = (&slice).as_kernel_arg_ptr();
1381 unsafe {
1383 let as_u64 = *(kernel_arg as *const u64);
1384 assert_eq!(as_u64, 0xDEAD_BEEF);
1385 }
1386 let slice_start = &slice as *const _ as usize;
1388 let slice_end = slice_start + size_of::<DeviceSlice<'_, f32>>();
1389 let arg_addr = kernel_arg as usize;
1390 assert!((slice_start..slice_end).contains(&arg_addr));
1391 }
1392}