1use crate::driver::{
2 result::{self, DriverError},
3 sys::{self, CUfunc_cache_enum, CUfunction_attribute_enum},
4};
5
6use std::{
7 ffi::CString,
8 marker::PhantomData,
9 ops::{Bound, RangeBounds},
10 string::String,
11 sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering},
12 sync::Arc,
13 vec::Vec,
14};
15
16#[derive(Debug)]
31pub struct CudaContext {
32 pub(crate) cu_device: sys::CUdevice,
33 pub(crate) cu_ctx: sys::CUcontext,
34 pub(crate) ordinal: usize,
35 pub(crate) has_async_alloc: bool,
36 pub(crate) is_primary: bool,
40 pub(crate) num_streams: AtomicUsize,
41 pub(crate) event_tracking: AtomicBool,
42 pub(crate) error_state: AtomicU32,
43}
44
45unsafe impl Send for CudaContext {}
46unsafe impl Sync for CudaContext {}
47
48impl Drop for CudaContext {
49 fn drop(&mut self) {
50 self.record_err(self.bind_to_thread());
51 let ctx = std::mem::replace(&mut self.cu_ctx, std::ptr::null_mut());
52 if !ctx.is_null() {
53 if self.is_primary {
54 self.record_err(unsafe { result::primary_ctx::release(self.cu_device) });
55 } else {
56 self.record_err(unsafe { sys::cuCtxDestroy_v2(ctx).result() });
58 }
59 }
60 }
61}
62
63impl PartialEq for CudaContext {
64 fn eq(&self, other: &Self) -> bool {
65 self.cu_device == other.cu_device
66 && self.cu_ctx == other.cu_ctx
67 && self.ordinal == other.ordinal
68 }
69}
70impl Eq for CudaContext {}
71
72impl CudaContext {
73 pub fn new(ordinal: usize) -> Result<Arc<Self>, DriverError> {
75 result::init()?;
76 let cu_device = result::device::get(ordinal as i32)?;
77 let cu_ctx = unsafe { result::primary_ctx::retain(cu_device) }?;
78 let has_async_alloc = unsafe {
79 let memory_pools_supported = result::device::get_attribute(
80 cu_device,
81 sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
82 )?;
83 memory_pools_supported > 0
84 };
85 let ctx = Arc::new(CudaContext {
86 cu_device,
87 cu_ctx,
88 ordinal,
89 has_async_alloc,
90 is_primary: true,
91 num_streams: AtomicUsize::new(0),
92 event_tracking: AtomicBool::new(true),
93 error_state: AtomicU32::new(0),
94 });
95 ctx.bind_to_thread()?;
96 Ok(ctx)
97 }
98
99 #[cfg(any(
109 feature = "cuda-11040",
110 feature = "cuda-11050",
111 feature = "cuda-11060",
112 feature = "cuda-11070",
113 feature = "cuda-11080",
114 feature = "cuda-12000",
115 feature = "cuda-12010",
116 feature = "cuda-12020",
117 feature = "cuda-12030",
118 feature = "cuda-12040",
119 feature = "cuda-12050",
120 feature = "cuda-12060",
121 feature = "cuda-12080",
122 feature = "cuda-12090",
123 feature = "cuda-13000",
124 feature = "cuda-13010"
125 ))]
126 pub fn new_non_primary(ordinal: usize, flags: u32) -> Result<Arc<Self>, DriverError> {
127 result::init()?;
128 let cu_device = result::device::get(ordinal as i32)?;
129
130 #[cfg(any(
131 feature = "cuda-12050",
132 feature = "cuda-12060",
133 feature = "cuda-12080",
134 feature = "cuda-12090",
135 feature = "cuda-13000",
136 feature = "cuda-13010"
137 ))]
138 let cu_ctx = unsafe { result::ctx::create_v4(std::ptr::null_mut(), flags, cu_device) }?;
139
140 #[cfg(not(any(
141 feature = "cuda-12050",
142 feature = "cuda-12060",
143 feature = "cuda-12080",
144 feature = "cuda-12090",
145 feature = "cuda-13000",
146 feature = "cuda-13010"
147 )))]
148 let cu_ctx = unsafe { result::ctx::create_v3(flags, cu_device) }?;
149
150 let has_async_alloc = unsafe {
151 let memory_pools_supported = result::device::get_attribute(
152 cu_device,
153 sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
154 )?;
155 memory_pools_supported > 0
156 };
157 let ctx = Arc::new(CudaContext {
158 cu_device,
159 cu_ctx,
160 ordinal,
161 has_async_alloc,
162 is_primary: false,
163 num_streams: AtomicUsize::new(0),
164 event_tracking: AtomicBool::new(true),
165 error_state: AtomicU32::new(0),
166 });
167 ctx.bind_to_thread()?;
168 Ok(ctx)
169 }
170
171 #[cfg(any(
181 feature = "cuda-12050",
182 feature = "cuda-12060",
183 feature = "cuda-12080",
184 feature = "cuda-12090",
185 feature = "cuda-13000",
186 feature = "cuda-13010"
187 ))]
188 pub fn new_cig(
189 ordinal: usize,
190 flags: u32,
191 cig_params: &mut sys::CUctxCigParam,
192 ) -> Result<Arc<Self>, DriverError> {
193 result::init()?;
194 let cu_device = result::device::get(ordinal as i32)?;
195 let mut ctx_create_params = sys::CUctxCreateParams_st {
196 execAffinityParams: std::ptr::null_mut(),
197 numExecAffinityParams: 0,
198 cigParams: cig_params as *mut sys::CUctxCigParam,
199 };
200 let cu_ctx = unsafe { result::ctx::create_v4(&mut ctx_create_params, flags, cu_device) }?;
201 let has_async_alloc = unsafe {
202 let memory_pools_supported = result::device::get_attribute(
203 cu_device,
204 sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
205 )?;
206 memory_pools_supported > 0
207 };
208 let ctx = Arc::new(CudaContext {
209 cu_device,
210 cu_ctx,
211 ordinal,
212 has_async_alloc,
213 is_primary: false,
214 num_streams: AtomicUsize::new(0),
215 event_tracking: AtomicBool::new(true),
216 error_state: AtomicU32::new(0),
217 });
218 ctx.bind_to_thread()?;
219 Ok(ctx)
220 }
221
222 pub unsafe fn from_raw_context(
234 ordinal: usize,
235 cu_device: sys::CUdevice,
236 cu_ctx: sys::CUcontext,
237 ) -> Result<Arc<Self>, DriverError> {
238 let has_async_alloc = {
239 let memory_pools_supported = result::device::get_attribute(
240 cu_device,
241 sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
242 )?;
243 memory_pools_supported > 0
244 };
245 let ctx = Arc::new(CudaContext {
246 cu_device,
247 cu_ctx,
248 ordinal,
249 has_async_alloc,
250 is_primary: false,
251 num_streams: AtomicUsize::new(0),
252 event_tracking: AtomicBool::new(true),
253 error_state: AtomicU32::new(0),
254 });
255 ctx.bind_to_thread()?;
256 Ok(ctx)
257 }
258
259 pub fn is_primary(&self) -> bool {
264 self.is_primary
265 }
266
267 pub fn has_async_alloc(&self) -> bool {
275 self.has_async_alloc
276 }
277
278 pub fn device_count() -> Result<i32, DriverError> {
280 result::init()?;
281 result::device::get_count()
282 }
283
284 pub fn ordinal(&self) -> usize {
286 self.ordinal
287 }
288
289 pub fn name(&self) -> Result<String, result::DriverError> {
291 self.check_err()?;
292 result::device::get_name(self.cu_device)
293 }
294
295 pub fn uuid(&self) -> Result<sys::CUuuid, result::DriverError> {
297 self.check_err()?;
298 result::device::get_uuid(self.cu_device)
299 }
300
301 pub fn compute_capability(&self) -> Result<(i32, i32), result::DriverError> {
303 self.check_err()?;
304 let capability_major =
305 self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?;
306 let capability_minor =
307 self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)?;
308
309 Ok((capability_major, capability_minor))
310 }
311
312 pub fn total_mem(&self) -> Result<usize, DriverError> {
314 self.check_err()?;
315 unsafe { result::device::total_mem(self.cu_device) }
316 }
317
318 pub fn mem_get_info(&self) -> Result<(usize, usize), DriverError> {
322 self.bind_to_thread()?;
323 result::mem_get_info()
324 }
325 pub fn cu_device(&self) -> sys::CUdevice {
334 self.cu_device
335 }
336
337 pub fn cu_ctx(&self) -> sys::CUcontext {
346 self.cu_ctx
347 }
348
349 pub fn bind_to_thread(&self) -> Result<(), DriverError> {
351 self.check_err()?;
352 if match result::ctx::get_current()? {
353 Some(curr_ctx) => curr_ctx != self.cu_ctx,
354 None => true,
355 } {
356 unsafe { result::ctx::set_current(self.cu_ctx) }?;
357 }
358 Ok(())
359 }
360
361 pub fn attribute(&self, attrib: sys::CUdevice_attribute) -> Result<i32, result::DriverError> {
363 self.check_err()?;
364 unsafe { result::device::get_attribute(self.cu_device, attrib) }
365 }
366
367 pub fn synchronize(&self) -> Result<(), DriverError> {
370 self.bind_to_thread()?;
371 result::ctx::synchronize()
372 }
373
374 #[cfg(not(any(
378 feature = "cuda-11040",
379 feature = "cuda-11050",
380 feature = "cuda-11060",
381 feature = "cuda-11070",
382 feature = "cuda-11080",
383 feature = "cuda-12000"
384 )))]
385 pub fn set_blocking_synchronize(&self) -> Result<(), DriverError> {
386 self.set_flags(sys::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC)
387 }
388
389 #[cfg(not(any(
391 feature = "cuda-11040",
392 feature = "cuda-11050",
393 feature = "cuda-11060",
394 feature = "cuda-11070",
395 feature = "cuda-11080",
396 feature = "cuda-12000"
397 )))]
398 pub fn set_flags(&self, flags: sys::CUctx_flags) -> Result<(), DriverError> {
399 self.bind_to_thread()?;
400 result::ctx::set_flags(flags)
401 }
402
403 pub fn get_limit(&self, limit: sys::CUlimit) -> Result<usize, DriverError> {
407 self.bind_to_thread()?;
408 result::ctx::get_limit(limit)
409 }
410
411 pub fn set_limit(&self, limit: sys::CUlimit, value: usize) -> Result<(), DriverError> {
420 self.bind_to_thread()?;
421 result::ctx::set_limit(limit, value)
422 }
423
424 pub fn get_cache_config(&self) -> Result<sys::CUfunc_cache, DriverError> {
428 self.bind_to_thread()?;
429 result::ctx::get_cache_config()
430 }
431
432 pub fn set_cache_config(&self, config: sys::CUfunc_cache) -> Result<(), DriverError> {
442 self.bind_to_thread()?;
443 result::ctx::set_cache_config(config)
444 }
445
446 pub fn is_in_multi_stream_mode(&self) -> bool {
451 self.num_streams.load(Ordering::Relaxed) > 0
452 }
453
454 pub fn is_event_tracking(&self) -> bool {
459 self.event_tracking.load(Ordering::Relaxed)
460 }
461
462 pub fn is_managing_stream_synchronization(&self) -> bool {
467 self.is_in_multi_stream_mode() && self.is_event_tracking()
468 }
469
470 pub unsafe fn enable_event_tracking(&self) {
480 self.event_tracking.store(true, Ordering::Relaxed);
481 }
482
483 pub unsafe fn disable_event_tracking(&self) {
494 self.event_tracking.store(false, Ordering::Relaxed);
495 }
496
497 pub fn check_err(&self) -> Result<(), DriverError> {
503 let error_state = self.error_state.swap(0, Ordering::Relaxed);
504 if error_state == 0 {
505 Ok(())
506 } else {
507 Err(result::DriverError(unsafe {
508 std::mem::transmute::<u32, sys::cudaError_enum>(error_state)
509 }))
510 }
511 }
512
513 pub fn record_err<T>(&self, result: Result<T, DriverError>) {
515 if let Err(err) = result {
516 self.error_state.store(err.0 as u32, Ordering::Relaxed)
517 }
518 }
519}
520
521#[derive(Debug)]
532pub struct CudaEvent {
533 pub(crate) cu_event: sys::CUevent,
534 pub(crate) ctx: Arc<CudaContext>,
535}
536
537unsafe impl Send for CudaEvent {}
538unsafe impl Sync for CudaEvent {}
539
540impl Drop for CudaEvent {
541 fn drop(&mut self) {
542 self.ctx.record_err(self.ctx.bind_to_thread());
543 self.ctx
544 .record_err(unsafe { result::event::destroy(self.cu_event) });
545 }
546}
547
548impl CudaContext {
549 pub fn new_event(
552 self: &Arc<Self>,
553 flags: Option<sys::CUevent_flags>,
554 ) -> Result<CudaEvent, DriverError> {
555 let flags = flags.unwrap_or(sys::CUevent_flags::CU_EVENT_DISABLE_TIMING);
556 self.bind_to_thread()?;
557 let cu_event = result::event::create(flags)?;
558 Ok(CudaEvent {
559 cu_event,
560 ctx: self.clone(),
561 })
562 }
563}
564
565impl CudaEvent {
566 pub fn cu_event(&self) -> sys::CUevent {
571 self.cu_event
572 }
573
574 pub fn context(&self) -> &Arc<CudaContext> {
576 &self.ctx
577 }
578
579 pub fn record(&self, stream: &CudaStream) -> Result<(), DriverError> {
588 if self.ctx != stream.ctx {
589 return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
590 }
591 self.ctx.bind_to_thread()?;
592 unsafe { result::event::record(self.cu_event, stream.cu_stream) }
593 }
594
595 pub fn synchronize(&self) -> Result<(), DriverError> {
597 self.ctx.bind_to_thread()?;
598 unsafe { result::event::synchronize(self.cu_event) }
599 }
600
601 pub fn elapsed_ms(&self, end: &Self) -> Result<f32, DriverError> {
604 if self.ctx != end.ctx {
605 return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
606 }
607 self.ctx.bind_to_thread()?;
608 self.synchronize()?;
609 end.synchronize()?;
610 unsafe { result::event::elapsed(self.cu_event, end.cu_event) }
611 }
612
613 pub fn is_complete(&self) -> bool {
615 unsafe { result::event::query(self.cu_event) }.is_ok()
616 }
617}
618
619#[derive(Debug, PartialEq, Eq)]
631pub struct CudaStream {
632 pub(crate) cu_stream: sys::CUstream,
633 pub(crate) ctx: Arc<CudaContext>,
634}
635
636unsafe impl Send for CudaStream {}
637unsafe impl Sync for CudaStream {}
638
639impl Drop for CudaStream {
640 fn drop(&mut self) {
641 self.ctx.record_err(self.ctx.bind_to_thread());
642 let cu_stream = std::mem::replace(&mut self.cu_stream, std::ptr::null_mut());
643 if !cu_stream.is_null() && cu_stream != (0x2 as _) {
644 self.ctx.num_streams.fetch_sub(1, Ordering::Relaxed);
645 self.ctx
646 .record_err(unsafe { result::stream::destroy(cu_stream) });
647 }
648 }
649}
650
651impl CudaContext {
652 pub fn default_stream(self: &Arc<Self>) -> Arc<CudaStream> {
655 Arc::new(CudaStream {
656 cu_stream: std::ptr::null_mut(),
657 ctx: self.clone(),
658 })
659 }
660
661 pub fn per_thread_stream(self: &Arc<Self>) -> Arc<CudaStream> {
663 Arc::new(CudaStream {
664 cu_stream: 0x2 as _,
666 ctx: self.clone(),
667 })
668 }
669
670 pub fn new_stream(self: &Arc<Self>) -> Result<Arc<CudaStream>, DriverError> {
675 self.bind_to_thread()?;
676 let prev_num_streams = self.num_streams.fetch_add(1, Ordering::Relaxed);
677 if prev_num_streams == 0 && self.is_event_tracking() {
678 self.synchronize()?;
679 }
680 let cu_stream = result::stream::create(result::stream::StreamKind::NonBlocking)?;
681 Ok(Arc::new(CudaStream {
682 cu_stream,
683 ctx: self.clone(),
684 }))
685 }
686}
687
688impl CudaStream {
689 pub fn fork(&self) -> Result<Arc<Self>, DriverError> {
691 self.ctx.bind_to_thread()?;
692 self.ctx.num_streams.fetch_add(1, Ordering::Relaxed);
693 let cu_stream = result::stream::create(result::stream::StreamKind::NonBlocking)?;
694 let stream = Arc::new(CudaStream {
695 cu_stream,
696 ctx: self.ctx.clone(),
697 });
698 stream.join(self)?;
699 Ok(stream)
700 }
701
702 pub fn cu_stream(&self) -> sys::CUstream {
706 self.cu_stream
707 }
708
709 pub fn context(&self) -> &Arc<CudaContext> {
711 &self.ctx
712 }
713
714 pub fn synchronize(&self) -> Result<(), DriverError> {
719 self.ctx.bind_to_thread()?;
720 unsafe { result::stream::synchronize(self.cu_stream) }
721 }
722
723 pub fn record_event(
725 &self,
726 flags: Option<sys::CUevent_flags>,
727 ) -> Result<CudaEvent, DriverError> {
728 let event = self.ctx.new_event(flags)?;
729 event.record(self)?;
730 Ok(event)
731 }
732
733 pub fn wait(&self, event: &CudaEvent) -> Result<(), DriverError> {
740 if self.ctx != event.ctx {
741 return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
742 }
743 self.ctx.bind_to_thread()?;
744 unsafe {
745 result::stream::wait_event(
746 self.cu_stream,
747 event.cu_event,
748 sys::CUevent_wait_flags::CU_EVENT_WAIT_DEFAULT,
749 )
750 }
751 }
752
753 pub fn join(&self, other: &CudaStream) -> Result<(), DriverError> {
756 self.wait(&other.record_event(None)?)
757 }
758}
759
760#[derive(Debug)]
764pub struct CudaSlice<T> {
765 pub(crate) cu_device_ptr: sys::CUdeviceptr,
766 pub(crate) len: usize,
767 pub(crate) read: Option<CudaEvent>,
768 pub(crate) write: Option<CudaEvent>,
769 pub(crate) stream: Arc<CudaStream>,
770 pub(crate) marker: PhantomData<*const T>,
771}
772
773unsafe impl<T> Send for CudaSlice<T> {}
774unsafe impl<T> Sync for CudaSlice<T> {}
775
776impl<T> Drop for CudaSlice<T> {
777 fn drop(&mut self) {
778 let ctx = &self.stream.ctx;
779 if let Some(read) = self.read.as_ref() {
780 ctx.record_err(self.stream.wait(read));
781 }
782 if let Some(write) = self.write.as_ref() {
783 ctx.record_err(self.stream.wait(write));
784 }
785 if ctx.has_async_alloc {
786 ctx.record_err(unsafe {
787 result::free_async(self.cu_device_ptr, self.stream.cu_stream)
788 });
789 } else {
790 ctx.record_err(self.stream.synchronize());
791 ctx.record_err(unsafe { result::free_sync(self.cu_device_ptr) });
792 }
793 }
794}
795
796impl<T> CudaSlice<T> {
797 pub fn len(&self) -> usize {
799 self.len
800 }
801
802 pub fn num_bytes(&self) -> usize {
804 self.len * std::mem::size_of::<T>()
805 }
806
807 pub fn is_empty(&self) -> bool {
809 self.len == 0
810 }
811
812 pub fn ordinal(&self) -> usize {
814 self.stream.ctx.ordinal
815 }
816
817 pub fn context(&self) -> &Arc<CudaContext> {
819 &self.stream.ctx
820 }
821
822 pub fn stream(&self) -> &Arc<CudaStream> {
824 &self.stream
825 }
826}
827
828impl<T: DeviceRepr> CudaSlice<T> {
829 pub fn try_clone(&self) -> Result<Self, result::DriverError> {
831 self.stream.clone_dtod(self)
832 }
833}
834
835impl<T: DeviceRepr> Clone for CudaSlice<T> {
836 fn clone(&self) -> Self {
837 self.try_clone().unwrap()
838 }
839}
840
841impl<T: Clone + Default + DeviceRepr> TryFrom<CudaSlice<T>> for Vec<T> {
842 type Error = result::DriverError;
843 fn try_from(value: CudaSlice<T>) -> Result<Self, Self::Error> {
844 value.stream.clone_dtoh(&value)
845 }
846}
847
848#[derive(Debug)]
850pub struct CudaView<'a, T> {
851 pub(crate) ptr: sys::CUdeviceptr,
852 pub(crate) len: usize,
853 pub(crate) read: &'a Option<CudaEvent>,
854 pub(crate) write: &'a Option<CudaEvent>,
855 pub(crate) stream: &'a Arc<CudaStream>,
856 marker: PhantomData<&'a [T]>,
857}
858
859impl<T> CudaSlice<T> {
860 pub fn as_view(&self) -> CudaView<'_, T> {
861 CudaView {
862 ptr: self.cu_device_ptr,
863 len: self.len,
864 read: &self.read,
865 write: &self.write,
866 stream: &self.stream,
867 marker: PhantomData,
868 }
869 }
870}
871
872impl<T> CudaView<'_, T> {
873 pub fn len(&self) -> usize {
875 self.len
876 }
877
878 pub fn is_empty(&self) -> bool {
879 self.len == 0
880 }
881
882 fn resize(&self, start: usize, end: usize) -> Self {
883 assert!(start <= end && end <= self.len);
884 Self {
885 ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
886 len: end - start,
887 read: self.read,
888 write: self.write,
889 stream: self.stream,
890 marker: PhantomData,
891 }
892 }
893}
894
895#[derive(Debug)]
897pub struct CudaViewMut<'a, T> {
898 pub(crate) ptr: sys::CUdeviceptr,
899 pub(crate) len: usize,
900 pub(crate) read: &'a Option<CudaEvent>,
901 pub(crate) write: &'a Option<CudaEvent>,
902 pub(crate) stream: &'a Arc<CudaStream>,
903 marker: PhantomData<&'a mut [T]>,
904}
905
906impl<T> CudaSlice<T> {
907 pub fn as_view_mut(&mut self) -> CudaViewMut<'_, T> {
908 CudaViewMut {
909 ptr: self.cu_device_ptr,
910 len: self.len,
911 read: &self.read,
912 write: &self.write,
913 stream: &self.stream,
914 marker: PhantomData,
915 }
916 }
917}
918
919impl<T> CudaViewMut<'_, T> {
920 pub fn len(&self) -> usize {
922 self.len
923 }
924 pub fn is_empty(&self) -> bool {
925 self.len == 0
926 }
927
928 pub fn as_view<'b>(&'b self) -> CudaView<'b, T> {
930 CudaView {
931 ptr: self.ptr,
932 len: self.len,
933 read: self.read,
934 write: self.write,
935 stream: self.stream,
936 marker: PhantomData,
937 }
938 }
939}
940
941pub unsafe trait ValidAsZeroBits {}
948unsafe impl ValidAsZeroBits for bool {}
949unsafe impl ValidAsZeroBits for i8 {}
950unsafe impl ValidAsZeroBits for i16 {}
951unsafe impl ValidAsZeroBits for i32 {}
952unsafe impl ValidAsZeroBits for i64 {}
953unsafe impl ValidAsZeroBits for i128 {}
954unsafe impl ValidAsZeroBits for isize {}
955unsafe impl ValidAsZeroBits for u8 {}
956unsafe impl ValidAsZeroBits for u16 {}
957unsafe impl ValidAsZeroBits for u32 {}
958unsafe impl ValidAsZeroBits for u64 {}
959unsafe impl ValidAsZeroBits for u128 {}
960unsafe impl ValidAsZeroBits for usize {}
961unsafe impl ValidAsZeroBits for f32 {}
962unsafe impl ValidAsZeroBits for f64 {}
963#[cfg(feature = "f16")]
964unsafe impl ValidAsZeroBits for half::f16 {}
965#[cfg(feature = "f16")]
966unsafe impl ValidAsZeroBits for half::bf16 {}
967unsafe impl<T: ValidAsZeroBits, const M: usize> ValidAsZeroBits for [T; M] {}
968macro_rules! impl_tuples {
973 ($t:tt) => {
974 impl_tuples!(@ $t);
975 };
976 ($l:tt $(,$t:tt)+) => {
978 impl_tuples!($($t),+);
979 impl_tuples!(@ $l $(,$t)+);
980 };
981 (@ $($t:tt),+) => {
982 unsafe impl<$($t: ValidAsZeroBits,)+> ValidAsZeroBits for ($($t,)+) {}
983 };
984}
985impl_tuples!(A, B, C, D, E, F, G, H, I, J, K, L);
986
987pub unsafe trait DeviceRepr {}
996unsafe impl DeviceRepr for bool {}
997unsafe impl DeviceRepr for i8 {}
998unsafe impl DeviceRepr for i16 {}
999unsafe impl DeviceRepr for i32 {}
1000unsafe impl DeviceRepr for i64 {}
1001unsafe impl DeviceRepr for i128 {}
1002unsafe impl DeviceRepr for isize {}
1003unsafe impl DeviceRepr for u8 {}
1004unsafe impl DeviceRepr for u16 {}
1005unsafe impl DeviceRepr for u32 {}
1006unsafe impl DeviceRepr for u64 {}
1007unsafe impl DeviceRepr for u128 {}
1008unsafe impl DeviceRepr for usize {}
1009unsafe impl DeviceRepr for f32 {}
1010unsafe impl DeviceRepr for f64 {}
1011#[cfg(feature = "f16")]
1012unsafe impl DeviceRepr for half::f16 {}
1013#[cfg(feature = "f16")]
1014unsafe impl DeviceRepr for half::bf16 {}
1015
1016#[cfg(feature = "f8")]
1017unsafe impl DeviceRepr for float8::F8E4M3 {}
1018#[cfg(feature = "f8")]
1019unsafe impl ValidAsZeroBits for float8::F8E4M3 {}
1020
1021#[cfg(feature = "f8")]
1022unsafe impl DeviceRepr for float8::F8E5M2 {}
1023#[cfg(feature = "f8")]
1024unsafe impl ValidAsZeroBits for float8::F8E5M2 {}
1025
1026#[cfg(feature = "f4")]
1027unsafe impl DeviceRepr for float4::F4E2M1 {}
1028#[cfg(feature = "f4")]
1029unsafe impl ValidAsZeroBits for float4::F4E2M1 {}
1030
1031#[cfg(feature = "f4")]
1032unsafe impl DeviceRepr for float4::E8M0 {}
1033#[cfg(feature = "f4")]
1034unsafe impl ValidAsZeroBits for float4::E8M0 {}
1035
1036#[cfg(feature = "f4")]
1037unsafe impl DeviceRepr for float4::F4E2M1x2 {}
1038#[cfg(feature = "f4")]
1039unsafe impl ValidAsZeroBits for float4::F4E2M1x2 {}
1040
1041unsafe impl<const N: usize, T> DeviceRepr for [T; N] where T: DeviceRepr {}
1042
1043pub trait DeviceSlice<T> {
1047 fn len(&self) -> usize;
1048 fn num_bytes(&self) -> usize {
1049 self.len() * std::mem::size_of::<T>()
1050 }
1051 fn is_empty(&self) -> bool {
1052 self.len() == 0
1053 }
1054 fn stream(&self) -> &Arc<CudaStream>;
1055}
1056
1057impl<T> DeviceSlice<T> for CudaSlice<T> {
1058 fn len(&self) -> usize {
1059 self.len
1060 }
1061 fn stream(&self) -> &Arc<CudaStream> {
1062 &self.stream
1063 }
1064}
1065
1066impl<T> DeviceSlice<T> for CudaView<'_, T> {
1067 fn len(&self) -> usize {
1068 self.len
1069 }
1070 fn stream(&self) -> &Arc<CudaStream> {
1071 self.stream
1072 }
1073}
1074
1075impl<T> DeviceSlice<T> for CudaViewMut<'_, T> {
1076 fn len(&self) -> usize {
1077 self.len
1078 }
1079 fn stream(&self) -> &Arc<CudaStream> {
1080 self.stream
1081 }
1082}
1083
1084#[derive(Debug)]
1087#[must_use]
1088pub enum SyncOnDrop<'a> {
1089 Record(Option<(&'a CudaEvent, &'a CudaStream)>),
1091 Sync(Option<&'a CudaStream>),
1093}
1094
1095impl<'a> SyncOnDrop<'a> {
1096 pub fn record_event(event: &'a Option<CudaEvent>, stream: &'a CudaStream) -> Self {
1098 SyncOnDrop::Record(event.as_ref().map(|e| (e, stream)))
1099 }
1100 pub fn sync_stream(stream: &'a CudaStream) -> Self {
1102 SyncOnDrop::Sync(Some(stream))
1103 }
1104}
1105
1106impl Drop for SyncOnDrop<'_> {
1107 fn drop(&mut self) {
1108 match self {
1109 SyncOnDrop::Record(target) => {
1110 if let Some((event, stream)) = std::mem::take(target) {
1111 stream.ctx.record_err(event.record(stream));
1112 }
1113 }
1114 SyncOnDrop::Sync(target) => {
1115 if let Some(stream) = std::mem::take(target) {
1116 stream.ctx.record_err(stream.synchronize());
1117 }
1118 }
1119 }
1120 }
1121}
1122
1123pub trait DevicePtr<T>: DeviceSlice<T> {
1125 fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>);
1140}
1141
1142impl<T> DevicePtr<T> for CudaSlice<T> {
1143 fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1144 if self.stream.context().is_managing_stream_synchronization() {
1145 if let Some(write) = self.write.as_ref() {
1146 stream.ctx.record_err(stream.wait(write));
1147 }
1148 }
1149 (
1150 self.cu_device_ptr,
1151 SyncOnDrop::record_event(&self.read, stream),
1152 )
1153 }
1154}
1155
1156impl<T> DevicePtr<T> for CudaView<'_, T> {
1157 fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1158 if self.stream.context().is_managing_stream_synchronization() {
1159 if let Some(write) = self.write.as_ref() {
1160 stream.ctx.record_err(stream.wait(write));
1161 }
1162 }
1163 (self.ptr, SyncOnDrop::record_event(self.read, stream))
1164 }
1165}
1166
1167impl<T> DevicePtr<T> for CudaViewMut<'_, T> {
1168 fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1169 if self.stream.context().is_managing_stream_synchronization() {
1170 if let Some(write) = self.write.as_ref() {
1171 stream.ctx.record_err(stream.wait(write));
1172 }
1173 }
1174 (self.ptr, SyncOnDrop::record_event(self.read, stream))
1175 }
1176}
1177
1178pub trait DevicePtrMut<T>: DeviceSlice<T> {
1180 fn device_ptr_mut<'a>(
1195 &'a mut self,
1196 stream: &'a CudaStream,
1197 ) -> (sys::CUdeviceptr, SyncOnDrop<'a>);
1198}
1199
1200impl<T> DevicePtrMut<T> for CudaSlice<T> {
1201 fn device_ptr_mut<'a>(
1202 &'a mut self,
1203 stream: &'a CudaStream,
1204 ) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1205 if self.stream.context().is_managing_stream_synchronization() {
1206 if let Some(read) = self.read.as_ref() {
1207 stream.ctx.record_err(stream.wait(read));
1208 }
1209 if let Some(write) = self.write.as_ref() {
1210 stream.ctx.record_err(stream.wait(write));
1211 }
1212 }
1213 (
1214 self.cu_device_ptr,
1215 SyncOnDrop::record_event(&self.write, stream),
1216 )
1217 }
1218}
1219
1220impl<T> DevicePtrMut<T> for CudaViewMut<'_, T> {
1221 fn device_ptr_mut<'a>(
1222 &'a mut self,
1223 stream: &'a CudaStream,
1224 ) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1225 if self.stream.context().is_managing_stream_synchronization() {
1226 if let Some(read) = self.read.as_ref() {
1227 stream.ctx.record_err(stream.wait(read));
1228 }
1229 if let Some(write) = self.write.as_ref() {
1230 stream.ctx.record_err(stream.wait(write));
1231 }
1232 }
1233 (self.ptr, SyncOnDrop::record_event(self.write, stream))
1234 }
1235}
1236
1237pub trait HostSlice<T> {
1239 fn len(&self) -> usize;
1240 fn is_empty(&self) -> bool {
1241 self.len() == 0
1242 }
1243
1244 unsafe fn stream_synced_slice<'a>(
1248 &'a self,
1249 stream: &'a CudaStream,
1250 ) -> (&'a [T], SyncOnDrop<'a>);
1251
1252 unsafe fn stream_synced_mut_slice<'a>(
1256 &'a mut self,
1257 stream: &'a CudaStream,
1258 ) -> (&'a mut [T], SyncOnDrop<'a>);
1259}
1260
1261impl<T, const N: usize> HostSlice<T> for [T; N] {
1262 fn len(&self) -> usize {
1263 N
1264 }
1265 unsafe fn stream_synced_slice<'a>(
1266 &'a self,
1267 _stream: &'a CudaStream,
1268 ) -> (&'a [T], SyncOnDrop<'a>) {
1269 (self, SyncOnDrop::Sync(None))
1270 }
1271 unsafe fn stream_synced_mut_slice<'a>(
1272 &'a mut self,
1273 _stream: &'a CudaStream,
1274 ) -> (&'a mut [T], SyncOnDrop<'a>) {
1275 (self, SyncOnDrop::Sync(None))
1276 }
1277}
1278
1279impl<T> HostSlice<T> for [T] {
1280 fn len(&self) -> usize {
1281 self.len()
1282 }
1283 unsafe fn stream_synced_slice<'a>(
1284 &'a self,
1285 _stream: &'a CudaStream,
1286 ) -> (&'a [T], SyncOnDrop<'a>) {
1287 (self, SyncOnDrop::Sync(None))
1288 }
1289 unsafe fn stream_synced_mut_slice<'a>(
1290 &'a mut self,
1291 _stream: &'a CudaStream,
1292 ) -> (&'a mut [T], SyncOnDrop<'a>) {
1293 (self, SyncOnDrop::Sync(None))
1294 }
1295}
1296
1297impl<T> HostSlice<T> for Vec<T> {
1298 fn len(&self) -> usize {
1299 self.len()
1300 }
1301 unsafe fn stream_synced_slice<'a>(
1302 &'a self,
1303 _stream: &'a CudaStream,
1304 ) -> (&'a [T], SyncOnDrop<'a>) {
1305 (self, SyncOnDrop::Sync(None))
1306 }
1307 unsafe fn stream_synced_mut_slice<'a>(
1308 &'a mut self,
1309 _stream: &'a CudaStream,
1310 ) -> (&'a mut [T], SyncOnDrop<'a>) {
1311 (self, SyncOnDrop::Sync(None))
1312 }
1313}
1314
1315#[derive(Debug)]
1322pub struct PinnedHostSlice<T> {
1323 pub(crate) ptr: *mut T,
1324 pub(crate) len: usize,
1325 pub(crate) event: CudaEvent,
1326}
1327
1328unsafe impl<T> Send for PinnedHostSlice<T> {}
1329unsafe impl<T> Sync for PinnedHostSlice<T> {}
1330
1331impl<T> Drop for PinnedHostSlice<T> {
1332 fn drop(&mut self) {
1333 let ctx = &self.event.ctx;
1334 ctx.record_err(self.event.synchronize());
1335 ctx.record_err(unsafe { result::free_host(self.ptr as _) });
1336 }
1337}
1338
1339impl CudaContext {
1340 pub unsafe fn alloc_pinned<T: DeviceRepr>(
1347 self: &Arc<Self>,
1348 len: usize,
1349 ) -> Result<PinnedHostSlice<T>, DriverError> {
1350 self.bind_to_thread()?;
1351 let ptr = result::malloc_host(
1352 len * std::mem::size_of::<T>(),
1353 sys::CU_MEMHOSTALLOC_WRITECOMBINED,
1354 )?;
1355 let ptr = ptr as *mut T;
1356 assert!(!ptr.is_null());
1357 assert!(len * std::mem::size_of::<T>() < isize::MAX as usize);
1358 assert!(ptr.is_aligned());
1359 let event = self.new_event(Some(sys::CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
1360 Ok(PinnedHostSlice { ptr, len, event })
1361 }
1362}
1363
1364impl<T> PinnedHostSlice<T> {
1365 pub fn context(&self) -> &Arc<CudaContext> {
1367 &self.event.ctx
1368 }
1369
1370 pub fn len(&self) -> usize {
1372 self.len
1373 }
1374
1375 pub fn num_bytes(&self) -> usize {
1377 self.len * std::mem::size_of::<T>()
1378 }
1379
1380 pub fn is_empty(&self) -> bool {
1381 self.len() == 0
1382 }
1383}
1384
1385impl<T: ValidAsZeroBits> PinnedHostSlice<T> {
1386 pub fn as_ptr(&self) -> Result<*const T, DriverError> {
1389 self.event.synchronize()?;
1390 Ok(self.ptr)
1391 }
1392
1393 pub fn as_mut_ptr(&mut self) -> Result<*mut T, DriverError> {
1396 self.event.synchronize()?;
1397 Ok(self.ptr)
1398 }
1399
1400 pub fn as_slice(&self) -> Result<&[T], DriverError> {
1403 self.event.synchronize()?;
1404 Ok(unsafe { std::slice::from_raw_parts(self.ptr, self.len) })
1405 }
1406
1407 pub fn as_mut_slice(&mut self) -> Result<&mut [T], DriverError> {
1410 self.event.synchronize()?;
1411 Ok(unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) })
1412 }
1413}
1414
1415impl<T> HostSlice<T> for PinnedHostSlice<T> {
1416 fn len(&self) -> usize {
1417 self.len
1418 }
1419
1420 unsafe fn stream_synced_slice<'a>(
1421 &'a self,
1422 stream: &'a CudaStream,
1423 ) -> (&'a [T], SyncOnDrop<'a>) {
1424 stream.ctx.record_err(stream.wait(&self.event));
1425 (
1426 std::slice::from_raw_parts(self.ptr, self.len),
1427 SyncOnDrop::Record(Some((&self.event, stream))),
1428 )
1429 }
1430 unsafe fn stream_synced_mut_slice<'a>(
1431 &'a mut self,
1432 stream: &'a CudaStream,
1433 ) -> (&'a mut [T], SyncOnDrop<'a>) {
1434 stream.ctx.record_err(stream.wait(&self.event));
1435 (
1436 std::slice::from_raw_parts_mut(self.ptr, self.len),
1437 SyncOnDrop::Record(Some((&self.event, stream))),
1438 )
1439 }
1440}
1441
1442impl CudaStream {
1443 pub fn null<T>(self: &Arc<Self>) -> Result<CudaSlice<T>, result::DriverError> {
1445 self.ctx.bind_to_thread()?;
1446 let cu_device_ptr = if self.ctx.has_async_alloc {
1447 unsafe { result::malloc_async(self.cu_stream, 0) }?
1448 } else {
1449 unsafe { result::malloc_sync(0) }?
1450 };
1451 Ok(CudaSlice {
1452 cu_device_ptr,
1453 len: 0,
1454 read: None,
1455 write: None,
1456 stream: self.clone(),
1457 marker: PhantomData,
1458 })
1459 }
1460
1461 pub unsafe fn alloc<T: DeviceRepr>(
1465 self: &Arc<Self>,
1466 len: usize,
1467 ) -> Result<CudaSlice<T>, DriverError> {
1468 self.ctx.bind_to_thread()?;
1469 let cu_device_ptr = if self.ctx.has_async_alloc {
1470 result::malloc_async(self.cu_stream, len * std::mem::size_of::<T>())?
1471 } else {
1472 result::malloc_sync(len * std::mem::size_of::<T>())?
1473 };
1474 let (read, write) = if self.ctx.is_event_tracking() {
1475 (
1476 Some(self.ctx.new_event(None)?),
1477 Some(self.ctx.new_event(None)?),
1478 )
1479 } else {
1480 (None, None)
1481 };
1482 Ok(CudaSlice {
1483 cu_device_ptr,
1484 len,
1485 read,
1486 write,
1487 stream: self.clone(),
1488 marker: PhantomData,
1489 })
1490 }
1491
1492 pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(
1494 self: &Arc<Self>,
1495 len: usize,
1496 ) -> Result<CudaSlice<T>, DriverError> {
1497 let mut dst = unsafe { self.alloc(len) }?;
1498 self.memset_zeros(&mut dst)?;
1499 Ok(dst)
1500 }
1501
1502 pub fn memset_zeros<T: DeviceRepr + ValidAsZeroBits, Dst: DevicePtrMut<T>>(
1504 self: &Arc<Self>,
1505 dst: &mut Dst,
1506 ) -> Result<(), DriverError> {
1507 self.ctx.bind_to_thread()?;
1508 let num_bytes = dst.num_bytes();
1509 let (dptr, _record) = dst.device_ptr_mut(self);
1510 unsafe { result::memset_d8_async(dptr, 0, num_bytes, self.cu_stream) }?;
1511 Ok(())
1512 }
1513
1514 #[deprecated = "Use clone_htod"]
1516 pub fn memcpy_stod<T: DeviceRepr, Src: HostSlice<T> + ?Sized>(
1517 self: &Arc<Self>,
1518 src: &Src,
1519 ) -> Result<CudaSlice<T>, DriverError> {
1520 let mut dst = unsafe { self.alloc(src.len()) }?;
1521 self.memcpy_htod(src, &mut dst)?;
1522 Ok(dst)
1523 }
1524
1525 pub fn clone_htod<T: DeviceRepr, Src: HostSlice<T> + ?Sized>(
1527 self: &Arc<Self>,
1528 src: &Src,
1529 ) -> Result<CudaSlice<T>, DriverError> {
1530 let mut dst = unsafe { self.alloc(src.len()) }?;
1531 self.memcpy_htod(src, &mut dst)?;
1532 Ok(dst)
1533 }
1534
1535 pub fn memcpy_htod<T: DeviceRepr, Src: HostSlice<T> + ?Sized, Dst: DevicePtrMut<T>>(
1537 self: &Arc<Self>,
1538 src: &Src,
1539 dst: &mut Dst,
1540 ) -> Result<(), DriverError> {
1541 assert!(dst.len() >= src.len());
1542 self.ctx.bind_to_thread()?;
1543 let (src, _record_src) = unsafe { src.stream_synced_slice(self) };
1544 let (dst, _record_dst) = dst.device_ptr_mut(self);
1545 unsafe { result::memcpy_htod_async(dst, src, self.cu_stream) }
1546 }
1547
1548 #[deprecated = "Use clone_dtoh"]
1550 pub fn memcpy_dtov<T: DeviceRepr, Src: DevicePtr<T>>(
1551 self: &Arc<Self>,
1552 src: &Src,
1553 ) -> Result<Vec<T>, DriverError> {
1554 let mut dst = Vec::with_capacity(src.len());
1555 #[allow(clippy::uninit_vec)]
1556 unsafe {
1557 dst.set_len(src.len())
1558 };
1559 self.memcpy_dtoh(src, &mut dst)?;
1560 Ok(dst)
1561 }
1562
1563 pub fn clone_dtoh<T: DeviceRepr, Src: DevicePtr<T>>(
1565 self: &Arc<Self>,
1566 src: &Src,
1567 ) -> Result<Vec<T>, DriverError> {
1568 let mut dst = Vec::with_capacity(src.len());
1569 #[allow(clippy::uninit_vec)]
1570 unsafe {
1571 dst.set_len(src.len())
1572 };
1573 self.memcpy_dtoh(src, &mut dst)?;
1574 Ok(dst)
1575 }
1576
1577 pub fn memcpy_dtoh<T: DeviceRepr, Src: DevicePtr<T>, Dst: HostSlice<T> + ?Sized>(
1579 self: &Arc<Self>,
1580 src: &Src,
1581 dst: &mut Dst,
1582 ) -> Result<(), DriverError> {
1583 assert!(dst.len() >= src.len());
1584 self.ctx.bind_to_thread()?;
1585 let (src, _record_src) = src.device_ptr(self);
1586 let (dst, _record_dst) = unsafe { dst.stream_synced_mut_slice(self) };
1587 unsafe { result::memcpy_dtoh_async(dst, src, self.cu_stream) }
1588 }
1589
1590 pub fn memcpy_dtod<T, Src: DevicePtr<T>, Dst: DevicePtrMut<T>>(
1592 self: &Arc<Self>,
1593 src: &Src,
1594 dst: &mut Dst,
1595 ) -> Result<(), DriverError> {
1596 assert!(dst.len() >= src.len());
1597 self.ctx.bind_to_thread()?;
1598
1599 let num_bytes = src.num_bytes();
1600
1601 let src_ctx = src.stream().context();
1602 let dst_ctx = self.context();
1603
1604 let (src, _record_src) = src.device_ptr(self);
1605 let (dst, _record_dst) = dst.device_ptr_mut(self);
1606
1607 if src_ctx == dst_ctx {
1608 unsafe { result::memcpy_dtod_async(dst, src, num_bytes, self.cu_stream) }
1609 } else {
1610 unsafe {
1611 result::memcpy_peer_async(
1612 dst_ctx.cu_ctx,
1613 dst,
1614 src_ctx.cu_ctx,
1615 src,
1616 num_bytes,
1617 self.cu_stream,
1618 )
1619 }
1620 }
1621 }
1622
1623 pub fn clone_dtod<T: DeviceRepr, Src: DevicePtr<T>>(
1625 self: &Arc<Self>,
1626 src: &Src,
1627 ) -> Result<CudaSlice<T>, DriverError> {
1628 let mut dst = unsafe { self.alloc(src.len()) }?;
1629 self.memcpy_dtod(src, &mut dst)?;
1630 Ok(dst)
1631 }
1632}
1633
1634impl<T> CudaSlice<T> {
1635 pub fn slice(&self, bounds: impl RangeBounds<usize>) -> CudaView<'_, T> {
1665 self.as_view().slice(bounds)
1666 }
1667
1668 pub fn try_slice(&self, bounds: impl RangeBounds<usize>) -> Option<CudaView<'_, T>> {
1670 self.as_view().try_slice(bounds)
1671 }
1672
1673 pub fn slice_mut(&mut self, bounds: impl RangeBounds<usize>) -> CudaViewMut<'_, T> {
1717 self.try_slice_mut(bounds).unwrap()
1718 }
1719
1720 pub fn try_slice_mut(&mut self, bounds: impl RangeBounds<usize>) -> Option<CudaViewMut<'_, T>> {
1722 to_range(bounds, self.len).map(|(start, end)| CudaViewMut {
1723 ptr: self.cu_device_ptr + (start * std::mem::size_of::<T>()) as u64,
1724 len: end - start,
1725 read: &self.read,
1726 write: &self.write,
1727 stream: &self.stream,
1728 marker: PhantomData,
1729 })
1730 }
1731
1732 pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'_, S>> {
1740 self.as_view().transmute(len)
1741 }
1742
1743 pub unsafe fn transmute_mut<S>(&mut self, len: usize) -> Option<CudaViewMut<'_, S>> {
1751 (len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
1752 CudaViewMut {
1753 ptr: self.cu_device_ptr,
1754 len,
1755 read: &self.read,
1756 write: &self.write,
1757 stream: &self.stream,
1758 marker: PhantomData,
1759 },
1760 )
1761 }
1762
1763 pub fn split_at(&self, mid: usize) -> (CudaView<'_, T>, CudaView<'_, T>) {
1764 self.as_view().split_at(mid)
1765 }
1766
1767 pub fn try_split_at(&self, mid: usize) -> Option<(CudaView<'_, T>, CudaView<'_, T>)> {
1769 self.as_view().try_split_at(mid)
1770 }
1771
1772 pub fn split_at_mut(&mut self, mid: usize) -> (CudaViewMut<'_, T>, CudaViewMut<'_, T>) {
1788 self.try_split_at_mut(mid).unwrap()
1789 }
1790
1791 pub fn try_split_at_mut(
1795 &mut self,
1796 mid: usize,
1797 ) -> Option<(CudaViewMut<'_, T>, CudaViewMut<'_, T>)> {
1798 let length = self.len;
1799 (mid <= length).then(|| {
1800 let a = CudaViewMut {
1801 ptr: self.cu_device_ptr,
1802 len: mid,
1803 read: &self.read,
1804 write: &self.write,
1805 stream: &self.stream,
1806 marker: PhantomData,
1807 };
1808 let b = CudaViewMut {
1809 ptr: self.cu_device_ptr + (mid * std::mem::size_of::<T>()) as u64,
1810 len: length - mid,
1811 read: &self.read,
1812 write: &self.write,
1813 stream: &self.stream,
1814 marker: PhantomData,
1815 };
1816 (a, b)
1817 })
1818 }
1819}
1820
1821impl<'a, T> CudaView<'a, T> {
1822 pub fn slice(&self, bounds: impl RangeBounds<usize>) -> Self {
1839 self.try_slice(bounds).unwrap()
1840 }
1841
1842 pub fn try_slice(&self, bounds: impl RangeBounds<usize>) -> Option<Self> {
1844 to_range(bounds, self.len).map(|(start, end)| self.resize(start, end))
1845 }
1846
1847 pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'a, S>> {
1855 (len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
1856 CudaView {
1857 ptr: self.ptr,
1858 len,
1859 read: self.read,
1860 write: self.write,
1861 stream: self.stream,
1862 marker: PhantomData,
1863 },
1864 )
1865 }
1866
1867 pub fn split_at(&self, mid: usize) -> (Self, Self) {
1868 self.try_split_at(mid).unwrap()
1869 }
1870
1871 pub fn try_split_at(&self, mid: usize) -> Option<(Self, Self)> {
1875 (mid <= self.len()).then(|| (self.resize(0, mid), self.resize(mid, self.len)))
1876 }
1877}
1878
1879impl<'a, T> CudaViewMut<'a, T> {
1880 pub fn slice<'b>(&'b self, bounds: impl RangeBounds<usize>) -> CudaView<'b, T> {
1912 self.try_slice(bounds).unwrap()
1913 }
1914
1915 pub fn try_slice<'b>(&'b self, bounds: impl RangeBounds<usize>) -> Option<CudaView<'b, T>> {
1917 to_range(bounds, self.len).map(move |(start, end)| self.as_view().resize(start, end))
1918 }
1919
1920 pub unsafe fn transmute<'b, S>(&'b self, len: usize) -> Option<CudaView<'b, S>> {
1928 (len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
1929 CudaView {
1930 ptr: self.ptr,
1931 len,
1932 read: self.read,
1933 write: self.write,
1934 stream: self.stream,
1935 marker: PhantomData,
1936 },
1937 )
1938 }
1939
1940 pub fn slice_mut<'b>(&'b mut self, bounds: impl RangeBounds<usize>) -> CudaViewMut<'b, T> {
1944 self.try_slice_mut(bounds).unwrap()
1945 }
1946
1947 pub fn try_slice_mut<'b>(
1949 &'b mut self,
1950 bounds: impl RangeBounds<usize>,
1951 ) -> Option<CudaViewMut<'b, T>> {
1952 to_range(bounds, self.len).map(|(start, end)| CudaViewMut {
1953 ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
1954 len: end - start,
1955 read: self.read,
1956 write: self.write,
1957 stream: self.stream,
1958 marker: PhantomData,
1959 })
1960 }
1961
1962 pub fn split_at_mut<'b>(&'b mut self, mid: usize) -> (CudaViewMut<'b, T>, CudaViewMut<'b, T>) {
1978 self.try_split_at_mut(mid).unwrap()
1979 }
1980
1981 pub fn try_split_at_mut<'b>(
1985 &'b mut self,
1986 mid: usize,
1987 ) -> Option<(CudaViewMut<'b, T>, CudaViewMut<'b, T>)> {
1988 let length = self.len;
1989 (mid <= length).then(|| {
1990 let a = CudaViewMut {
1991 ptr: self.ptr,
1992 len: mid,
1993 read: self.read,
1994 write: self.write,
1995 stream: self.stream,
1996 marker: PhantomData,
1997 };
1998 let b = CudaViewMut {
1999 ptr: self.ptr + (mid * std::mem::size_of::<T>()) as u64,
2000 len: length - mid,
2001 read: self.read,
2002 write: self.write,
2003 stream: self.stream,
2004 marker: PhantomData,
2005 };
2006 (a, b)
2007 })
2008 }
2009
2010 pub unsafe fn transmute_mut<'b, S>(&'b mut self, len: usize) -> Option<CudaViewMut<'b, S>> {
2018 (len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
2019 CudaViewMut {
2020 ptr: self.ptr,
2021 len,
2022 read: self.read,
2023 write: self.write,
2024 stream: self.stream,
2025 marker: PhantomData,
2026 },
2027 )
2028 }
2029}
2030
2031pub(super) fn to_range(range: impl RangeBounds<usize>, len: usize) -> Option<(usize, usize)> {
2032 let start = match range.start_bound() {
2033 Bound::Included(&n) => n,
2034 Bound::Excluded(&n) => n + 1,
2035 Bound::Unbounded => 0,
2036 };
2037 let end = match range.end_bound() {
2038 Bound::Included(&n) => n + 1,
2039 Bound::Excluded(&n) => n,
2040 Bound::Unbounded => len,
2041 };
2042 (start <= end && end <= len).then_some((start, end))
2043}
2044
2045#[derive(Debug)]
2049pub struct CudaModule {
2050 pub(crate) cu_module: sys::CUmodule,
2051 pub(crate) ctx: Arc<CudaContext>,
2052}
2053
2054unsafe impl Send for CudaModule {}
2055unsafe impl Sync for CudaModule {}
2056
2057impl Drop for CudaModule {
2058 fn drop(&mut self) {
2059 self.ctx.record_err(self.ctx.bind_to_thread());
2060 self.ctx
2061 .record_err(unsafe { result::module::unload(self.cu_module) });
2062 }
2063}
2064
2065impl CudaContext {
2066 #[cfg(feature = "nvrtc")]
2070 pub fn load_module(
2071 self: &Arc<Self>,
2072 ptx: crate::nvrtc::Ptx,
2073 ) -> Result<Arc<CudaModule>, result::DriverError> {
2074 self.bind_to_thread()?;
2075
2076 let cu_module = match ptx.0 {
2077 crate::nvrtc::PtxKind::Image(image) => unsafe {
2078 result::module::load_data(image.as_ptr() as *const _)
2079 },
2080 crate::nvrtc::PtxKind::Src(src) => {
2081 let c_src = CString::new(src).unwrap();
2082 unsafe { result::module::load_data(c_src.as_ptr() as *const _) }
2083 }
2084 crate::nvrtc::PtxKind::File(path) => {
2085 let name_c = CString::new(path.to_str().unwrap()).unwrap();
2086 result::module::load(name_c)
2087 }
2088 crate::nvrtc::PtxKind::Binary(data) => unsafe {
2089 result::module::load_data(data.as_ptr() as *const _)
2090 },
2091 }?;
2092 Ok(Arc::new(CudaModule {
2093 cu_module,
2094 ctx: self.clone(),
2095 }))
2096 }
2097}
2098
2099#[derive(Debug, Clone)]
2101pub struct CudaFunction {
2102 pub(crate) cu_function: sys::CUfunction,
2103 #[allow(unused)]
2104 pub(crate) module: Arc<CudaModule>,
2105}
2106
2107unsafe impl Send for CudaFunction {}
2108unsafe impl Sync for CudaFunction {}
2109
2110impl CudaModule {
2111 pub fn load_function(self: &Arc<Self>, fn_name: &str) -> Result<CudaFunction, DriverError> {
2113 let fn_name_c = CString::new(fn_name).unwrap();
2114 let cu_function = unsafe { result::module::get_function(self.cu_module, fn_name_c) }?;
2115 Ok(CudaFunction {
2116 cu_function,
2117 module: self.clone(),
2118 })
2119 }
2120
2121 pub fn get_global<'a>(
2136 self: &'a Arc<Self>,
2137 name: &str,
2138 stream: &'a Arc<CudaStream>,
2139 ) -> Result<CudaViewMut<'a, u8>, DriverError> {
2140 let name_c =
2141 CString::new(name).map_err(|_| DriverError(sys::CUresult::CUDA_ERROR_INVALID_VALUE))?;
2142 let (cu_device_ptr, bytes) = unsafe { result::module::get_global(self.cu_module, name_c) }?;
2143 Ok(CudaViewMut {
2144 ptr: cu_device_ptr,
2145 len: bytes,
2146 read: &None,
2147 write: &None,
2148 stream,
2149 marker: PhantomData,
2150 })
2151 }
2152}
2153
2154impl CudaFunction {
2155 pub fn occupancy_available_dynamic_smem_per_block(
2156 &self,
2157 num_blocks: u32,
2158 block_size: u32,
2159 ) -> Result<usize, result::DriverError> {
2160 let mut dynamic_smem_size: usize = 0;
2161
2162 unsafe {
2163 sys::cuOccupancyAvailableDynamicSMemPerBlock(
2164 &mut dynamic_smem_size,
2165 self.cu_function,
2166 num_blocks as std::ffi::c_int,
2167 block_size as std::ffi::c_int,
2168 )
2169 .result()?
2170 };
2171
2172 Ok(dynamic_smem_size)
2173 }
2174
2175 pub fn occupancy_max_active_blocks_per_multiprocessor(
2176 &self,
2177 block_size: u32,
2178 dynamic_smem_size: usize,
2179 flags: Option<sys::CUoccupancy_flags_enum>,
2180 ) -> Result<u32, result::DriverError> {
2181 let mut num_blocks: std::ffi::c_int = 0;
2182 let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
2183
2184 unsafe {
2185 sys::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
2186 &mut num_blocks,
2187 self.cu_function,
2188 block_size as std::ffi::c_int,
2189 dynamic_smem_size,
2190 flags as std::ffi::c_uint,
2191 )
2192 .result()?
2193 };
2194
2195 Ok(num_blocks as u32)
2196 }
2197
2198 #[cfg(not(any(
2199 feature = "cuda-11070",
2200 feature = "cuda-11060",
2201 feature = "cuda-11050",
2202 feature = "cuda-11040"
2203 )))]
2204 pub fn occupancy_max_active_clusters(
2205 &self,
2206 config: crate::driver::LaunchConfig,
2207 stream: &CudaStream,
2208 ) -> Result<u32, result::DriverError> {
2209 let mut num_clusters: std::ffi::c_int = 0;
2210
2211 let cfg = sys::CUlaunchConfig {
2212 gridDimX: config.grid_dim.0,
2213 gridDimY: config.grid_dim.1,
2214 gridDimZ: config.grid_dim.2,
2215 blockDimX: config.block_dim.0,
2216 blockDimY: config.block_dim.1,
2217 blockDimZ: config.block_dim.2,
2218 sharedMemBytes: config.shared_mem_bytes,
2219 hStream: stream.cu_stream,
2220 attrs: std::ptr::null_mut(),
2221 numAttrs: 0,
2222 };
2223
2224 unsafe {
2225 sys::cuOccupancyMaxActiveClusters(&mut num_clusters, self.cu_function, &cfg).result()?
2226 };
2227
2228 Ok(num_clusters as u32)
2229 }
2230
2231 pub fn occupancy_max_potential_block_size(
2232 &self,
2233 block_size_to_dynamic_smem_size: extern "C" fn(block_size: std::ffi::c_int) -> usize,
2234 dynamic_smem_size: usize,
2235 block_size_limit: u32,
2236 flags: Option<sys::CUoccupancy_flags_enum>,
2237 ) -> Result<(u32, u32), result::DriverError> {
2238 let mut min_grid_size: std::ffi::c_int = 0;
2239 let mut block_size: std::ffi::c_int = 0;
2240 let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
2241
2242 unsafe {
2243 sys::cuOccupancyMaxPotentialBlockSizeWithFlags(
2244 &mut min_grid_size,
2245 &mut block_size,
2246 self.cu_function,
2247 Some(block_size_to_dynamic_smem_size),
2248 dynamic_smem_size,
2249 block_size_limit as std::ffi::c_int,
2250 flags as std::ffi::c_uint,
2251 )
2252 .result()?
2253 };
2254
2255 Ok((min_grid_size as u32, block_size as u32))
2256 }
2257
2258 #[cfg(not(any(
2259 feature = "cuda-11070",
2260 feature = "cuda-11060",
2261 feature = "cuda-11050",
2262 feature = "cuda-11040"
2263 )))]
2264 pub fn occupancy_max_potential_cluster_size(
2265 &self,
2266 config: crate::driver::LaunchConfig,
2267 stream: &CudaStream,
2268 ) -> Result<u32, result::DriverError> {
2269 let mut cluster_size: std::ffi::c_int = 0;
2270
2271 let cfg = sys::CUlaunchConfig {
2272 gridDimX: config.grid_dim.0,
2273 gridDimY: config.grid_dim.1,
2274 gridDimZ: config.grid_dim.2,
2275 blockDimX: config.block_dim.0,
2276 blockDimY: config.block_dim.1,
2277 blockDimZ: config.block_dim.2,
2278 sharedMemBytes: config.shared_mem_bytes,
2279 hStream: stream.cu_stream,
2280 attrs: std::ptr::null_mut(),
2281 numAttrs: 0,
2282 };
2283
2284 unsafe {
2285 sys::cuOccupancyMaxPotentialClusterSize(&mut cluster_size, self.cu_function, &cfg)
2286 .result()?
2287 };
2288
2289 Ok(cluster_size as u32)
2290 }
2291
2292 pub fn get_attribute(
2296 &self,
2297 attribute: CUfunction_attribute_enum,
2298 ) -> Result<i32, result::DriverError> {
2299 self.module.ctx.bind_to_thread()?;
2300 unsafe { result::function::get_function_attribute(self.cu_function, attribute) }
2301 }
2302
2303 pub fn num_regs(&self) -> Result<i32, result::DriverError> {
2305 self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_NUM_REGS)
2306 }
2307
2308 pub fn shared_size_bytes(&self) -> Result<i32, result::DriverError> {
2310 self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
2311 }
2312
2313 pub fn const_size_bytes(&self) -> Result<i32, result::DriverError> {
2315 self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES)
2316 }
2317
2318 pub fn local_size_bytes(&self) -> Result<i32, result::DriverError> {
2320 self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
2321 }
2322
2323 pub fn max_threads_per_block(&self) -> Result<i32, result::DriverError> {
2325 self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
2326 }
2327
2328 pub fn ptx_version(&self) -> Result<i32, result::DriverError> {
2330 self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_PTX_VERSION)
2331 }
2332
2333 pub fn binary_version(&self) -> Result<i32, result::DriverError> {
2335 self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_BINARY_VERSION)
2336 }
2337
2338 pub fn set_attribute(
2340 &self,
2341 attribute: CUfunction_attribute_enum,
2342 value: i32,
2343 ) -> Result<(), result::DriverError> {
2344 unsafe { result::function::set_function_attribute(self.cu_function, attribute, value) }
2345 }
2346
2347 pub fn set_function_cache_config(
2349 &self,
2350 attribute: CUfunc_cache_enum,
2351 ) -> Result<(), result::DriverError> {
2352 unsafe { result::function::set_function_cache_config(self.cu_function, attribute) }
2353 }
2354}
2355
2356impl<T> CudaSlice<T> {
2357 pub fn leak(self) -> sys::CUdeviceptr {
2362 let mut s = std::mem::ManuallyDrop::new(self);
2363 let ptr = s.cu_device_ptr;
2364
2365 if let Some(read) = s.read.as_ref() {
2367 s.stream.ctx.record_err(s.stream.wait(read));
2368 }
2369 if let Some(write) = s.write.as_ref() {
2370 s.stream.ctx.record_err(s.stream.wait(write));
2371 }
2372
2373 unsafe {
2375 std::ptr::drop_in_place(&mut s.read);
2376 std::ptr::drop_in_place(&mut s.write);
2377 std::ptr::drop_in_place(&mut s.stream);
2378 }
2379
2380 ptr
2381 }
2382}
2383
2384impl CudaStream {
2385 pub unsafe fn upgrade_device_ptr<T>(
2394 self: &Arc<Self>,
2395 cu_device_ptr: sys::CUdeviceptr,
2396 len: usize,
2397 ) -> CudaSlice<T> {
2398 let (read, write) = if self.ctx.is_event_tracking() {
2399 (
2400 Some(self.ctx.new_event(None).unwrap()),
2401 Some(self.ctx.new_event(None).unwrap()),
2402 )
2403 } else {
2404 (None, None)
2405 };
2406 CudaSlice {
2407 cu_device_ptr,
2408 len,
2409 read,
2410 write,
2411 stream: self.clone(),
2412 marker: PhantomData,
2413 }
2414 }
2415}
2416
2417#[cfg(test)]
2418mod tests {
2419 use std::time::Instant;
2420
2421 use super::*;
2422
2423 #[test]
2424 fn test_transmutes() {
2425 let ctx = CudaContext::new(0).unwrap();
2426 let stream = ctx.default_stream();
2427 let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
2428 assert!(unsafe { slice.transmute::<f32>(25) }.is_some());
2429 assert!(unsafe { slice.transmute::<f32>(26) }.is_none());
2430 assert!(unsafe { slice.transmute_mut::<f32>(25) }.is_some());
2431 assert!(unsafe { slice.transmute_mut::<f32>(26) }.is_none());
2432
2433 {
2434 let view = slice.slice(0..100);
2435 assert!(unsafe { view.transmute::<f32>(25) }.is_some());
2436 assert!(unsafe { view.transmute::<f32>(26) }.is_none());
2437 }
2438
2439 {
2440 let mut view_mut = slice.slice_mut(0..100);
2441 assert!(unsafe { view_mut.transmute::<f32>(25) }.is_some());
2442 assert!(unsafe { view_mut.transmute::<f32>(26) }.is_none());
2443 assert!(unsafe { view_mut.transmute_mut::<f32>(25) }.is_some());
2444 assert!(unsafe { view_mut.transmute_mut::<f32>(26) }.is_none());
2445 }
2446 }
2447
2448 #[test]
2449 fn test_threading() {
2450 let ctx1 = CudaContext::new(0).unwrap();
2451 let ctx2 = ctx1.clone();
2452
2453 let thread1 = std::thread::spawn(move || {
2454 ctx1.bind_to_thread()?;
2455 ctx1.default_stream().alloc_zeros::<f32>(10)
2456 });
2457 let thread2 = std::thread::spawn(move || {
2458 ctx2.bind_to_thread()?;
2459 ctx2.default_stream().alloc_zeros::<f32>(10)
2460 });
2461
2462 let _: crate::driver::CudaSlice<f32> = thread1.join().unwrap().unwrap();
2463 let _: crate::driver::CudaSlice<f32> = thread2.join().unwrap().unwrap();
2464 }
2465
2466 #[test]
2467 fn test_post_build_arc_count() {
2468 let ctx = CudaContext::new(0).unwrap();
2469 assert_eq!(Arc::strong_count(&ctx), 1);
2470 }
2471
2472 #[test]
2473 fn test_post_alloc_arc_counts() {
2474 let ctx = CudaContext::new(0).unwrap();
2475 assert_eq!(Arc::strong_count(&ctx), 1);
2476 let stream = ctx.default_stream();
2477 assert_eq!(Arc::strong_count(&ctx), 2);
2478 let t = stream.alloc_zeros::<f32>(1).unwrap();
2479 assert_eq!(Arc::strong_count(&ctx), 4);
2480 assert_eq!(Arc::strong_count(&stream), 2);
2481 drop(t);
2482 assert_eq!(Arc::strong_count(&ctx), 2);
2483 assert_eq!(Arc::strong_count(&stream), 1);
2484 drop(stream);
2485 assert_eq!(Arc::strong_count(&ctx), 1);
2486 }
2487
2488 #[test]
2489 #[ignore = "must be executed by itself"]
2490 fn test_post_alloc_memory() {
2491 let ctx = CudaContext::new(0).unwrap();
2492 let stream = ctx.default_stream();
2493
2494 let (free1, total1) = ctx.mem_get_info().unwrap();
2495
2496 let t = stream.clone_htod(&[0.0f32; 5]).unwrap();
2497 let (free2, total2) = ctx.mem_get_info().unwrap();
2498 assert_eq!(total1, total2);
2499 assert!(free2 < free1);
2500
2501 drop(t);
2502 ctx.synchronize().unwrap();
2503
2504 let (free3, total3) = ctx.mem_get_info().unwrap();
2505 assert_eq!(total2, total3);
2506 assert!(free3 > free2);
2507 assert_eq!(free3, free1);
2508 }
2509
2510 #[test]
2511 fn test_ctx_copy_to_views() {
2512 let ctx = CudaContext::new(0).unwrap();
2513 let stream = ctx.default_stream();
2514
2515 let smalls = [
2516 stream.clone_htod(&[-1.0f32, -0.8]).unwrap(),
2517 stream.clone_htod(&[-0.6, -0.4]).unwrap(),
2518 stream.clone_htod(&[-0.2, 0.0]).unwrap(),
2519 stream.clone_htod(&[0.2, 0.4]).unwrap(),
2520 stream.clone_htod(&[0.6, 0.8]).unwrap(),
2521 ];
2522 let mut big = stream.alloc_zeros::<f32>(10).unwrap();
2523
2524 let mut offset = 0;
2525 for small in smalls.iter() {
2526 let mut sub = big.slice_mut(offset..offset + small.len());
2527 stream.memcpy_dtod(small, &mut sub).unwrap();
2528 offset += small.len();
2529 }
2530
2531 assert_eq!(
2532 stream.clone_dtoh(&big).unwrap(),
2533 [-1.0, -0.8, -0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6, 0.8]
2534 );
2535 }
2536
2537 #[test]
2538 fn test_leak_and_upgrade() {
2539 let ctx = CudaContext::new(0).unwrap();
2540 let stream = ctx.default_stream();
2541
2542 let a = stream.clone_htod(&[1.0f32, 2.0, 3.0, 4.0, 5.0]).unwrap();
2543
2544 let ptr = a.leak();
2545 let b = unsafe { stream.upgrade_device_ptr::<f32>(ptr, 3) };
2546 assert_eq!(stream.clone_dtoh(&b).unwrap(), &[1.0, 2.0, 3.0]);
2547
2548 let ptr = b.leak();
2549 let c = unsafe { stream.upgrade_device_ptr::<f32>(ptr, 5) };
2550 assert_eq!(stream.clone_dtoh(&c).unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
2551 }
2552
2553 #[test]
2555 fn test_slice_is_freed_with_correct_context() {
2556 let ctx0 = CudaContext::new(0).unwrap();
2557 let slice = ctx0.default_stream().clone_htod(&[1.0; 10]).unwrap();
2558 let ctx1 = CudaContext::new(0).unwrap();
2559 ctx1.bind_to_thread().unwrap();
2560 drop(ctx0);
2561 drop(slice);
2562 drop(ctx1);
2563 }
2564
2565 #[test]
2567 fn test_copy_uses_correct_context() {
2568 let ctx0 = CudaContext::new(0).unwrap();
2569 let _ctx1 = CudaContext::new(0).unwrap();
2570 let slice = ctx0.default_stream().clone_htod(&[1.0; 10]).unwrap();
2571 let _out = ctx0.default_stream().clone_dtoh(&slice).unwrap();
2572 }
2573
2574 #[test]
2575 fn test_htod_copy_pinned() {
2576 let truth = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2577 let ctx = CudaContext::new(0).unwrap();
2578 let stream = ctx.default_stream();
2579 let mut pinned = unsafe { ctx.alloc_pinned::<f32>(10) }.unwrap();
2580 pinned.as_mut_slice().unwrap().clone_from_slice(&truth);
2581 assert_eq!(pinned.as_slice().unwrap(), &truth);
2582 let dst = stream.clone_htod(&pinned).unwrap();
2583 let host = stream.clone_dtoh(&dst).unwrap();
2584 assert_eq!(&host, &truth);
2585 }
2586
2587 #[test]
2588 fn test_pinned_copy_is_faster() {
2589 let ctx = CudaContext::new(0).unwrap();
2590 let stream = ctx.new_stream().unwrap();
2591
2592 let n = 100_000;
2593 let n_samples = 5;
2594 let not_pinned = std::vec![0.0f32; n];
2595
2596 let start = Instant::now();
2597 for _ in 0..n_samples {
2598 let _ = stream.clone_htod(¬_pinned).unwrap();
2599 stream.synchronize().unwrap();
2600 }
2601 let unpinned_elapsed = start.elapsed() / n_samples;
2602
2603 let pinned = unsafe { ctx.alloc_pinned::<f32>(n) }.unwrap();
2604
2605 let start = Instant::now();
2606 for _ in 0..n_samples {
2607 let _ = stream.clone_htod(&pinned).unwrap();
2608 stream.synchronize().unwrap();
2609 }
2610 let pinned_elapsed = start.elapsed() / n_samples;
2611
2612 assert!(
2615 pinned_elapsed.as_secs_f32() * 1.5 < unpinned_elapsed.as_secs_f32(),
2616 "{unpinned_elapsed:?} vs {pinned_elapsed:?}"
2617 );
2618 }
2619
2620 #[test]
2621 fn test_primary_context_is_primary() {
2622 let ctx = CudaContext::new(0).unwrap();
2623 assert!(ctx.is_primary());
2624 }
2625
2626 #[cfg(any(
2629 feature = "cuda-11040",
2630 feature = "cuda-11050",
2631 feature = "cuda-11060",
2632 feature = "cuda-11070",
2633 feature = "cuda-11080",
2634 feature = "cuda-12000",
2635 feature = "cuda-12010",
2636 feature = "cuda-12020",
2637 feature = "cuda-12030",
2638 feature = "cuda-12040",
2639 feature = "cuda-12050",
2640 feature = "cuda-12060",
2641 feature = "cuda-12080",
2642 feature = "cuda-12090",
2643 ))]
2644 fn create_non_primary_context() -> (sys::CUdevice, sys::CUcontext) {
2645 result::init().unwrap();
2646 let cu_device = result::device::get(0).unwrap();
2647
2648 #[cfg(any(
2649 feature = "cuda-12050",
2650 feature = "cuda-12060",
2651 feature = "cuda-12080",
2652 feature = "cuda-12090",
2653 feature = "cuda-13000",
2654 feature = "cuda-13010",
2655 ))]
2656 let cu_ctx = unsafe { result::ctx::create_v4(std::ptr::null_mut(), 0, cu_device) }
2657 .expect("cuCtxCreate_v4 failed");
2658
2659 #[cfg(not(any(
2660 feature = "cuda-12050",
2661 feature = "cuda-12060",
2662 feature = "cuda-12080",
2663 feature = "cuda-12090",
2664 feature = "cuda-13000",
2665 feature = "cuda-13010",
2666 )))]
2667 let cu_ctx =
2668 unsafe { result::ctx::create_v3(0, cu_device) }.expect("cuCtxCreate_v3 failed");
2669
2670 assert!(!cu_ctx.is_null());
2671 (cu_device, cu_ctx)
2672 }
2673
2674 #[test]
2675 #[cfg(any(
2676 feature = "cuda-11040",
2677 feature = "cuda-11050",
2678 feature = "cuda-11060",
2679 feature = "cuda-11070",
2680 feature = "cuda-11080",
2681 feature = "cuda-12000",
2682 feature = "cuda-12010",
2683 feature = "cuda-12020",
2684 feature = "cuda-12030",
2685 feature = "cuda-12040",
2686 feature = "cuda-12050",
2687 feature = "cuda-12060",
2688 feature = "cuda-12080",
2689 feature = "cuda-12090",
2690 ))]
2691 fn test_from_raw_context_creates_and_destroys() {
2692 let (cu_device, cu_ctx) = create_non_primary_context();
2693
2694 let ctx = unsafe { CudaContext::from_raw_context(0, cu_device, cu_ctx) }.unwrap();
2695 assert!(!ctx.is_primary());
2696 ctx.bind_to_thread().unwrap();
2698 drop(ctx);
2700 }
2701
2702 #[test]
2703 #[cfg(any(
2704 feature = "cuda-11040",
2705 feature = "cuda-11050",
2706 feature = "cuda-11060",
2707 feature = "cuda-11070",
2708 feature = "cuda-11080",
2709 feature = "cuda-12000",
2710 feature = "cuda-12010",
2711 feature = "cuda-12020",
2712 feature = "cuda-12030",
2713 feature = "cuda-12040",
2714 feature = "cuda-12050",
2715 feature = "cuda-12060",
2716 feature = "cuda-12080",
2717 feature = "cuda-12090",
2718 ))]
2719 fn test_from_raw_context_bind_to_thread() {
2720 let (cu_device, cu_ctx) = create_non_primary_context();
2721
2722 let ctx = unsafe { CudaContext::from_raw_context(0, cu_device, cu_ctx) }.unwrap();
2723
2724 let ctx2 = ctx.clone();
2726 let handle = std::thread::spawn(move || {
2727 ctx2.bind_to_thread().unwrap();
2728 let stream = ctx2.default_stream();
2729 let data = stream.clone_htod(&[1.0f32, 2.0, 3.0]).unwrap();
2730 let result = stream.clone_dtoh(&data).unwrap();
2731 assert_eq!(result, std::vec![1.0f32, 2.0, 3.0]);
2732 });
2733 handle.join().unwrap();
2734 }
2735
2736 #[test]
2737 #[cfg(any(
2738 feature = "cuda-11040",
2739 feature = "cuda-11050",
2740 feature = "cuda-11060",
2741 feature = "cuda-11070",
2742 feature = "cuda-11080",
2743 feature = "cuda-12000",
2744 feature = "cuda-12010",
2745 feature = "cuda-12020",
2746 feature = "cuda-12030",
2747 feature = "cuda-12040",
2748 feature = "cuda-12050",
2749 feature = "cuda-12060",
2750 feature = "cuda-12080",
2751 feature = "cuda-12090",
2752 feature = "cuda-13000",
2753 feature = "cuda-13010",
2754 ))]
2755 fn test_new_non_primary_creates_and_destroys() {
2756 let ctx = CudaContext::new_non_primary(0, 0).unwrap();
2757 assert!(!ctx.is_primary());
2758 ctx.bind_to_thread().unwrap();
2759 drop(ctx);
2760 }
2761
2762 #[test]
2763 #[cfg(any(
2764 feature = "cuda-11040",
2765 feature = "cuda-11050",
2766 feature = "cuda-11060",
2767 feature = "cuda-11070",
2768 feature = "cuda-11080",
2769 feature = "cuda-12000",
2770 feature = "cuda-12010",
2771 feature = "cuda-12020",
2772 feature = "cuda-12030",
2773 feature = "cuda-12040",
2774 feature = "cuda-12050",
2775 feature = "cuda-12060",
2776 feature = "cuda-12080",
2777 feature = "cuda-12090",
2778 feature = "cuda-13000",
2779 feature = "cuda-13010",
2780 ))]
2781 fn test_new_non_primary_htod_dtoh() {
2782 let ctx = CudaContext::new_non_primary(0, 0).unwrap();
2783 let stream = ctx.default_stream();
2784 let data = stream.clone_htod(&[1.0f32, 2.0, 3.0]).unwrap();
2785 let result = stream.clone_dtoh(&data).unwrap();
2786 assert_eq!(result, std::vec![1.0f32, 2.0, 3.0]);
2787 }
2788
2789 #[test]
2790 #[cfg(any(
2791 feature = "cuda-11040",
2792 feature = "cuda-11050",
2793 feature = "cuda-11060",
2794 feature = "cuda-11070",
2795 feature = "cuda-11080",
2796 feature = "cuda-12000",
2797 feature = "cuda-12010",
2798 feature = "cuda-12020",
2799 feature = "cuda-12030",
2800 feature = "cuda-12040",
2801 feature = "cuda-12050",
2802 feature = "cuda-12060",
2803 feature = "cuda-12080",
2804 feature = "cuda-12090",
2805 feature = "cuda-13000",
2806 feature = "cuda-13010",
2807 ))]
2808 fn test_new_non_primary_cross_thread() {
2809 let ctx = CudaContext::new_non_primary(0, 0).unwrap();
2810 let ctx2 = ctx.clone();
2811 let handle = std::thread::spawn(move || {
2812 ctx2.bind_to_thread().unwrap();
2813 let stream = ctx2.default_stream();
2814 let data = stream.clone_htod(&[4.0f32, 5.0, 6.0]).unwrap();
2815 let result = stream.clone_dtoh(&data).unwrap();
2816 assert_eq!(result, std::vec![4.0f32, 5.0, 6.0]);
2817 });
2818 handle.join().unwrap();
2819 }
2820}