Skip to main content

singe_cuda/
module.rs

1use std::{
2    borrow::Cow,
3    ffi::CString,
4    fmt::{self, Display, Formatter},
5    marker::PhantomData,
6    mem::{ManuallyDrop, MaybeUninit, align_of, size_of},
7    ptr,
8    sync::Arc,
9};
10
11use singe_cuda_sys::driver;
12
13use crate::{
14    context::Context,
15    dim::Dim3,
16    error::{Error, Result},
17    graph::{ExecutableGraph, Graph, GraphNode},
18    kernel::{self, ModuleKernelHandle},
19    memory::{DeviceMemory, ManagedMemory},
20    stream::{GraphRecordable, Stream, StreamCaptureScope},
21    try_ffi,
22    types::{DeviceFunction, FunctionAttribute, SharedMemoryCarveout},
23    utility::{to_u32, to_u64},
24    view::{DeviceRepr, DeviceSlice, DeviceSliceMut, DeviceView, DeviceViewMut},
25};
26
27bitflags::bitflags! {
28    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29    pub struct OccupancyFlags: u32 {
30        const DEFAULT = driver::CUoccupancy_flags::CU_OCCUPANCY_DEFAULT as _;
31        const DISABLE_CACHING_OVERRIDE = driver::CUoccupancy_flags::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE as _;
32    }
33}
34
35impl Display for OccupancyFlags {
36    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
37        if self.is_empty() {
38            return Ok(());
39        }
40        let mut first = true;
41        let write_sep = |f: &mut Formatter<'_>, first: &mut bool, name: &str| -> fmt::Result {
42            if *first {
43                *first = false;
44            } else {
45                f.write_str(" | ")?;
46            }
47            f.write_str(name)
48        };
49
50        if self.contains(Self::DEFAULT) {
51            write_sep(f, &mut first, "CU_OCCUPANCY_DEFAULT")?;
52        }
53        if self.contains(Self::DISABLE_CACHING_OVERRIDE) {
54            write_sep(f, &mut first, "CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE")?;
55        }
56
57        Ok(())
58    }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub struct FunctionAttributes {
63    pub shared_size_bytes: usize,
64    pub const_size_bytes: usize,
65    pub local_size_bytes: usize,
66    pub max_threads_per_block: i32,
67    pub num_regs: i32,
68    pub ptx_version: i32,
69    pub binary_version: i32,
70    pub cache_mode_ca: bool,
71    pub max_dynamic_shared_size_bytes: i32,
72    pub preferred_shared_memory_carveout: i32,
73    pub cluster_dim_must_be_set: bool,
74    pub required_cluster_width: i32,
75    pub required_cluster_height: i32,
76    pub required_cluster_depth: i32,
77    pub cluster_scheduling_policy_preference: i32,
78    pub non_portable_cluster_size_allowed: bool,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub struct OccupancyMaxPotentialBlockSize {
83    pub min_grid_size: i32,
84    pub block_size: i32,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub struct ClusterLaunchConfig {
89    grid_dim: Dim3,
90    block_dim: Dim3,
91    shared_memory_bytes: usize,
92}
93
94#[derive(Debug)]
95pub struct Module {
96    handle: driver::CUmodule,
97    ctx: Arc<Context>,
98    owns_handle: bool,
99}
100
101#[derive(Debug, Clone, Copy)]
102pub struct Global<'a> {
103    ptr: *mut (),
104    size: usize,
105    _module: &'a Module,
106}
107
108#[derive(Debug, Clone, Copy)]
109pub struct TextureReference<'a> {
110    handle: driver::CUtexref,
111    _module: &'a Module,
112}
113
114#[derive(Debug, Clone, Copy)]
115pub struct SurfaceReference<'a> {
116    handle: driver::CUsurfref,
117    _module: &'a Module,
118}
119
120#[derive(Debug, Clone)]
121pub struct ModuleImage<'a> {
122    data: Cow<'a, [u8]>,
123}
124
125#[derive(Debug)]
126pub struct KernelFunction<'a> {
127    handle: DeviceFunction,
128    module: &'a Module,
129}
130
131#[derive(Debug)]
132pub struct KernelLaunchOperation<'kernel, 'config, P> {
133    function: &'kernel KernelFunction<'kernel>,
134    config: &'config LaunchConfig,
135    params: P,
136}
137
138#[derive(Debug, Clone)]
139pub struct LaunchConfig {
140    grid_dim: Dim3,
141    block_dim: Dim3,
142    shared_memory_bytes: usize,
143}
144
145/// Dynamically built CUDA kernel argument list.
146///
147/// Use this builder when the argument list depends on runtime conditions:
148///
149/// ```ignore
150/// let mut params = KernelParameters::new();
151/// params.arg(&input_ptr).arg(&output_ptr).push(len);
152/// function.launch(&config, params)?;
153/// ```
154///
155/// For fixed argument lists, passing a tuple of references directly to
156/// [`KernelFunction::launch`] avoids building a dynamic list:
157///
158/// ```ignore
159/// function.launch(&config, (&input_ptr, &mut output_ptr, &len))?;
160/// ```
161///
162/// Borrowed arguments are tied to the lifetime of this list. Owned scalar and
163/// pointer arguments pushed with [`Self::push`] or [`Self::owned_arg`] are stored
164/// inline when they fit, so ordinary kernel launches do not allocate per
165/// argument.
166pub struct KernelParameters<'a> {
167    arguments: Vec<KernelParameter<'a>>,
168}
169
170const INLINE_KERNEL_ARGUMENTS: usize = 16;
171const INLINE_KERNEL_ARGUMENT_BYTES: usize = 16;
172
173mod private {
174    pub trait Sealed {}
175}
176
177/// Appends a value to a CUDA kernel parameter list.
178///
179/// Implementations convert Rust wrapper types into the value a CUDA kernel sees
180/// at the ABI boundary, such as a scalar or device pointer.
181pub trait PushKernelArg {
182    fn push_to<'a>(self, params: &mut KernelParameters<'a>);
183}
184
185/// Kernel launch arguments accepted by launch and graph-node APIs.
186///
187/// This sealed trait is implemented for [`KernelParameters`], `()`, and tuples
188/// of shared or mutable references up to 16 elements.
189pub trait KernelLaunchArgs<'a>: private::Sealed {
190    #[doc(hidden)]
191    fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R;
192}
193
194/// Encoded CUDA kernel arguments for one launch or graph-node update call.
195///
196/// This is an implementation detail of [`KernelLaunchArgs`]. CUDA receives a
197/// temporary array of pointers to encoded argument values; callers should use
198/// [`KernelParameters`] or generated launch methods instead of constructing raw
199/// argument arrays directly.
200#[doc(hidden)]
201pub struct EncodedKernelArgs<'a> {
202    pointers: &'a mut [*mut ()],
203}
204
205trait KernelTupleArgument<'a> {
206    fn into_kernel_argument_ptr(self) -> *mut ();
207}
208
209enum KernelParameter<'a> {
210    Borrowed {
211        ptr: *mut (),
212        _marker: PhantomData<&'a ()>,
213    },
214    Owned(OwnedKernelArgument),
215}
216
217impl fmt::Debug for KernelParameter<'_> {
218    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
219        match self {
220            Self::Borrowed { ptr, .. } => f.debug_tuple("Borrowed").field(ptr).finish(),
221            Self::Owned(value) => f.debug_tuple("Owned").field(value).finish(),
222        }
223    }
224}
225
226enum OwnedKernelArgument {
227    Inline(InlineKernelArgument),
228    Boxed(Box<dyn KernelArgumentStorage>),
229}
230
231impl fmt::Debug for OwnedKernelArgument {
232    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
233        match self {
234            Self::Inline(value) => f.debug_tuple("Inline").field(value).finish(),
235            Self::Boxed(_) => f.debug_tuple("Boxed").finish_non_exhaustive(),
236        }
237    }
238}
239
240trait KernelArgumentStorage {
241    fn as_mut_ptr(&mut self) -> *mut ();
242}
243
244impl<T> KernelArgumentStorage for T {
245    fn as_mut_ptr(&mut self) -> *mut () {
246        ptr::from_mut(self).cast()
247    }
248}
249
250#[derive(Clone, Copy)]
251#[repr(C, align(16))]
252struct InlineKernelArgument {
253    bytes: [MaybeUninit<u8>; INLINE_KERNEL_ARGUMENT_BYTES],
254}
255
256impl fmt::Debug for InlineKernelArgument {
257    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
258        f.debug_struct("InlineKernelArgument")
259            .finish_non_exhaustive()
260    }
261}
262
263impl fmt::Debug for KernelParameters<'_> {
264    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
265        f.debug_struct("KernelParameters")
266            .field("arguments", &self.arguments.len())
267            .finish()
268    }
269}
270
271impl Module {
272    pub unsafe fn from_raw(handle: driver::CUmodule, ctx: Arc<Context>) -> Result<Self> {
273        if handle.is_null() {
274            return Err(Error::NullHandle);
275        }
276
277        Ok(Self {
278            handle,
279            ctx,
280            owns_handle: true,
281        })
282    }
283
284    pub const unsafe fn from_borrowed_raw(handle: driver::CUmodule, ctx: Arc<Context>) -> Self {
285        Self {
286            handle,
287            ctx,
288            owns_handle: false,
289        }
290    }
291
292    /// Returns the kernel function with the given name from the module.
293    ///
294    /// # Errors
295    ///
296    /// Returns [`crate::error::Status::NotFound`] if the module has no kernel function
297    /// named `name`. Also returns an error if `name` contains an interior NUL
298    /// byte, the module context cannot be bound, CUDA rejects the lookup, or a
299    /// previous asynchronous launch reports an error.
300    pub fn function(&self, name: &str) -> Result<KernelFunction<'_>> {
301        unsafe {
302            let c_name = CString::new(name)?;
303            let mut function_handle = ptr::null_mut();
304            try_ffi!(driver::cuModuleGetFunction(
305                &raw mut function_handle,
306                self.handle,
307                c_name.as_ptr(),
308            ))?;
309            if function_handle.is_null() {
310                return Err(Error::NullHandle);
311            }
312            let function = DeviceFunction::from_raw(function_handle);
313            Ok(KernelFunction::from_raw(function, self))
314        }
315    }
316
317    /// Returns the number of functions in this module.
318    ///
319    /// # Errors
320    ///
321    /// Returns an error if CUDA Driver cannot report the function count.
322    pub fn function_count(&self) -> Result<usize> {
323        unsafe {
324            let mut count = 0;
325            try_ffi!(driver::cuModuleGetFunctionCount(
326                &raw mut count,
327                self.handle
328            ))?;
329            Ok(count as usize)
330        }
331    }
332
333    pub const fn as_raw(&self) -> driver::CUmodule {
334        self.handle
335    }
336
337    /// Consumes the module and returns the raw CUDA module handle without
338    /// unloading it.
339    ///
340    /// The caller becomes responsible for eventually unloading the returned
341    /// handle with CUDA.
342    pub fn into_raw(self) -> driver::CUmodule {
343        let module = ManuallyDrop::new(self);
344        module.handle
345    }
346
347    /// Returns the base pointer and size of the global with the given name located in the module.
348    ///
349    /// The returned [`Global`] borrows this module, so the module remains loaded
350    /// for at least as long as the global reference is usable.
351    ///
352    /// # Errors
353    ///
354    /// Returns [`crate::error::Status::NotFound`] if the module has no global variable
355    /// named `name`. Also returns an error if `name` contains an interior NUL
356    /// byte, the module context cannot be bound, CUDA rejects the lookup, or a
357    /// previous asynchronous launch reports an error.
358    pub fn global(&self, name: &str) -> Result<Global<'_>> {
359        let c_name = CString::new(name)?;
360        let mut ptr = 0;
361        let mut size = 0;
362        self.ctx.bind()?;
363        unsafe {
364            try_ffi!(driver::cuModuleGetGlobal_v2(
365                &raw mut ptr,
366                &raw mut size,
367                self.handle,
368                c_name.as_ptr(),
369            ))?;
370        }
371        Ok(Global {
372            ptr: ptr as _,
373            size: size as _,
374            _module: self,
375        })
376    }
377}
378
379impl Drop for Module {
380    fn drop(&mut self) {
381        if !self.owns_handle {
382            return;
383        }
384
385        if let Err(err) = self.ctx.bind() {
386            #[cfg(debug_assertions)]
387            eprintln!("failed to bind context before unloading module: {err}");
388            return;
389        }
390
391        unsafe {
392            if let Err(err) = try_ffi!(driver::cuModuleUnload(self.handle)) {
393                #[cfg(debug_assertions)]
394                eprintln!("failed to unload cuda module: {err}");
395            }
396        }
397    }
398}
399
400// CUDA modules are immutable after loading in this wrapper. Kernel/function
401// lookups use shared references and CUDA owns internal synchronization.
402unsafe impl Send for Module {}
403unsafe impl Sync for Module {}
404
405impl<'a> ModuleImage<'a> {
406    pub const fn new(data: &'a [u8]) -> Self {
407        Self {
408            data: Cow::Borrowed(data),
409        }
410    }
411
412    pub fn from_vec(data: Vec<u8>) -> Self {
413        Self {
414            data: Cow::Owned(data),
415        }
416    }
417
418    pub fn from_string(data: String) -> Self {
419        Self::from_vec(data.into_bytes())
420    }
421
422    pub fn as_ptr(&self) -> *const () {
423        self.data.as_ptr().cast()
424    }
425
426    pub fn as_bytes(&self) -> &[u8] {
427        self.data.as_ref()
428    }
429}
430
431impl Global<'_> {
432    pub const fn as_ptr(&self) -> *mut () {
433        self.ptr
434    }
435
436    pub const fn byte_len(&self) -> usize {
437        self.size
438    }
439}
440
441impl TextureReference<'_> {
442    pub const fn as_raw(&self) -> driver::CUtexref {
443        self.handle
444    }
445}
446
447impl SurfaceReference<'_> {
448    pub const fn as_raw(&self) -> driver::CUsurfref {
449        self.handle
450    }
451}
452
453impl KernelFunction<'_> {
454    pub const unsafe fn from_raw(handle: DeviceFunction, module: &Module) -> KernelFunction<'_> {
455        KernelFunction { handle, module }
456    }
457
458    /// Creates a stream operation that launches this kernel.
459    ///
460    /// # Safety
461    ///
462    /// If this operation is recorded during stream capture, CUDA copies kernel argument values into the captured graph.
463    /// For pointer arguments, only the pointer address is copied.
464    /// The caller must ensure every copied pointer value remains valid for every captured graph execution that can use this operation, and mutable pointer arguments must remain exclusive for the work ordered by those graph launches.
465    pub const unsafe fn launch_operation<'kernel, 'config, P>(
466        &'kernel self,
467        config: &'config LaunchConfig,
468        params: P,
469    ) -> KernelLaunchOperation<'kernel, 'config, P> {
470        KernelLaunchOperation {
471            function: self,
472            config,
473            params,
474        }
475    }
476
477    fn check_graph_context(&self, graph: &Graph) -> Result<()> {
478        if matches!(graph.context(), Some(ctx) if ctx != self.module.ctx.as_ref()) {
479            return Err(Error::GraphContextMismatch);
480        }
481        Ok(())
482    }
483
484    fn check_executable_graph_context(&self, executable: &ExecutableGraph) -> Result<()> {
485        if matches!(executable.context(), Some(ctx) if ctx != self.module.ctx.as_ref()) {
486            return Err(Error::GraphContextMismatch);
487        }
488        Ok(())
489    }
490
491    /// Invokes this kernel function on a grid of blocks.
492    /// Each block contains the threads specified by [`LaunchConfig::block_dim`].
493    ///
494    /// [`LaunchConfig::shared_memory_bytes`] sets the amount of dynamic shared memory available to each thread block.
495    ///
496    /// Kernel parameters are passed with [`KernelParameters`] or tuples of shared or mutable references.
497    ///
498    /// Launching the kernel invalidates the persistent function state set through the following deprecated APIs: [`sys::cuFuncSetBlockShape`](singe_cuda_sys::driver::cuFuncSetBlockShape), [`sys::cuFuncSetSharedSize`](singe_cuda_sys::driver::cuFuncSetSharedSize), [`sys::cuParamSetSize`](singe_cuda_sys::driver::cuParamSetSize), [`sys::cuParamSeti`](singe_cuda_sys::driver::cuParamSeti), [`sys::cuParamSetf`](singe_cuda_sys::driver::cuParamSetf), [`sys::cuParamSetv`](singe_cuda_sys::driver::cuParamSetv).
499    ///
500    /// The kernel must either have been compiled with toolchain version 3.2 or later so that it contains kernel parameter information, or have no kernel parameters.
501    /// If either of these conditions is not met, the launch returns [`crate::error::Status::InvalidImage`].
502    ///
503    /// # Errors
504    ///
505    /// Returns [`crate::error::Status::InvalidImage`] if the kernel parameter metadata
506    /// requirements above are not met. Also returns an error if the module
507    /// context cannot be bound, CUDA rejects the launch, or a previous
508    /// asynchronous launch reports an error.
509    pub fn launch<'a, P>(&self, config: &LaunchConfig, params: P) -> Result<()>
510    where
511        P: KernelLaunchArgs<'a>,
512    {
513        self.module.ctx.bind()?;
514        params.with_encoded_arguments(|mut arguments| unsafe {
515            try_ffi!(driver::cuLaunchKernel(
516                self.handle.as_raw(),
517                config.grid_dim().x,
518                config.grid_dim().y,
519                config.grid_dim().z,
520                config.block_dim().x,
521                config.block_dim().y,
522                config.block_dim().z,
523                config.shared_memory_bytes_u32(),
524                ptr::null_mut(),
525                arguments.as_mut_ptr().cast(),
526                ptr::null_mut(),
527            ))?;
528            Ok(())
529        })
530    }
531
532    /// Invokes this kernel function on a grid of blocks using the given stream.
533    /// Each block contains the threads specified by [`LaunchConfig::block_dim`].
534    ///
535    /// [`LaunchConfig::shared_memory_bytes`] sets the amount of dynamic shared memory available to each thread block.
536    ///
537    /// Kernel parameters are passed with [`KernelParameters`] or tuples of shared or mutable references.
538    ///
539    /// Launching the kernel invalidates the persistent function state set through the following deprecated APIs: [`sys::cuFuncSetBlockShape`](singe_cuda_sys::driver::cuFuncSetBlockShape), [`sys::cuFuncSetSharedSize`](singe_cuda_sys::driver::cuFuncSetSharedSize), [`sys::cuParamSetSize`](singe_cuda_sys::driver::cuParamSetSize), [`sys::cuParamSeti`](singe_cuda_sys::driver::cuParamSeti), [`sys::cuParamSetf`](singe_cuda_sys::driver::cuParamSetf), [`sys::cuParamSetv`](singe_cuda_sys::driver::cuParamSetv).
540    ///
541    /// The kernel must either have been compiled with toolchain version 3.2 or later so that it contains kernel parameter information, or have no kernel parameters.
542    /// If either of these conditions is not met, the launch returns [`crate::error::Status::InvalidImage`].
543    ///
544    /// # Errors
545    ///
546    /// Returns [`crate::error::Status::InvalidImage`] if the kernel parameter metadata
547    /// requirements above are not met. Also returns an error if `stream` belongs
548    /// to a different context, the module context cannot be bound, CUDA rejects
549    /// the launch, or a previous asynchronous launch reports an error.
550    pub fn launch_on<'a, P>(&self, config: &LaunchConfig, params: P, stream: &Stream) -> Result<()>
551    where
552        P: KernelLaunchArgs<'a>,
553    {
554        if stream.context() != self.module.ctx.as_ref() {
555            return Err(driver::CUresult::CUDA_ERROR_INVALID_CONTEXT.into());
556        }
557
558        self.module.ctx.bind()?;
559        params.with_encoded_arguments(|mut arguments| unsafe {
560            try_ffi!(driver::cuLaunchKernel(
561                self.handle.as_raw(),
562                config.grid_dim().x,
563                config.grid_dim().y,
564                config.grid_dim().z,
565                config.block_dim().x,
566                config.block_dim().y,
567                config.block_dim().z,
568                config.shared_memory_bytes_u32(),
569                stream.as_raw(),
570                arguments.as_mut_ptr().cast(),
571                ptr::null_mut(),
572            ))?;
573            Ok(())
574        })
575    }
576
577    /// Adds this kernel to `graph` as a kernel node.
578    ///
579    /// # Safety
580    ///
581    /// CUDA copies each kernel argument value during this call. Non-pointer
582    /// argument values may be borrowed from stack or temporary storage that
583    /// outlives this call. If an argument value is a pointer, CUDA stores only
584    /// the pointer address. The caller must ensure every copied pointer value
585    /// remains valid for every graph instantiation, update, and launch that can
586    /// execute the created node. Mutable pointer arguments must remain exclusive
587    /// for the work ordered by those launches.
588    pub unsafe fn add_to_graph<'a, P>(
589        &self,
590        graph: &mut Graph,
591        dependencies: &[GraphNode],
592        config: &LaunchConfig,
593        params: P,
594    ) -> Result<GraphNode>
595    where
596        P: KernelLaunchArgs<'a>,
597    {
598        self.check_graph_context(graph)?;
599        unsafe { graph.add_kernel_node(dependencies, self.handle, config, params) }
600    }
601
602    /// Updates this kernel's parameters in an executable graph node.
603    ///
604    /// # Safety
605    ///
606    /// CUDA copies each kernel argument value during this call. Non-pointer
607    /// argument values may be borrowed from stack or temporary storage that
608    /// outlives this call. If an argument value is a pointer, CUDA stores only
609    /// the pointer address. The caller must ensure every copied pointer value
610    /// remains valid for every future launch that can execute `node`. Mutable
611    /// pointer arguments must remain exclusive for the work ordered by those
612    /// launches.
613    pub unsafe fn set_graph_node_params<'a, P>(
614        &self,
615        executable: &mut ExecutableGraph,
616        node: GraphNode,
617        config: &LaunchConfig,
618        params: P,
619    ) -> Result<()>
620    where
621        P: KernelLaunchArgs<'a>,
622    {
623        self.check_executable_graph_context(executable)?;
624        unsafe { executable.set_kernel_node_params(node, self.handle, config, params) }
625    }
626
627    pub const fn module(&self) -> &Module {
628        self.module
629    }
630
631    pub fn name(&self) -> Result<String> {
632        kernel::name::<ModuleKernelHandle>(self.module.ctx.as_ref(), self.handle.as_raw())
633    }
634
635    pub fn attribute(&self, attribute: FunctionAttribute) -> Result<i32> {
636        kernel::attribute::<ModuleKernelHandle>(
637            self.module.ctx.as_ref(),
638            self.handle.as_raw(),
639            attribute,
640        )
641    }
642
643    pub fn set_attribute(&self, attribute: FunctionAttribute, value: i32) -> Result<()> {
644        kernel::set_attribute::<ModuleKernelHandle>(
645            self.module.ctx.as_ref(),
646            self.handle.as_raw(),
647            attribute,
648            value,
649        )
650    }
651
652    pub fn set_max_dynamic_shared_memory_bytes(&self, bytes: i32) -> Result<()> {
653        self.set_attribute(FunctionAttribute::MaxDynamicSharedSizeBytes, bytes)
654    }
655
656    pub fn set_preferred_shared_memory_carveout(
657        &self,
658        carveout: SharedMemoryCarveout,
659    ) -> Result<()> {
660        self.set_attribute(
661            FunctionAttribute::PreferredSharedMemoryCarveout,
662            i32::from(carveout),
663        )
664    }
665
666    pub fn attributes(&self) -> Result<FunctionAttributes> {
667        Ok(FunctionAttributes {
668            shared_size_bytes: self.attribute(FunctionAttribute::SharedSizeBytes)? as usize,
669            const_size_bytes: self.attribute(FunctionAttribute::ConstSizeBytes)? as usize,
670            local_size_bytes: self.attribute(FunctionAttribute::LocalSizeBytes)? as usize,
671            max_threads_per_block: self.attribute(FunctionAttribute::MaxThreadsPerBlock)?,
672            num_regs: self.attribute(FunctionAttribute::NumRegs)?,
673            ptx_version: self.attribute(FunctionAttribute::PtxVersion)?,
674            binary_version: self.attribute(FunctionAttribute::BinaryVersion)?,
675            cache_mode_ca: self.attribute(FunctionAttribute::CacheModeCa)? != 0,
676            max_dynamic_shared_size_bytes: self
677                .attribute(FunctionAttribute::MaxDynamicSharedSizeBytes)?,
678            preferred_shared_memory_carveout: self
679                .attribute(FunctionAttribute::PreferredSharedMemoryCarveout)?,
680            cluster_dim_must_be_set: self.attribute(FunctionAttribute::ClusterSizeMustBeSet)? != 0,
681            required_cluster_width: self.attribute(FunctionAttribute::RequiredClusterWidth)?,
682            required_cluster_height: self.attribute(FunctionAttribute::RequiredClusterHeight)?,
683            required_cluster_depth: self.attribute(FunctionAttribute::RequiredClusterDepth)?,
684            cluster_scheduling_policy_preference: self
685                .attribute(FunctionAttribute::ClusterSchedulingPolicyPreference)?,
686            non_portable_cluster_size_allowed: self
687                .attribute(FunctionAttribute::NonPortableClusterSizeAllowed)?
688                != 0,
689        })
690    }
691
692    pub fn occupancy_max_active_blocks_per_multiprocessor(
693        &self,
694        block_size: i32,
695        dynamic_shared_memory_bytes: usize,
696    ) -> Result<i32> {
697        self.occupancy_max_active_blocks_per_multiprocessor_with_flags(
698            block_size,
699            dynamic_shared_memory_bytes,
700            OccupancyFlags::DEFAULT,
701        )
702    }
703
704    /// Returns the maximum number of active blocks per streaming multiprocessor.
705    ///
706    /// `flags` controls how special cases are handled.
707    /// The valid flags are:
708    ///
709    /// * [`OccupancyFlags::DEFAULT`], which maintains the default behavior as [`sys::cuOccupancyMaxActiveBlocksPerMultiprocessor`](singe_cuda_sys::driver::cuOccupancyMaxActiveBlocksPerMultiprocessor);
710    ///
711    /// * [`OccupancyFlags::DISABLE_CACHING_OVERRIDE`], which suppresses the default behavior on platforms where global caching affects occupancy.
712    ///   On such platforms, if caching
713    ///   is enabled, but per-block SM resource usage would result in zero occupancy, the occupancy calculator will calculate the occupancy
714    ///   as if caching is disabled.
715    ///   Setting [`OccupancyFlags::DISABLE_CACHING_OVERRIDE`] makes the occupancy calculator return 0 in such cases.
716    ///   More information can be found about this feature in the "Unified
717    ///   L1/Texture Cache" section of the Maxwell tuning guide.
718    ///
719    /// For context-less kernels queried via [`Library::kernel`](crate::library::Library::kernel).
720    /// Here, this wrapper uses the current context for calculations.
721    ///
722    /// # Errors
723    ///
724    /// Returns an error if the module context cannot be bound, CUDA rejects the
725    /// occupancy query, or a previous asynchronous launch reports an error.
726    pub fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
727        &self,
728        block_size: i32,
729        dynamic_shared_memory_bytes: usize,
730        flags: OccupancyFlags,
731    ) -> Result<i32> {
732        self.module.ctx.bind()?;
733        let dynamic_shared_memory_bytes =
734            validate_dynamic_shared_memory_bytes(dynamic_shared_memory_bytes)?;
735        let mut blocks = 0;
736        unsafe {
737            try_ffi!(
738                driver::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
739                    &raw mut blocks,
740                    self.handle.as_raw(),
741                    block_size,
742                    dynamic_shared_memory_bytes,
743                    flags.bits(),
744                )
745            )?;
746        }
747        Ok(blocks)
748    }
749
750    /// Returns dynamic shared memory available per block when launching `num_blocks` blocks on a streaming multiprocessor.
751    ///
752    /// The returned value is the maximum size of dynamic shared memory that allows `num_blocks` blocks per streaming multiprocessor.
753    ///
754    /// For context-less kernels queried via [`Library::kernel`](crate::library::Library::kernel).
755    /// Here, this wrapper uses the current context for calculations.
756    ///
757    /// # Errors
758    ///
759    /// Returns an error if the module context cannot be bound, CUDA rejects the
760    /// occupancy query, or a previous asynchronous launch reports an error.
761    pub fn occupancy_available_dynamic_shared_memory_per_block(
762        &self,
763        num_blocks: i32,
764        block_size: i32,
765    ) -> Result<usize> {
766        self.module.ctx.bind()?;
767        let mut bytes = 0;
768        unsafe {
769            try_ffi!(driver::cuOccupancyAvailableDynamicSMemPerBlock(
770                &raw mut bytes,
771                self.handle.as_raw(),
772                num_blocks,
773                block_size,
774            ))?;
775        }
776        Ok(bytes as usize)
777    }
778
779    pub fn occupancy_max_potential_block_size(
780        &self,
781        dynamic_shared_memory_bytes: usize,
782        block_size_limit: i32,
783    ) -> Result<OccupancyMaxPotentialBlockSize> {
784        self.occupancy_max_potential_block_size_with_flags(
785            dynamic_shared_memory_bytes,
786            block_size_limit,
787            OccupancyFlags::DEFAULT,
788        )
789    }
790
791    /// An extended version of [`sys::cuOccupancyMaxPotentialBlockSize`](singe_cuda_sys::driver::cuOccupancyMaxPotentialBlockSize).
792    /// In addition to arguments passed to [`sys::cuOccupancyMaxPotentialBlockSize`](singe_cuda_sys::driver::cuOccupancyMaxPotentialBlockSize), [`KernelFunction::occupancy_max_potential_block_size_with_flags`] also takes `flags`.
793    ///
794    /// `flags` controls how special cases are handled.
795    /// The valid flags are:
796    ///
797    /// * [`OccupancyFlags::DEFAULT`], which maintains the default behavior as [`sys::cuOccupancyMaxPotentialBlockSize`](singe_cuda_sys::driver::cuOccupancyMaxPotentialBlockSize);
798    ///
799    /// * [`OccupancyFlags::DISABLE_CACHING_OVERRIDE`], which suppresses the default behavior on platforms where global caching affects occupancy.
800    ///   On such platforms, the launch
801    ///   configurations that produce maximal occupancy might not support global caching.
802    ///   Setting [`OccupancyFlags::DISABLE_CACHING_OVERRIDE`] guarantees that the produced launch configuration is global caching compatible at a potential cost of occupancy.
803    ///   More
804    ///   information can be found about this feature in the "Unified L1/Texture Cache" section of the Maxwell tuning guide.
805    ///
806    /// For context-less kernels queried via [`Library::kernel`](crate::library::Library::kernel).
807    /// Here, this wrapper uses the current context for calculations.
808    ///
809    /// # Errors
810    ///
811    /// Returns an error if the module context cannot be bound, CUDA rejects the
812    /// occupancy query, or a previous asynchronous launch reports an error.
813    pub fn occupancy_max_potential_block_size_with_flags(
814        &self,
815        dynamic_shared_memory_bytes: usize,
816        block_size_limit: i32,
817        flags: OccupancyFlags,
818    ) -> Result<OccupancyMaxPotentialBlockSize> {
819        self.module.ctx.bind()?;
820        let dynamic_shared_memory_bytes =
821            validate_dynamic_shared_memory_bytes(dynamic_shared_memory_bytes)?;
822        let mut min_grid_size = 0;
823        let mut block_size = 0;
824        unsafe {
825            try_ffi!(driver::cuOccupancyMaxPotentialBlockSizeWithFlags(
826                &raw mut min_grid_size,
827                &raw mut block_size,
828                self.handle.as_raw(),
829                None,
830                dynamic_shared_memory_bytes,
831                block_size_limit,
832                flags.bits(),
833            ))?;
834        }
835        Ok(OccupancyMaxPotentialBlockSize {
836            min_grid_size,
837            block_size,
838        })
839    }
840
841    /// Given this kernel and launch configuration, returns the maximum cluster size.
842    ///
843    /// The cluster dimensions in `config` are ignored.
844    /// If the kernel has a required cluster size set, the returned value reflects the required cluster size.
845    ///
846    /// By default this returns a value that is portable on future hardware.
847    /// A higher value may be returned if the kernel function allows non-portable cluster sizes.
848    ///
849    /// Respects the compile-time launch bounds.
850    ///
851    /// For context-less kernels queried via [`Library::kernel`](crate::library::Library::kernel).
852    /// Here, this wrapper uses the current context for calculations.
853    ///
854    /// # Errors
855    ///
856    /// Returns an error if the module context cannot be bound, CUDA rejects the
857    /// occupancy query, or a previous asynchronous launch reports an error.
858    pub fn occupancy_max_potential_cluster_size(&self, config: ClusterLaunchConfig) -> Result<i32> {
859        self.module.ctx.bind()?;
860        let mut cluster_size = 0;
861        let config = driver::CUlaunchConfig {
862            gridDimX: config.grid_dim().x,
863            gridDimY: config.grid_dim().y,
864            gridDimZ: config.grid_dim().z,
865            blockDimX: config.block_dim().x,
866            blockDimY: config.block_dim().y,
867            blockDimZ: config.block_dim().z,
868            sharedMemBytes: config.shared_memory_bytes_u32(),
869            hStream: ptr::null_mut(),
870            attrs: ptr::null_mut(),
871            numAttrs: 0,
872        };
873        unsafe {
874            try_ffi!(driver::cuOccupancyMaxPotentialClusterSize(
875                &raw mut cluster_size,
876                self.handle.as_raw(),
877                &raw const config,
878            ))?;
879        }
880        Ok(cluster_size)
881    }
882
883    /// Given this kernel and launch configuration, returns the maximum number of clusters that could co-exist on the target device.
884    ///
885    /// If the kernel already has a required cluster size set, the cluster size from `config` must either be unspecified or match the required size.
886    /// Without required sizes, the cluster size must be specified in `config`; otherwise this method returns an error.
887    ///
888    /// Various kernel function attributes may affect occupancy calculation.
889    /// Runtime environment may affect how the hardware schedules the clusters, so the calculated occupancy is not guaranteed to be achievable.
890    ///
891    /// For context-less kernels queried via [`Library::kernel`](crate::library::Library::kernel).
892    /// Here, this wrapper uses the current context for calculations.
893    ///
894    /// # Errors
895    ///
896    /// Returns an error if the module context cannot be bound, `config` does
897    /// not specify a valid cluster size for this kernel, CUDA rejects the
898    /// occupancy query, or a previous asynchronous launch reports an error.
899    pub fn occupancy_max_active_clusters(&self, config: ClusterLaunchConfig) -> Result<i32> {
900        self.module.ctx.bind()?;
901        let mut clusters = 0;
902        let config = driver::CUlaunchConfig {
903            gridDimX: config.grid_dim().x,
904            gridDimY: config.grid_dim().y,
905            gridDimZ: config.grid_dim().z,
906            blockDimX: config.block_dim().x,
907            blockDimY: config.block_dim().y,
908            blockDimZ: config.block_dim().z,
909            sharedMemBytes: config.shared_memory_bytes_u32(),
910            hStream: ptr::null_mut(),
911            attrs: ptr::null_mut(),
912            numAttrs: 0,
913        };
914        unsafe {
915            try_ffi!(driver::cuOccupancyMaxActiveClusters(
916                &raw mut clusters,
917                self.handle.as_raw(),
918                &raw const config,
919            ))?;
920        }
921        Ok(clusters)
922    }
923
924    pub const fn as_raw(&self) -> DeviceFunction {
925        self.handle
926    }
927}
928
929unsafe impl<'a, P> GraphRecordable for KernelLaunchOperation<'_, '_, P>
930where
931    P: KernelLaunchArgs<'a>,
932{
933    type Output = ();
934
935    fn record(self, scope: &StreamCaptureScope<'_>) -> Result<Self::Output> {
936        self.function
937            .launch_on(self.config, self.params, scope.stream())
938    }
939}
940
941impl LaunchConfig {
942    pub fn new(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Result<Self> {
943        validate_dim3(grid_dim, "grid_dim")?;
944        validate_dim3(block_dim, "block_dim")?;
945        validate_shared_memory_bytes(shared_memory_bytes)?;
946        Ok(Self::from_validated(
947            grid_dim,
948            block_dim,
949            shared_memory_bytes,
950        ))
951    }
952
953    const fn from_validated(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Self {
954        Self {
955            grid_dim,
956            block_dim,
957            shared_memory_bytes,
958        }
959    }
960
961    pub const fn grid_dim(&self) -> Dim3 {
962        self.grid_dim
963    }
964
965    pub const fn block_dim(&self) -> Dim3 {
966        self.block_dim
967    }
968
969    pub const fn shared_memory_bytes(&self) -> usize {
970        self.shared_memory_bytes
971    }
972
973    pub(crate) const fn shared_memory_bytes_u32(&self) -> u32 {
974        self.shared_memory_bytes as u32
975    }
976
977    pub fn with_shared_memory_bytes(mut self, shared_memory_bytes: usize) -> Result<Self> {
978        validate_shared_memory_bytes(shared_memory_bytes)?;
979        self.shared_memory_bytes = shared_memory_bytes;
980        Ok(self)
981    }
982
983    pub fn try_for_1d_grid(element_count: usize, block_size: usize) -> Result<Self> {
984        validate_block_dimension(block_size, "block_size")?;
985        let grid_size = element_count.div_ceil(block_size);
986
987        validate_grid_dimension(grid_size, "grid_size")?;
988
989        Ok(Self::from_validated(
990            Dim3::new(to_u32(grid_size, "grid_size")?, 1, 1),
991            Dim3::new(to_u32(block_size, "block_size")?, 1, 1),
992            0,
993        ))
994    }
995
996    pub fn for_1d_grid(element_count: usize, block_size: usize) -> Self {
997        Self::try_for_1d_grid(element_count, block_size)
998            .expect("invalid 1d cuda launch configuration")
999    }
1000
1001    pub fn try_for_num_elems(element_count: usize, block_size: usize) -> Result<Self> {
1002        Self::try_for_1d_grid(element_count, block_size)
1003    }
1004
1005    pub fn for_num_elems(element_count: usize, block_size: usize) -> Self {
1006        Self::try_for_num_elems(element_count, block_size)
1007            .expect("invalid cuda launch configuration")
1008    }
1009
1010    pub fn try_for_2d_grid(
1011        width: usize,
1012        height: usize,
1013        block_width: usize,
1014        block_height: usize,
1015    ) -> Result<Self> {
1016        validate_block_dimension(block_width, "block_width")?;
1017        validate_block_dimension(block_height, "block_height")?;
1018        let grid_x = width.div_ceil(block_width);
1019        let grid_y = height.div_ceil(block_height);
1020        validate_grid_dimension(grid_x, "grid_x")?;
1021        validate_grid_dimension(grid_y, "grid_y")?;
1022
1023        Ok(Self::from_validated(
1024            Dim3::new(to_u32(grid_x, "grid_x")?, to_u32(grid_y, "grid_y")?, 1),
1025            Dim3::new(
1026                to_u32(block_width, "block_width")?,
1027                to_u32(block_height, "block_height")?,
1028                1,
1029            ),
1030            0,
1031        ))
1032    }
1033
1034    pub fn for_2d_grid(
1035        width: usize,
1036        height: usize,
1037        block_width: usize,
1038        block_height: usize,
1039    ) -> Self {
1040        Self::try_for_2d_grid(width, height, block_width, block_height)
1041            .expect("invalid 2d cuda launch configuration")
1042    }
1043
1044    pub fn try_for_3d_grid(
1045        width: usize,
1046        height: usize,
1047        depth: usize,
1048        block_width: usize,
1049        block_height: usize,
1050        block_depth: usize,
1051    ) -> Result<Self> {
1052        validate_block_dimension(block_width, "block_width")?;
1053        validate_block_dimension(block_height, "block_height")?;
1054        validate_block_dimension(block_depth, "block_depth")?;
1055        let grid_x = width.div_ceil(block_width);
1056        let grid_y = height.div_ceil(block_height);
1057        let grid_z = depth.div_ceil(block_depth);
1058        validate_grid_dimension(grid_x, "grid_x")?;
1059        validate_grid_dimension(grid_y, "grid_y")?;
1060        validate_grid_dimension(grid_z, "grid_z")?;
1061
1062        Ok(Self::from_validated(
1063            Dim3::new(
1064                to_u32(grid_x, "grid_x")?,
1065                to_u32(grid_y, "grid_y")?,
1066                to_u32(grid_z, "grid_z")?,
1067            ),
1068            Dim3::new(
1069                to_u32(block_width, "block_width")?,
1070                to_u32(block_height, "block_height")?,
1071                to_u32(block_depth, "block_depth")?,
1072            ),
1073            0,
1074        ))
1075    }
1076
1077    pub fn for_3d_grid(
1078        width: usize,
1079        height: usize,
1080        depth: usize,
1081        block_width: usize,
1082        block_height: usize,
1083        block_depth: usize,
1084    ) -> Self {
1085        Self::try_for_3d_grid(width, height, depth, block_width, block_height, block_depth)
1086            .expect("invalid 3d cuda launch configuration")
1087    }
1088}
1089
1090impl ClusterLaunchConfig {
1091    pub fn new(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Result<Self> {
1092        validate_dim3(grid_dim, "grid_dim")?;
1093        validate_dim3(block_dim, "block_dim")?;
1094        validate_shared_memory_bytes(shared_memory_bytes)?;
1095        Ok(Self {
1096            grid_dim,
1097            block_dim,
1098            shared_memory_bytes,
1099        })
1100    }
1101
1102    pub const fn grid_dim(&self) -> Dim3 {
1103        self.grid_dim
1104    }
1105
1106    pub const fn block_dim(&self) -> Dim3 {
1107        self.block_dim
1108    }
1109
1110    pub const fn shared_memory_bytes(&self) -> usize {
1111        self.shared_memory_bytes
1112    }
1113
1114    pub(crate) const fn shared_memory_bytes_u32(&self) -> u32 {
1115        self.shared_memory_bytes as u32
1116    }
1117
1118    pub fn with_shared_memory_bytes(mut self, shared_memory_bytes: usize) -> Result<Self> {
1119        validate_shared_memory_bytes(shared_memory_bytes)?;
1120        self.shared_memory_bytes = shared_memory_bytes;
1121        Ok(self)
1122    }
1123}
1124
1125fn validate_dim3(value: Dim3, name: &str) -> Result<()> {
1126    validate_grid_dimension(value.x as usize, &format!("{name}.x"))?;
1127    validate_grid_dimension(value.y as usize, &format!("{name}.y"))?;
1128    validate_grid_dimension(value.z as usize, &format!("{name}.z"))?;
1129    Ok(())
1130}
1131
1132fn validate_grid_dimension(value: usize, name: &str) -> Result<()> {
1133    if value == 0 {
1134        return Err(Error::ZeroValue {
1135            name: name.to_owned(),
1136        });
1137    }
1138    Ok(())
1139}
1140
1141fn validate_block_dimension(value: usize, name: &str) -> Result<()> {
1142    if value == 0 {
1143        return Err(Error::ZeroValue {
1144            name: name.to_owned(),
1145        });
1146    }
1147    Ok(())
1148}
1149
1150fn validate_shared_memory_bytes(value: usize) -> Result<u32> {
1151    to_u32(value, "shared_memory_bytes")
1152}
1153
1154fn validate_dynamic_shared_memory_bytes(value: usize) -> Result<u64> {
1155    to_u64(value, "dynamic_shared_memory_bytes")
1156}
1157
1158impl<'a> KernelParameters<'a> {
1159    pub const fn new() -> Self {
1160        Self {
1161            arguments: Vec::new(),
1162        }
1163    }
1164
1165    pub fn arg<T: 'a>(&mut self, value: &'a T) -> &mut Self {
1166        self.arguments.push(KernelParameter::Borrowed {
1167            ptr: ptr::from_ref(value).cast_mut().cast::<()>(),
1168            _marker: PhantomData,
1169        });
1170        self
1171    }
1172
1173    pub fn arg_mut<T: 'a>(&mut self, value: &'a mut T) -> &mut Self {
1174        self.arguments.push(KernelParameter::Borrowed {
1175            ptr: ptr::from_mut(value).cast::<()>(),
1176            _marker: PhantomData,
1177        });
1178        self
1179    }
1180
1181    /// Pushes a copied kernel argument whose storage is owned by this list.
1182    ///
1183    /// Small scalar and pointer values are stored inline. Larger values fall
1184    /// back to heap storage, while keeping the argument pointee stable until
1185    /// CUDA has copied it during launch or graph-node creation.
1186    pub fn owned_arg<T: Copy + 'static>(&mut self, value: T) -> &mut Self {
1187        let value = OwnedKernelArgument::from_value(value);
1188        self.arguments.push(KernelParameter::Owned(value));
1189        self
1190    }
1191
1192    pub fn push<A: PushKernelArg>(&mut self, arg: A) -> &mut Self {
1193        arg.push_to(self);
1194        self
1195    }
1196
1197    pub fn device_slice<T: DeviceRepr, S: DeviceSlice<T> + ?Sized>(
1198        &mut self,
1199        slice: &S,
1200    ) -> &mut Self {
1201        // Kernels take the device address for slice-like wrappers; length is a
1202        // separate scalar argument when the kernel needs it.
1203        self.owned_arg(slice.as_device_ptr())
1204    }
1205
1206    pub fn device_slice_mut<T: DeviceRepr, S: DeviceSliceMut<T> + ?Sized>(
1207        &mut self,
1208        slice: &mut S,
1209    ) -> &mut Self {
1210        self.owned_arg(slice.as_device_mut_ptr())
1211    }
1212
1213    fn raw_pointers(&mut self) -> RawKernelPointers {
1214        RawKernelPointers::from_parameters(self.arguments.as_mut_slice())
1215    }
1216}
1217
1218impl<'a> KernelParameter<'a> {
1219    fn as_mut_ptr(&mut self) -> *mut () {
1220        match self {
1221            Self::Borrowed { ptr, .. } => *ptr,
1222            Self::Owned(value) => value.as_mut_ptr(),
1223        }
1224    }
1225}
1226
1227impl OwnedKernelArgument {
1228    fn from_value<T: Copy + 'static>(value: T) -> Self {
1229        if size_of::<T>() <= INLINE_KERNEL_ARGUMENT_BYTES
1230            && align_of::<T>() <= align_of::<InlineKernelArgument>()
1231        {
1232            Self::Inline(InlineKernelArgument::from_value(value))
1233        } else {
1234            Self::Boxed(Box::new(value))
1235        }
1236    }
1237
1238    fn as_mut_ptr(&mut self) -> *mut () {
1239        match self {
1240            Self::Inline(value) => value.as_mut_ptr(),
1241            Self::Boxed(value) => value.as_mut().as_mut_ptr(),
1242        }
1243    }
1244}
1245
1246impl InlineKernelArgument {
1247    fn from_value<T: Copy>(value: T) -> Self {
1248        let mut storage = Self {
1249            bytes: [MaybeUninit::uninit(); INLINE_KERNEL_ARGUMENT_BYTES],
1250        };
1251        unsafe {
1252            ptr::write(storage.as_mut_ptr().cast::<T>(), value);
1253        }
1254        storage
1255    }
1256
1257    fn as_mut_ptr(&mut self) -> *mut () {
1258        self.bytes.as_mut_ptr().cast()
1259    }
1260}
1261
1262enum RawKernelPointers {
1263    Inline {
1264        pointers: [*mut (); INLINE_KERNEL_ARGUMENTS],
1265        len: usize,
1266    },
1267    Heap(Vec<*mut ()>),
1268}
1269
1270impl RawKernelPointers {
1271    fn from_parameters(parameters: &mut [KernelParameter<'_>]) -> Self {
1272        if parameters.len() <= INLINE_KERNEL_ARGUMENTS {
1273            let mut pointers = [ptr::null_mut(); INLINE_KERNEL_ARGUMENTS];
1274            for (dst, parameter) in pointers.iter_mut().zip(&mut *parameters) {
1275                *dst = parameter.as_mut_ptr();
1276            }
1277            Self::Inline {
1278                pointers,
1279                len: parameters.len(),
1280            }
1281        } else {
1282            Self::Heap(
1283                parameters
1284                    .iter_mut()
1285                    .map(KernelParameter::as_mut_ptr)
1286                    .collect(),
1287            )
1288        }
1289    }
1290
1291    fn as_mut_slice(&mut self) -> &mut [*mut ()] {
1292        match self {
1293            Self::Inline { pointers, len } => &mut pointers[..*len],
1294            Self::Heap(pointers) => pointers.as_mut_slice(),
1295        }
1296    }
1297}
1298
1299impl EncodedKernelArgs<'_> {
1300    pub(crate) fn as_mut_ptr(&mut self) -> *mut *mut () {
1301        self.pointers.as_mut_ptr()
1302    }
1303}
1304
1305impl<'a> KernelLaunchArgs<'a> for KernelParameters<'a> {
1306    fn with_encoded_arguments<R>(mut self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
1307        let mut pointers = self.raw_pointers();
1308        f(EncodedKernelArgs {
1309            pointers: pointers.as_mut_slice(),
1310        })
1311    }
1312}
1313
1314impl private::Sealed for KernelParameters<'_> {}
1315
1316impl<'a> KernelLaunchArgs<'a> for &mut KernelParameters<'a> {
1317    fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
1318        let mut pointers = self.raw_pointers();
1319        f(EncodedKernelArgs {
1320            pointers: pointers.as_mut_slice(),
1321        })
1322    }
1323}
1324
1325impl private::Sealed for &mut KernelParameters<'_> {}
1326
1327impl<'a> KernelLaunchArgs<'a> for () {
1328    fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
1329        let mut pointers: [*mut (); 0] = [];
1330        f(EncodedKernelArgs {
1331            pointers: &mut pointers,
1332        })
1333    }
1334}
1335
1336impl private::Sealed for () {}
1337
1338macro_rules! impl_kernel_arguments_for_tuple {
1339    ($($arg:ident),+ $(,)?) => {
1340        impl<'a, $($arg),+> private::Sealed for ($($arg,)+)
1341        where
1342            $($arg: KernelTupleArgument<'a>,)+
1343        {
1344        }
1345
1346        impl<'a, $($arg),+> KernelLaunchArgs<'a> for ($($arg,)+)
1347        where
1348            $($arg: KernelTupleArgument<'a>,)+
1349        {
1350            fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
1351                #[allow(non_snake_case)]
1352                let ($($arg,)+) = self;
1353                let mut pointers = [
1354                    $($arg.into_kernel_argument_ptr(),)+
1355                ];
1356                f(EncodedKernelArgs {
1357                    pointers: &mut pointers,
1358                })
1359            }
1360        }
1361    };
1362}
1363
1364impl<'a, T: 'a> KernelTupleArgument<'a> for &'a T {
1365    fn into_kernel_argument_ptr(self) -> *mut () {
1366        ptr::from_ref(self).cast_mut().cast()
1367    }
1368}
1369
1370impl<'a, T: 'a> KernelTupleArgument<'a> for &'a mut T {
1371    fn into_kernel_argument_ptr(self) -> *mut () {
1372        ptr::from_mut(self).cast()
1373    }
1374}
1375
1376impl_kernel_arguments_for_tuple!(A);
1377impl_kernel_arguments_for_tuple!(A, B);
1378impl_kernel_arguments_for_tuple!(A, B, C);
1379impl_kernel_arguments_for_tuple!(A, B, C, D);
1380impl_kernel_arguments_for_tuple!(A, B, C, D, E);
1381impl_kernel_arguments_for_tuple!(A, B, C, D, E, F);
1382impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G);
1383impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H);
1384impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I);
1385impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J);
1386impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K);
1387impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
1388impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M);
1389impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N);
1390impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O);
1391impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P);
1392
1393macro_rules! impl_push_scalar {
1394    ($($ty:ty),* $(,)?) => {
1395        $(
1396            impl PushKernelArg for $ty {
1397                fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1398                    params.owned_arg(self);
1399                }
1400            }
1401        )*
1402    };
1403}
1404
1405impl_push_scalar!(
1406    u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64,
1407);
1408
1409impl<T: DeviceRepr> PushKernelArg for &DeviceMemory<T> {
1410    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1411        params.device_slice(self);
1412    }
1413}
1414
1415impl<T: DeviceRepr> PushKernelArg for &mut DeviceMemory<T> {
1416    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1417        params.device_slice_mut(self);
1418    }
1419}
1420
1421impl<T: DeviceRepr> PushKernelArg for &ManagedMemory<T> {
1422    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1423        params.device_slice(self);
1424    }
1425}
1426
1427impl<T: DeviceRepr> PushKernelArg for &mut ManagedMemory<T> {
1428    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1429        params.device_slice_mut(self);
1430    }
1431}
1432
1433impl<T: DeviceRepr> PushKernelArg for DeviceView<'_, T> {
1434    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1435        params.owned_arg(self.as_ptr());
1436    }
1437}
1438
1439impl<T: DeviceRepr> PushKernelArg for &DeviceView<'_, T> {
1440    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1441        params.owned_arg(self.as_device_ptr());
1442    }
1443}
1444
1445impl<T: DeviceRepr> PushKernelArg for &DeviceViewMut<'_, T> {
1446    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1447        params.owned_arg(self.as_device_ptr());
1448    }
1449}
1450
1451impl<T: DeviceRepr> PushKernelArg for &mut DeviceViewMut<'_, T> {
1452    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1453        params.owned_arg(self.as_device_mut_ptr());
1454    }
1455}
1456
1457impl Default for KernelParameters<'_> {
1458    fn default() -> Self {
1459        Self::new()
1460    }
1461}
1462
1463#[cfg(test)]
1464mod tests {
1465    use super::*;
1466
1467    #[derive(Clone, Copy)]
1468    #[repr(C)]
1469    struct LargeArgument {
1470        words: [u64; 3],
1471    }
1472
1473    #[test]
1474    fn boxed_owned_kernel_argument_points_to_inner_value() {
1475        let mut argument = OwnedKernelArgument::from_value(LargeArgument { words: [1, 2, 3] });
1476        assert!(matches!(argument, OwnedKernelArgument::Boxed(_)));
1477
1478        let expected = match &mut argument {
1479            OwnedKernelArgument::Boxed(value) => value.as_mut().as_mut_ptr(),
1480            OwnedKernelArgument::Inline(_) => unreachable!(),
1481        };
1482
1483        assert_eq!(argument.as_mut_ptr(), expected);
1484    }
1485
1486    #[test]
1487    fn launch_config_rejects_zero_grid_dimensions() {
1488        let error = LaunchConfig::try_for_1d_grid(0, 128).unwrap_err();
1489        assert!(matches!(error, Error::ZeroValue { name } if name == "grid_size"));
1490
1491        let error = LaunchConfig::new(Dim3::new(0, 1, 1), Dim3::new(128, 1, 1), 0).unwrap_err();
1492        assert!(matches!(error, Error::ZeroValue { name } if name == "grid_dim.x"));
1493    }
1494
1495    #[test]
1496    fn launch_config_rejects_invalid_shared_memory_size() {
1497        let error = LaunchConfig::try_for_1d_grid(1, 128)
1498            .unwrap()
1499            .with_shared_memory_bytes(u32::MAX as usize + 1)
1500            .unwrap_err();
1501        assert!(matches!(error, Error::OutOfRange { name } if name == "shared_memory_bytes"));
1502    }
1503
1504    #[test]
1505    fn launch_config_exposes_checked_shared_memory_u32() {
1506        let config = LaunchConfig::try_for_1d_grid(1, 128)
1507            .unwrap()
1508            .with_shared_memory_bytes(u32::MAX as usize)
1509            .unwrap();
1510
1511        assert_eq!(config.shared_memory_bytes(), u32::MAX as usize);
1512        assert_eq!(config.shared_memory_bytes_u32(), u32::MAX);
1513    }
1514
1515    #[test]
1516    fn occupancy_dynamic_shared_memory_uses_checked_driver_width() {
1517        assert_eq!(validate_dynamic_shared_memory_bytes(0).unwrap(), 0);
1518        assert_eq!(
1519            validate_dynamic_shared_memory_bytes(usize::MAX).unwrap(),
1520            usize::MAX as u64
1521        );
1522    }
1523
1524    #[test]
1525    fn cluster_launch_config_uses_checked_construction() {
1526        let config = ClusterLaunchConfig::new(Dim3::new(1, 1, 1), Dim3::new(32, 1, 1), 0).unwrap();
1527
1528        assert_eq!(config.grid_dim(), Dim3::new(1, 1, 1));
1529        assert_eq!(config.block_dim(), Dim3::new(32, 1, 1));
1530        assert_eq!(config.shared_memory_bytes(), 0);
1531    }
1532}