Skip to main content

cudarc/driver/safe/
core.rs

1use crate::driver::{
2    result::{self, DriverError},
3    sys::{self, CUfunc_cache_enum, CUfunction_attribute_enum},
4};
5
6use std::{
7    ffi::CString,
8    marker::PhantomData,
9    ops::{Bound, RangeBounds},
10    string::String,
11    sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering},
12    sync::Arc,
13    vec::Vec,
14};
15
16/// Represents a CUDA context on a certain device.
17///
18/// - [`CudaContext::new()`] retains the device's primary context.
19/// - [`CudaContext::new_non_primary()`] creates an independent non-primary context.
20/// - [`CudaContext::new_cig()`] creates a non-primary context with CiG (CUDA in Graphics) parameters (CUDA 12.050+).
21/// - [`CudaContext::from_raw_context()`] wraps a pre-existing raw `CUcontext`.
22///
23/// This is the entrypoint to using any cuda calls, all objects maintain a pointer to `Arc<CudaContext>`
24/// to ensure proper lifetimes.
25///
26/// # On thread safety
27///
28/// This object is thread safe and can be shared/used on multiple threads. All safe apis call
29/// [CudaContext::bind_to_thread()] before doing work in a certain context.
30#[derive(Debug)]
31pub struct CudaContext {
32    pub(crate) cu_device: sys::CUdevice,
33    pub(crate) cu_ctx: sys::CUcontext,
34    pub(crate) ordinal: usize,
35    pub(crate) has_async_alloc: bool,
36    /// Whether this wraps a primary context (true) or a non-primary context (false).
37    /// Primary contexts are released via `cuDevicePrimaryCtxRelease`, while non-primary
38    /// contexts are destroyed via `cuCtxDestroy_v2`.
39    pub(crate) is_primary: bool,
40    pub(crate) num_streams: AtomicUsize,
41    pub(crate) event_tracking: AtomicBool,
42    pub(crate) error_state: AtomicU32,
43}
44
45unsafe impl Send for CudaContext {}
46unsafe impl Sync for CudaContext {}
47
48impl Drop for CudaContext {
49    fn drop(&mut self) {
50        self.record_err(self.bind_to_thread());
51        let ctx = std::mem::replace(&mut self.cu_ctx, std::ptr::null_mut());
52        if !ctx.is_null() {
53            if self.is_primary {
54                self.record_err(unsafe { result::primary_ctx::release(self.cu_device) });
55            } else {
56                // Non-primary contexts (e.g., CiG) are destroyed directly.
57                self.record_err(unsafe { sys::cuCtxDestroy_v2(ctx).result() });
58            }
59        }
60    }
61}
62
63impl PartialEq for CudaContext {
64    fn eq(&self, other: &Self) -> bool {
65        self.cu_device == other.cu_device
66            && self.cu_ctx == other.cu_ctx
67            && self.ordinal == other.ordinal
68    }
69}
70impl Eq for CudaContext {}
71
72impl CudaContext {
73    /// Creates a new context on the specified device ordinal.
74    pub fn new(ordinal: usize) -> Result<Arc<Self>, DriverError> {
75        result::init()?;
76        let cu_device = result::device::get(ordinal as i32)?;
77        let cu_ctx = unsafe { result::primary_ctx::retain(cu_device) }?;
78        let has_async_alloc = unsafe {
79            let memory_pools_supported = result::device::get_attribute(
80                cu_device,
81                sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
82            )?;
83            memory_pools_supported > 0
84        };
85        let ctx = Arc::new(CudaContext {
86            cu_device,
87            cu_ctx,
88            ordinal,
89            has_async_alloc,
90            is_primary: true,
91            num_streams: AtomicUsize::new(0),
92            event_tracking: AtomicBool::new(true),
93            error_state: AtomicU32::new(0),
94        });
95        ctx.bind_to_thread()?;
96        Ok(ctx)
97    }
98
99    /// Creates a new non-primary CUDA context on the specified device ordinal.
100    ///
101    /// Unlike [`CudaContext::new()`] which retains the device's primary context,
102    /// this creates an independent context via `cuCtxCreate_v4` (CUDA 12.050+)
103    /// or `cuCtxCreate_v3` (CUDA 11.040–12.040). On drop, the context is
104    /// destroyed via `cuCtxDestroy_v2`.
105    ///
106    /// `flags` controls scheduling policy and other options — use 0 for defaults
107    /// (`CU_CTX_SCHED_AUTO`). See [`sys::CUctx_flags`] for available flags.
108    #[cfg(any(
109        feature = "cuda-11040",
110        feature = "cuda-11050",
111        feature = "cuda-11060",
112        feature = "cuda-11070",
113        feature = "cuda-11080",
114        feature = "cuda-12000",
115        feature = "cuda-12010",
116        feature = "cuda-12020",
117        feature = "cuda-12030",
118        feature = "cuda-12040",
119        feature = "cuda-12050",
120        feature = "cuda-12060",
121        feature = "cuda-12080",
122        feature = "cuda-12090",
123        feature = "cuda-13000",
124        feature = "cuda-13010"
125    ))]
126    pub fn new_non_primary(ordinal: usize, flags: u32) -> Result<Arc<Self>, DriverError> {
127        result::init()?;
128        let cu_device = result::device::get(ordinal as i32)?;
129
130        #[cfg(any(
131            feature = "cuda-12050",
132            feature = "cuda-12060",
133            feature = "cuda-12080",
134            feature = "cuda-12090",
135            feature = "cuda-13000",
136            feature = "cuda-13010"
137        ))]
138        let cu_ctx = unsafe { result::ctx::create_v4(std::ptr::null_mut(), flags, cu_device) }?;
139
140        #[cfg(not(any(
141            feature = "cuda-12050",
142            feature = "cuda-12060",
143            feature = "cuda-12080",
144            feature = "cuda-12090",
145            feature = "cuda-13000",
146            feature = "cuda-13010"
147        )))]
148        let cu_ctx = unsafe { result::ctx::create_v3(flags, cu_device) }?;
149
150        let has_async_alloc = unsafe {
151            let memory_pools_supported = result::device::get_attribute(
152                cu_device,
153                sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
154            )?;
155            memory_pools_supported > 0
156        };
157        let ctx = Arc::new(CudaContext {
158            cu_device,
159            cu_ctx,
160            ordinal,
161            has_async_alloc,
162            is_primary: false,
163            num_streams: AtomicUsize::new(0),
164            event_tracking: AtomicBool::new(true),
165            error_state: AtomicU32::new(0),
166        });
167        ctx.bind_to_thread()?;
168        Ok(ctx)
169    }
170
171    /// Creates a new CUDA context with CiG (CUDA in Graphics) parameters.
172    ///
173    /// This uses `cuCtxCreate_v4` to create a non-primary context that shares
174    /// resources with a graphics API (e.g., D3D12). Requires CUDA 12.050+.
175    ///
176    /// `flags` controls scheduling policy and other options — use 0 for defaults.
177    /// `cig_params` specifies the CiG shared data type and pointer.
178    ///
179    /// On drop, the context is destroyed via `cuCtxDestroy_v2`.
180    #[cfg(any(
181        feature = "cuda-12050",
182        feature = "cuda-12060",
183        feature = "cuda-12080",
184        feature = "cuda-12090",
185        feature = "cuda-13000",
186        feature = "cuda-13010"
187    ))]
188    pub fn new_cig(
189        ordinal: usize,
190        flags: u32,
191        cig_params: &mut sys::CUctxCigParam,
192    ) -> Result<Arc<Self>, DriverError> {
193        result::init()?;
194        let cu_device = result::device::get(ordinal as i32)?;
195        let mut ctx_create_params = sys::CUctxCreateParams_st {
196            execAffinityParams: std::ptr::null_mut(),
197            numExecAffinityParams: 0,
198            cigParams: cig_params as *mut sys::CUctxCigParam,
199        };
200        let cu_ctx = unsafe { result::ctx::create_v4(&mut ctx_create_params, flags, cu_device) }?;
201        let has_async_alloc = unsafe {
202            let memory_pools_supported = result::device::get_attribute(
203                cu_device,
204                sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
205            )?;
206            memory_pools_supported > 0
207        };
208        let ctx = Arc::new(CudaContext {
209            cu_device,
210            cu_ctx,
211            ordinal,
212            has_async_alloc,
213            is_primary: false,
214            num_streams: AtomicUsize::new(0),
215            event_tracking: AtomicBool::new(true),
216            error_state: AtomicU32::new(0),
217        });
218        ctx.bind_to_thread()?;
219        Ok(ctx)
220    }
221
222    /// Wrap a pre-existing raw CUcontext (e.g., a CiG context created via `cuCtxCreate_v4`).
223    ///
224    /// The context must already be valid and will be made current on the calling thread.
225    /// On drop, calls `cuCtxDestroy_v2` instead of `cuDevicePrimaryCtxRelease`.
226    ///
227    /// # Safety
228    ///
229    /// - `cu_ctx` must be a valid CUDA context that was created (not yet destroyed).
230    /// - `cu_device` must be the device the context was created for.
231    /// - The caller must not destroy or release the context after calling this function;
232    ///   ownership is transferred to the returned `Arc<CudaContext>`.
233    pub unsafe fn from_raw_context(
234        ordinal: usize,
235        cu_device: sys::CUdevice,
236        cu_ctx: sys::CUcontext,
237    ) -> Result<Arc<Self>, DriverError> {
238        let has_async_alloc = {
239            let memory_pools_supported = result::device::get_attribute(
240                cu_device,
241                sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
242            )?;
243            memory_pools_supported > 0
244        };
245        let ctx = Arc::new(CudaContext {
246            cu_device,
247            cu_ctx,
248            ordinal,
249            has_async_alloc,
250            is_primary: false,
251            num_streams: AtomicUsize::new(0),
252            event_tracking: AtomicBool::new(true),
253            error_state: AtomicU32::new(0),
254        });
255        ctx.bind_to_thread()?;
256        Ok(ctx)
257    }
258
259    /// Returns whether this context wraps a primary context.
260    ///
261    /// Primary contexts are created via `cuDevicePrimaryCtxRetain` and released on drop.
262    /// Non-primary contexts (e.g., CiG) are destroyed via `cuCtxDestroy_v2` on drop.
263    pub fn is_primary(&self) -> bool {
264        self.is_primary
265    }
266
267    /// Returns whether this context supports asynchronous memory allocation.
268    ///
269    /// By default, the value of this parameter is filled by querying the
270    /// `CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED` attribute and checking if the number of pools
271    /// is greater than 0.
272    /// Memory allocations performed through the default [CudaStream] will use `cuMemAllocAsync`
273    /// over `cuMemAlloc` if this method returns `true`.
274    pub fn has_async_alloc(&self) -> bool {
275        self.has_async_alloc
276    }
277
278    /// The number of devices available.
279    pub fn device_count() -> Result<i32, DriverError> {
280        result::init()?;
281        result::device::get_count()
282    }
283
284    /// Get the `ordinal` index of the device this is on.
285    pub fn ordinal(&self) -> usize {
286        self.ordinal
287    }
288
289    /// Get the name of this device.
290    pub fn name(&self) -> Result<String, result::DriverError> {
291        self.check_err()?;
292        result::device::get_name(self.cu_device)
293    }
294
295    /// Get the UUID of this device.
296    pub fn uuid(&self) -> Result<sys::CUuuid, result::DriverError> {
297        self.check_err()?;
298        result::device::get_uuid(self.cu_device)
299    }
300
301    /// Get the compute capability of this device as a (major,minor) tuple
302    pub fn compute_capability(&self) -> Result<(i32, i32), result::DriverError> {
303        self.check_err()?;
304        let capability_major =
305            self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?;
306        let capability_minor =
307            self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)?;
308
309        Ok((capability_major, capability_minor))
310    }
311
312    /// Get the total memory available on this device, in bytes.
313    pub fn total_mem(&self) -> Result<usize, DriverError> {
314        self.check_err()?;
315        unsafe { result::device::total_mem(self.cu_device) }
316    }
317
318    /// Returns the free and total device memory in bytes as a `(free, total)` tuple.
319    /// Note: this calls [CudaContext::bind_to_thread()] to ensure the query
320    /// runs against this device's context.
321    pub fn mem_get_info(&self) -> Result<(usize, usize), DriverError> {
322        self.bind_to_thread()?;
323        result::mem_get_info()
324    }
325    /// Get the underlying [sys::CUdevice] of this [CudaContext].
326    ///
327    /// # Safety
328    /// While this function is marked as safe, actually using the
329    /// returned object is unsafe.
330    ///
331    /// **You must not free/release the device pointer**, as it is still
332    /// owned by the [CudaContext].
333    pub fn cu_device(&self) -> sys::CUdevice {
334        self.cu_device
335    }
336
337    /// Get the underlying [sys::CUcontext] of this [CudaContext].
338    ///
339    /// # Safety
340    /// While this function is marked as safe, actually using the
341    /// returned object is unsafe.
342    ///
343    /// **You must not free/release the context pointer**, as it is still
344    /// owned by the [CudaContext].
345    pub fn cu_ctx(&self) -> sys::CUcontext {
346        self.cu_ctx
347    }
348
349    /// Binds this context to the calling thread. Calling this is key for thread safety.
350    pub fn bind_to_thread(&self) -> Result<(), DriverError> {
351        self.check_err()?;
352        if match result::ctx::get_current()? {
353            Some(curr_ctx) => curr_ctx != self.cu_ctx,
354            None => true,
355        } {
356            unsafe { result::ctx::set_current(self.cu_ctx) }?;
357        }
358        Ok(())
359    }
360
361    /// Get the value of the specified attribute of the device in [CudaContext].
362    pub fn attribute(&self, attrib: sys::CUdevice_attribute) -> Result<i32, result::DriverError> {
363        self.check_err()?;
364        unsafe { result::device::get_attribute(self.cu_device, attrib) }
365    }
366
367    /// Synchronize this context. Will only block CPU if you call [CudaContext::set_flags()] with
368    /// [sys::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC].
369    pub fn synchronize(&self) -> Result<(), DriverError> {
370        self.bind_to_thread()?;
371        result::ctx::synchronize()
372    }
373
374    /// Ensures calls to [CudaContext::synchronize()] block the calling thread.
375    ///
376    /// Sets [sys::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC]
377    #[cfg(not(any(
378        feature = "cuda-11040",
379        feature = "cuda-11050",
380        feature = "cuda-11060",
381        feature = "cuda-11070",
382        feature = "cuda-11080",
383        feature = "cuda-12000"
384    )))]
385    pub fn set_blocking_synchronize(&self) -> Result<(), DriverError> {
386        self.set_flags(sys::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC)
387    }
388
389    /// Set flags for this context
390    #[cfg(not(any(
391        feature = "cuda-11040",
392        feature = "cuda-11050",
393        feature = "cuda-11060",
394        feature = "cuda-11070",
395        feature = "cuda-11080",
396        feature = "cuda-12000"
397    )))]
398    pub fn set_flags(&self, flags: sys::CUctx_flags) -> Result<(), DriverError> {
399        self.bind_to_thread()?;
400        result::ctx::set_flags(flags)
401    }
402
403    /// Gets the value of a context limit.
404    ///
405    /// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g9f2d47d1745752aa16da7ed0d111b6a8)
406    pub fn get_limit(&self, limit: sys::CUlimit) -> Result<usize, DriverError> {
407        self.bind_to_thread()?;
408        result::ctx::get_limit(limit)
409    }
410
411    /// Sets the value of a context limit.
412    ///
413    /// Common limits:
414    /// - `CU_LIMIT_STACK_SIZE` - Stack size for each thread
415    /// - `CU_LIMIT_PRINTF_FIFO_SIZE` - Size of printf buffer
416    /// - `CU_LIMIT_MALLOC_HEAP_SIZE` - Heap size for malloc() in kernels
417    ///
418    /// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g0651954dfb9788173e60a9af7201e65a)
419    pub fn set_limit(&self, limit: sys::CUlimit, value: usize) -> Result<(), DriverError> {
420        self.bind_to_thread()?;
421        result::ctx::set_limit(limit, value)
422    }
423
424    /// Gets the L1/shared memory cache configuration preference.
425    ///
426    /// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g40b6b141698f76744dea6e39b9a25360)
427    pub fn get_cache_config(&self) -> Result<sys::CUfunc_cache, DriverError> {
428        self.bind_to_thread()?;
429        result::ctx::get_cache_config()
430    }
431
432    /// Sets the L1/shared memory cache configuration preference.
433    ///
434    /// Options:
435    /// - `CU_FUNC_CACHE_PREFER_NONE` - No preference
436    /// - `CU_FUNC_CACHE_PREFER_SHARED` - Prefer larger shared memory, smaller L1
437    /// - `CU_FUNC_CACHE_PREFER_L1` - Prefer larger L1, smaller shared memory
438    /// - `CU_FUNC_CACHE_PREFER_EQUAL` - Equal split between L1 and shared
439    ///
440    /// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g54699acf7e2ef27279d013ca2095f4a3)
441    pub fn set_cache_config(&self, config: sys::CUfunc_cache) -> Result<(), DriverError> {
442        self.bind_to_thread()?;
443        result::ctx::set_cache_config(config)
444    }
445
446    /// Whether multiple streams have been created in this context. If so,
447    /// the [CudaSlice::read] and [CudaSlice::write] events will be activated.
448    ///
449    /// This only get's set to true by [CudaContext::new_stream()].
450    pub fn is_in_multi_stream_mode(&self) -> bool {
451        self.num_streams.load(Ordering::Relaxed) > 0
452    }
453
454    /// Whether event tracking is being managed by this context
455    /// (via [CudaContext::enable_event_tracking()], which is the default behavior),
456    /// or `false` if the user is manually managing stream synchronization
457    /// (via [CudaContext::disable_event_tracking()]).
458    pub fn is_event_tracking(&self) -> bool {
459        self.event_tracking.load(Ordering::Relaxed)
460    }
461
462    /// Whether the context is automatically managing multiple stream synchronization.
463    /// Both of these must be true:
464    /// - [CudaContext::is_in_multi_stream_mode()]
465    /// - [CudaContext::is_event_tracking()]
466    pub fn is_managing_stream_synchronization(&self) -> bool {
467        self.is_in_multi_stream_mode() && self.is_event_tracking()
468    }
469
470    /// When turned on, all [CudaSlice] **created after calling this function** will
471    /// record usages using [CudaEvent] to ensure proper synchronization between streams.
472    ///
473    /// # Safety
474    ///
475    /// If [CudaContext::disable_event_tracking()] was called previously, then any
476    /// [CudaSlice] created after that and before this current call won't have [CudaEvent]
477    /// tracking their uses. Those [CudaSlice] will not manage their synchronization, even
478    /// after this call.
479    pub unsafe fn enable_event_tracking(&self) {
480        self.event_tracking.store(true, Ordering::Relaxed);
481    }
482
483    /// When turned on, all [CudaSlice] **created after calling this function** will
484    /// not track uses via [CudaEvent]s.
485    ///
486    /// # Safety
487    ///
488    /// It is up to the user to ensure proper synchronization between multiple streams:
489    /// - Ensure that no [CudaSlice] is freed before a use on another stream is finished.
490    /// - Ensure that a [CudaSlice] is not used on another stream before allocation on the
491    ///   allocating stream finishes.
492    /// - Ensure that a [CudaSlice] is not written two concurrently by multiple streams.
493    pub unsafe fn disable_event_tracking(&self) {
494        self.event_tracking.store(false, Ordering::Relaxed);
495    }
496
497    /// Checks to see if there have been any calls that stored an Err in a function
498    /// that couldn't return a result (e.g. Drop calls).
499    ///
500    /// If there are any errors stored, this method will return the Err value, and
501    /// then clear the stored error state.
502    pub fn check_err(&self) -> Result<(), DriverError> {
503        let error_state = self.error_state.swap(0, Ordering::Relaxed);
504        if error_state == 0 {
505            Ok(())
506        } else {
507            Err(result::DriverError(unsafe {
508                std::mem::transmute::<u32, sys::cudaError_enum>(error_state)
509            }))
510        }
511    }
512
513    /// Records a result for later inspection when a Result can be returned.
514    pub fn record_err<T>(&self, result: Result<T, DriverError>) {
515        if let Err(err) = result {
516            self.error_state.store(err.0 as u32, Ordering::Relaxed)
517        }
518    }
519}
520
521/// A lightweight synchronization primitive used to synchronize between [CudaStream]s.
522///
523/// - Create using [CudaContext::new_event()].
524/// - Record a point of time in a stream using [CudaEvent::record()].
525/// - Either call [CudaEvent::synchronize()] or [CudaStream::wait()] to use.
526///
527/// Note that calls to [CudaEvent::record()] will not change any **previous calls** to [CudaStream::wait()].
528///
529/// # Thread safety
530/// This object is thread safe
531#[derive(Debug)]
532pub struct CudaEvent {
533    pub(crate) cu_event: sys::CUevent,
534    pub(crate) ctx: Arc<CudaContext>,
535}
536
537unsafe impl Send for CudaEvent {}
538unsafe impl Sync for CudaEvent {}
539
540impl Drop for CudaEvent {
541    fn drop(&mut self) {
542        self.ctx.record_err(self.ctx.bind_to_thread());
543        self.ctx
544            .record_err(unsafe { result::event::destroy(self.cu_event) });
545    }
546}
547
548impl CudaContext {
549    /// Creates a new [CudaEvent] with no work recorded. If `flags` is None, the event is created with
550    /// [sys::CUevent_flags::CU_EVENT_DISABLE_TIMING].
551    pub fn new_event(
552        self: &Arc<Self>,
553        flags: Option<sys::CUevent_flags>,
554    ) -> Result<CudaEvent, DriverError> {
555        let flags = flags.unwrap_or(sys::CUevent_flags::CU_EVENT_DISABLE_TIMING);
556        self.bind_to_thread()?;
557        let cu_event = result::event::create(flags)?;
558        Ok(CudaEvent {
559            cu_event,
560            ctx: self.clone(),
561        })
562    }
563}
564
565impl CudaEvent {
566    /// The underlying cu_event object.
567    ///
568    /// # Safety
569    /// Do not destroy this value
570    pub fn cu_event(&self) -> sys::CUevent {
571        self.cu_event
572    }
573
574    /// The context this was created in.
575    pub fn context(&self) -> &Arc<CudaContext> {
576        &self.ctx
577    }
578
579    /// Records the current amount of work in [CudaStream] into this event.
580    ///
581    /// **This does not affect any previous calls to [CudaStream::wait()]**
582    ///
583    /// If `stream` belongs to a different [CudaContext], this will fail with
584    /// [sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT].
585    ///
586    /// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g95424d3be52c4eb95d83861b70fb89d1)
587    pub fn record(&self, stream: &CudaStream) -> Result<(), DriverError> {
588        if self.ctx != stream.ctx {
589            return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
590        }
591        self.ctx.bind_to_thread()?;
592        unsafe { result::event::record(self.cu_event, stream.cu_stream) }
593    }
594
595    /// Will only block CPU thraed if [sys::CUevent_flags::CU_EVENT_BLOCKING_SYNC] was used to create this event.
596    pub fn synchronize(&self) -> Result<(), DriverError> {
597        self.ctx.bind_to_thread()?;
598        unsafe { result::event::synchronize(self.cu_event) }
599    }
600
601    /// The time between two events. `self` is the start event, and `end` is the end event.
602    /// This is effectively `end - self`.
603    pub fn elapsed_ms(&self, end: &Self) -> Result<f32, DriverError> {
604        if self.ctx != end.ctx {
605            return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
606        }
607        self.ctx.bind_to_thread()?;
608        self.synchronize()?;
609        end.synchronize()?;
610        unsafe { result::event::elapsed(self.cu_event, end.cu_event) }
611    }
612
613    /// Returns `true` if all recorded work has been completed, `false` otherwise.
614    pub fn is_complete(&self) -> bool {
615        unsafe { result::event::query(self.cu_event) }.is_ok()
616    }
617}
618
619/// A wrapper around [sys::CUstream] that you can schedule work on.
620///
621/// - Create with [CudaContext::new_stream()], [CudaContext::default_stream()], or [CudaStream::fork()].
622///
623/// **Work done on this is asynchronous with respect to the host.**
624///
625/// See [CUDA C/C++ Streams and Concurrency](https://developer.download.nvidia.com/CUDA/training/StreamsAndConcurrencyWebinar.pdf)
626/// See [3. Stream synchronization behavior](https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html)
627/// See [6.6. Event Management](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html)
628/// See [Out-of-order execution](https://en.wikipedia.org/wiki/Out-of-order_execution)
629/// See [Dependence analysis](https://en.wikipedia.org/wiki/Dependence_analysis)
630#[derive(Debug, PartialEq, Eq)]
631pub struct CudaStream {
632    pub(crate) cu_stream: sys::CUstream,
633    pub(crate) ctx: Arc<CudaContext>,
634}
635
636unsafe impl Send for CudaStream {}
637unsafe impl Sync for CudaStream {}
638
639impl Drop for CudaStream {
640    fn drop(&mut self) {
641        self.ctx.record_err(self.ctx.bind_to_thread());
642        let cu_stream = std::mem::replace(&mut self.cu_stream, std::ptr::null_mut());
643        if !cu_stream.is_null() && cu_stream != (0x2 as _) {
644            self.ctx.num_streams.fetch_sub(1, Ordering::Relaxed);
645            self.ctx
646                .record_err(unsafe { result::stream::destroy(cu_stream) });
647        }
648    }
649}
650
651impl CudaContext {
652    /// Get's the default stream for this context (the null ptr stream). Note that context's
653    /// on the same device can all submit to the same default stream from separate context objects.
654    pub fn default_stream(self: &Arc<Self>) -> Arc<CudaStream> {
655        Arc::new(CudaStream {
656            cu_stream: std::ptr::null_mut(),
657            ctx: self.clone(),
658        })
659    }
660
661    /// Get's the per-thread stream handle. See https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior
662    pub fn per_thread_stream(self: &Arc<Self>) -> Arc<CudaStream> {
663        Arc::new(CudaStream {
664            // See https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g7b7129befd6f52708309acafd1c46197
665            cu_stream: 0x2 as _,
666            ctx: self.clone(),
667        })
668    }
669
670    /// Create a new [sys::CUstream_flags::CU_STREAM_NON_BLOCKING] stream.
671    ///
672    /// This will swap the calling context to multi stream mode [CudaContext::is_in_multi_stream_mode()].
673    /// If the context is not already in multiple stream mode, then this function will also call [CudaContext::synchronize()].
674    pub fn new_stream(self: &Arc<Self>) -> Result<Arc<CudaStream>, DriverError> {
675        self.bind_to_thread()?;
676        let prev_num_streams = self.num_streams.fetch_add(1, Ordering::Relaxed);
677        if prev_num_streams == 0 && self.is_event_tracking() {
678            self.synchronize()?;
679        }
680        let cu_stream = result::stream::create(result::stream::StreamKind::NonBlocking)?;
681        Ok(Arc::new(CudaStream {
682            cu_stream,
683            ctx: self.clone(),
684        }))
685    }
686}
687
688impl CudaStream {
689    /// Create's a new stream and then makes the new stream wait on `self`
690    pub fn fork(&self) -> Result<Arc<Self>, DriverError> {
691        self.ctx.bind_to_thread()?;
692        self.ctx.num_streams.fetch_add(1, Ordering::Relaxed);
693        let cu_stream = result::stream::create(result::stream::StreamKind::NonBlocking)?;
694        let stream = Arc::new(CudaStream {
695            cu_stream,
696            ctx: self.ctx.clone(),
697        });
698        stream.join(self)?;
699        Ok(stream)
700    }
701
702    /// The underlying cuda stream object
703    /// # Safety
704    /// Do not destroy this value.
705    pub fn cu_stream(&self) -> sys::CUstream {
706        self.cu_stream
707    }
708
709    /// The context the stream belongs to.
710    pub fn context(&self) -> &Arc<CudaContext> {
711        &self.ctx
712    }
713
714    /// Will only block CPU if you call [CudaContext::set_flags()] with
715    /// [sys::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC].
716    ///
717    /// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g15e49dd91ec15991eb7c0a741beb7dad)
718    pub fn synchronize(&self) -> Result<(), DriverError> {
719        self.ctx.bind_to_thread()?;
720        unsafe { result::stream::synchronize(self.cu_stream) }
721    }
722
723    /// Creates a new [CudaEvent] and records the current work in the stream to the event.
724    pub fn record_event(
725        &self,
726        flags: Option<sys::CUevent_flags>,
727    ) -> Result<CudaEvent, DriverError> {
728        let event = self.ctx.new_event(flags)?;
729        event.record(self)?;
730        Ok(event)
731    }
732
733    /// Waits for the work recorded in [CudaEvent] to be completed.
734    ///
735    /// You can record new work in `event` after calling this method without
736    /// affecting this call.
737    ///
738    /// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g6a898b652dfc6aa1d5c8d97062618b2f)
739    pub fn wait(&self, event: &CudaEvent) -> Result<(), DriverError> {
740        if self.ctx != event.ctx {
741            return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_INVALID_CONTEXT));
742        }
743        self.ctx.bind_to_thread()?;
744        unsafe {
745            result::stream::wait_event(
746                self.cu_stream,
747                event.cu_event,
748                sys::CUevent_wait_flags::CU_EVENT_WAIT_DEFAULT,
749            )
750        }
751    }
752
753    /// Ensures this stream waits for the current workload in `other` to complete.
754    /// This is shorthand for `self.wait(other.record_event())`
755    pub fn join(&self, other: &CudaStream) -> Result<(), DriverError> {
756        self.wait(&other.record_event(None)?)
757    }
758}
759
760/// `Vec<T>` on a cuda device. You can allocate and modify this with [CudaStream].
761///
762/// This object is thread safe.
763#[derive(Debug)]
764pub struct CudaSlice<T> {
765    pub(crate) cu_device_ptr: sys::CUdeviceptr,
766    pub(crate) len: usize,
767    pub(crate) read: Option<CudaEvent>,
768    pub(crate) write: Option<CudaEvent>,
769    pub(crate) stream: Arc<CudaStream>,
770    pub(crate) marker: PhantomData<*const T>,
771}
772
773unsafe impl<T> Send for CudaSlice<T> {}
774unsafe impl<T> Sync for CudaSlice<T> {}
775
776impl<T> Drop for CudaSlice<T> {
777    fn drop(&mut self) {
778        let ctx = &self.stream.ctx;
779        if let Some(read) = self.read.as_ref() {
780            ctx.record_err(self.stream.wait(read));
781        }
782        if let Some(write) = self.write.as_ref() {
783            ctx.record_err(self.stream.wait(write));
784        }
785        if ctx.has_async_alloc {
786            ctx.record_err(unsafe {
787                result::free_async(self.cu_device_ptr, self.stream.cu_stream)
788            });
789        } else {
790            ctx.record_err(self.stream.synchronize());
791            ctx.record_err(unsafe { result::free_sync(self.cu_device_ptr) });
792        }
793    }
794}
795
796impl<T> CudaSlice<T> {
797    /// The number of elements of `T` in this object.
798    pub fn len(&self) -> usize {
799        self.len
800    }
801
802    /// The number of bytes in this object.
803    pub fn num_bytes(&self) -> usize {
804        self.len * std::mem::size_of::<T>()
805    }
806
807    /// True if there are no elements in the object.
808    pub fn is_empty(&self) -> bool {
809        self.len == 0
810    }
811
812    /// The device ordinal this belongs to
813    pub fn ordinal(&self) -> usize {
814        self.stream.ctx.ordinal
815    }
816
817    /// The context this belongs to
818    pub fn context(&self) -> &Arc<CudaContext> {
819        &self.stream.ctx
820    }
821
822    /// The stream this object was allocated on and later will be dropped on.
823    pub fn stream(&self) -> &Arc<CudaStream> {
824        &self.stream
825    }
826}
827
828impl<T: DeviceRepr> CudaSlice<T> {
829    /// Allocates copy of self and schedules a device to device copy of memory.
830    pub fn try_clone(&self) -> Result<Self, result::DriverError> {
831        self.stream.clone_dtod(self)
832    }
833}
834
835impl<T: DeviceRepr> Clone for CudaSlice<T> {
836    fn clone(&self) -> Self {
837        self.try_clone().unwrap()
838    }
839}
840
841impl<T: Clone + Default + DeviceRepr> TryFrom<CudaSlice<T>> for Vec<T> {
842    type Error = result::DriverError;
843    fn try_from(value: CudaSlice<T>) -> Result<Self, Self::Error> {
844        value.stream.clone_dtoh(&value)
845    }
846}
847
848/// `&[T]` on a cuda device. An immutable sub-view into a [CudaSlice] created by [CudaSlice::as_view()]/[CudaSlice::slice()].
849#[derive(Debug)]
850pub struct CudaView<'a, T> {
851    pub(crate) ptr: sys::CUdeviceptr,
852    pub(crate) len: usize,
853    pub(crate) read: &'a Option<CudaEvent>,
854    pub(crate) write: &'a Option<CudaEvent>,
855    pub(crate) stream: &'a Arc<CudaStream>,
856    marker: PhantomData<&'a [T]>,
857}
858
859impl<T> CudaSlice<T> {
860    pub fn as_view(&self) -> CudaView<'_, T> {
861        CudaView {
862            ptr: self.cu_device_ptr,
863            len: self.len,
864            read: &self.read,
865            write: &self.write,
866            stream: &self.stream,
867            marker: PhantomData,
868        }
869    }
870}
871
872impl<T> CudaView<'_, T> {
873    /// The number of elements `T` in this view.
874    pub fn len(&self) -> usize {
875        self.len
876    }
877
878    pub fn is_empty(&self) -> bool {
879        self.len == 0
880    }
881
882    fn resize(&self, start: usize, end: usize) -> Self {
883        assert!(start <= end && end <= self.len);
884        Self {
885            ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
886            len: end - start,
887            read: self.read,
888            write: self.write,
889            stream: self.stream,
890            marker: PhantomData,
891        }
892    }
893}
894
895/// `&mut [T]` on a cuda device. A mutable sub-view into a [CudaSlice] created by [CudaSlice::as_view_mut()]/[CudaSlice::slice_mut()].
896#[derive(Debug)]
897pub struct CudaViewMut<'a, T> {
898    pub(crate) ptr: sys::CUdeviceptr,
899    pub(crate) len: usize,
900    pub(crate) read: &'a Option<CudaEvent>,
901    pub(crate) write: &'a Option<CudaEvent>,
902    pub(crate) stream: &'a Arc<CudaStream>,
903    marker: PhantomData<&'a mut [T]>,
904}
905
906impl<T> CudaSlice<T> {
907    pub fn as_view_mut(&mut self) -> CudaViewMut<'_, T> {
908        CudaViewMut {
909            ptr: self.cu_device_ptr,
910            len: self.len,
911            read: &self.read,
912            write: &self.write,
913            stream: &self.stream,
914            marker: PhantomData,
915        }
916    }
917}
918
919impl<T> CudaViewMut<'_, T> {
920    /// Number of elements `T` that are in this view.
921    pub fn len(&self) -> usize {
922        self.len
923    }
924    pub fn is_empty(&self) -> bool {
925        self.len == 0
926    }
927
928    /// Downgrade this to a `&[T]`
929    pub fn as_view<'b>(&'b self) -> CudaView<'b, T> {
930        CudaView {
931            ptr: self.ptr,
932            len: self.len,
933            read: self.read,
934            write: self.write,
935            stream: self.stream,
936            marker: PhantomData,
937        }
938    }
939}
940
941/// Marker trait to indicate that the type is valid
942/// when all of its bits are set to 0.
943///
944/// # Safety
945/// Not all types are valid when all bits are set to 0.
946/// Be very sure when implementing this trait!
947pub unsafe trait ValidAsZeroBits {}
948unsafe impl ValidAsZeroBits for bool {}
949unsafe impl ValidAsZeroBits for i8 {}
950unsafe impl ValidAsZeroBits for i16 {}
951unsafe impl ValidAsZeroBits for i32 {}
952unsafe impl ValidAsZeroBits for i64 {}
953unsafe impl ValidAsZeroBits for i128 {}
954unsafe impl ValidAsZeroBits for isize {}
955unsafe impl ValidAsZeroBits for u8 {}
956unsafe impl ValidAsZeroBits for u16 {}
957unsafe impl ValidAsZeroBits for u32 {}
958unsafe impl ValidAsZeroBits for u64 {}
959unsafe impl ValidAsZeroBits for u128 {}
960unsafe impl ValidAsZeroBits for usize {}
961unsafe impl ValidAsZeroBits for f32 {}
962unsafe impl ValidAsZeroBits for f64 {}
963#[cfg(feature = "f16")]
964unsafe impl ValidAsZeroBits for half::f16 {}
965#[cfg(feature = "f16")]
966unsafe impl ValidAsZeroBits for half::bf16 {}
967unsafe impl<T: ValidAsZeroBits, const M: usize> ValidAsZeroBits for [T; M] {}
968/// Implement `ValidAsZeroBits` for tuples if all elements are `ValidAsZeroBits`,
969///
970/// # Note
971/// This will also implement `ValidAsZeroBits` for a tuple with one element
972macro_rules! impl_tuples {
973    ($t:tt) => {
974        impl_tuples!(@ $t);
975    };
976    // the $l is in front of the reptition to prevent parsing ambiguities
977    ($l:tt $(,$t:tt)+) => {
978        impl_tuples!($($t),+);
979        impl_tuples!(@ $l $(,$t)+);
980    };
981    (@ $($t:tt),+) => {
982        unsafe impl<$($t: ValidAsZeroBits,)+> ValidAsZeroBits for ($($t,)+) {}
983    };
984}
985impl_tuples!(A, B, C, D, E, F, G, H, I, J, K, L);
986
987/// Something that can be copied to device memory and
988/// turned into a parameter for [result::launch_kernel].
989///
990/// # Safety
991///
992/// This is unsafe because a struct should likely
993/// be `#[repr(C)]` to be represented in cuda memory,
994/// and not all types are valid.
995pub unsafe trait DeviceRepr {}
996unsafe impl DeviceRepr for bool {}
997unsafe impl DeviceRepr for i8 {}
998unsafe impl DeviceRepr for i16 {}
999unsafe impl DeviceRepr for i32 {}
1000unsafe impl DeviceRepr for i64 {}
1001unsafe impl DeviceRepr for i128 {}
1002unsafe impl DeviceRepr for isize {}
1003unsafe impl DeviceRepr for u8 {}
1004unsafe impl DeviceRepr for u16 {}
1005unsafe impl DeviceRepr for u32 {}
1006unsafe impl DeviceRepr for u64 {}
1007unsafe impl DeviceRepr for u128 {}
1008unsafe impl DeviceRepr for usize {}
1009unsafe impl DeviceRepr for f32 {}
1010unsafe impl DeviceRepr for f64 {}
1011#[cfg(feature = "f16")]
1012unsafe impl DeviceRepr for half::f16 {}
1013#[cfg(feature = "f16")]
1014unsafe impl DeviceRepr for half::bf16 {}
1015
1016#[cfg(feature = "f8")]
1017unsafe impl DeviceRepr for float8::F8E4M3 {}
1018#[cfg(feature = "f8")]
1019unsafe impl ValidAsZeroBits for float8::F8E4M3 {}
1020
1021#[cfg(feature = "f8")]
1022unsafe impl DeviceRepr for float8::F8E5M2 {}
1023#[cfg(feature = "f8")]
1024unsafe impl ValidAsZeroBits for float8::F8E5M2 {}
1025
1026#[cfg(feature = "f4")]
1027unsafe impl DeviceRepr for float4::F4E2M1 {}
1028#[cfg(feature = "f4")]
1029unsafe impl ValidAsZeroBits for float4::F4E2M1 {}
1030
1031#[cfg(feature = "f4")]
1032unsafe impl DeviceRepr for float4::E8M0 {}
1033#[cfg(feature = "f4")]
1034unsafe impl ValidAsZeroBits for float4::E8M0 {}
1035
1036#[cfg(feature = "f4")]
1037unsafe impl DeviceRepr for float4::F4E2M1x2 {}
1038#[cfg(feature = "f4")]
1039unsafe impl ValidAsZeroBits for float4::F4E2M1x2 {}
1040
1041unsafe impl<const N: usize, T> DeviceRepr for [T; N] where T: DeviceRepr {}
1042
1043/// Base trait for abstracting over [CudaSlice]/[CudaView]/[CudaViewMut].
1044///
1045/// Don't use this directly - use [DevicePtr]/[DevicePtrMut].
1046pub trait DeviceSlice<T> {
1047    fn len(&self) -> usize;
1048    fn num_bytes(&self) -> usize {
1049        self.len() * std::mem::size_of::<T>()
1050    }
1051    fn is_empty(&self) -> bool {
1052        self.len() == 0
1053    }
1054    fn stream(&self) -> &Arc<CudaStream>;
1055}
1056
1057impl<T> DeviceSlice<T> for CudaSlice<T> {
1058    fn len(&self) -> usize {
1059        self.len
1060    }
1061    fn stream(&self) -> &Arc<CudaStream> {
1062        &self.stream
1063    }
1064}
1065
1066impl<T> DeviceSlice<T> for CudaView<'_, T> {
1067    fn len(&self) -> usize {
1068        self.len
1069    }
1070    fn stream(&self) -> &Arc<CudaStream> {
1071        self.stream
1072    }
1073}
1074
1075impl<T> DeviceSlice<T> for CudaViewMut<'_, T> {
1076    fn len(&self) -> usize {
1077        self.len
1078    }
1079    fn stream(&self) -> &Arc<CudaStream> {
1080        self.stream
1081    }
1082}
1083
1084/// A synchronization primitive to enable stream & event synchronization.
1085/// Primarily used with [DevicePtr] and [DevicePtrMut]
1086#[derive(Debug)]
1087#[must_use]
1088pub enum SyncOnDrop<'a> {
1089    /// Will record the stream's workload to the event on drop.
1090    Record(Option<(&'a CudaEvent, &'a CudaStream)>),
1091    /// Will call stream synchronize on drop.
1092    Sync(Option<&'a CudaStream>),
1093}
1094
1095impl<'a> SyncOnDrop<'a> {
1096    /// Construct a [SyncOnDrop::Record] variant
1097    pub fn record_event(event: &'a Option<CudaEvent>, stream: &'a CudaStream) -> Self {
1098        SyncOnDrop::Record(event.as_ref().map(|e| (e, stream)))
1099    }
1100    /// Construct a [SyncOnDrop::Sync] variant
1101    pub fn sync_stream(stream: &'a CudaStream) -> Self {
1102        SyncOnDrop::Sync(Some(stream))
1103    }
1104}
1105
1106impl Drop for SyncOnDrop<'_> {
1107    fn drop(&mut self) {
1108        match self {
1109            SyncOnDrop::Record(target) => {
1110                if let Some((event, stream)) = std::mem::take(target) {
1111                    stream.ctx.record_err(event.record(stream));
1112                }
1113            }
1114            SyncOnDrop::Sync(target) => {
1115                if let Some(stream) = std::mem::take(target) {
1116                    stream.ctx.record_err(stream.synchronize());
1117                }
1118            }
1119        }
1120    }
1121}
1122
1123/// Abstraction over [CudaSlice]/[CudaView]
1124pub trait DevicePtr<T>: DeviceSlice<T> {
1125    /// Retrieve the device pointer with the intent to read the device memory
1126    /// associated with it.
1127    ///
1128    /// Implementations of this method should ensure `stream` waits for any previous
1129    /// writes of this memory before continuing (do not need to wait for any previous reads).
1130    ///
1131    /// The [SyncOnDrop] item of the return tuple should be dropped **after** the read of
1132    /// the [sys::CUdeviceptr] is scheduled.
1133    ///
1134    /// In most cases you can use like:
1135    /// ```ignore
1136    /// let (src, _record_src) = src.device_ptr(&stream);
1137    /// ```
1138    /// Which will drop the [SyncOnDrop] at the end of the scope.
1139    fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>);
1140}
1141
1142impl<T> DevicePtr<T> for CudaSlice<T> {
1143    fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1144        if self.stream.context().is_managing_stream_synchronization() {
1145            if let Some(write) = self.write.as_ref() {
1146                stream.ctx.record_err(stream.wait(write));
1147            }
1148        }
1149        (
1150            self.cu_device_ptr,
1151            SyncOnDrop::record_event(&self.read, stream),
1152        )
1153    }
1154}
1155
1156impl<T> DevicePtr<T> for CudaView<'_, T> {
1157    fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1158        if self.stream.context().is_managing_stream_synchronization() {
1159            if let Some(write) = self.write.as_ref() {
1160                stream.ctx.record_err(stream.wait(write));
1161            }
1162        }
1163        (self.ptr, SyncOnDrop::record_event(self.read, stream))
1164    }
1165}
1166
1167impl<T> DevicePtr<T> for CudaViewMut<'_, T> {
1168    fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1169        if self.stream.context().is_managing_stream_synchronization() {
1170            if let Some(write) = self.write.as_ref() {
1171                stream.ctx.record_err(stream.wait(write));
1172            }
1173        }
1174        (self.ptr, SyncOnDrop::record_event(self.read, stream))
1175    }
1176}
1177
1178/// Abstraction over [CudaSlice]/[CudaViewMut]
1179pub trait DevicePtrMut<T>: DeviceSlice<T> {
1180    /// Retrieve the device pointer with the intent to modify the device memory
1181    /// associated with it.
1182    ///
1183    /// Implementations of this method should ensure `stream` waits for any previous
1184    /// reads/writes of this memory before continuing.
1185    ///
1186    /// The [SyncOnDrop] item of the return tuple should be dropped **after** the write of
1187    /// the [sys::CUdeviceptr] is scheduled.
1188    ///
1189    /// In most cases you can use like:
1190    /// ```ignore
1191    /// let (src, _record_src) = src.device_ptr_mut(&stream);
1192    /// ```
1193    /// Which will drop the [SyncOnDrop] at the end of the scope.
1194    fn device_ptr_mut<'a>(
1195        &'a mut self,
1196        stream: &'a CudaStream,
1197    ) -> (sys::CUdeviceptr, SyncOnDrop<'a>);
1198}
1199
1200impl<T> DevicePtrMut<T> for CudaSlice<T> {
1201    fn device_ptr_mut<'a>(
1202        &'a mut self,
1203        stream: &'a CudaStream,
1204    ) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1205        if self.stream.context().is_managing_stream_synchronization() {
1206            if let Some(read) = self.read.as_ref() {
1207                stream.ctx.record_err(stream.wait(read));
1208            }
1209            if let Some(write) = self.write.as_ref() {
1210                stream.ctx.record_err(stream.wait(write));
1211            }
1212        }
1213        (
1214            self.cu_device_ptr,
1215            SyncOnDrop::record_event(&self.write, stream),
1216        )
1217    }
1218}
1219
1220impl<T> DevicePtrMut<T> for CudaViewMut<'_, T> {
1221    fn device_ptr_mut<'a>(
1222        &'a mut self,
1223        stream: &'a CudaStream,
1224    ) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
1225        if self.stream.context().is_managing_stream_synchronization() {
1226            if let Some(read) = self.read.as_ref() {
1227                stream.ctx.record_err(stream.wait(read));
1228            }
1229            if let Some(write) = self.write.as_ref() {
1230                stream.ctx.record_err(stream.wait(write));
1231            }
1232        }
1233        (self.ptr, SyncOnDrop::record_event(self.write, stream))
1234    }
1235}
1236
1237/// Abstraction over `&[T]`, `&Vec<T>` and [`PinnedHostSlice<T>`].
1238pub trait HostSlice<T> {
1239    fn len(&self) -> usize;
1240    fn is_empty(&self) -> bool {
1241        self.len() == 0
1242    }
1243
1244    /// # Safety
1245    /// This is **only** safe if the resulting slice is used with `stream`. Otherwise
1246    /// You may run into device synchronization errors
1247    unsafe fn stream_synced_slice<'a>(
1248        &'a self,
1249        stream: &'a CudaStream,
1250    ) -> (&'a [T], SyncOnDrop<'a>);
1251
1252    /// # Safety
1253    /// This is **only** safe if the resulting slice is used with `stream`. Otherwise
1254    /// You may run into device synchronization errors
1255    unsafe fn stream_synced_mut_slice<'a>(
1256        &'a mut self,
1257        stream: &'a CudaStream,
1258    ) -> (&'a mut [T], SyncOnDrop<'a>);
1259}
1260
1261impl<T, const N: usize> HostSlice<T> for [T; N] {
1262    fn len(&self) -> usize {
1263        N
1264    }
1265    unsafe fn stream_synced_slice<'a>(
1266        &'a self,
1267        _stream: &'a CudaStream,
1268    ) -> (&'a [T], SyncOnDrop<'a>) {
1269        (self, SyncOnDrop::Sync(None))
1270    }
1271    unsafe fn stream_synced_mut_slice<'a>(
1272        &'a mut self,
1273        _stream: &'a CudaStream,
1274    ) -> (&'a mut [T], SyncOnDrop<'a>) {
1275        (self, SyncOnDrop::Sync(None))
1276    }
1277}
1278
1279impl<T> HostSlice<T> for [T] {
1280    fn len(&self) -> usize {
1281        self.len()
1282    }
1283    unsafe fn stream_synced_slice<'a>(
1284        &'a self,
1285        _stream: &'a CudaStream,
1286    ) -> (&'a [T], SyncOnDrop<'a>) {
1287        (self, SyncOnDrop::Sync(None))
1288    }
1289    unsafe fn stream_synced_mut_slice<'a>(
1290        &'a mut self,
1291        _stream: &'a CudaStream,
1292    ) -> (&'a mut [T], SyncOnDrop<'a>) {
1293        (self, SyncOnDrop::Sync(None))
1294    }
1295}
1296
1297impl<T> HostSlice<T> for Vec<T> {
1298    fn len(&self) -> usize {
1299        self.len()
1300    }
1301    unsafe fn stream_synced_slice<'a>(
1302        &'a self,
1303        _stream: &'a CudaStream,
1304    ) -> (&'a [T], SyncOnDrop<'a>) {
1305        (self, SyncOnDrop::Sync(None))
1306    }
1307    unsafe fn stream_synced_mut_slice<'a>(
1308        &'a mut self,
1309        _stream: &'a CudaStream,
1310    ) -> (&'a mut [T], SyncOnDrop<'a>) {
1311        (self, SyncOnDrop::Sync(None))
1312    }
1313}
1314
1315/// Rust side data that the `cuda` driver knows is pinned. This is different
1316/// than `Pin<Vec<T>>` mainly because cuda driver manages this memory and ensures
1317/// it is page locked.
1318///
1319/// Allocate this with [CudaContext::alloc_pinned()], and do device copies with
1320/// [CudaStream::clone_htod()]/[CudaStream::memcpy_htod()]/[CudaStream::memcpy_dtoh()]
1321#[derive(Debug)]
1322pub struct PinnedHostSlice<T> {
1323    pub(crate) ptr: *mut T,
1324    pub(crate) len: usize,
1325    pub(crate) event: CudaEvent,
1326}
1327
1328unsafe impl<T> Send for PinnedHostSlice<T> {}
1329unsafe impl<T> Sync for PinnedHostSlice<T> {}
1330
1331impl<T> Drop for PinnedHostSlice<T> {
1332    fn drop(&mut self) {
1333        let ctx = &self.event.ctx;
1334        ctx.record_err(self.event.synchronize());
1335        ctx.record_err(unsafe { result::free_host(self.ptr as _) });
1336    }
1337}
1338
1339impl CudaContext {
1340    /// Allocates page locked host memory with [sys::CU_MEMHOSTALLOC_WRITECOMBINED] flags.
1341    ///
1342    /// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g572ca4011bfcb25034888a14d4e035b9)
1343    ///
1344    /// # Safety
1345    /// 1. This is unsafe because the memory is unset after this call.
1346    pub unsafe fn alloc_pinned<T: DeviceRepr>(
1347        self: &Arc<Self>,
1348        len: usize,
1349    ) -> Result<PinnedHostSlice<T>, DriverError> {
1350        self.bind_to_thread()?;
1351        let ptr = result::malloc_host(
1352            len * std::mem::size_of::<T>(),
1353            sys::CU_MEMHOSTALLOC_WRITECOMBINED,
1354        )?;
1355        let ptr = ptr as *mut T;
1356        assert!(!ptr.is_null());
1357        assert!(len * std::mem::size_of::<T>() < isize::MAX as usize);
1358        assert!(ptr.is_aligned());
1359        let event = self.new_event(Some(sys::CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
1360        Ok(PinnedHostSlice { ptr, len, event })
1361    }
1362}
1363
1364impl<T> PinnedHostSlice<T> {
1365    /// The context this was created in.
1366    pub fn context(&self) -> &Arc<CudaContext> {
1367        &self.event.ctx
1368    }
1369
1370    /// The number of elements `T` in this slice.
1371    pub fn len(&self) -> usize {
1372        self.len
1373    }
1374
1375    /// The number of bytes in this slice.
1376    pub fn num_bytes(&self) -> usize {
1377        self.len * std::mem::size_of::<T>()
1378    }
1379
1380    pub fn is_empty(&self) -> bool {
1381        self.len() == 0
1382    }
1383}
1384
1385impl<T: ValidAsZeroBits> PinnedHostSlice<T> {
1386    /// Waits for any scheduled work to complete and then returns a refernce
1387    /// to the host side data.
1388    pub fn as_ptr(&self) -> Result<*const T, DriverError> {
1389        self.event.synchronize()?;
1390        Ok(self.ptr)
1391    }
1392
1393    /// Waits for any scheduled work to complete and then returns a refernce
1394    /// to the host side data.
1395    pub fn as_mut_ptr(&mut self) -> Result<*mut T, DriverError> {
1396        self.event.synchronize()?;
1397        Ok(self.ptr)
1398    }
1399
1400    /// Waits for any scheduled work to complete and then returns a refernce
1401    /// to the host side data.
1402    pub fn as_slice(&self) -> Result<&[T], DriverError> {
1403        self.event.synchronize()?;
1404        Ok(unsafe { std::slice::from_raw_parts(self.ptr, self.len) })
1405    }
1406
1407    /// Waits for any scheduled work to complete and then returns a refernce
1408    /// to the host side data.
1409    pub fn as_mut_slice(&mut self) -> Result<&mut [T], DriverError> {
1410        self.event.synchronize()?;
1411        Ok(unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) })
1412    }
1413}
1414
1415impl<T> HostSlice<T> for PinnedHostSlice<T> {
1416    fn len(&self) -> usize {
1417        self.len
1418    }
1419
1420    unsafe fn stream_synced_slice<'a>(
1421        &'a self,
1422        stream: &'a CudaStream,
1423    ) -> (&'a [T], SyncOnDrop<'a>) {
1424        stream.ctx.record_err(stream.wait(&self.event));
1425        (
1426            std::slice::from_raw_parts(self.ptr, self.len),
1427            SyncOnDrop::Record(Some((&self.event, stream))),
1428        )
1429    }
1430    unsafe fn stream_synced_mut_slice<'a>(
1431        &'a mut self,
1432        stream: &'a CudaStream,
1433    ) -> (&'a mut [T], SyncOnDrop<'a>) {
1434        stream.ctx.record_err(stream.wait(&self.event));
1435        (
1436            std::slice::from_raw_parts_mut(self.ptr, self.len),
1437            SyncOnDrop::Record(Some((&self.event, stream))),
1438        )
1439    }
1440}
1441
1442impl CudaStream {
1443    /// Allocates an empty [CudaSlice] with 0 length.
1444    pub fn null<T>(self: &Arc<Self>) -> Result<CudaSlice<T>, result::DriverError> {
1445        self.ctx.bind_to_thread()?;
1446        let cu_device_ptr = if self.ctx.has_async_alloc {
1447            unsafe { result::malloc_async(self.cu_stream, 0) }?
1448        } else {
1449            unsafe { result::malloc_sync(0) }?
1450        };
1451        Ok(CudaSlice {
1452            cu_device_ptr,
1453            len: 0,
1454            read: None,
1455            write: None,
1456            stream: self.clone(),
1457            marker: PhantomData,
1458        })
1459    }
1460
1461    /// Allocates a [CudaSlice] with `len` elements of type `T`.
1462    /// # Safety
1463    /// This is unsafe because the memory is unset.
1464    pub unsafe fn alloc<T: DeviceRepr>(
1465        self: &Arc<Self>,
1466        len: usize,
1467    ) -> Result<CudaSlice<T>, DriverError> {
1468        self.ctx.bind_to_thread()?;
1469        let cu_device_ptr = if self.ctx.has_async_alloc {
1470            result::malloc_async(self.cu_stream, len * std::mem::size_of::<T>())?
1471        } else {
1472            result::malloc_sync(len * std::mem::size_of::<T>())?
1473        };
1474        let (read, write) = if self.ctx.is_event_tracking() {
1475            (
1476                Some(self.ctx.new_event(None)?),
1477                Some(self.ctx.new_event(None)?),
1478            )
1479        } else {
1480            (None, None)
1481        };
1482        Ok(CudaSlice {
1483            cu_device_ptr,
1484            len,
1485            read,
1486            write,
1487            stream: self.clone(),
1488            marker: PhantomData,
1489        })
1490    }
1491
1492    /// Allocates a [CudaSlice] with `len` elements of type `T`. All values are zero'd out.
1493    pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(
1494        self: &Arc<Self>,
1495        len: usize,
1496    ) -> Result<CudaSlice<T>, DriverError> {
1497        let mut dst = unsafe { self.alloc(len) }?;
1498        self.memset_zeros(&mut dst)?;
1499        Ok(dst)
1500    }
1501
1502    /// Set's all the memory in `dst` to 0. `dst` can be a [CudaSlice] or [CudaViewMut]
1503    pub fn memset_zeros<T: DeviceRepr + ValidAsZeroBits, Dst: DevicePtrMut<T>>(
1504        self: &Arc<Self>,
1505        dst: &mut Dst,
1506    ) -> Result<(), DriverError> {
1507        self.ctx.bind_to_thread()?;
1508        let num_bytes = dst.num_bytes();
1509        let (dptr, _record) = dst.device_ptr_mut(self);
1510        unsafe { result::memset_d8_async(dptr, 0, num_bytes, self.cu_stream) }?;
1511        Ok(())
1512    }
1513
1514    /// Copy a `[T]`/`Vec<T>`/[`PinnedHostSlice<T>`] to a new [`CudaSlice`].
1515    #[deprecated = "Use clone_htod"]
1516    pub fn memcpy_stod<T: DeviceRepr, Src: HostSlice<T> + ?Sized>(
1517        self: &Arc<Self>,
1518        src: &Src,
1519    ) -> Result<CudaSlice<T>, DriverError> {
1520        let mut dst = unsafe { self.alloc(src.len()) }?;
1521        self.memcpy_htod(src, &mut dst)?;
1522        Ok(dst)
1523    }
1524
1525    /// Copy a `[T]`/`Vec<T>`/[`PinnedHostSlice<T>`] to a new [`CudaSlice`].
1526    pub fn clone_htod<T: DeviceRepr, Src: HostSlice<T> + ?Sized>(
1527        self: &Arc<Self>,
1528        src: &Src,
1529    ) -> Result<CudaSlice<T>, DriverError> {
1530        let mut dst = unsafe { self.alloc(src.len()) }?;
1531        self.memcpy_htod(src, &mut dst)?;
1532        Ok(dst)
1533    }
1534
1535    /// Copy a `[T]`/`Vec<T>`/[`PinnedHostSlice<T>`] into an existing [`CudaSlice`]/[`CudaViewMut`].
1536    pub fn memcpy_htod<T: DeviceRepr, Src: HostSlice<T> + ?Sized, Dst: DevicePtrMut<T>>(
1537        self: &Arc<Self>,
1538        src: &Src,
1539        dst: &mut Dst,
1540    ) -> Result<(), DriverError> {
1541        assert!(dst.len() >= src.len());
1542        self.ctx.bind_to_thread()?;
1543        let (src, _record_src) = unsafe { src.stream_synced_slice(self) };
1544        let (dst, _record_dst) = dst.device_ptr_mut(self);
1545        unsafe { result::memcpy_htod_async(dst, src, self.cu_stream) }
1546    }
1547
1548    /// Copy a [`CudaSlice`]/[`CudaView`] to a new [`Vec<T>`].
1549    #[deprecated = "Use clone_dtoh"]
1550    pub fn memcpy_dtov<T: DeviceRepr, Src: DevicePtr<T>>(
1551        self: &Arc<Self>,
1552        src: &Src,
1553    ) -> Result<Vec<T>, DriverError> {
1554        let mut dst = Vec::with_capacity(src.len());
1555        #[allow(clippy::uninit_vec)]
1556        unsafe {
1557            dst.set_len(src.len())
1558        };
1559        self.memcpy_dtoh(src, &mut dst)?;
1560        Ok(dst)
1561    }
1562
1563    /// Copy a [`CudaSlice`]/[`CudaView`] to a new [`Vec<T>`].
1564    pub fn clone_dtoh<T: DeviceRepr, Src: DevicePtr<T>>(
1565        self: &Arc<Self>,
1566        src: &Src,
1567    ) -> Result<Vec<T>, DriverError> {
1568        let mut dst = Vec::with_capacity(src.len());
1569        #[allow(clippy::uninit_vec)]
1570        unsafe {
1571            dst.set_len(src.len())
1572        };
1573        self.memcpy_dtoh(src, &mut dst)?;
1574        Ok(dst)
1575    }
1576
1577    /// Copy a [`CudaSlice`]/[`CudaView`] to a existing `[T]`/[`Vec<T>`]/[`PinnedHostSlice<T>`].
1578    pub fn memcpy_dtoh<T: DeviceRepr, Src: DevicePtr<T>, Dst: HostSlice<T> + ?Sized>(
1579        self: &Arc<Self>,
1580        src: &Src,
1581        dst: &mut Dst,
1582    ) -> Result<(), DriverError> {
1583        assert!(dst.len() >= src.len());
1584        self.ctx.bind_to_thread()?;
1585        let (src, _record_src) = src.device_ptr(self);
1586        let (dst, _record_dst) = unsafe { dst.stream_synced_mut_slice(self) };
1587        unsafe { result::memcpy_dtoh_async(dst, src, self.cu_stream) }
1588    }
1589
1590    /// Copy a [`CudaSlice`]/[`CudaView`] to a existing [`CudaSlice`]/[`CudaViewMut`].
1591    pub fn memcpy_dtod<T, Src: DevicePtr<T>, Dst: DevicePtrMut<T>>(
1592        self: &Arc<Self>,
1593        src: &Src,
1594        dst: &mut Dst,
1595    ) -> Result<(), DriverError> {
1596        assert!(dst.len() >= src.len());
1597        self.ctx.bind_to_thread()?;
1598
1599        let num_bytes = src.num_bytes();
1600
1601        let src_ctx = src.stream().context();
1602        let dst_ctx = self.context();
1603
1604        let (src, _record_src) = src.device_ptr(self);
1605        let (dst, _record_dst) = dst.device_ptr_mut(self);
1606
1607        if src_ctx == dst_ctx {
1608            unsafe { result::memcpy_dtod_async(dst, src, num_bytes, self.cu_stream) }
1609        } else {
1610            unsafe {
1611                result::memcpy_peer_async(
1612                    dst_ctx.cu_ctx,
1613                    dst,
1614                    src_ctx.cu_ctx,
1615                    src,
1616                    num_bytes,
1617                    self.cu_stream,
1618                )
1619            }
1620        }
1621    }
1622
1623    /// Copy a [`CudaSlice`]/[`CudaView`] to a new [`CudaSlice`].
1624    pub fn clone_dtod<T: DeviceRepr, Src: DevicePtr<T>>(
1625        self: &Arc<Self>,
1626        src: &Src,
1627    ) -> Result<CudaSlice<T>, DriverError> {
1628        let mut dst = unsafe { self.alloc(src.len()) }?;
1629        self.memcpy_dtod(src, &mut dst)?;
1630        Ok(dst)
1631    }
1632}
1633
1634impl<T> CudaSlice<T> {
1635    /// Creates a [CudaView] at the specified offset from the start of `self`.
1636    ///
1637    /// Panics if `range.start >= self.len`.
1638    ///
1639    /// # Example
1640    ///
1641    /// ```rust
1642    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaView};
1643    /// # fn do_something(view: &CudaView<u8>) {}
1644    /// # let ctx = CudaContext::new(0).unwrap();
1645    /// # let stream = ctx.default_stream();
1646    /// let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1647    /// let mut view = slice.slice(0..50);
1648    /// do_something(&view);
1649    /// ```
1650    ///
1651    /// Like a normal slice, borrow checking prevents the underlying [CudaSlice] from being dropped.
1652    /// ```rust,compile_fail
1653    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaView};
1654    /// # fn do_something(view: &CudaView<u8>) {}
1655    /// # let ctx = CudaContext::new(0).unwrap();
1656    /// # let stream = ctx.default_stream();
1657    /// let view = {
1658    ///     let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1659    ///     // cannot return view, since it borrows from slice
1660    ///     slice.slice(0..50)
1661    /// };
1662    /// do_something(&view);
1663    /// ```
1664    pub fn slice(&self, bounds: impl RangeBounds<usize>) -> CudaView<'_, T> {
1665        self.as_view().slice(bounds)
1666    }
1667
1668    /// Fallible version of [CudaSlice::slice()].
1669    pub fn try_slice(&self, bounds: impl RangeBounds<usize>) -> Option<CudaView<'_, T>> {
1670        self.as_view().try_slice(bounds)
1671    }
1672
1673    /// Creates a [CudaViewMut] at the specified offset from the start of `self`.
1674    ///
1675    /// Panics if `range` and `0...self.len()` are not overlapping.
1676    ///
1677    /// # Example
1678    ///
1679    /// ```rust
1680    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaViewMut};
1681    /// # fn do_something(view: &mut CudaViewMut<u8>) {}
1682    /// # let ctx = CudaContext::new(0).unwrap();
1683    /// # let stream = ctx.default_stream();
1684    /// let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1685    /// let mut view = slice.slice_mut(0..50);
1686    /// do_something(&mut view);
1687    /// ```
1688    ///
1689    /// Like a normal mutable slice, borrow checking prevents the underlying [CudaSlice] from being dropped.
1690    /// ```rust,compile_fail
1691    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaViewMut};
1692    /// # fn do_something(view: &mut CudaViewMut<u8>) {}
1693    /// # let ctx = CudaContext::new(0).unwrap();
1694    /// # let stream = ctx.default_stream();
1695    /// let mut view = {
1696    ///     let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1697    ///     // cannot return view, since it borrows from slice
1698    ///     slice.slice_mut(0..50)
1699    /// };
1700    /// do_something(&mut view);
1701    /// ```
1702    ///
1703    /// Like with normal mutable slices, one cannot mutably slice twice into the same [CudaSlice]:
1704    /// ```rust,compile_fail
1705    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaViewMut};
1706    /// # fn do_something(view: CudaViewMut<u8>, view2: CudaViewMut<u8>) {}
1707    /// # let ctx = CudaContext::new(0).unwrap();
1708    /// # let stream = ctx.default_stream();
1709    /// let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1710    /// let mut view1 = slice.slice_mut(0..50);
1711    /// // cannot borrow twice from slice
1712    /// let mut view2 = slice.slice_mut(50..100);
1713    /// do_something(view1, view2);
1714    /// ```
1715    /// If you need non-overlapping mutable views into a [CudaSlice], you can use [CudaSlice::split_at_mut()].
1716    pub fn slice_mut(&mut self, bounds: impl RangeBounds<usize>) -> CudaViewMut<'_, T> {
1717        self.try_slice_mut(bounds).unwrap()
1718    }
1719
1720    /// Fallible version of [CudaSlice::slice_mut]
1721    pub fn try_slice_mut(&mut self, bounds: impl RangeBounds<usize>) -> Option<CudaViewMut<'_, T>> {
1722        to_range(bounds, self.len).map(|(start, end)| CudaViewMut {
1723            ptr: self.cu_device_ptr + (start * std::mem::size_of::<T>()) as u64,
1724            len: end - start,
1725            read: &self.read,
1726            write: &self.write,
1727            stream: &self.stream,
1728            marker: PhantomData,
1729        })
1730    }
1731
1732    /// Reinterprets the slice of memory into a different type. `len` is the number
1733    /// of elements of the new type `S` that are expected. If not enough bytes
1734    /// are allocated in `self` for the view, then this returns `None`.
1735    ///
1736    /// # Safety
1737    /// This is unsafe because not the memory for the view may not be a valid interpretation
1738    /// for the type `S`.
1739    pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'_, S>> {
1740        self.as_view().transmute(len)
1741    }
1742
1743    /// Reinterprets the slice of memory into a different type. `len` is the number
1744    /// of elements of the new type `S` that are expected. If not enough bytes
1745    /// are allocated in `self` for the view, then this returns `None`.
1746    ///
1747    /// # Safety
1748    /// This is unsafe because not the memory for the view may not be a valid interpretation
1749    /// for the type `S`.
1750    pub unsafe fn transmute_mut<S>(&mut self, len: usize) -> Option<CudaViewMut<'_, S>> {
1751        (len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
1752            CudaViewMut {
1753                ptr: self.cu_device_ptr,
1754                len,
1755                read: &self.read,
1756                write: &self.write,
1757                stream: &self.stream,
1758                marker: PhantomData,
1759            },
1760        )
1761    }
1762
1763    pub fn split_at(&self, mid: usize) -> (CudaView<'_, T>, CudaView<'_, T>) {
1764        self.as_view().split_at(mid)
1765    }
1766
1767    /// Fallible version of [CudaSlice::split_at]. Returns `None` if `mid > self.len`.
1768    pub fn try_split_at(&self, mid: usize) -> Option<(CudaView<'_, T>, CudaView<'_, T>)> {
1769        self.as_view().try_split_at(mid)
1770    }
1771
1772    /// Splits the [CudaSlice] into two at the given index, returning two [CudaViewMut] for the two halves.
1773    ///
1774    /// Panics if `mid > self.len`.
1775    ///
1776    /// This method can be used to create non-overlapping mutable views into a [CudaSlice].
1777    /// ```rust
1778    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaViewMut};
1779    /// # fn do_something(view: CudaViewMut<u8>, view2: CudaViewMut<u8>) {}
1780    /// # let ctx = CudaContext::new(0).unwrap();
1781    /// # let stream = ctx.default_stream();
1782    /// let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1783    /// // split the slice into two non-overlapping, mutable views
1784    /// let (mut view1, mut view2) = slice.split_at_mut(50);
1785    /// do_something(view1, view2);
1786    /// ```
1787    pub fn split_at_mut(&mut self, mid: usize) -> (CudaViewMut<'_, T>, CudaViewMut<'_, T>) {
1788        self.try_split_at_mut(mid).unwrap()
1789    }
1790
1791    /// Fallible version of [CudaSlice::split_at_mut].
1792    ///
1793    /// Returns `None` if `mid > self.len`.
1794    pub fn try_split_at_mut(
1795        &mut self,
1796        mid: usize,
1797    ) -> Option<(CudaViewMut<'_, T>, CudaViewMut<'_, T>)> {
1798        let length = self.len;
1799        (mid <= length).then(|| {
1800            let a = CudaViewMut {
1801                ptr: self.cu_device_ptr,
1802                len: mid,
1803                read: &self.read,
1804                write: &self.write,
1805                stream: &self.stream,
1806                marker: PhantomData,
1807            };
1808            let b = CudaViewMut {
1809                ptr: self.cu_device_ptr + (mid * std::mem::size_of::<T>()) as u64,
1810                len: length - mid,
1811                read: &self.read,
1812                write: &self.write,
1813                stream: &self.stream,
1814                marker: PhantomData,
1815            };
1816            (a, b)
1817        })
1818    }
1819}
1820
1821impl<'a, T> CudaView<'a, T> {
1822    /// Creates a [CudaView] at the specified offset from the start of `self`.
1823    ///
1824    /// Panics if `range.start >= self.len`.
1825    ///
1826    /// # Example
1827    ///
1828    /// ```rust
1829    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaView};
1830    /// # fn do_something(view: &CudaView<u8>) {}
1831    /// # let ctx = CudaContext::new(0).unwrap();
1832    /// # let stream = ctx.default_stream();
1833    /// let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1834    /// let mut view = slice.slice(0..50);
1835    /// let mut view2 = view.slice(0..25);
1836    /// do_something(&view);
1837    /// ```
1838    pub fn slice(&self, bounds: impl RangeBounds<usize>) -> Self {
1839        self.try_slice(bounds).unwrap()
1840    }
1841
1842    /// Fallible version of [CudaView::slice]
1843    pub fn try_slice(&self, bounds: impl RangeBounds<usize>) -> Option<Self> {
1844        to_range(bounds, self.len).map(|(start, end)| self.resize(start, end))
1845    }
1846
1847    /// Reinterprets the slice of memory into a different type. `len` is the number
1848    /// of elements of the new type `S` that are expected. If not enough bytes
1849    /// are allocated in `self` for the view, then this returns `None`.
1850    ///
1851    /// # Safety
1852    /// This is unsafe because not the memory for the view may not be a valid interpretation
1853    /// for the type `S`.
1854    pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'a, S>> {
1855        (len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
1856            CudaView {
1857                ptr: self.ptr,
1858                len,
1859                read: self.read,
1860                write: self.write,
1861                stream: self.stream,
1862                marker: PhantomData,
1863            },
1864        )
1865    }
1866
1867    pub fn split_at(&self, mid: usize) -> (Self, Self) {
1868        self.try_split_at(mid).unwrap()
1869    }
1870
1871    /// Fallible version of [CudaSlice::split_at].
1872    ///
1873    /// Returns `None` if `mid > self.len`.
1874    pub fn try_split_at(&self, mid: usize) -> Option<(Self, Self)> {
1875        (mid <= self.len()).then(|| (self.resize(0, mid), self.resize(mid, self.len)))
1876    }
1877}
1878
1879impl<'a, T> CudaViewMut<'a, T> {
1880    /// Creates a [CudaView] at the specified offset from the start of `self`.
1881    ///
1882    /// Panics if `range` and `0...self.len()` are not overlapping.
1883    ///
1884    /// # Example
1885    ///
1886    /// ```rust
1887    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaViewMut};
1888    /// # fn do_something(view: &mut CudaViewMut<u8>) {}
1889    /// # let ctx = CudaContext::new(0).unwrap();
1890    /// # let stream = ctx.default_stream();
1891    /// let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1892    /// let mut view = slice.slice_mut(0..50);
1893    /// let mut view2 = view.slice_mut(0..25);
1894    /// do_something(&mut view2);
1895    /// ```
1896    ///
1897    /// One cannot slice twice into the same [CudaViewMut]:
1898    /// ```rust,compile_fail
1899    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaViewMut};
1900    /// # fn do_something(view: CudaViewMut<u8>, view2: CudaViewMut<u8>) {}
1901    /// # let ctx = CudaContext::new(0).unwrap();
1902    /// # let stream = ctx.default_stream();
1903    /// let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1904    /// let mut view = slice.slice_mut(0..50);
1905    /// // cannot borrow twice from same view
1906    /// let mut view1 = slice.slice_mut(0..25);
1907    /// let mut view2 = slice.slice_mut(25..50);
1908    /// do_something(view1, view2);
1909    /// ```
1910    /// If you need non-overlapping mutable views into a [CudaViewMut], you can use [CudaViewMut::split_at_mut()].
1911    pub fn slice<'b>(&'b self, bounds: impl RangeBounds<usize>) -> CudaView<'b, T> {
1912        self.try_slice(bounds).unwrap()
1913    }
1914
1915    /// Fallible version of [CudaViewMut::slice]
1916    pub fn try_slice<'b>(&'b self, bounds: impl RangeBounds<usize>) -> Option<CudaView<'b, T>> {
1917        to_range(bounds, self.len).map(move |(start, end)| self.as_view().resize(start, end))
1918    }
1919
1920    /// Reinterprets the slice of memory into a different type. `len` is the number
1921    /// of elements of the new type `S` that are expected. If not enough bytes
1922    /// are allocated in `self` for the view, then this returns `None`.
1923    ///
1924    /// # Safety
1925    /// This is unsafe because not the memory for the view may not be a valid interpretation
1926    /// for the type `S`.
1927    pub unsafe fn transmute<'b, S>(&'b self, len: usize) -> Option<CudaView<'b, S>> {
1928        (len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
1929            CudaView {
1930                ptr: self.ptr,
1931                len,
1932                read: self.read,
1933                write: self.write,
1934                stream: self.stream,
1935                marker: PhantomData,
1936            },
1937        )
1938    }
1939
1940    /// Creates a [CudaViewMut] at the specified offset from the start of `self`.
1941    ///
1942    /// Panics if `range` and `0...self.len()` are not overlapping.
1943    pub fn slice_mut<'b>(&'b mut self, bounds: impl RangeBounds<usize>) -> CudaViewMut<'b, T> {
1944        self.try_slice_mut(bounds).unwrap()
1945    }
1946
1947    /// Fallible version of [CudaViewMut::slice_mut]
1948    pub fn try_slice_mut<'b>(
1949        &'b mut self,
1950        bounds: impl RangeBounds<usize>,
1951    ) -> Option<CudaViewMut<'b, T>> {
1952        to_range(bounds, self.len).map(|(start, end)| CudaViewMut {
1953            ptr: self.ptr + (start * std::mem::size_of::<T>()) as u64,
1954            len: end - start,
1955            read: self.read,
1956            write: self.write,
1957            stream: self.stream,
1958            marker: PhantomData,
1959        })
1960    }
1961
1962    /// Splits the [CudaViewMut] into two at the given index.
1963    ///
1964    /// Panics if `mid > self.len`.
1965    ///
1966    /// This method can be used to create non-overlapping mutable views into a [CudaViewMut].
1967    /// ```rust
1968    /// # use cudarc::driver::safe::{CudaContext, CudaSlice, CudaViewMut};
1969    /// # fn do_something(view: CudaViewMut<u8>, view2: CudaViewMut<u8>) {}
1970    /// # let ctx = CudaContext::new(0).unwrap();
1971    /// # let stream = ctx.default_stream();
1972    /// let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
1973    /// let mut view = slice.slice_mut(0..50);
1974    /// // split the view into two non-overlapping, mutable views
1975    /// let (mut view1, mut view2) = view.split_at_mut(25);
1976    /// do_something(view1, view2);
1977    pub fn split_at_mut<'b>(&'b mut self, mid: usize) -> (CudaViewMut<'b, T>, CudaViewMut<'b, T>) {
1978        self.try_split_at_mut(mid).unwrap()
1979    }
1980
1981    /// Fallible version of [CudaViewMut::split_at_mut].
1982    ///
1983    /// Returns `None` if `mid > self.len`
1984    pub fn try_split_at_mut<'b>(
1985        &'b mut self,
1986        mid: usize,
1987    ) -> Option<(CudaViewMut<'b, T>, CudaViewMut<'b, T>)> {
1988        let length = self.len;
1989        (mid <= length).then(|| {
1990            let a = CudaViewMut {
1991                ptr: self.ptr,
1992                len: mid,
1993                read: self.read,
1994                write: self.write,
1995                stream: self.stream,
1996                marker: PhantomData,
1997            };
1998            let b = CudaViewMut {
1999                ptr: self.ptr + (mid * std::mem::size_of::<T>()) as u64,
2000                len: length - mid,
2001                read: self.read,
2002                write: self.write,
2003                stream: self.stream,
2004                marker: PhantomData,
2005            };
2006            (a, b)
2007        })
2008    }
2009
2010    /// Reinterprets the slice of memory into a different type. `len` is the number
2011    /// of elements of the new type `S` that are expected. If not enough bytes
2012    /// are allocated in `self` for the view, then this returns `None`.
2013    ///
2014    /// # Safety
2015    /// This is unsafe because not the memory for the view may not be a valid interpretation
2016    /// for the type `S`.
2017    pub unsafe fn transmute_mut<'b, S>(&'b mut self, len: usize) -> Option<CudaViewMut<'b, S>> {
2018        (len * std::mem::size_of::<S>() <= self.len * std::mem::size_of::<T>()).then_some(
2019            CudaViewMut {
2020                ptr: self.ptr,
2021                len,
2022                read: self.read,
2023                write: self.write,
2024                stream: self.stream,
2025                marker: PhantomData,
2026            },
2027        )
2028    }
2029}
2030
2031pub(super) fn to_range(range: impl RangeBounds<usize>, len: usize) -> Option<(usize, usize)> {
2032    let start = match range.start_bound() {
2033        Bound::Included(&n) => n,
2034        Bound::Excluded(&n) => n + 1,
2035        Bound::Unbounded => 0,
2036    };
2037    let end = match range.end_bound() {
2038        Bound::Included(&n) => n + 1,
2039        Bound::Excluded(&n) => n,
2040        Bound::Unbounded => len,
2041    };
2042    (start <= end && end <= len).then_some((start, end))
2043}
2044
2045/// Wrapper around [sys::CUmodule]. Create with [CudaContext::load_module()].
2046///
2047/// Call [CudaModule::load_function] to load a [CudaFunction].
2048#[derive(Debug)]
2049pub struct CudaModule {
2050    pub(crate) cu_module: sys::CUmodule,
2051    pub(crate) ctx: Arc<CudaContext>,
2052}
2053
2054unsafe impl Send for CudaModule {}
2055unsafe impl Sync for CudaModule {}
2056
2057impl Drop for CudaModule {
2058    fn drop(&mut self) {
2059        self.ctx.record_err(self.ctx.bind_to_thread());
2060        self.ctx
2061            .record_err(unsafe { result::module::unload(self.cu_module) });
2062    }
2063}
2064
2065impl CudaContext {
2066    /// Dynamically load a compiled ptx into this context.
2067    ///
2068    /// - `ptx` contains the compiled ptx
2069    #[cfg(feature = "nvrtc")]
2070    pub fn load_module(
2071        self: &Arc<Self>,
2072        ptx: crate::nvrtc::Ptx,
2073    ) -> Result<Arc<CudaModule>, result::DriverError> {
2074        self.bind_to_thread()?;
2075
2076        let cu_module = match ptx.0 {
2077            crate::nvrtc::PtxKind::Image(image) => unsafe {
2078                result::module::load_data(image.as_ptr() as *const _)
2079            },
2080            crate::nvrtc::PtxKind::Src(src) => {
2081                let c_src = CString::new(src).unwrap();
2082                unsafe { result::module::load_data(c_src.as_ptr() as *const _) }
2083            }
2084            crate::nvrtc::PtxKind::File(path) => {
2085                let name_c = CString::new(path.to_str().unwrap()).unwrap();
2086                result::module::load(name_c)
2087            }
2088            crate::nvrtc::PtxKind::Binary(data) => unsafe {
2089                result::module::load_data(data.as_ptr() as *const _)
2090            },
2091        }?;
2092        Ok(Arc::new(CudaModule {
2093            cu_module,
2094            ctx: self.clone(),
2095        }))
2096    }
2097}
2098
2099/// Wrapper around [sys::CUfunction]. Used by [CudaStream::launch_builder] to execute kernels.
2100#[derive(Debug, Clone)]
2101pub struct CudaFunction {
2102    pub(crate) cu_function: sys::CUfunction,
2103    #[allow(unused)]
2104    pub(crate) module: Arc<CudaModule>,
2105}
2106
2107unsafe impl Send for CudaFunction {}
2108unsafe impl Sync for CudaFunction {}
2109
2110impl CudaModule {
2111    /// Loads a function from the loaded module with the given name.
2112    pub fn load_function(self: &Arc<Self>, fn_name: &str) -> Result<CudaFunction, DriverError> {
2113        let fn_name_c = CString::new(fn_name).unwrap();
2114        let cu_function = unsafe { result::module::get_function(self.cu_module, fn_name_c) }?;
2115        Ok(CudaFunction {
2116            cu_function,
2117            module: self.clone(),
2118        })
2119    }
2120
2121    /// Gets a global/constant symbol from the loaded module as a [CudaSlice<u8>].
2122    ///
2123    /// This can be used to access `__constant__` memory declared in CUDA kernels.
2124    /// The returned slice can be transmuted to the appropriate type via views.
2125    ///
2126    /// # Example
2127    ///
2128    /// ```ignore
2129    /// // In CUDA: __constant__ float my_const[4];
2130    /// let symbol = module.get_global("my_const", &stream)?;
2131    /// let mut symbol_view = symbol.as_view_mut();
2132    /// let mut symbol_f32 = unsafe { symbol_view.transmute_mut::<f32>(4).unwrap() };
2133    /// stream.memcpy_htod(&[1.0f32, 2.0, 3.0, 4.0], &mut symbol_f32)?;
2134    /// ```
2135    pub fn get_global<'a>(
2136        self: &'a Arc<Self>,
2137        name: &str,
2138        stream: &'a Arc<CudaStream>,
2139    ) -> Result<CudaViewMut<'a, u8>, DriverError> {
2140        let name_c =
2141            CString::new(name).map_err(|_| DriverError(sys::CUresult::CUDA_ERROR_INVALID_VALUE))?;
2142        let (cu_device_ptr, bytes) = unsafe { result::module::get_global(self.cu_module, name_c) }?;
2143        Ok(CudaViewMut {
2144            ptr: cu_device_ptr,
2145            len: bytes,
2146            read: &None,
2147            write: &None,
2148            stream,
2149            marker: PhantomData,
2150        })
2151    }
2152}
2153
2154impl CudaFunction {
2155    pub fn occupancy_available_dynamic_smem_per_block(
2156        &self,
2157        num_blocks: u32,
2158        block_size: u32,
2159    ) -> Result<usize, result::DriverError> {
2160        let mut dynamic_smem_size: usize = 0;
2161
2162        unsafe {
2163            sys::cuOccupancyAvailableDynamicSMemPerBlock(
2164                &mut dynamic_smem_size,
2165                self.cu_function,
2166                num_blocks as std::ffi::c_int,
2167                block_size as std::ffi::c_int,
2168            )
2169            .result()?
2170        };
2171
2172        Ok(dynamic_smem_size)
2173    }
2174
2175    pub fn occupancy_max_active_blocks_per_multiprocessor(
2176        &self,
2177        block_size: u32,
2178        dynamic_smem_size: usize,
2179        flags: Option<sys::CUoccupancy_flags_enum>,
2180    ) -> Result<u32, result::DriverError> {
2181        let mut num_blocks: std::ffi::c_int = 0;
2182        let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
2183
2184        unsafe {
2185            sys::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
2186                &mut num_blocks,
2187                self.cu_function,
2188                block_size as std::ffi::c_int,
2189                dynamic_smem_size,
2190                flags as std::ffi::c_uint,
2191            )
2192            .result()?
2193        };
2194
2195        Ok(num_blocks as u32)
2196    }
2197
2198    #[cfg(not(any(
2199        feature = "cuda-11070",
2200        feature = "cuda-11060",
2201        feature = "cuda-11050",
2202        feature = "cuda-11040"
2203    )))]
2204    pub fn occupancy_max_active_clusters(
2205        &self,
2206        config: crate::driver::LaunchConfig,
2207        stream: &CudaStream,
2208    ) -> Result<u32, result::DriverError> {
2209        let mut num_clusters: std::ffi::c_int = 0;
2210
2211        let cfg = sys::CUlaunchConfig {
2212            gridDimX: config.grid_dim.0,
2213            gridDimY: config.grid_dim.1,
2214            gridDimZ: config.grid_dim.2,
2215            blockDimX: config.block_dim.0,
2216            blockDimY: config.block_dim.1,
2217            blockDimZ: config.block_dim.2,
2218            sharedMemBytes: config.shared_mem_bytes,
2219            hStream: stream.cu_stream,
2220            attrs: std::ptr::null_mut(),
2221            numAttrs: 0,
2222        };
2223
2224        unsafe {
2225            sys::cuOccupancyMaxActiveClusters(&mut num_clusters, self.cu_function, &cfg).result()?
2226        };
2227
2228        Ok(num_clusters as u32)
2229    }
2230
2231    pub fn occupancy_max_potential_block_size(
2232        &self,
2233        block_size_to_dynamic_smem_size: extern "C" fn(block_size: std::ffi::c_int) -> usize,
2234        dynamic_smem_size: usize,
2235        block_size_limit: u32,
2236        flags: Option<sys::CUoccupancy_flags_enum>,
2237    ) -> Result<(u32, u32), result::DriverError> {
2238        let mut min_grid_size: std::ffi::c_int = 0;
2239        let mut block_size: std::ffi::c_int = 0;
2240        let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
2241
2242        unsafe {
2243            sys::cuOccupancyMaxPotentialBlockSizeWithFlags(
2244                &mut min_grid_size,
2245                &mut block_size,
2246                self.cu_function,
2247                Some(block_size_to_dynamic_smem_size),
2248                dynamic_smem_size,
2249                block_size_limit as std::ffi::c_int,
2250                flags as std::ffi::c_uint,
2251            )
2252            .result()?
2253        };
2254
2255        Ok((min_grid_size as u32, block_size as u32))
2256    }
2257
2258    #[cfg(not(any(
2259        feature = "cuda-11070",
2260        feature = "cuda-11060",
2261        feature = "cuda-11050",
2262        feature = "cuda-11040"
2263    )))]
2264    pub fn occupancy_max_potential_cluster_size(
2265        &self,
2266        config: crate::driver::LaunchConfig,
2267        stream: &CudaStream,
2268    ) -> Result<u32, result::DriverError> {
2269        let mut cluster_size: std::ffi::c_int = 0;
2270
2271        let cfg = sys::CUlaunchConfig {
2272            gridDimX: config.grid_dim.0,
2273            gridDimY: config.grid_dim.1,
2274            gridDimZ: config.grid_dim.2,
2275            blockDimX: config.block_dim.0,
2276            blockDimY: config.block_dim.1,
2277            blockDimZ: config.block_dim.2,
2278            sharedMemBytes: config.shared_mem_bytes,
2279            hStream: stream.cu_stream,
2280            attrs: std::ptr::null_mut(),
2281            numAttrs: 0,
2282        };
2283
2284        unsafe {
2285            sys::cuOccupancyMaxPotentialClusterSize(&mut cluster_size, self.cu_function, &cfg)
2286                .result()?
2287        };
2288
2289        Ok(cluster_size as u32)
2290    }
2291
2292    /// Get the value of a specific attribute of this [CudaFunction].
2293    ///
2294    /// See [CUDA docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b)
2295    pub fn get_attribute(
2296        &self,
2297        attribute: CUfunction_attribute_enum,
2298    ) -> Result<i32, result::DriverError> {
2299        self.module.ctx.bind_to_thread()?;
2300        unsafe { result::function::get_function_attribute(self.cu_function, attribute) }
2301    }
2302
2303    /// Get the number of registers used per thread.
2304    pub fn num_regs(&self) -> Result<i32, result::DriverError> {
2305        self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_NUM_REGS)
2306    }
2307
2308    /// Get the size of statically-allocated shared memory in bytes.
2309    pub fn shared_size_bytes(&self) -> Result<i32, result::DriverError> {
2310        self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
2311    }
2312
2313    /// Get the size of constant memory in bytes used by this function.
2314    pub fn const_size_bytes(&self) -> Result<i32, result::DriverError> {
2315        self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES)
2316    }
2317
2318    /// Get the size of local memory in bytes used per thread.
2319    pub fn local_size_bytes(&self) -> Result<i32, result::DriverError> {
2320        self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
2321    }
2322
2323    /// Get the maximum number of threads per block for this function.
2324    pub fn max_threads_per_block(&self) -> Result<i32, result::DriverError> {
2325        self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
2326    }
2327
2328    /// Get the PTX virtual architecture version for which the function was compiled.
2329    pub fn ptx_version(&self) -> Result<i32, result::DriverError> {
2330        self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_PTX_VERSION)
2331    }
2332
2333    /// Get the binary architecture version for which the function was compiled.
2334    pub fn binary_version(&self) -> Result<i32, result::DriverError> {
2335        self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_BINARY_VERSION)
2336    }
2337
2338    /// Set the value of a specific attribute of this [CudaFunction].
2339    pub fn set_attribute(
2340        &self,
2341        attribute: CUfunction_attribute_enum,
2342        value: i32,
2343    ) -> Result<(), result::DriverError> {
2344        unsafe { result::function::set_function_attribute(self.cu_function, attribute, value) }
2345    }
2346
2347    /// Set the cache config of this [CudaFunction].
2348    pub fn set_function_cache_config(
2349        &self,
2350        attribute: CUfunc_cache_enum,
2351    ) -> Result<(), result::DriverError> {
2352        unsafe { result::function::set_function_cache_config(self.cu_function, attribute) }
2353    }
2354}
2355
2356impl<T> CudaSlice<T> {
2357    /// Takes ownership of the underlying [sys::CUdeviceptr]. **It is up
2358    /// to the owner to free this value**.
2359    ///
2360    /// Drops the underlying host_buf if there is one.
2361    pub fn leak(self) -> sys::CUdeviceptr {
2362        let mut s = std::mem::ManuallyDrop::new(self);
2363        let ptr = s.cu_device_ptr;
2364
2365        // Ensure pending operations are complete before resources are released.
2366        if let Some(read) = s.read.as_ref() {
2367            s.stream.ctx.record_err(s.stream.wait(read));
2368        }
2369        if let Some(write) = s.write.as_ref() {
2370            s.stream.ctx.record_err(s.stream.wait(write));
2371        }
2372
2373        // Manually drop fields that own resources.
2374        unsafe {
2375            std::ptr::drop_in_place(&mut s.read);
2376            std::ptr::drop_in_place(&mut s.write);
2377            std::ptr::drop_in_place(&mut s.stream);
2378        }
2379
2380        ptr
2381    }
2382}
2383
2384impl CudaStream {
2385    /// Creates a [CudaSlice] from a [sys::CUdeviceptr]. Useful in conjunction with
2386    /// [`CudaSlice::leak()`].
2387    ///
2388    /// # Safety
2389    /// - `cu_device_ptr` must be a valid allocation
2390    /// - `cu_device_ptr` must space for `len * std::mem::size_of<T>()` bytes
2391    /// - The memory may not be valid for type `T`, so some sort of memset operation
2392    ///   should be called on the memory.
2393    pub unsafe fn upgrade_device_ptr<T>(
2394        self: &Arc<Self>,
2395        cu_device_ptr: sys::CUdeviceptr,
2396        len: usize,
2397    ) -> CudaSlice<T> {
2398        let (read, write) = if self.ctx.is_event_tracking() {
2399            (
2400                Some(self.ctx.new_event(None).unwrap()),
2401                Some(self.ctx.new_event(None).unwrap()),
2402            )
2403        } else {
2404            (None, None)
2405        };
2406        CudaSlice {
2407            cu_device_ptr,
2408            len,
2409            read,
2410            write,
2411            stream: self.clone(),
2412            marker: PhantomData,
2413        }
2414    }
2415}
2416
2417#[cfg(test)]
2418mod tests {
2419    use std::time::Instant;
2420
2421    use super::*;
2422
2423    #[test]
2424    fn test_transmutes() {
2425        let ctx = CudaContext::new(0).unwrap();
2426        let stream = ctx.default_stream();
2427        let mut slice = stream.alloc_zeros::<u8>(100).unwrap();
2428        assert!(unsafe { slice.transmute::<f32>(25) }.is_some());
2429        assert!(unsafe { slice.transmute::<f32>(26) }.is_none());
2430        assert!(unsafe { slice.transmute_mut::<f32>(25) }.is_some());
2431        assert!(unsafe { slice.transmute_mut::<f32>(26) }.is_none());
2432
2433        {
2434            let view = slice.slice(0..100);
2435            assert!(unsafe { view.transmute::<f32>(25) }.is_some());
2436            assert!(unsafe { view.transmute::<f32>(26) }.is_none());
2437        }
2438
2439        {
2440            let mut view_mut = slice.slice_mut(0..100);
2441            assert!(unsafe { view_mut.transmute::<f32>(25) }.is_some());
2442            assert!(unsafe { view_mut.transmute::<f32>(26) }.is_none());
2443            assert!(unsafe { view_mut.transmute_mut::<f32>(25) }.is_some());
2444            assert!(unsafe { view_mut.transmute_mut::<f32>(26) }.is_none());
2445        }
2446    }
2447
2448    #[test]
2449    fn test_threading() {
2450        let ctx1 = CudaContext::new(0).unwrap();
2451        let ctx2 = ctx1.clone();
2452
2453        let thread1 = std::thread::spawn(move || {
2454            ctx1.bind_to_thread()?;
2455            ctx1.default_stream().alloc_zeros::<f32>(10)
2456        });
2457        let thread2 = std::thread::spawn(move || {
2458            ctx2.bind_to_thread()?;
2459            ctx2.default_stream().alloc_zeros::<f32>(10)
2460        });
2461
2462        let _: crate::driver::CudaSlice<f32> = thread1.join().unwrap().unwrap();
2463        let _: crate::driver::CudaSlice<f32> = thread2.join().unwrap().unwrap();
2464    }
2465
2466    #[test]
2467    fn test_post_build_arc_count() {
2468        let ctx = CudaContext::new(0).unwrap();
2469        assert_eq!(Arc::strong_count(&ctx), 1);
2470    }
2471
2472    #[test]
2473    fn test_post_alloc_arc_counts() {
2474        let ctx = CudaContext::new(0).unwrap();
2475        assert_eq!(Arc::strong_count(&ctx), 1);
2476        let stream = ctx.default_stream();
2477        assert_eq!(Arc::strong_count(&ctx), 2);
2478        let t = stream.alloc_zeros::<f32>(1).unwrap();
2479        assert_eq!(Arc::strong_count(&ctx), 4);
2480        assert_eq!(Arc::strong_count(&stream), 2);
2481        drop(t);
2482        assert_eq!(Arc::strong_count(&ctx), 2);
2483        assert_eq!(Arc::strong_count(&stream), 1);
2484        drop(stream);
2485        assert_eq!(Arc::strong_count(&ctx), 1);
2486    }
2487
2488    #[test]
2489    #[ignore = "must be executed by itself"]
2490    fn test_post_alloc_memory() {
2491        let ctx = CudaContext::new(0).unwrap();
2492        let stream = ctx.default_stream();
2493
2494        let (free1, total1) = ctx.mem_get_info().unwrap();
2495
2496        let t = stream.clone_htod(&[0.0f32; 5]).unwrap();
2497        let (free2, total2) = ctx.mem_get_info().unwrap();
2498        assert_eq!(total1, total2);
2499        assert!(free2 < free1);
2500
2501        drop(t);
2502        ctx.synchronize().unwrap();
2503
2504        let (free3, total3) = ctx.mem_get_info().unwrap();
2505        assert_eq!(total2, total3);
2506        assert!(free3 > free2);
2507        assert_eq!(free3, free1);
2508    }
2509
2510    #[test]
2511    fn test_ctx_copy_to_views() {
2512        let ctx = CudaContext::new(0).unwrap();
2513        let stream = ctx.default_stream();
2514
2515        let smalls = [
2516            stream.clone_htod(&[-1.0f32, -0.8]).unwrap(),
2517            stream.clone_htod(&[-0.6, -0.4]).unwrap(),
2518            stream.clone_htod(&[-0.2, 0.0]).unwrap(),
2519            stream.clone_htod(&[0.2, 0.4]).unwrap(),
2520            stream.clone_htod(&[0.6, 0.8]).unwrap(),
2521        ];
2522        let mut big = stream.alloc_zeros::<f32>(10).unwrap();
2523
2524        let mut offset = 0;
2525        for small in smalls.iter() {
2526            let mut sub = big.slice_mut(offset..offset + small.len());
2527            stream.memcpy_dtod(small, &mut sub).unwrap();
2528            offset += small.len();
2529        }
2530
2531        assert_eq!(
2532            stream.clone_dtoh(&big).unwrap(),
2533            [-1.0, -0.8, -0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6, 0.8]
2534        );
2535    }
2536
2537    #[test]
2538    fn test_leak_and_upgrade() {
2539        let ctx = CudaContext::new(0).unwrap();
2540        let stream = ctx.default_stream();
2541
2542        let a = stream.clone_htod(&[1.0f32, 2.0, 3.0, 4.0, 5.0]).unwrap();
2543
2544        let ptr = a.leak();
2545        let b = unsafe { stream.upgrade_device_ptr::<f32>(ptr, 3) };
2546        assert_eq!(stream.clone_dtoh(&b).unwrap(), &[1.0, 2.0, 3.0]);
2547
2548        let ptr = b.leak();
2549        let c = unsafe { stream.upgrade_device_ptr::<f32>(ptr, 5) };
2550        assert_eq!(stream.clone_dtoh(&c).unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
2551    }
2552
2553    /// See https://github.com/chelsea0x3b/cudarc/issues/160
2554    #[test]
2555    fn test_slice_is_freed_with_correct_context() {
2556        let ctx0 = CudaContext::new(0).unwrap();
2557        let slice = ctx0.default_stream().clone_htod(&[1.0; 10]).unwrap();
2558        let ctx1 = CudaContext::new(0).unwrap();
2559        ctx1.bind_to_thread().unwrap();
2560        drop(ctx0);
2561        drop(slice);
2562        drop(ctx1);
2563    }
2564
2565    /// See https://github.com/chelsea0x3b/cudarc/issues/161
2566    #[test]
2567    fn test_copy_uses_correct_context() {
2568        let ctx0 = CudaContext::new(0).unwrap();
2569        let _ctx1 = CudaContext::new(0).unwrap();
2570        let slice = ctx0.default_stream().clone_htod(&[1.0; 10]).unwrap();
2571        let _out = ctx0.default_stream().clone_dtoh(&slice).unwrap();
2572    }
2573
2574    #[test]
2575    fn test_htod_copy_pinned() {
2576        let truth = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2577        let ctx = CudaContext::new(0).unwrap();
2578        let stream = ctx.default_stream();
2579        let mut pinned = unsafe { ctx.alloc_pinned::<f32>(10) }.unwrap();
2580        pinned.as_mut_slice().unwrap().clone_from_slice(&truth);
2581        assert_eq!(pinned.as_slice().unwrap(), &truth);
2582        let dst = stream.clone_htod(&pinned).unwrap();
2583        let host = stream.clone_dtoh(&dst).unwrap();
2584        assert_eq!(&host, &truth);
2585    }
2586
2587    #[test]
2588    fn test_pinned_copy_is_faster() {
2589        let ctx = CudaContext::new(0).unwrap();
2590        let stream = ctx.new_stream().unwrap();
2591
2592        let n = 100_000;
2593        let n_samples = 5;
2594        let not_pinned = std::vec![0.0f32; n];
2595
2596        let start = Instant::now();
2597        for _ in 0..n_samples {
2598            let _ = stream.clone_htod(&not_pinned).unwrap();
2599            stream.synchronize().unwrap();
2600        }
2601        let unpinned_elapsed = start.elapsed() / n_samples;
2602
2603        let pinned = unsafe { ctx.alloc_pinned::<f32>(n) }.unwrap();
2604
2605        let start = Instant::now();
2606        for _ in 0..n_samples {
2607            let _ = stream.clone_htod(&pinned).unwrap();
2608            stream.synchronize().unwrap();
2609        }
2610        let pinned_elapsed = start.elapsed() / n_samples;
2611
2612        // pinned memory transfer speed should be at least 2x faster, but this depends
2613        // on device
2614        assert!(
2615            pinned_elapsed.as_secs_f32() * 1.5 < unpinned_elapsed.as_secs_f32(),
2616            "{unpinned_elapsed:?} vs {pinned_elapsed:?}"
2617        );
2618    }
2619
2620    #[test]
2621    fn test_primary_context_is_primary() {
2622        let ctx = CudaContext::new(0).unwrap();
2623        assert!(ctx.is_primary());
2624    }
2625
2626    /// Helper to create a non-primary context for testing `from_raw_context`.
2627    /// Uses `cuCtxCreate_v4` (CUDA 12.050+) or `cuCtxCreate_v3` (CUDA 11.040–12.040).
2628    #[cfg(any(
2629        feature = "cuda-11040",
2630        feature = "cuda-11050",
2631        feature = "cuda-11060",
2632        feature = "cuda-11070",
2633        feature = "cuda-11080",
2634        feature = "cuda-12000",
2635        feature = "cuda-12010",
2636        feature = "cuda-12020",
2637        feature = "cuda-12030",
2638        feature = "cuda-12040",
2639        feature = "cuda-12050",
2640        feature = "cuda-12060",
2641        feature = "cuda-12080",
2642        feature = "cuda-12090",
2643    ))]
2644    fn create_non_primary_context() -> (sys::CUdevice, sys::CUcontext) {
2645        result::init().unwrap();
2646        let cu_device = result::device::get(0).unwrap();
2647
2648        #[cfg(any(
2649            feature = "cuda-12050",
2650            feature = "cuda-12060",
2651            feature = "cuda-12080",
2652            feature = "cuda-12090",
2653            feature = "cuda-13000",
2654            feature = "cuda-13010",
2655        ))]
2656        let cu_ctx = unsafe { result::ctx::create_v4(std::ptr::null_mut(), 0, cu_device) }
2657            .expect("cuCtxCreate_v4 failed");
2658
2659        #[cfg(not(any(
2660            feature = "cuda-12050",
2661            feature = "cuda-12060",
2662            feature = "cuda-12080",
2663            feature = "cuda-12090",
2664            feature = "cuda-13000",
2665            feature = "cuda-13010",
2666        )))]
2667        let cu_ctx =
2668            unsafe { result::ctx::create_v3(0, cu_device) }.expect("cuCtxCreate_v3 failed");
2669
2670        assert!(!cu_ctx.is_null());
2671        (cu_device, cu_ctx)
2672    }
2673
2674    #[test]
2675    #[cfg(any(
2676        feature = "cuda-11040",
2677        feature = "cuda-11050",
2678        feature = "cuda-11060",
2679        feature = "cuda-11070",
2680        feature = "cuda-11080",
2681        feature = "cuda-12000",
2682        feature = "cuda-12010",
2683        feature = "cuda-12020",
2684        feature = "cuda-12030",
2685        feature = "cuda-12040",
2686        feature = "cuda-12050",
2687        feature = "cuda-12060",
2688        feature = "cuda-12080",
2689        feature = "cuda-12090",
2690    ))]
2691    fn test_from_raw_context_creates_and_destroys() {
2692        let (cu_device, cu_ctx) = create_non_primary_context();
2693
2694        let ctx = unsafe { CudaContext::from_raw_context(0, cu_device, cu_ctx) }.unwrap();
2695        assert!(!ctx.is_primary());
2696        // Verify the context is bound and usable.
2697        ctx.bind_to_thread().unwrap();
2698        // Drop should call cuCtxDestroy_v2, not primary_ctx::release.
2699        drop(ctx);
2700    }
2701
2702    #[test]
2703    #[cfg(any(
2704        feature = "cuda-11040",
2705        feature = "cuda-11050",
2706        feature = "cuda-11060",
2707        feature = "cuda-11070",
2708        feature = "cuda-11080",
2709        feature = "cuda-12000",
2710        feature = "cuda-12010",
2711        feature = "cuda-12020",
2712        feature = "cuda-12030",
2713        feature = "cuda-12040",
2714        feature = "cuda-12050",
2715        feature = "cuda-12060",
2716        feature = "cuda-12080",
2717        feature = "cuda-12090",
2718    ))]
2719    fn test_from_raw_context_bind_to_thread() {
2720        let (cu_device, cu_ctx) = create_non_primary_context();
2721
2722        let ctx = unsafe { CudaContext::from_raw_context(0, cu_device, cu_ctx) }.unwrap();
2723
2724        // Verify bind_to_thread works from another thread.
2725        let ctx2 = ctx.clone();
2726        let handle = std::thread::spawn(move || {
2727            ctx2.bind_to_thread().unwrap();
2728            let stream = ctx2.default_stream();
2729            let data = stream.clone_htod(&[1.0f32, 2.0, 3.0]).unwrap();
2730            let result = stream.clone_dtoh(&data).unwrap();
2731            assert_eq!(result, std::vec![1.0f32, 2.0, 3.0]);
2732        });
2733        handle.join().unwrap();
2734    }
2735
2736    #[test]
2737    #[cfg(any(
2738        feature = "cuda-11040",
2739        feature = "cuda-11050",
2740        feature = "cuda-11060",
2741        feature = "cuda-11070",
2742        feature = "cuda-11080",
2743        feature = "cuda-12000",
2744        feature = "cuda-12010",
2745        feature = "cuda-12020",
2746        feature = "cuda-12030",
2747        feature = "cuda-12040",
2748        feature = "cuda-12050",
2749        feature = "cuda-12060",
2750        feature = "cuda-12080",
2751        feature = "cuda-12090",
2752        feature = "cuda-13000",
2753        feature = "cuda-13010",
2754    ))]
2755    fn test_new_non_primary_creates_and_destroys() {
2756        let ctx = CudaContext::new_non_primary(0, 0).unwrap();
2757        assert!(!ctx.is_primary());
2758        ctx.bind_to_thread().unwrap();
2759        drop(ctx);
2760    }
2761
2762    #[test]
2763    #[cfg(any(
2764        feature = "cuda-11040",
2765        feature = "cuda-11050",
2766        feature = "cuda-11060",
2767        feature = "cuda-11070",
2768        feature = "cuda-11080",
2769        feature = "cuda-12000",
2770        feature = "cuda-12010",
2771        feature = "cuda-12020",
2772        feature = "cuda-12030",
2773        feature = "cuda-12040",
2774        feature = "cuda-12050",
2775        feature = "cuda-12060",
2776        feature = "cuda-12080",
2777        feature = "cuda-12090",
2778        feature = "cuda-13000",
2779        feature = "cuda-13010",
2780    ))]
2781    fn test_new_non_primary_htod_dtoh() {
2782        let ctx = CudaContext::new_non_primary(0, 0).unwrap();
2783        let stream = ctx.default_stream();
2784        let data = stream.clone_htod(&[1.0f32, 2.0, 3.0]).unwrap();
2785        let result = stream.clone_dtoh(&data).unwrap();
2786        assert_eq!(result, std::vec![1.0f32, 2.0, 3.0]);
2787    }
2788
2789    #[test]
2790    #[cfg(any(
2791        feature = "cuda-11040",
2792        feature = "cuda-11050",
2793        feature = "cuda-11060",
2794        feature = "cuda-11070",
2795        feature = "cuda-11080",
2796        feature = "cuda-12000",
2797        feature = "cuda-12010",
2798        feature = "cuda-12020",
2799        feature = "cuda-12030",
2800        feature = "cuda-12040",
2801        feature = "cuda-12050",
2802        feature = "cuda-12060",
2803        feature = "cuda-12080",
2804        feature = "cuda-12090",
2805        feature = "cuda-13000",
2806        feature = "cuda-13010",
2807    ))]
2808    fn test_new_non_primary_cross_thread() {
2809        let ctx = CudaContext::new_non_primary(0, 0).unwrap();
2810        let ctx2 = ctx.clone();
2811        let handle = std::thread::spawn(move || {
2812            ctx2.bind_to_thread().unwrap();
2813            let stream = ctx2.default_stream();
2814            let data = stream.clone_htod(&[4.0f32, 5.0, 6.0]).unwrap();
2815            let result = stream.clone_dtoh(&data).unwrap();
2816            assert_eq!(result, std::vec![4.0f32, 5.0, 6.0]);
2817        });
2818        handle.join().unwrap();
2819    }
2820}