Skip to main content

baracuda_runtime/
launch_attr.rs

1//! Extended launch attributes + `cudaLaunchKernelEx` (cluster launches,
2//! programmatic stream serialization, preferred shmem carveout).
3//!
4//! Modern launches go through [`LaunchExBuilder`]:
5//!
6//! ```no_run
7//! # use baracuda_runtime::launch_attr::LaunchExBuilder;
8//! # use baracuda_runtime::{Stream, Library};
9//! # let stream: Stream = todo!();
10//! # let kernel = todo!();
11//! # let mut args: [*mut core::ffi::c_void; 0] = [];
12//! unsafe {
13//!     LaunchExBuilder::new(&stream, (32, 1, 1), (256, 1, 1))
14//!         .cluster_dim((2, 1, 1))
15//!         .cooperative(true)
16//!         .launch(kernel, &mut args)
17//!         .unwrap();
18//! }
19//! ```
20
21use core::ffi::c_void;
22
23use baracuda_cuda_sys::runtime::runtime;
24use baracuda_cuda_sys::runtime::types::{
25    cudaLaunchAttribute, cudaLaunchAttributeID, cudaLaunchAttributeValue, cudaLaunchConfig_t, dim3,
26};
27
28use crate::error::{check, Result};
29use crate::launch::Dim3;
30use crate::module::Kernel;
31use crate::stream::Stream;
32
33/// Builder for `cudaLaunchKernelEx` — accepts up to ~14 attribute kinds.
34#[derive(Debug)]
35pub struct LaunchExBuilder<'s> {
36    config: cudaLaunchConfig_t,
37    attrs: Vec<cudaLaunchAttribute>,
38    _stream: &'s Stream,
39}
40
41impl<'s> LaunchExBuilder<'s> {
42    pub fn new(stream: &'s Stream, grid: impl Into<Dim3>, block: impl Into<Dim3>) -> Self {
43        let g: Dim3 = grid.into();
44        let b: Dim3 = block.into();
45        Self {
46            config: cudaLaunchConfig_t {
47                grid_dim: dim3::new(g.x, g.y, g.z),
48                block_dim: dim3::new(b.x, b.y, b.z),
49                dynamic_smem_bytes: 0,
50                stream: stream.as_raw(),
51                attrs: core::ptr::null_mut(),
52                num_attrs: 0,
53            },
54            attrs: Vec::new(),
55            _stream: stream,
56        }
57    }
58
59    pub fn dynamic_shared_memory(mut self, bytes: usize) -> Self {
60        self.config.dynamic_smem_bytes = bytes;
61        self
62    }
63
64    fn push(mut self, id: i32, val: cudaLaunchAttributeValue) -> Self {
65        self.attrs.push(cudaLaunchAttribute { id, _pad: 0, val });
66        self
67    }
68
69    /// Hopper cluster dimension (x, y, z) in blocks.
70    pub fn cluster_dim(self, dims: impl Into<Dim3>) -> Self {
71        let d: Dim3 = dims.into();
72        self.push(
73            cudaLaunchAttributeID::CLUSTER_DIMENSION,
74            cudaLaunchAttributeValue::cluster_dimension(d.x, d.y, d.z),
75        )
76    }
77
78    /// Enable a cooperative launch.
79    pub fn cooperative(self, enable: bool) -> Self {
80        self.push(
81            cudaLaunchAttributeID::COOPERATIVE,
82            cudaLaunchAttributeValue::cooperative(enable),
83        )
84    }
85
86    /// Assign a priority to this launch (overrides stream priority for
87    /// this kernel).
88    pub fn priority(self, prio: i32) -> Self {
89        self.push(
90            cudaLaunchAttributeID::PRIORITY,
91            cudaLaunchAttributeValue::priority(prio),
92        )
93    }
94
95    /// Push a raw attribute slot — escape hatch for IDs this builder
96    /// doesn't expose typed.
97    pub fn raw_attr(self, id: i32, val: cudaLaunchAttributeValue) -> Self {
98        self.push(id, val)
99    }
100
101    /// Execute the launch.
102    ///
103    /// # Safety
104    ///
105    /// `args` must match `kernel`'s C signature in count / order / types
106    /// exactly (the marshaling is bytewise).
107    pub unsafe fn launch(mut self, kernel: &Kernel, args: &mut [*mut c_void]) -> Result<()> {
108        if !self.attrs.is_empty() {
109            self.config.attrs = self.attrs.as_mut_ptr();
110            self.config.num_attrs = self.attrs.len() as core::ffi::c_uint;
111        }
112        let r = runtime()?;
113        let cu = r.cuda_launch_kernel_ex()?;
114        check(cu(
115            &self.config,
116            kernel.as_launch_ptr(),
117            if args.is_empty() {
118                core::ptr::null_mut()
119            } else {
120                args.as_mut_ptr()
121            },
122        ))
123    }
124}