Skip to main content

oxicuda_launch/
params.rs

1//! Kernel launch parameter configuration.
2//!
3//! This module provides [`LaunchParams`] and its builder for specifying
4//! the execution configuration of a GPU kernel launch: grid size,
5//! block size, and dynamic shared memory allocation.
6//!
7//! # Examples
8//!
9//! ```
10//! use oxicuda_launch::{LaunchParams, Dim3};
11//!
12//! // Direct construction
13//! let params = LaunchParams::new(Dim3::x(4), Dim3::x(256));
14//! assert_eq!(params.total_threads(), 1024);
15//!
16//! // With shared memory
17//! let params = LaunchParams::new(4u32, 256u32).with_shared_mem(4096);
18//! assert_eq!(params.shared_mem_bytes, 4096);
19//!
20//! // Builder pattern
21//! let params = LaunchParams::builder()
22//!     .grid(256u32)
23//!     .block(256u32)
24//!     .shared_mem(4096)
25//!     .build();
26//! assert_eq!(params.total_threads(), 256 * 256);
27//! ```
28
29use oxicuda_driver::device::Device;
30
31use crate::error::LaunchError;
32use crate::grid::Dim3;
33
34/// Parameters for a GPU kernel launch.
35///
36/// Specifies the execution configuration: grid size (number of blocks),
37/// block size (threads per block), and dynamic shared memory allocation.
38///
39/// # Examples
40///
41/// ```
42/// use oxicuda_launch::{LaunchParams, Dim3};
43///
44/// let params = LaunchParams::new(Dim3::x(256), Dim3::x(256));
45/// assert_eq!(params.grid, Dim3::x(256));
46/// assert_eq!(params.block, Dim3::x(256));
47/// assert_eq!(params.shared_mem_bytes, 0);
48/// ```
49#[derive(Debug, Clone, Copy)]
50pub struct LaunchParams {
51    /// Grid dimensions (number of thread blocks in each dimension).
52    pub grid: Dim3,
53    /// Block dimensions (number of threads per block in each dimension).
54    pub block: Dim3,
55    /// Dynamic shared memory allocation in bytes (default 0).
56    pub shared_mem_bytes: u32,
57}
58
59impl LaunchParams {
60    /// Creates new launch parameters with the given grid and block dimensions.
61    ///
62    /// Shared memory defaults to 0 bytes. Use [`with_shared_mem`](Self::with_shared_mem)
63    /// to specify dynamic shared memory.
64    ///
65    /// Both `grid` and `block` accept anything that converts to [`Dim3`],
66    /// including `u32`, `(u32, u32)`, and `(u32, u32, u32)`.
67    #[inline]
68    pub fn new(grid: impl Into<Dim3>, block: impl Into<Dim3>) -> Self {
69        Self {
70            grid: grid.into(),
71            block: block.into(),
72            shared_mem_bytes: 0,
73        }
74    }
75
76    /// Sets the dynamic shared memory allocation in bytes.
77    ///
78    /// Returns `self` for method chaining.
79    #[inline]
80    pub fn with_shared_mem(mut self, bytes: u32) -> Self {
81        self.shared_mem_bytes = bytes;
82        self
83    }
84
85    /// Returns a [`LaunchParamsBuilder`] for incremental configuration.
86    #[inline]
87    pub fn builder() -> LaunchParamsBuilder {
88        LaunchParamsBuilder::default()
89    }
90
91    /// Total number of threads in the launch (grid total * block total).
92    ///
93    /// Returns a `u64` to avoid overflow when grid and block totals
94    /// are both large `u32` values.
95    #[inline]
96    pub fn total_threads(&self) -> u64 {
97        self.grid.total() as u64 * self.block.total() as u64
98    }
99
100    /// Validates launch parameters against device hardware limits.
101    ///
102    /// Checks that:
103    /// - All block and grid dimensions are non-zero.
104    /// - The total threads per block does not exceed the device maximum.
105    /// - Each block dimension does not exceed its per-axis device maximum.
106    /// - Each grid dimension does not exceed its per-axis device maximum.
107    /// - The dynamic shared memory does not exceed the device maximum per block.
108    ///
109    /// # Errors
110    ///
111    /// Returns a [`LaunchError`] describing the first constraint violation
112    /// found, or a [`CudaError`](oxicuda_driver::CudaError) if device
113    /// attribute queries fail.
114    ///
115    /// # Examples
116    ///
117    /// ```rust,no_run
118    /// use oxicuda_launch::{LaunchParams, Dim3};
119    /// use oxicuda_driver::device::Device;
120    ///
121    /// oxicuda_driver::init()?;
122    /// let dev = Device::get(0)?;
123    /// let params = LaunchParams::new(256u32, 256u32);
124    /// params.validate(&dev)?;
125    /// # Ok::<(), Box<dyn std::error::Error>>(())
126    /// ```
127    pub fn validate(&self, device: &Device) -> Result<(), Box<dyn std::error::Error>> {
128        self.validate_inner(device)
129    }
130
131    /// Inner validation that queries device attributes and checks constraints.
132    fn validate_inner(&self, device: &Device) -> Result<(), Box<dyn std::error::Error>> {
133        // Check non-zero dimensions
134        if self.block.x == 0 {
135            return Err(Box::new(LaunchError::InvalidDimension {
136                dim: "block.x",
137                value: 0,
138            }));
139        }
140        if self.block.y == 0 {
141            return Err(Box::new(LaunchError::InvalidDimension {
142                dim: "block.y",
143                value: 0,
144            }));
145        }
146        if self.block.z == 0 {
147            return Err(Box::new(LaunchError::InvalidDimension {
148                dim: "block.z",
149                value: 0,
150            }));
151        }
152        if self.grid.x == 0 {
153            return Err(Box::new(LaunchError::InvalidDimension {
154                dim: "grid.x",
155                value: 0,
156            }));
157        }
158        if self.grid.y == 0 {
159            return Err(Box::new(LaunchError::InvalidDimension {
160                dim: "grid.y",
161                value: 0,
162            }));
163        }
164        if self.grid.z == 0 {
165            return Err(Box::new(LaunchError::InvalidDimension {
166                dim: "grid.z",
167                value: 0,
168            }));
169        }
170
171        // Query device limits
172        let max_threads = device.max_threads_per_block()? as u32;
173        let block_total = self.block.total();
174        if block_total > max_threads {
175            return Err(Box::new(LaunchError::BlockSizeExceedsLimit {
176                requested: block_total,
177                max: max_threads,
178            }));
179        }
180
181        // Per-axis block limits
182        let (max_bx, max_by, max_bz) = device.max_block_dim()?;
183        if self.block.x > max_bx as u32 {
184            return Err(Box::new(LaunchError::InvalidDimension {
185                dim: "block.x",
186                value: self.block.x,
187            }));
188        }
189        if self.block.y > max_by as u32 {
190            return Err(Box::new(LaunchError::InvalidDimension {
191                dim: "block.y",
192                value: self.block.y,
193            }));
194        }
195        if self.block.z > max_bz as u32 {
196            return Err(Box::new(LaunchError::InvalidDimension {
197                dim: "block.z",
198                value: self.block.z,
199            }));
200        }
201
202        // Per-axis grid limits
203        let (max_gx, max_gy, max_gz) = device.max_grid_dim()?;
204        if self.grid.x > max_gx as u32 {
205            return Err(Box::new(LaunchError::GridSizeExceedsLimit {
206                requested: self.grid.x,
207                max: max_gx as u32,
208            }));
209        }
210        if self.grid.y > max_gy as u32 {
211            return Err(Box::new(LaunchError::GridSizeExceedsLimit {
212                requested: self.grid.y,
213                max: max_gy as u32,
214            }));
215        }
216        if self.grid.z > max_gz as u32 {
217            return Err(Box::new(LaunchError::GridSizeExceedsLimit {
218                requested: self.grid.z,
219                max: max_gz as u32,
220            }));
221        }
222
223        // Shared memory limit
224        let max_smem = device.max_shared_memory_per_block()? as u32;
225        if self.shared_mem_bytes > max_smem {
226            return Err(Box::new(LaunchError::SharedMemoryExceedsLimit {
227                requested: self.shared_mem_bytes,
228                max: max_smem,
229            }));
230        }
231
232        Ok(())
233    }
234}
235
236/// Builder for [`LaunchParams`].
237///
238/// Provides a fluent interface for constructing launch parameters.
239/// If grid or block dimensions are not set, they default to `Dim3::x(1)`.
240///
241/// # Examples
242///
243/// ```
244/// use oxicuda_launch::{LaunchParams, Dim3};
245///
246/// let params = LaunchParams::builder()
247///     .grid((4u32, 4u32))
248///     .block(256u32)
249///     .shared_mem(1024)
250///     .build();
251///
252/// assert_eq!(params.grid, Dim3::xy(4, 4));
253/// assert_eq!(params.block, Dim3::x(256));
254/// assert_eq!(params.shared_mem_bytes, 1024);
255/// ```
256#[derive(Debug, Default)]
257pub struct LaunchParamsBuilder {
258    /// Grid dimensions, if set.
259    grid: Option<Dim3>,
260    /// Block dimensions, if set.
261    block: Option<Dim3>,
262    /// Dynamic shared memory in bytes.
263    shared_mem_bytes: u32,
264}
265
266impl LaunchParamsBuilder {
267    /// Sets the grid dimensions (number of thread blocks).
268    ///
269    /// Accepts anything that converts to [`Dim3`].
270    #[inline]
271    pub fn grid(mut self, dim: impl Into<Dim3>) -> Self {
272        self.grid = Some(dim.into());
273        self
274    }
275
276    /// Sets the block dimensions (threads per block).
277    ///
278    /// Accepts anything that converts to [`Dim3`].
279    #[inline]
280    pub fn block(mut self, dim: impl Into<Dim3>) -> Self {
281        self.block = Some(dim.into());
282        self
283    }
284
285    /// Sets the dynamic shared memory allocation in bytes.
286    #[inline]
287    pub fn shared_mem(mut self, bytes: u32) -> Self {
288        self.shared_mem_bytes = bytes;
289        self
290    }
291
292    /// Builds the [`LaunchParams`].
293    ///
294    /// If grid or block dimensions were not set, they default to
295    /// `Dim3::x(1)` (a single block or a single thread).
296    #[inline]
297    pub fn build(self) -> LaunchParams {
298        LaunchParams {
299            grid: self.grid.unwrap_or(Dim3::x(1)),
300            block: self.block.unwrap_or(Dim3::x(1)),
301            shared_mem_bytes: self.shared_mem_bytes,
302        }
303    }
304}
305
306// ---------------------------------------------------------------------------
307// Tests
308// ---------------------------------------------------------------------------
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn launch_params_new_basic() {
316        let p = LaunchParams::new(4u32, 256u32);
317        assert_eq!(p.grid, Dim3::x(4));
318        assert_eq!(p.block, Dim3::x(256));
319        assert_eq!(p.shared_mem_bytes, 0);
320    }
321
322    #[test]
323    fn launch_params_new_with_dim3() {
324        let p = LaunchParams::new(Dim3::xy(4, 4), Dim3::xy(16, 16));
325        assert_eq!(p.grid.total(), 16);
326        assert_eq!(p.block.total(), 256);
327    }
328
329    #[test]
330    fn launch_params_new_with_tuples() {
331        let p = LaunchParams::new((4u32, 4u32), (16u32, 16u32));
332        assert_eq!(p.grid, Dim3::xy(4, 4));
333        assert_eq!(p.block, Dim3::xy(16, 16));
334    }
335
336    #[test]
337    fn launch_params_with_shared_mem() {
338        let p = LaunchParams::new(1u32, 256u32).with_shared_mem(8192);
339        assert_eq!(p.shared_mem_bytes, 8192);
340    }
341
342    #[test]
343    fn launch_params_total_threads() {
344        let p = LaunchParams::new(4u32, 256u32);
345        assert_eq!(p.total_threads(), 1024);
346
347        let p = LaunchParams::new(Dim3::xy(4, 4), Dim3::xy(16, 16));
348        assert_eq!(p.total_threads(), 16 * 256);
349    }
350
351    #[test]
352    fn launch_params_total_threads_large() {
353        // Ensure no overflow: grid 65535x65535 * block 1024
354        let p = LaunchParams::new(Dim3::xy(65535, 65535), Dim3::x(1024));
355        let expected = 65535u64 * 65535u64 * 1024u64;
356        assert_eq!(p.total_threads(), expected);
357    }
358
359    #[test]
360    fn builder_defaults() {
361        let p = LaunchParams::builder().build();
362        assert_eq!(p.grid, Dim3::x(1));
363        assert_eq!(p.block, Dim3::x(1));
364        assert_eq!(p.shared_mem_bytes, 0);
365    }
366
367    #[test]
368    fn builder_full() {
369        let p = LaunchParams::builder()
370            .grid(128u32)
371            .block(256u32)
372            .shared_mem(4096)
373            .build();
374        assert_eq!(p.grid, Dim3::x(128));
375        assert_eq!(p.block, Dim3::x(256));
376        assert_eq!(p.shared_mem_bytes, 4096);
377    }
378
379    #[test]
380    fn builder_partial_grid_only() {
381        let p = LaunchParams::builder().grid(64u32).build();
382        assert_eq!(p.grid, Dim3::x(64));
383        assert_eq!(p.block, Dim3::x(1));
384    }
385
386    #[test]
387    fn builder_partial_block_only() {
388        let p = LaunchParams::builder().block(512u32).build();
389        assert_eq!(p.grid, Dim3::x(1));
390        assert_eq!(p.block, Dim3::x(512));
391    }
392
393    #[test]
394    fn builder_with_tuple_dims() {
395        let p = LaunchParams::builder()
396            .grid((8u32, 8u32))
397            .block((16u32, 16u32, 1u32))
398            .build();
399        assert_eq!(p.grid, Dim3::xy(8, 8));
400        assert_eq!(p.block, Dim3::new(16, 16, 1));
401    }
402
403    type ValidateFn = fn(&LaunchParams, &Device) -> Result<(), Box<dyn std::error::Error>>;
404
405    #[test]
406    fn validate_zero_block_x() {
407        let p = LaunchParams {
408            grid: Dim3::x(1),
409            block: Dim3::new(0, 1, 1),
410            shared_mem_bytes: 0,
411        };
412        // Cannot validate without a device on macOS, but we can test
413        // that the method exists and the error type is correct.
414        let _validate_fn: ValidateFn = LaunchParams::validate;
415        // Zero-dimension detection is the first check, before device queries.
416        // We can't call it without a device, so just verify compilation.
417        assert_eq!(p.block.x, 0);
418    }
419
420    #[test]
421    fn validate_zero_grid_z() {
422        let p = LaunchParams {
423            grid: Dim3::new(1, 1, 0),
424            block: Dim3::x(256),
425            shared_mem_bytes: 0,
426        };
427        assert_eq!(p.grid.z, 0);
428    }
429
430    #[test]
431    fn validate_signature_compiles() {
432        // Just verify the signature is well-formed.
433        let _: ValidateFn = LaunchParams::validate;
434    }
435
436    #[cfg(feature = "gpu-tests")]
437    #[test]
438    fn validate_with_real_device() {
439        oxicuda_driver::init().ok();
440        if let Ok(dev) = Device::get(0) {
441            let p = LaunchParams::new(4u32, 256u32);
442            assert!(p.validate(&dev).is_ok());
443
444            // Too many threads per block
445            let p2 = LaunchParams::new(1u32, Dim3::new(1024, 1024, 1));
446            assert!(p2.validate(&dev).is_err());
447        }
448    }
449}