baracuda_driver/launch.rs
1//! Kernel launch builder — the Rust equivalent of CUDA C's triple-chevron syntax.
2//!
3//! ```no_run
4//! # use baracuda_driver::{Context, Device, Stream, Module, DeviceBuffer};
5//! # fn example() -> baracuda_driver::Result<()> {
6//! # let device = Device::get(0)?;
7//! # let ctx = Context::new(&device)?;
8//! # let stream = Stream::new(&ctx)?;
9//! # let module = Module::load_ptx(&ctx, "")?;
10//! # let kernel = module.get_function("vector_add")?;
11//! # let mut d_c: DeviceBuffer<f32> = DeviceBuffer::new(&ctx, 1024)?;
12//! # let d_a: DeviceBuffer<f32> = DeviceBuffer::new(&ctx, 1024)?;
13//! # let d_b: DeviceBuffer<f32> = DeviceBuffer::new(&ctx, 1024)?;
14//! # let n = 1024u32;
15//! unsafe {
16//! kernel.launch()
17//! .grid((n.div_ceil(256), 1, 1))
18//! .block((256, 1, 1))
19//! .stream(&stream)
20//! .arg(&d_a.as_raw())
21//! .arg(&d_b.as_raw())
22//! .arg(&d_c.as_raw())
23//! .arg(&n)
24//! .launch()?;
25//! }
26//! # Ok(())
27//! # }
28//! ```
29
30use core::ffi::c_void;
31
32use baracuda_cuda_sys::{driver, CUstream};
33use baracuda_types::KernelArg;
34
35use crate::error::{check, Result};
36use crate::module::Function;
37use crate::stream::Stream;
38
39/// Three-dimensional grid/block size.
40#[derive(Copy, Clone, Debug, Eq, PartialEq)]
41pub struct Dim3 {
42 pub x: u32,
43 pub y: u32,
44 pub z: u32,
45}
46
47impl From<u32> for Dim3 {
48 fn from(x: u32) -> Self {
49 Self { x, y: 1, z: 1 }
50 }
51}
52
53impl From<(u32, u32)> for Dim3 {
54 fn from((x, y): (u32, u32)) -> Self {
55 Self { x, y, z: 1 }
56 }
57}
58
59impl From<(u32, u32, u32)> for Dim3 {
60 fn from((x, y, z): (u32, u32, u32)) -> Self {
61 Self { x, y, z }
62 }
63}
64
65impl Function {
66 /// Start a kernel-launch builder for this function.
67 #[inline]
68 pub fn launch(&self) -> LaunchBuilder<'_> {
69 LaunchBuilder {
70 function: self,
71 grid: Dim3 { x: 1, y: 1, z: 1 },
72 block: Dim3 { x: 1, y: 1, z: 1 },
73 shared_mem_bytes: 0,
74 stream: None,
75 args: Vec::new(),
76 }
77 }
78}
79
80/// Builder produced by [`Function::launch`]. Call [`LaunchBuilder::launch`]
81/// to actually enqueue the kernel.
82#[must_use = "the launch builder does nothing until `launch()` is called"]
83pub struct LaunchBuilder<'f> {
84 function: &'f Function,
85 grid: Dim3,
86 block: Dim3,
87 shared_mem_bytes: u32,
88 stream: Option<&'f Stream>,
89 args: Vec<*mut c_void>,
90}
91
92impl core::fmt::Debug for LaunchBuilder<'_> {
93 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
94 f.debug_struct("LaunchBuilder")
95 .field("grid", &self.grid)
96 .field("block", &self.block)
97 .field("shared_mem_bytes", &self.shared_mem_bytes)
98 .field("arg_count", &self.args.len())
99 .finish_non_exhaustive()
100 }
101}
102
103impl<'f> LaunchBuilder<'f> {
104 /// Set the grid (number of blocks per axis).
105 #[inline]
106 pub fn grid(mut self, grid: impl Into<Dim3>) -> Self {
107 self.grid = grid.into();
108 self
109 }
110
111 /// Set the block (number of threads per block per axis).
112 #[inline]
113 pub fn block(mut self, block: impl Into<Dim3>) -> Self {
114 self.block = block.into();
115 self
116 }
117
118 /// Set the amount of dynamic shared memory (bytes). Defaults to 0.
119 #[inline]
120 pub fn shared_mem_bytes(mut self, bytes: u32) -> Self {
121 self.shared_mem_bytes = bytes;
122 self
123 }
124
125 /// Launch on the specified [`Stream`]. Defaults to the legacy null stream.
126 #[inline]
127 pub fn stream(mut self, stream: &'f Stream) -> Self {
128 self.stream = Some(stream);
129 self
130 }
131
132 /// Append one kernel argument. Pass `&value` for each kernel parameter:
133 /// the referent must remain alive until [`launch`](Self::launch) is
134 /// called. CUDA copies argument bytes at submission time, so the
135 /// values don't need to outlive the *device* execution — only the
136 /// submission.
137 #[inline]
138 pub fn arg<K: KernelArg>(mut self, arg: K) -> Self {
139 self.args.push(arg.as_kernel_arg_ptr());
140 self
141 }
142
143 /// Actually enqueue the kernel.
144 ///
145 /// # Safety
146 ///
147 /// The caller must ensure that:
148 ///
149 /// 1. The number, order, and types of arguments match what the kernel
150 /// expects. Baracuda cannot see the kernel's signature, so a
151 /// mismatch here causes undefined behavior (typically corrupted
152 /// output, a device fault, or silent memory corruption).
153 /// 2. Any pointer-typed argument (e.g. `DeviceBuffer::as_raw()`) is
154 /// live for the duration of kernel execution — use streams + events
155 /// to manage this.
156 /// 3. Grid and block dimensions are within the device's supported
157 /// limits (see [`crate::Device::attribute`]).
158 pub unsafe fn launch(mut self) -> Result<()> { unsafe {
159 let d = driver()?;
160 let cu = d.cu_launch_kernel()?;
161 let stream_handle: CUstream = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
162 let args_ptr = if self.args.is_empty() {
163 core::ptr::null_mut()
164 } else {
165 self.args.as_mut_ptr()
166 };
167 check(cu(
168 self.function.as_raw(),
169 self.grid.x,
170 self.grid.y,
171 self.grid.z,
172 self.block.x,
173 self.block.y,
174 self.block.z,
175 self.shared_mem_bytes,
176 stream_handle,
177 args_ptr,
178 core::ptr::null_mut(), // extras — unused; we always pass args via the kernel_params slot
179 ))
180 }}
181
182 /// Enqueue the kernel via `cuLaunchKernelEx` (CUDA 12.0+), letting the
183 /// caller attach launch attributes (cluster dims, programmatic stream
184 /// serialization, priority, …). Pass an empty slice when you just want
185 /// the modern launch entry point with no attributes.
186 ///
187 /// # Safety
188 ///
189 /// Same responsibilities as [`launch`](Self::launch). Attribute payloads
190 /// in `attributes` must be populated correctly per
191 /// [`baracuda_cuda_sys::types::CUlaunchAttributeID`] — invalid
192 /// attribute payloads cause undefined behavior on the device.
193 pub unsafe fn launch_ex(
194 mut self,
195 attributes: &mut [baracuda_cuda_sys::types::CUlaunchAttribute],
196 ) -> Result<()> { unsafe {
197 let d = driver()?;
198 let cu = d.cu_launch_kernel_ex()?;
199 let stream_handle: CUstream = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
200 let args_ptr = if self.args.is_empty() {
201 core::ptr::null_mut()
202 } else {
203 self.args.as_mut_ptr()
204 };
205 let config = baracuda_cuda_sys::types::CUlaunchConfig {
206 grid_dim_x: self.grid.x,
207 grid_dim_y: self.grid.y,
208 grid_dim_z: self.grid.z,
209 block_dim_x: self.block.x,
210 block_dim_y: self.block.y,
211 block_dim_z: self.block.z,
212 shared_mem_bytes: self.shared_mem_bytes,
213 stream: stream_handle,
214 attrs: if attributes.is_empty() {
215 core::ptr::null_mut()
216 } else {
217 attributes.as_mut_ptr()
218 },
219 num_attrs: attributes.len() as core::ffi::c_uint,
220 };
221 check(cu(
222 &config,
223 self.function.as_raw(),
224 args_ptr,
225 core::ptr::null_mut(),
226 ))
227 }}
228}