Skip to main content

baracuda_driver/
launch.rs

1//! Kernel launch builder — the Rust equivalent of CUDA C's triple-chevron syntax.
2//!
3//! ```no_run
4//! # use baracuda_driver::{Context, Device, Stream, Module, DeviceBuffer};
5//! # fn example() -> baracuda_driver::Result<()> {
6//! # let device = Device::get(0)?;
7//! # let ctx = Context::new(&device)?;
8//! # let stream = Stream::new(&ctx)?;
9//! # let module = Module::load_ptx(&ctx, "")?;
10//! # let kernel = module.get_function("vector_add")?;
11//! # let mut d_c: DeviceBuffer<f32> = DeviceBuffer::new(&ctx, 1024)?;
12//! # let d_a: DeviceBuffer<f32> = DeviceBuffer::new(&ctx, 1024)?;
13//! # let d_b: DeviceBuffer<f32> = DeviceBuffer::new(&ctx, 1024)?;
14//! # let n = 1024u32;
15//! unsafe {
16//!     kernel.launch()
17//!         .grid((n.div_ceil(256), 1, 1))
18//!         .block((256, 1, 1))
19//!         .stream(&stream)
20//!         .arg(&d_a.as_raw())
21//!         .arg(&d_b.as_raw())
22//!         .arg(&d_c.as_raw())
23//!         .arg(&n)
24//!         .launch()?;
25//! }
26//! # Ok(())
27//! # }
28//! ```
29
30use core::ffi::c_void;
31
32use baracuda_cuda_sys::{driver, CUstream};
33use baracuda_types::KernelArg;
34
35use crate::error::{check, Result};
36use crate::module::Function;
37use crate::stream::Stream;
38
39/// Three-dimensional grid/block size.
40#[derive(Copy, Clone, Debug, Eq, PartialEq)]
41pub struct Dim3 {
42    pub x: u32,
43    pub y: u32,
44    pub z: u32,
45}
46
47impl From<u32> for Dim3 {
48    fn from(x: u32) -> Self {
49        Self { x, y: 1, z: 1 }
50    }
51}
52
53impl From<(u32, u32)> for Dim3 {
54    fn from((x, y): (u32, u32)) -> Self {
55        Self { x, y, z: 1 }
56    }
57}
58
59impl From<(u32, u32, u32)> for Dim3 {
60    fn from((x, y, z): (u32, u32, u32)) -> Self {
61        Self { x, y, z }
62    }
63}
64
65impl Function {
66    /// Start a kernel-launch builder for this function.
67    #[inline]
68    pub fn launch(&self) -> LaunchBuilder<'_> {
69        LaunchBuilder {
70            function: self,
71            grid: Dim3 { x: 1, y: 1, z: 1 },
72            block: Dim3 { x: 1, y: 1, z: 1 },
73            shared_mem_bytes: 0,
74            stream: None,
75            args: Vec::new(),
76        }
77    }
78}
79
80/// Builder produced by [`Function::launch`]. Call [`LaunchBuilder::launch`]
81/// to actually enqueue the kernel.
82#[must_use = "the launch builder does nothing until `launch()` is called"]
83pub struct LaunchBuilder<'f> {
84    function: &'f Function,
85    grid: Dim3,
86    block: Dim3,
87    shared_mem_bytes: u32,
88    stream: Option<&'f Stream>,
89    args: Vec<*mut c_void>,
90}
91
92impl core::fmt::Debug for LaunchBuilder<'_> {
93    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
94        f.debug_struct("LaunchBuilder")
95            .field("grid", &self.grid)
96            .field("block", &self.block)
97            .field("shared_mem_bytes", &self.shared_mem_bytes)
98            .field("arg_count", &self.args.len())
99            .finish_non_exhaustive()
100    }
101}
102
103impl<'f> LaunchBuilder<'f> {
104    /// Set the grid (number of blocks per axis).
105    #[inline]
106    pub fn grid(mut self, grid: impl Into<Dim3>) -> Self {
107        self.grid = grid.into();
108        self
109    }
110
111    /// Set the block (number of threads per block per axis).
112    #[inline]
113    pub fn block(mut self, block: impl Into<Dim3>) -> Self {
114        self.block = block.into();
115        self
116    }
117
118    /// Set the amount of dynamic shared memory (bytes). Defaults to 0.
119    #[inline]
120    pub fn shared_mem_bytes(mut self, bytes: u32) -> Self {
121        self.shared_mem_bytes = bytes;
122        self
123    }
124
125    /// Launch on the specified [`Stream`]. Defaults to the legacy null stream.
126    #[inline]
127    pub fn stream(mut self, stream: &'f Stream) -> Self {
128        self.stream = Some(stream);
129        self
130    }
131
132    /// Append one kernel argument. Pass `&value` for each kernel parameter:
133    /// the referent must remain alive until [`launch`](Self::launch) is
134    /// called. CUDA copies argument bytes at submission time, so the
135    /// values don't need to outlive the *device* execution — only the
136    /// submission.
137    #[inline]
138    pub fn arg<K: KernelArg>(mut self, arg: K) -> Self {
139        self.args.push(arg.as_kernel_arg_ptr());
140        self
141    }
142
143    /// Actually enqueue the kernel.
144    ///
145    /// # Safety
146    ///
147    /// The caller must ensure that:
148    ///
149    /// 1. The number, order, and types of arguments match what the kernel
150    ///    expects. Baracuda cannot see the kernel's signature, so a
151    ///    mismatch here causes undefined behavior (typically corrupted
152    ///    output, a device fault, or silent memory corruption).
153    /// 2. Any pointer-typed argument (e.g. `DeviceBuffer::as_raw()`) is
154    ///    live for the duration of kernel execution — use streams + events
155    ///    to manage this.
156    /// 3. Grid and block dimensions are within the device's supported
157    ///    limits (see [`crate::Device::attribute`]).
158    pub unsafe fn launch(mut self) -> Result<()> { unsafe {
159        let d = driver()?;
160        let cu = d.cu_launch_kernel()?;
161        let stream_handle: CUstream = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
162        let args_ptr = if self.args.is_empty() {
163            core::ptr::null_mut()
164        } else {
165            self.args.as_mut_ptr()
166        };
167        check(cu(
168            self.function.as_raw(),
169            self.grid.x,
170            self.grid.y,
171            self.grid.z,
172            self.block.x,
173            self.block.y,
174            self.block.z,
175            self.shared_mem_bytes,
176            stream_handle,
177            args_ptr,
178            core::ptr::null_mut(), // extras — unused; we always pass args via the kernel_params slot
179        ))
180    }}
181
182    /// Enqueue the kernel via `cuLaunchKernelEx` (CUDA 12.0+), letting the
183    /// caller attach launch attributes (cluster dims, programmatic stream
184    /// serialization, priority, …). Pass an empty slice when you just want
185    /// the modern launch entry point with no attributes.
186    ///
187    /// # Safety
188    ///
189    /// Same responsibilities as [`launch`](Self::launch). Attribute payloads
190    /// in `attributes` must be populated correctly per
191    /// [`baracuda_cuda_sys::types::CUlaunchAttributeID`] — invalid
192    /// attribute payloads cause undefined behavior on the device.
193    pub unsafe fn launch_ex(
194        mut self,
195        attributes: &mut [baracuda_cuda_sys::types::CUlaunchAttribute],
196    ) -> Result<()> { unsafe {
197        let d = driver()?;
198        let cu = d.cu_launch_kernel_ex()?;
199        let stream_handle: CUstream = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
200        let args_ptr = if self.args.is_empty() {
201            core::ptr::null_mut()
202        } else {
203            self.args.as_mut_ptr()
204        };
205        let config = baracuda_cuda_sys::types::CUlaunchConfig {
206            grid_dim_x: self.grid.x,
207            grid_dim_y: self.grid.y,
208            grid_dim_z: self.grid.z,
209            block_dim_x: self.block.x,
210            block_dim_y: self.block.y,
211            block_dim_z: self.block.z,
212            shared_mem_bytes: self.shared_mem_bytes,
213            stream: stream_handle,
214            attrs: if attributes.is_empty() {
215                core::ptr::null_mut()
216            } else {
217                attributes.as_mut_ptr()
218            },
219            num_attrs: attributes.len() as core::ffi::c_uint,
220        };
221        check(cu(
222            &config,
223            self.function.as_raw(),
224            args_ptr,
225            core::ptr::null_mut(),
226        ))
227    }}
228}