Skip to main content

baracuda_runtime/
launch.rs

1//! Kernel launch builder for the Runtime API.
2
3use core::ffi::c_void;
4
5use baracuda_cuda_sys::runtime::{cudaStream_t, runtime, types::dim3};
6use baracuda_types::KernelArg;
7
8use crate::error::{check, Result};
9use crate::module::Kernel;
10use crate::stream::Stream;
11
12/// Grid / block size triple, matching [`baracuda_driver::Dim3`].
13#[derive(Copy, Clone, Debug, Eq, PartialEq)]
14pub struct Dim3 {
15    pub x: u32,
16    pub y: u32,
17    pub z: u32,
18}
19
20impl Dim3 {
21    #[inline]
22    fn to_sys(self) -> dim3 {
23        dim3::new(self.x, self.y, self.z)
24    }
25}
26
27impl From<u32> for Dim3 {
28    fn from(x: u32) -> Self {
29        Self { x, y: 1, z: 1 }
30    }
31}
32
33impl From<(u32, u32)> for Dim3 {
34    fn from((x, y): (u32, u32)) -> Self {
35        Self { x, y, z: 1 }
36    }
37}
38
39impl From<(u32, u32, u32)> for Dim3 {
40    fn from((x, y, z): (u32, u32, u32)) -> Self {
41        Self { x, y, z }
42    }
43}
44
45impl Kernel {
46    /// Start a kernel-launch builder for this kernel.
47    #[inline]
48    pub fn launch(&self) -> LaunchBuilder<'_> {
49        LaunchBuilder {
50            kernel: self,
51            grid: Dim3 { x: 1, y: 1, z: 1 },
52            block: Dim3 { x: 1, y: 1, z: 1 },
53            shared_mem_bytes: 0,
54            stream: None,
55            args: Vec::new(),
56        }
57    }
58}
59
60/// Builder produced by [`Kernel::launch`].
61#[must_use = "the launch builder does nothing until `.launch()` is called"]
62pub struct LaunchBuilder<'k> {
63    kernel: &'k Kernel,
64    grid: Dim3,
65    block: Dim3,
66    shared_mem_bytes: usize,
67    stream: Option<&'k Stream>,
68    args: Vec<*mut c_void>,
69}
70
71impl core::fmt::Debug for LaunchBuilder<'_> {
72    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
73        f.debug_struct("LaunchBuilder")
74            .field("grid", &self.grid)
75            .field("block", &self.block)
76            .field("shared_mem_bytes", &self.shared_mem_bytes)
77            .field("arg_count", &self.args.len())
78            .finish_non_exhaustive()
79    }
80}
81
82impl<'k> LaunchBuilder<'k> {
83    #[inline]
84    pub fn grid(mut self, grid: impl Into<Dim3>) -> Self {
85        self.grid = grid.into();
86        self
87    }
88
89    #[inline]
90    pub fn block(mut self, block: impl Into<Dim3>) -> Self {
91        self.block = block.into();
92        self
93    }
94
95    #[inline]
96    pub fn shared_mem_bytes(mut self, bytes: usize) -> Self {
97        self.shared_mem_bytes = bytes;
98        self
99    }
100
101    #[inline]
102    pub fn stream(mut self, stream: &'k Stream) -> Self {
103        self.stream = Some(stream);
104        self
105    }
106
107    #[inline]
108    pub fn arg<K: KernelArg>(mut self, arg: K) -> Self {
109        self.args.push(arg.as_kernel_arg_ptr());
110        self
111    }
112
113    /// Enqueue the kernel.
114    ///
115    /// # Safety
116    ///
117    /// Same rules as [`baracuda_driver::LaunchBuilder::launch`]: argument
118    /// types and order must match the kernel's C signature, referenced
119    /// device memory must stay valid for the duration of device execution,
120    /// and grid/block dims must be within device limits.
121    pub unsafe fn launch(mut self) -> Result<()> {
122        let r = runtime()?;
123        let cu = r.cuda_launch_kernel()?;
124        let stream_handle: cudaStream_t = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
125        let args_ptr = if self.args.is_empty() {
126            core::ptr::null_mut()
127        } else {
128            self.args.as_mut_ptr()
129        };
130        check(cu(
131            self.kernel.as_launch_ptr(),
132            self.grid.to_sys(),
133            self.block.to_sys(),
134            args_ptr,
135            self.shared_mem_bytes,
136            stream_handle,
137        ))
138    }
139
140    /// Launch as a cooperative kernel — grid-wide sync via
141    /// `cooperative_groups::this_grid()`. All blocks must fit resident
142    /// on the device simultaneously; use
143    /// [`crate::Kernel::max_active_blocks_per_multiprocessor`] to size
144    /// the grid.
145    ///
146    /// # Safety
147    ///
148    /// Same as [`launch`](Self::launch) plus the kernel must be
149    /// compiled with cooperative-groups support.
150    pub unsafe fn launch_cooperative(mut self) -> Result<()> {
151        let r = runtime()?;
152        let cu = r.cuda_launch_cooperative_kernel()?;
153        let stream_handle: cudaStream_t = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
154        let args_ptr = if self.args.is_empty() {
155            core::ptr::null_mut()
156        } else {
157            self.args.as_mut_ptr()
158        };
159        check(cu(
160            self.kernel.as_launch_ptr(),
161            self.grid.to_sys(),
162            self.block.to_sys(),
163            args_ptr,
164            self.shared_mem_bytes,
165            stream_handle,
166        ))
167    }
168}