1use std::{
2 borrow::Cow,
3 ffi::CString,
4 fmt::{self, Display, Formatter},
5 marker::PhantomData,
6 mem::{ManuallyDrop, MaybeUninit, align_of, size_of},
7 ptr,
8 sync::Arc,
9};
10
11use singe_cuda_sys::driver;
12
13use crate::{
14 context::Context,
15 dim::Dim3,
16 error::{Error, Result},
17 graph::{ExecutableGraph, Graph, GraphNode},
18 kernel::{self, ModuleKernelHandle},
19 memory::{DeviceMemory, ManagedMemory},
20 stream::{GraphRecordable, Stream, StreamCaptureScope},
21 try_ffi,
22 types::{DeviceFunction, FunctionAttribute, SharedMemoryCarveout},
23 utility::{to_u32, to_u64},
24 view::{DeviceRepr, DeviceSlice, DeviceSliceMut, DeviceView, DeviceViewMut},
25};
26
27bitflags::bitflags! {
28 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29 pub struct OccupancyFlags: u32 {
30 const DEFAULT = driver::CUoccupancy_flags::CU_OCCUPANCY_DEFAULT as _;
31 const DISABLE_CACHING_OVERRIDE = driver::CUoccupancy_flags::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE as _;
32 }
33}
34
35impl Display for OccupancyFlags {
36 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
37 if self.is_empty() {
38 return Ok(());
39 }
40 let mut first = true;
41 let write_sep = |f: &mut Formatter<'_>, first: &mut bool, name: &str| -> fmt::Result {
42 if *first {
43 *first = false;
44 } else {
45 f.write_str(" | ")?;
46 }
47 f.write_str(name)
48 };
49
50 if self.contains(Self::DEFAULT) {
51 write_sep(f, &mut first, "CU_OCCUPANCY_DEFAULT")?;
52 }
53 if self.contains(Self::DISABLE_CACHING_OVERRIDE) {
54 write_sep(f, &mut first, "CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE")?;
55 }
56
57 Ok(())
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub struct FunctionAttributes {
63 pub shared_size_bytes: usize,
64 pub const_size_bytes: usize,
65 pub local_size_bytes: usize,
66 pub max_threads_per_block: i32,
67 pub num_regs: i32,
68 pub ptx_version: i32,
69 pub binary_version: i32,
70 pub cache_mode_ca: bool,
71 pub max_dynamic_shared_size_bytes: i32,
72 pub preferred_shared_memory_carveout: i32,
73 pub cluster_dim_must_be_set: bool,
74 pub required_cluster_width: i32,
75 pub required_cluster_height: i32,
76 pub required_cluster_depth: i32,
77 pub cluster_scheduling_policy_preference: i32,
78 pub non_portable_cluster_size_allowed: bool,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub struct OccupancyMaxPotentialBlockSize {
83 pub min_grid_size: i32,
84 pub block_size: i32,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub struct ClusterLaunchConfig {
89 grid_dim: Dim3,
90 block_dim: Dim3,
91 shared_memory_bytes: usize,
92}
93
94#[derive(Debug)]
95pub struct Module {
96 handle: driver::CUmodule,
97 ctx: Arc<Context>,
98 owns_handle: bool,
99}
100
101#[derive(Debug, Clone, Copy)]
102pub struct Global<'a> {
103 ptr: *mut (),
104 size: usize,
105 _module: &'a Module,
106}
107
108#[derive(Debug, Clone, Copy)]
109pub struct TextureReference<'a> {
110 handle: driver::CUtexref,
111 _module: &'a Module,
112}
113
114#[derive(Debug, Clone, Copy)]
115pub struct SurfaceReference<'a> {
116 handle: driver::CUsurfref,
117 _module: &'a Module,
118}
119
120#[derive(Debug, Clone)]
121pub struct ModuleImage<'a> {
122 data: Cow<'a, [u8]>,
123}
124
125#[derive(Debug)]
126pub struct KernelFunction<'a> {
127 handle: DeviceFunction,
128 module: &'a Module,
129}
130
131#[derive(Debug)]
132pub struct KernelLaunchOperation<'kernel, 'config, P> {
133 function: &'kernel KernelFunction<'kernel>,
134 config: &'config LaunchConfig,
135 params: P,
136}
137
138#[derive(Debug, Clone)]
139pub struct LaunchConfig {
140 grid_dim: Dim3,
141 block_dim: Dim3,
142 shared_memory_bytes: usize,
143}
144
145pub struct KernelParameters<'a> {
167 arguments: Vec<KernelParameter<'a>>,
168}
169
170const INLINE_KERNEL_ARGUMENTS: usize = 16;
171const INLINE_KERNEL_ARGUMENT_BYTES: usize = 16;
172
173mod private {
174 pub trait Sealed {}
175}
176
177pub trait PushKernelArg {
182 fn push_to<'a>(self, params: &mut KernelParameters<'a>);
183}
184
185pub trait KernelLaunchArgs<'a>: private::Sealed {
190 #[doc(hidden)]
191 fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R;
192}
193
194#[doc(hidden)]
201pub struct EncodedKernelArgs<'a> {
202 pointers: &'a mut [*mut ()],
203}
204
205trait KernelTupleArgument<'a> {
206 fn into_kernel_argument_ptr(self) -> *mut ();
207}
208
209enum KernelParameter<'a> {
210 Borrowed {
211 ptr: *mut (),
212 _marker: PhantomData<&'a ()>,
213 },
214 Owned(OwnedKernelArgument),
215}
216
217impl fmt::Debug for KernelParameter<'_> {
218 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
219 match self {
220 Self::Borrowed { ptr, .. } => f.debug_tuple("Borrowed").field(ptr).finish(),
221 Self::Owned(value) => f.debug_tuple("Owned").field(value).finish(),
222 }
223 }
224}
225
226enum OwnedKernelArgument {
227 Inline(InlineKernelArgument),
228 Boxed(Box<dyn KernelArgumentStorage>),
229}
230
231impl fmt::Debug for OwnedKernelArgument {
232 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
233 match self {
234 Self::Inline(value) => f.debug_tuple("Inline").field(value).finish(),
235 Self::Boxed(_) => f.debug_tuple("Boxed").finish_non_exhaustive(),
236 }
237 }
238}
239
240trait KernelArgumentStorage {
241 fn as_mut_ptr(&mut self) -> *mut ();
242}
243
244impl<T> KernelArgumentStorage for T {
245 fn as_mut_ptr(&mut self) -> *mut () {
246 ptr::from_mut(self).cast()
247 }
248}
249
250#[derive(Clone, Copy)]
251#[repr(C, align(16))]
252struct InlineKernelArgument {
253 bytes: [MaybeUninit<u8>; INLINE_KERNEL_ARGUMENT_BYTES],
254}
255
256impl fmt::Debug for InlineKernelArgument {
257 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
258 f.debug_struct("InlineKernelArgument")
259 .finish_non_exhaustive()
260 }
261}
262
263impl fmt::Debug for KernelParameters<'_> {
264 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
265 f.debug_struct("KernelParameters")
266 .field("arguments", &self.arguments.len())
267 .finish()
268 }
269}
270
271impl Module {
272 pub unsafe fn from_raw(handle: driver::CUmodule, ctx: Arc<Context>) -> Result<Self> {
273 if handle.is_null() {
274 return Err(Error::NullHandle);
275 }
276
277 Ok(Self {
278 handle,
279 ctx,
280 owns_handle: true,
281 })
282 }
283
284 pub const unsafe fn from_borrowed_raw(handle: driver::CUmodule, ctx: Arc<Context>) -> Self {
285 Self {
286 handle,
287 ctx,
288 owns_handle: false,
289 }
290 }
291
292 pub fn function(&self, name: &str) -> Result<KernelFunction<'_>> {
301 unsafe {
302 let c_name = CString::new(name)?;
303 let mut function_handle = ptr::null_mut();
304 try_ffi!(driver::cuModuleGetFunction(
305 &raw mut function_handle,
306 self.handle,
307 c_name.as_ptr(),
308 ))?;
309 if function_handle.is_null() {
310 return Err(Error::NullHandle);
311 }
312 let function = DeviceFunction::from_raw(function_handle);
313 Ok(KernelFunction::from_raw(function, self))
314 }
315 }
316
317 pub fn function_count(&self) -> Result<usize> {
323 unsafe {
324 let mut count = 0;
325 try_ffi!(driver::cuModuleGetFunctionCount(
326 &raw mut count,
327 self.handle
328 ))?;
329 Ok(count as usize)
330 }
331 }
332
333 pub const fn as_raw(&self) -> driver::CUmodule {
334 self.handle
335 }
336
337 pub fn into_raw(self) -> driver::CUmodule {
343 let module = ManuallyDrop::new(self);
344 module.handle
345 }
346
347 pub fn global(&self, name: &str) -> Result<Global<'_>> {
359 let c_name = CString::new(name)?;
360 let mut ptr = 0;
361 let mut size = 0;
362 self.ctx.bind()?;
363 unsafe {
364 try_ffi!(driver::cuModuleGetGlobal_v2(
365 &raw mut ptr,
366 &raw mut size,
367 self.handle,
368 c_name.as_ptr(),
369 ))?;
370 }
371 Ok(Global {
372 ptr: ptr as _,
373 size: size as _,
374 _module: self,
375 })
376 }
377}
378
379impl Drop for Module {
380 fn drop(&mut self) {
381 if !self.owns_handle {
382 return;
383 }
384
385 if let Err(err) = self.ctx.bind() {
386 #[cfg(debug_assertions)]
387 eprintln!("failed to bind context before unloading module: {err}");
388 return;
389 }
390
391 unsafe {
392 if let Err(err) = try_ffi!(driver::cuModuleUnload(self.handle)) {
393 #[cfg(debug_assertions)]
394 eprintln!("failed to unload cuda module: {err}");
395 }
396 }
397 }
398}
399
400unsafe impl Send for Module {}
403unsafe impl Sync for Module {}
404
405impl<'a> ModuleImage<'a> {
406 pub const fn new(data: &'a [u8]) -> Self {
407 Self {
408 data: Cow::Borrowed(data),
409 }
410 }
411
412 pub fn from_vec(data: Vec<u8>) -> Self {
413 Self {
414 data: Cow::Owned(data),
415 }
416 }
417
418 pub fn from_string(data: String) -> Self {
419 Self::from_vec(data.into_bytes())
420 }
421
422 pub fn as_ptr(&self) -> *const () {
423 self.data.as_ptr().cast()
424 }
425
426 pub fn as_bytes(&self) -> &[u8] {
427 self.data.as_ref()
428 }
429}
430
431impl Global<'_> {
432 pub const fn as_ptr(&self) -> *mut () {
433 self.ptr
434 }
435
436 pub const fn byte_len(&self) -> usize {
437 self.size
438 }
439}
440
441impl TextureReference<'_> {
442 pub const fn as_raw(&self) -> driver::CUtexref {
443 self.handle
444 }
445}
446
447impl SurfaceReference<'_> {
448 pub const fn as_raw(&self) -> driver::CUsurfref {
449 self.handle
450 }
451}
452
453impl KernelFunction<'_> {
454 pub const unsafe fn from_raw(handle: DeviceFunction, module: &Module) -> KernelFunction<'_> {
455 KernelFunction { handle, module }
456 }
457
458 pub const unsafe fn launch_operation<'kernel, 'config, P>(
466 &'kernel self,
467 config: &'config LaunchConfig,
468 params: P,
469 ) -> KernelLaunchOperation<'kernel, 'config, P> {
470 KernelLaunchOperation {
471 function: self,
472 config,
473 params,
474 }
475 }
476
477 fn check_graph_context(&self, graph: &Graph) -> Result<()> {
478 if matches!(graph.context(), Some(ctx) if ctx != self.module.ctx.as_ref()) {
479 return Err(Error::GraphContextMismatch);
480 }
481 Ok(())
482 }
483
484 fn check_executable_graph_context(&self, executable: &ExecutableGraph) -> Result<()> {
485 if matches!(executable.context(), Some(ctx) if ctx != self.module.ctx.as_ref()) {
486 return Err(Error::GraphContextMismatch);
487 }
488 Ok(())
489 }
490
491 pub fn launch<'a, P>(&self, config: &LaunchConfig, params: P) -> Result<()>
510 where
511 P: KernelLaunchArgs<'a>,
512 {
513 self.module.ctx.bind()?;
514 params.with_encoded_arguments(|mut arguments| unsafe {
515 try_ffi!(driver::cuLaunchKernel(
516 self.handle.as_raw(),
517 config.grid_dim().x,
518 config.grid_dim().y,
519 config.grid_dim().z,
520 config.block_dim().x,
521 config.block_dim().y,
522 config.block_dim().z,
523 config.shared_memory_bytes_u32(),
524 ptr::null_mut(),
525 arguments.as_mut_ptr().cast(),
526 ptr::null_mut(),
527 ))?;
528 Ok(())
529 })
530 }
531
532 pub fn launch_on<'a, P>(&self, config: &LaunchConfig, params: P, stream: &Stream) -> Result<()>
551 where
552 P: KernelLaunchArgs<'a>,
553 {
554 if stream.context() != self.module.ctx.as_ref() {
555 return Err(driver::CUresult::CUDA_ERROR_INVALID_CONTEXT.into());
556 }
557
558 self.module.ctx.bind()?;
559 params.with_encoded_arguments(|mut arguments| unsafe {
560 try_ffi!(driver::cuLaunchKernel(
561 self.handle.as_raw(),
562 config.grid_dim().x,
563 config.grid_dim().y,
564 config.grid_dim().z,
565 config.block_dim().x,
566 config.block_dim().y,
567 config.block_dim().z,
568 config.shared_memory_bytes_u32(),
569 stream.as_raw(),
570 arguments.as_mut_ptr().cast(),
571 ptr::null_mut(),
572 ))?;
573 Ok(())
574 })
575 }
576
577 pub unsafe fn add_to_graph<'a, P>(
589 &self,
590 graph: &mut Graph,
591 dependencies: &[GraphNode],
592 config: &LaunchConfig,
593 params: P,
594 ) -> Result<GraphNode>
595 where
596 P: KernelLaunchArgs<'a>,
597 {
598 self.check_graph_context(graph)?;
599 unsafe { graph.add_kernel_node(dependencies, self.handle, config, params) }
600 }
601
602 pub unsafe fn set_graph_node_params<'a, P>(
614 &self,
615 executable: &mut ExecutableGraph,
616 node: GraphNode,
617 config: &LaunchConfig,
618 params: P,
619 ) -> Result<()>
620 where
621 P: KernelLaunchArgs<'a>,
622 {
623 self.check_executable_graph_context(executable)?;
624 unsafe { executable.set_kernel_node_params(node, self.handle, config, params) }
625 }
626
627 pub const fn module(&self) -> &Module {
628 self.module
629 }
630
631 pub fn name(&self) -> Result<String> {
632 kernel::name::<ModuleKernelHandle>(self.module.ctx.as_ref(), self.handle.as_raw())
633 }
634
635 pub fn attribute(&self, attribute: FunctionAttribute) -> Result<i32> {
636 kernel::attribute::<ModuleKernelHandle>(
637 self.module.ctx.as_ref(),
638 self.handle.as_raw(),
639 attribute,
640 )
641 }
642
643 pub fn set_attribute(&self, attribute: FunctionAttribute, value: i32) -> Result<()> {
644 kernel::set_attribute::<ModuleKernelHandle>(
645 self.module.ctx.as_ref(),
646 self.handle.as_raw(),
647 attribute,
648 value,
649 )
650 }
651
652 pub fn set_max_dynamic_shared_memory_bytes(&self, bytes: i32) -> Result<()> {
653 self.set_attribute(FunctionAttribute::MaxDynamicSharedSizeBytes, bytes)
654 }
655
656 pub fn set_preferred_shared_memory_carveout(
657 &self,
658 carveout: SharedMemoryCarveout,
659 ) -> Result<()> {
660 self.set_attribute(
661 FunctionAttribute::PreferredSharedMemoryCarveout,
662 i32::from(carveout),
663 )
664 }
665
666 pub fn attributes(&self) -> Result<FunctionAttributes> {
667 Ok(FunctionAttributes {
668 shared_size_bytes: self.attribute(FunctionAttribute::SharedSizeBytes)? as usize,
669 const_size_bytes: self.attribute(FunctionAttribute::ConstSizeBytes)? as usize,
670 local_size_bytes: self.attribute(FunctionAttribute::LocalSizeBytes)? as usize,
671 max_threads_per_block: self.attribute(FunctionAttribute::MaxThreadsPerBlock)?,
672 num_regs: self.attribute(FunctionAttribute::NumRegs)?,
673 ptx_version: self.attribute(FunctionAttribute::PtxVersion)?,
674 binary_version: self.attribute(FunctionAttribute::BinaryVersion)?,
675 cache_mode_ca: self.attribute(FunctionAttribute::CacheModeCa)? != 0,
676 max_dynamic_shared_size_bytes: self
677 .attribute(FunctionAttribute::MaxDynamicSharedSizeBytes)?,
678 preferred_shared_memory_carveout: self
679 .attribute(FunctionAttribute::PreferredSharedMemoryCarveout)?,
680 cluster_dim_must_be_set: self.attribute(FunctionAttribute::ClusterSizeMustBeSet)? != 0,
681 required_cluster_width: self.attribute(FunctionAttribute::RequiredClusterWidth)?,
682 required_cluster_height: self.attribute(FunctionAttribute::RequiredClusterHeight)?,
683 required_cluster_depth: self.attribute(FunctionAttribute::RequiredClusterDepth)?,
684 cluster_scheduling_policy_preference: self
685 .attribute(FunctionAttribute::ClusterSchedulingPolicyPreference)?,
686 non_portable_cluster_size_allowed: self
687 .attribute(FunctionAttribute::NonPortableClusterSizeAllowed)?
688 != 0,
689 })
690 }
691
692 pub fn occupancy_max_active_blocks_per_multiprocessor(
693 &self,
694 block_size: i32,
695 dynamic_shared_memory_bytes: usize,
696 ) -> Result<i32> {
697 self.occupancy_max_active_blocks_per_multiprocessor_with_flags(
698 block_size,
699 dynamic_shared_memory_bytes,
700 OccupancyFlags::DEFAULT,
701 )
702 }
703
704 pub fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
727 &self,
728 block_size: i32,
729 dynamic_shared_memory_bytes: usize,
730 flags: OccupancyFlags,
731 ) -> Result<i32> {
732 self.module.ctx.bind()?;
733 let dynamic_shared_memory_bytes =
734 validate_dynamic_shared_memory_bytes(dynamic_shared_memory_bytes)?;
735 let mut blocks = 0;
736 unsafe {
737 try_ffi!(
738 driver::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
739 &raw mut blocks,
740 self.handle.as_raw(),
741 block_size,
742 dynamic_shared_memory_bytes,
743 flags.bits(),
744 )
745 )?;
746 }
747 Ok(blocks)
748 }
749
750 pub fn occupancy_available_dynamic_shared_memory_per_block(
762 &self,
763 num_blocks: i32,
764 block_size: i32,
765 ) -> Result<usize> {
766 self.module.ctx.bind()?;
767 let mut bytes = 0;
768 unsafe {
769 try_ffi!(driver::cuOccupancyAvailableDynamicSMemPerBlock(
770 &raw mut bytes,
771 self.handle.as_raw(),
772 num_blocks,
773 block_size,
774 ))?;
775 }
776 Ok(bytes as usize)
777 }
778
779 pub fn occupancy_max_potential_block_size(
780 &self,
781 dynamic_shared_memory_bytes: usize,
782 block_size_limit: i32,
783 ) -> Result<OccupancyMaxPotentialBlockSize> {
784 self.occupancy_max_potential_block_size_with_flags(
785 dynamic_shared_memory_bytes,
786 block_size_limit,
787 OccupancyFlags::DEFAULT,
788 )
789 }
790
791 pub fn occupancy_max_potential_block_size_with_flags(
814 &self,
815 dynamic_shared_memory_bytes: usize,
816 block_size_limit: i32,
817 flags: OccupancyFlags,
818 ) -> Result<OccupancyMaxPotentialBlockSize> {
819 self.module.ctx.bind()?;
820 let dynamic_shared_memory_bytes =
821 validate_dynamic_shared_memory_bytes(dynamic_shared_memory_bytes)?;
822 let mut min_grid_size = 0;
823 let mut block_size = 0;
824 unsafe {
825 try_ffi!(driver::cuOccupancyMaxPotentialBlockSizeWithFlags(
826 &raw mut min_grid_size,
827 &raw mut block_size,
828 self.handle.as_raw(),
829 None,
830 dynamic_shared_memory_bytes,
831 block_size_limit,
832 flags.bits(),
833 ))?;
834 }
835 Ok(OccupancyMaxPotentialBlockSize {
836 min_grid_size,
837 block_size,
838 })
839 }
840
841 pub fn occupancy_max_potential_cluster_size(&self, config: ClusterLaunchConfig) -> Result<i32> {
859 self.module.ctx.bind()?;
860 let mut cluster_size = 0;
861 let config = driver::CUlaunchConfig {
862 gridDimX: config.grid_dim().x,
863 gridDimY: config.grid_dim().y,
864 gridDimZ: config.grid_dim().z,
865 blockDimX: config.block_dim().x,
866 blockDimY: config.block_dim().y,
867 blockDimZ: config.block_dim().z,
868 sharedMemBytes: config.shared_memory_bytes_u32(),
869 hStream: ptr::null_mut(),
870 attrs: ptr::null_mut(),
871 numAttrs: 0,
872 };
873 unsafe {
874 try_ffi!(driver::cuOccupancyMaxPotentialClusterSize(
875 &raw mut cluster_size,
876 self.handle.as_raw(),
877 &raw const config,
878 ))?;
879 }
880 Ok(cluster_size)
881 }
882
883 pub fn occupancy_max_active_clusters(&self, config: ClusterLaunchConfig) -> Result<i32> {
900 self.module.ctx.bind()?;
901 let mut clusters = 0;
902 let config = driver::CUlaunchConfig {
903 gridDimX: config.grid_dim().x,
904 gridDimY: config.grid_dim().y,
905 gridDimZ: config.grid_dim().z,
906 blockDimX: config.block_dim().x,
907 blockDimY: config.block_dim().y,
908 blockDimZ: config.block_dim().z,
909 sharedMemBytes: config.shared_memory_bytes_u32(),
910 hStream: ptr::null_mut(),
911 attrs: ptr::null_mut(),
912 numAttrs: 0,
913 };
914 unsafe {
915 try_ffi!(driver::cuOccupancyMaxActiveClusters(
916 &raw mut clusters,
917 self.handle.as_raw(),
918 &raw const config,
919 ))?;
920 }
921 Ok(clusters)
922 }
923
924 pub const fn as_raw(&self) -> DeviceFunction {
925 self.handle
926 }
927}
928
929unsafe impl<'a, P> GraphRecordable for KernelLaunchOperation<'_, '_, P>
930where
931 P: KernelLaunchArgs<'a>,
932{
933 type Output = ();
934
935 fn record(self, scope: &StreamCaptureScope<'_>) -> Result<Self::Output> {
936 self.function
937 .launch_on(self.config, self.params, scope.stream())
938 }
939}
940
941impl LaunchConfig {
942 pub fn new(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Result<Self> {
943 validate_dim3(grid_dim, "grid_dim")?;
944 validate_dim3(block_dim, "block_dim")?;
945 validate_shared_memory_bytes(shared_memory_bytes)?;
946 Ok(Self::from_validated(
947 grid_dim,
948 block_dim,
949 shared_memory_bytes,
950 ))
951 }
952
953 const fn from_validated(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Self {
954 Self {
955 grid_dim,
956 block_dim,
957 shared_memory_bytes,
958 }
959 }
960
961 pub const fn grid_dim(&self) -> Dim3 {
962 self.grid_dim
963 }
964
965 pub const fn block_dim(&self) -> Dim3 {
966 self.block_dim
967 }
968
969 pub const fn shared_memory_bytes(&self) -> usize {
970 self.shared_memory_bytes
971 }
972
973 pub(crate) const fn shared_memory_bytes_u32(&self) -> u32 {
974 self.shared_memory_bytes as u32
975 }
976
977 pub fn with_shared_memory_bytes(mut self, shared_memory_bytes: usize) -> Result<Self> {
978 validate_shared_memory_bytes(shared_memory_bytes)?;
979 self.shared_memory_bytes = shared_memory_bytes;
980 Ok(self)
981 }
982
983 pub fn try_for_1d_grid(element_count: usize, block_size: usize) -> Result<Self> {
984 validate_block_dimension(block_size, "block_size")?;
985 let grid_size = element_count.div_ceil(block_size);
986
987 validate_grid_dimension(grid_size, "grid_size")?;
988
989 Ok(Self::from_validated(
990 Dim3::new(to_u32(grid_size, "grid_size")?, 1, 1),
991 Dim3::new(to_u32(block_size, "block_size")?, 1, 1),
992 0,
993 ))
994 }
995
996 pub fn for_1d_grid(element_count: usize, block_size: usize) -> Self {
997 Self::try_for_1d_grid(element_count, block_size)
998 .expect("invalid 1d cuda launch configuration")
999 }
1000
1001 pub fn try_for_num_elems(element_count: usize, block_size: usize) -> Result<Self> {
1002 Self::try_for_1d_grid(element_count, block_size)
1003 }
1004
1005 pub fn for_num_elems(element_count: usize, block_size: usize) -> Self {
1006 Self::try_for_num_elems(element_count, block_size)
1007 .expect("invalid cuda launch configuration")
1008 }
1009
1010 pub fn try_for_2d_grid(
1011 width: usize,
1012 height: usize,
1013 block_width: usize,
1014 block_height: usize,
1015 ) -> Result<Self> {
1016 validate_block_dimension(block_width, "block_width")?;
1017 validate_block_dimension(block_height, "block_height")?;
1018 let grid_x = width.div_ceil(block_width);
1019 let grid_y = height.div_ceil(block_height);
1020 validate_grid_dimension(grid_x, "grid_x")?;
1021 validate_grid_dimension(grid_y, "grid_y")?;
1022
1023 Ok(Self::from_validated(
1024 Dim3::new(to_u32(grid_x, "grid_x")?, to_u32(grid_y, "grid_y")?, 1),
1025 Dim3::new(
1026 to_u32(block_width, "block_width")?,
1027 to_u32(block_height, "block_height")?,
1028 1,
1029 ),
1030 0,
1031 ))
1032 }
1033
1034 pub fn for_2d_grid(
1035 width: usize,
1036 height: usize,
1037 block_width: usize,
1038 block_height: usize,
1039 ) -> Self {
1040 Self::try_for_2d_grid(width, height, block_width, block_height)
1041 .expect("invalid 2d cuda launch configuration")
1042 }
1043
1044 pub fn try_for_3d_grid(
1045 width: usize,
1046 height: usize,
1047 depth: usize,
1048 block_width: usize,
1049 block_height: usize,
1050 block_depth: usize,
1051 ) -> Result<Self> {
1052 validate_block_dimension(block_width, "block_width")?;
1053 validate_block_dimension(block_height, "block_height")?;
1054 validate_block_dimension(block_depth, "block_depth")?;
1055 let grid_x = width.div_ceil(block_width);
1056 let grid_y = height.div_ceil(block_height);
1057 let grid_z = depth.div_ceil(block_depth);
1058 validate_grid_dimension(grid_x, "grid_x")?;
1059 validate_grid_dimension(grid_y, "grid_y")?;
1060 validate_grid_dimension(grid_z, "grid_z")?;
1061
1062 Ok(Self::from_validated(
1063 Dim3::new(
1064 to_u32(grid_x, "grid_x")?,
1065 to_u32(grid_y, "grid_y")?,
1066 to_u32(grid_z, "grid_z")?,
1067 ),
1068 Dim3::new(
1069 to_u32(block_width, "block_width")?,
1070 to_u32(block_height, "block_height")?,
1071 to_u32(block_depth, "block_depth")?,
1072 ),
1073 0,
1074 ))
1075 }
1076
1077 pub fn for_3d_grid(
1078 width: usize,
1079 height: usize,
1080 depth: usize,
1081 block_width: usize,
1082 block_height: usize,
1083 block_depth: usize,
1084 ) -> Self {
1085 Self::try_for_3d_grid(width, height, depth, block_width, block_height, block_depth)
1086 .expect("invalid 3d cuda launch configuration")
1087 }
1088}
1089
1090impl ClusterLaunchConfig {
1091 pub fn new(grid_dim: Dim3, block_dim: Dim3, shared_memory_bytes: usize) -> Result<Self> {
1092 validate_dim3(grid_dim, "grid_dim")?;
1093 validate_dim3(block_dim, "block_dim")?;
1094 validate_shared_memory_bytes(shared_memory_bytes)?;
1095 Ok(Self {
1096 grid_dim,
1097 block_dim,
1098 shared_memory_bytes,
1099 })
1100 }
1101
1102 pub const fn grid_dim(&self) -> Dim3 {
1103 self.grid_dim
1104 }
1105
1106 pub const fn block_dim(&self) -> Dim3 {
1107 self.block_dim
1108 }
1109
1110 pub const fn shared_memory_bytes(&self) -> usize {
1111 self.shared_memory_bytes
1112 }
1113
1114 pub(crate) const fn shared_memory_bytes_u32(&self) -> u32 {
1115 self.shared_memory_bytes as u32
1116 }
1117
1118 pub fn with_shared_memory_bytes(mut self, shared_memory_bytes: usize) -> Result<Self> {
1119 validate_shared_memory_bytes(shared_memory_bytes)?;
1120 self.shared_memory_bytes = shared_memory_bytes;
1121 Ok(self)
1122 }
1123}
1124
1125fn validate_dim3(value: Dim3, name: &str) -> Result<()> {
1126 validate_grid_dimension(value.x as usize, &format!("{name}.x"))?;
1127 validate_grid_dimension(value.y as usize, &format!("{name}.y"))?;
1128 validate_grid_dimension(value.z as usize, &format!("{name}.z"))?;
1129 Ok(())
1130}
1131
1132fn validate_grid_dimension(value: usize, name: &str) -> Result<()> {
1133 if value == 0 {
1134 return Err(Error::ZeroValue {
1135 name: name.to_owned(),
1136 });
1137 }
1138 Ok(())
1139}
1140
1141fn validate_block_dimension(value: usize, name: &str) -> Result<()> {
1142 if value == 0 {
1143 return Err(Error::ZeroValue {
1144 name: name.to_owned(),
1145 });
1146 }
1147 Ok(())
1148}
1149
1150fn validate_shared_memory_bytes(value: usize) -> Result<u32> {
1151 to_u32(value, "shared_memory_bytes")
1152}
1153
1154fn validate_dynamic_shared_memory_bytes(value: usize) -> Result<u64> {
1155 to_u64(value, "dynamic_shared_memory_bytes")
1156}
1157
1158impl<'a> KernelParameters<'a> {
1159 pub const fn new() -> Self {
1160 Self {
1161 arguments: Vec::new(),
1162 }
1163 }
1164
1165 pub fn arg<T: 'a>(&mut self, value: &'a T) -> &mut Self {
1166 self.arguments.push(KernelParameter::Borrowed {
1167 ptr: ptr::from_ref(value).cast_mut().cast::<()>(),
1168 _marker: PhantomData,
1169 });
1170 self
1171 }
1172
1173 pub fn arg_mut<T: 'a>(&mut self, value: &'a mut T) -> &mut Self {
1174 self.arguments.push(KernelParameter::Borrowed {
1175 ptr: ptr::from_mut(value).cast::<()>(),
1176 _marker: PhantomData,
1177 });
1178 self
1179 }
1180
1181 pub fn owned_arg<T: Copy + 'static>(&mut self, value: T) -> &mut Self {
1187 let value = OwnedKernelArgument::from_value(value);
1188 self.arguments.push(KernelParameter::Owned(value));
1189 self
1190 }
1191
1192 pub fn push<A: PushKernelArg>(&mut self, arg: A) -> &mut Self {
1193 arg.push_to(self);
1194 self
1195 }
1196
1197 pub fn device_slice<T: DeviceRepr, S: DeviceSlice<T> + ?Sized>(
1198 &mut self,
1199 slice: &S,
1200 ) -> &mut Self {
1201 self.owned_arg(slice.as_device_ptr())
1204 }
1205
1206 pub fn device_slice_mut<T: DeviceRepr, S: DeviceSliceMut<T> + ?Sized>(
1207 &mut self,
1208 slice: &mut S,
1209 ) -> &mut Self {
1210 self.owned_arg(slice.as_device_mut_ptr())
1211 }
1212
1213 fn raw_pointers(&mut self) -> RawKernelPointers {
1214 RawKernelPointers::from_parameters(self.arguments.as_mut_slice())
1215 }
1216}
1217
1218impl<'a> KernelParameter<'a> {
1219 fn as_mut_ptr(&mut self) -> *mut () {
1220 match self {
1221 Self::Borrowed { ptr, .. } => *ptr,
1222 Self::Owned(value) => value.as_mut_ptr(),
1223 }
1224 }
1225}
1226
1227impl OwnedKernelArgument {
1228 fn from_value<T: Copy + 'static>(value: T) -> Self {
1229 if size_of::<T>() <= INLINE_KERNEL_ARGUMENT_BYTES
1230 && align_of::<T>() <= align_of::<InlineKernelArgument>()
1231 {
1232 Self::Inline(InlineKernelArgument::from_value(value))
1233 } else {
1234 Self::Boxed(Box::new(value))
1235 }
1236 }
1237
1238 fn as_mut_ptr(&mut self) -> *mut () {
1239 match self {
1240 Self::Inline(value) => value.as_mut_ptr(),
1241 Self::Boxed(value) => value.as_mut().as_mut_ptr(),
1242 }
1243 }
1244}
1245
1246impl InlineKernelArgument {
1247 fn from_value<T: Copy>(value: T) -> Self {
1248 let mut storage = Self {
1249 bytes: [MaybeUninit::uninit(); INLINE_KERNEL_ARGUMENT_BYTES],
1250 };
1251 unsafe {
1252 ptr::write(storage.as_mut_ptr().cast::<T>(), value);
1253 }
1254 storage
1255 }
1256
1257 fn as_mut_ptr(&mut self) -> *mut () {
1258 self.bytes.as_mut_ptr().cast()
1259 }
1260}
1261
1262enum RawKernelPointers {
1263 Inline {
1264 pointers: [*mut (); INLINE_KERNEL_ARGUMENTS],
1265 len: usize,
1266 },
1267 Heap(Vec<*mut ()>),
1268}
1269
1270impl RawKernelPointers {
1271 fn from_parameters(parameters: &mut [KernelParameter<'_>]) -> Self {
1272 if parameters.len() <= INLINE_KERNEL_ARGUMENTS {
1273 let mut pointers = [ptr::null_mut(); INLINE_KERNEL_ARGUMENTS];
1274 for (dst, parameter) in pointers.iter_mut().zip(&mut *parameters) {
1275 *dst = parameter.as_mut_ptr();
1276 }
1277 Self::Inline {
1278 pointers,
1279 len: parameters.len(),
1280 }
1281 } else {
1282 Self::Heap(
1283 parameters
1284 .iter_mut()
1285 .map(KernelParameter::as_mut_ptr)
1286 .collect(),
1287 )
1288 }
1289 }
1290
1291 fn as_mut_slice(&mut self) -> &mut [*mut ()] {
1292 match self {
1293 Self::Inline { pointers, len } => &mut pointers[..*len],
1294 Self::Heap(pointers) => pointers.as_mut_slice(),
1295 }
1296 }
1297}
1298
1299impl EncodedKernelArgs<'_> {
1300 pub(crate) fn as_mut_ptr(&mut self) -> *mut *mut () {
1301 self.pointers.as_mut_ptr()
1302 }
1303}
1304
1305impl<'a> KernelLaunchArgs<'a> for KernelParameters<'a> {
1306 fn with_encoded_arguments<R>(mut self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
1307 let mut pointers = self.raw_pointers();
1308 f(EncodedKernelArgs {
1309 pointers: pointers.as_mut_slice(),
1310 })
1311 }
1312}
1313
1314impl private::Sealed for KernelParameters<'_> {}
1315
1316impl<'a> KernelLaunchArgs<'a> for &mut KernelParameters<'a> {
1317 fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
1318 let mut pointers = self.raw_pointers();
1319 f(EncodedKernelArgs {
1320 pointers: pointers.as_mut_slice(),
1321 })
1322 }
1323}
1324
1325impl private::Sealed for &mut KernelParameters<'_> {}
1326
1327impl<'a> KernelLaunchArgs<'a> for () {
1328 fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
1329 let mut pointers: [*mut (); 0] = [];
1330 f(EncodedKernelArgs {
1331 pointers: &mut pointers,
1332 })
1333 }
1334}
1335
1336impl private::Sealed for () {}
1337
1338macro_rules! impl_kernel_arguments_for_tuple {
1339 ($($arg:ident),+ $(,)?) => {
1340 impl<'a, $($arg),+> private::Sealed for ($($arg,)+)
1341 where
1342 $($arg: KernelTupleArgument<'a>,)+
1343 {
1344 }
1345
1346 impl<'a, $($arg),+> KernelLaunchArgs<'a> for ($($arg,)+)
1347 where
1348 $($arg: KernelTupleArgument<'a>,)+
1349 {
1350 fn with_encoded_arguments<R>(self, f: impl FnOnce(EncodedKernelArgs<'_>) -> R) -> R {
1351 #[allow(non_snake_case)]
1352 let ($($arg,)+) = self;
1353 let mut pointers = [
1354 $($arg.into_kernel_argument_ptr(),)+
1355 ];
1356 f(EncodedKernelArgs {
1357 pointers: &mut pointers,
1358 })
1359 }
1360 }
1361 };
1362}
1363
1364impl<'a, T: 'a> KernelTupleArgument<'a> for &'a T {
1365 fn into_kernel_argument_ptr(self) -> *mut () {
1366 ptr::from_ref(self).cast_mut().cast()
1367 }
1368}
1369
1370impl<'a, T: 'a> KernelTupleArgument<'a> for &'a mut T {
1371 fn into_kernel_argument_ptr(self) -> *mut () {
1372 ptr::from_mut(self).cast()
1373 }
1374}
1375
1376impl_kernel_arguments_for_tuple!(A);
1377impl_kernel_arguments_for_tuple!(A, B);
1378impl_kernel_arguments_for_tuple!(A, B, C);
1379impl_kernel_arguments_for_tuple!(A, B, C, D);
1380impl_kernel_arguments_for_tuple!(A, B, C, D, E);
1381impl_kernel_arguments_for_tuple!(A, B, C, D, E, F);
1382impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G);
1383impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H);
1384impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I);
1385impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J);
1386impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K);
1387impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
1388impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M);
1389impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N);
1390impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O);
1391impl_kernel_arguments_for_tuple!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P);
1392
1393macro_rules! impl_push_scalar {
1394 ($($ty:ty),* $(,)?) => {
1395 $(
1396 impl PushKernelArg for $ty {
1397 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1398 params.owned_arg(self);
1399 }
1400 }
1401 )*
1402 };
1403}
1404
1405impl_push_scalar!(
1406 u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64,
1407);
1408
1409impl<T: DeviceRepr> PushKernelArg for &DeviceMemory<T> {
1410 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1411 params.device_slice(self);
1412 }
1413}
1414
1415impl<T: DeviceRepr> PushKernelArg for &mut DeviceMemory<T> {
1416 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1417 params.device_slice_mut(self);
1418 }
1419}
1420
1421impl<T: DeviceRepr> PushKernelArg for &ManagedMemory<T> {
1422 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1423 params.device_slice(self);
1424 }
1425}
1426
1427impl<T: DeviceRepr> PushKernelArg for &mut ManagedMemory<T> {
1428 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1429 params.device_slice_mut(self);
1430 }
1431}
1432
1433impl<T: DeviceRepr> PushKernelArg for DeviceView<'_, T> {
1434 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1435 params.owned_arg(self.as_ptr());
1436 }
1437}
1438
1439impl<T: DeviceRepr> PushKernelArg for &DeviceView<'_, T> {
1440 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1441 params.owned_arg(self.as_device_ptr());
1442 }
1443}
1444
1445impl<T: DeviceRepr> PushKernelArg for &DeviceViewMut<'_, T> {
1446 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1447 params.owned_arg(self.as_device_ptr());
1448 }
1449}
1450
1451impl<T: DeviceRepr> PushKernelArg for &mut DeviceViewMut<'_, T> {
1452 fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
1453 params.owned_arg(self.as_device_mut_ptr());
1454 }
1455}
1456
1457impl Default for KernelParameters<'_> {
1458 fn default() -> Self {
1459 Self::new()
1460 }
1461}
1462
1463#[cfg(test)]
1464mod tests {
1465 use super::*;
1466
1467 #[derive(Clone, Copy)]
1468 #[repr(C)]
1469 struct LargeArgument {
1470 words: [u64; 3],
1471 }
1472
1473 #[test]
1474 fn boxed_owned_kernel_argument_points_to_inner_value() {
1475 let mut argument = OwnedKernelArgument::from_value(LargeArgument { words: [1, 2, 3] });
1476 assert!(matches!(argument, OwnedKernelArgument::Boxed(_)));
1477
1478 let expected = match &mut argument {
1479 OwnedKernelArgument::Boxed(value) => value.as_mut().as_mut_ptr(),
1480 OwnedKernelArgument::Inline(_) => unreachable!(),
1481 };
1482
1483 assert_eq!(argument.as_mut_ptr(), expected);
1484 }
1485
1486 #[test]
1487 fn launch_config_rejects_zero_grid_dimensions() {
1488 let error = LaunchConfig::try_for_1d_grid(0, 128).unwrap_err();
1489 assert!(matches!(error, Error::ZeroValue { name } if name == "grid_size"));
1490
1491 let error = LaunchConfig::new(Dim3::new(0, 1, 1), Dim3::new(128, 1, 1), 0).unwrap_err();
1492 assert!(matches!(error, Error::ZeroValue { name } if name == "grid_dim.x"));
1493 }
1494
1495 #[test]
1496 fn launch_config_rejects_invalid_shared_memory_size() {
1497 let error = LaunchConfig::try_for_1d_grid(1, 128)
1498 .unwrap()
1499 .with_shared_memory_bytes(u32::MAX as usize + 1)
1500 .unwrap_err();
1501 assert!(matches!(error, Error::OutOfRange { name } if name == "shared_memory_bytes"));
1502 }
1503
1504 #[test]
1505 fn launch_config_exposes_checked_shared_memory_u32() {
1506 let config = LaunchConfig::try_for_1d_grid(1, 128)
1507 .unwrap()
1508 .with_shared_memory_bytes(u32::MAX as usize)
1509 .unwrap();
1510
1511 assert_eq!(config.shared_memory_bytes(), u32::MAX as usize);
1512 assert_eq!(config.shared_memory_bytes_u32(), u32::MAX);
1513 }
1514
1515 #[test]
1516 fn occupancy_dynamic_shared_memory_uses_checked_driver_width() {
1517 assert_eq!(validate_dynamic_shared_memory_bytes(0).unwrap(), 0);
1518 assert_eq!(
1519 validate_dynamic_shared_memory_bytes(usize::MAX).unwrap(),
1520 usize::MAX as u64
1521 );
1522 }
1523
1524 #[test]
1525 fn cluster_launch_config_uses_checked_construction() {
1526 let config = ClusterLaunchConfig::new(Dim3::new(1, 1, 1), Dim3::new(32, 1, 1), 0).unwrap();
1527
1528 assert_eq!(config.grid_dim(), Dim3::new(1, 1, 1));
1529 assert_eq!(config.block_dim(), Dim3::new(32, 1, 1));
1530 assert_eq!(config.shared_memory_bytes(), 0);
1531 }
1532}