Skip to main content

oxicuda_launch/
grid.rs

1//! Grid and block dimension types for kernel launch configuration.
2//!
3//! CUDA kernels are launched with a grid of thread blocks.
4//! Each block contains threads organized in up to 3 dimensions.
5//!
6//! # Dimension model
7//!
8//! The CUDA execution model uses a two-level hierarchy:
9//!
10//! - **Grid**: A collection of thread blocks, specified as up to 3D dimensions.
11//! - **Block**: A collection of threads within a block, also up to 3D.
12//!
13//! Both are described by [`Dim3`], which defaults unused dimensions to 1.
14//!
15//! # Helper function
16//!
17//! The [`grid_size_for`] function computes the minimum grid size needed
18//! to cover a given number of elements with a given block size (ceiling
19//! division).
20
21use std::fmt;
22
23use oxicuda_driver::error::CudaResult;
24use oxicuda_driver::module::Function;
25
26/// 3-dimensional size specification for grids and blocks.
27///
28/// Used to specify the number of thread blocks in a grid
29/// and the number of threads in a block. Dimensions default
30/// to 1 when not explicitly provided.
31///
32/// # Examples
33///
34/// ```
35/// use oxicuda_launch::Dim3;
36///
37/// // 1D: 256 threads
38/// let block = Dim3::x(256);
39/// assert_eq!(block.x, 256);
40/// assert_eq!(block.y, 1);
41/// assert_eq!(block.z, 1);
42///
43/// // 2D: 16x16 threads
44/// let block = Dim3::xy(16, 16);
45/// assert_eq!(block.total(), 256);
46///
47/// // 3D
48/// let block = Dim3::new(8, 8, 4);
49/// assert_eq!(block.total(), 256);
50///
51/// // From conversions
52/// let block: Dim3 = 256u32.into();
53/// assert_eq!(block, Dim3::x(256));
54///
55/// let block: Dim3 = (16u32, 16u32).into();
56/// assert_eq!(block, Dim3::xy(16, 16));
57///
58/// let block: Dim3 = (8u32, 8u32, 4u32).into();
59/// assert_eq!(block, Dim3::new(8, 8, 4));
60/// ```
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
62pub struct Dim3 {
63    /// X dimension.
64    pub x: u32,
65    /// Y dimension (default 1).
66    pub y: u32,
67    /// Z dimension (default 1).
68    pub z: u32,
69}
70
71impl Dim3 {
72    /// Creates a new `Dim3` with explicit values for all three dimensions.
73    #[inline]
74    pub fn new(x: u32, y: u32, z: u32) -> Self {
75        Self { x, y, z }
76    }
77
78    /// Creates a 1-dimensional `Dim3` with the given X value.
79    ///
80    /// Y and Z are set to 1.
81    #[inline]
82    pub fn x(x: u32) -> Self {
83        Self::new(x, 1, 1)
84    }
85
86    /// Creates a 2-dimensional `Dim3` with the given X and Y values.
87    ///
88    /// Z is set to 1.
89    #[inline]
90    pub fn xy(x: u32, y: u32) -> Self {
91        Self::new(x, y, 1)
92    }
93
94    /// Total number of elements (`x * y * z`).
95    ///
96    /// For a grid dimension, this is the total number of thread blocks.
97    /// For a block dimension, this is the total number of threads per block.
98    #[inline]
99    pub fn total(&self) -> u32 {
100        self.x * self.y * self.z
101    }
102}
103
104impl From<u32> for Dim3 {
105    /// Converts a single `u32` into a 1D `Dim3`.
106    #[inline]
107    fn from(x: u32) -> Self {
108        Self::x(x)
109    }
110}
111
112impl From<(u32, u32)> for Dim3 {
113    /// Converts a `(u32, u32)` tuple into a 2D `Dim3`.
114    #[inline]
115    fn from((x, y): (u32, u32)) -> Self {
116        Self::xy(x, y)
117    }
118}
119
120impl From<(u32, u32, u32)> for Dim3 {
121    /// Converts a `(u32, u32, u32)` tuple into a 3D `Dim3`.
122    #[inline]
123    fn from((x, y, z): (u32, u32, u32)) -> Self {
124        Self::new(x, y, z)
125    }
126}
127
128impl fmt::Display for Dim3 {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        if self.z != 1 {
131            write!(f, "({}, {}, {})", self.x, self.y, self.z)
132        } else if self.y != 1 {
133            write!(f, "({}, {})", self.x, self.y)
134        } else {
135            write!(f, "{}", self.x)
136        }
137    }
138}
139
140// ---------------------------------------------------------------------------
141// Occupancy-based auto grid sizing
142// ---------------------------------------------------------------------------
143
144/// Computes optimal grid and block dimensions for a 1D problem of `n` elements.
145///
146/// Queries the CUDA occupancy API to determine the block size that
147/// maximises multiprocessor occupancy for the given kernel function,
148/// then calculates the grid size needed to cover `n` work items.
149///
150/// Returns `(grid_dim, block_dim)` suitable for use with [`LaunchParams`](crate::LaunchParams).
151///
152/// # Errors
153///
154/// Returns a [`CudaError`](oxicuda_driver::CudaError) if the occupancy
155/// query fails (e.g., invalid function handle, driver not loaded).
156///
157/// # Examples
158///
159/// ```rust,no_run
160/// use oxicuda_launch::grid::auto_grid_for;
161/// # use oxicuda_driver::module::Module;
162///
163/// # let module: Module = todo!();
164/// let func = module.get_function("my_kernel")?;
165/// let (grid, block) = auto_grid_for(&func, 100_000)?;
166/// # Ok::<(), oxicuda_driver::error::CudaError>(())
167/// ```
168pub fn auto_grid_for(func: &Function, n: usize) -> CudaResult<(Dim3, Dim3)> {
169    let (_min_grid, optimal_block) = func.optimal_block_size(0)?;
170    let block_size = optimal_block as u32;
171    let grid_x = if n == 0 {
172        0
173    } else {
174        (n as u32).div_ceil(block_size)
175    };
176    Ok((Dim3::x(grid_x), Dim3::x(block_size)))
177}
178
179/// Computes optimal grid and block dimensions for a 2D problem.
180///
181/// Given a kernel function and problem dimensions `(width, height)`,
182/// this function determines a 2D block size and the corresponding
183/// grid dimensions.  The block is sized as a square (or near-square)
184/// tile whose total thread count respects the occupancy-optimal value.
185///
186/// Returns `(grid_dim, block_dim)` as 2D [`Dim3`] values.
187///
188/// # Errors
189///
190/// Returns a [`CudaError`](oxicuda_driver::CudaError) if the occupancy
191/// query fails.
192///
193/// # Examples
194///
195/// ```rust,no_run
196/// use oxicuda_launch::grid::auto_grid_2d;
197/// # use oxicuda_driver::module::Module;
198///
199/// # let module: Module = todo!();
200/// let func = module.get_function("my_kernel_2d")?;
201/// let (grid, block) = auto_grid_2d(&func, 1920, 1080)?;
202/// # Ok::<(), oxicuda_driver::error::CudaError>(())
203/// ```
204pub fn auto_grid_2d(func: &Function, width: usize, height: usize) -> CudaResult<(Dim3, Dim3)> {
205    let (_min_grid, optimal_block) = func.optimal_block_size(0)?;
206    let total = optimal_block as u32;
207
208    // Find a near-square block tile. Start from sqrt and round down to
209    // powers-of-two-friendly values.
210    let sqrt_approx = (total as f64).sqrt() as u32;
211    let block_x = nearest_power_of_two_le(sqrt_approx).max(1);
212    let block_y = (total / block_x).max(1);
213
214    let grid_x = if width == 0 {
215        0
216    } else {
217        (width as u32).div_ceil(block_x)
218    };
219    let grid_y = if height == 0 {
220        0
221    } else {
222        (height as u32).div_ceil(block_y)
223    };
224
225    Ok((Dim3::xy(grid_x, grid_y), Dim3::xy(block_x, block_y)))
226}
227
228/// Returns the largest power of two that is less than or equal to `n`.
229///
230/// Returns 1 if `n` is 0.
231fn nearest_power_of_two_le(n: u32) -> u32 {
232    if n == 0 {
233        return 1;
234    }
235    // Highest bit position gives the largest power-of-2 <= n.
236    1u32 << (31 - n.leading_zeros())
237}
238
239// ---------------------------------------------------------------------------
240// Simple grid sizing helper
241// ---------------------------------------------------------------------------
242
243/// Calculate the grid size needed to cover `n` elements with `block_size` threads.
244///
245/// Returns `(n + block_size - 1) / block_size`, i.e., ceiling division.
246/// This is the standard formula for determining how many thread blocks
247/// are needed to process `n` work items when each block handles
248/// `block_size` items.
249///
250/// # Panics
251///
252/// Panics if `block_size` is zero.
253///
254/// # Examples
255///
256/// ```
257/// use oxicuda_launch::grid_size_for;
258///
259/// assert_eq!(grid_size_for(1000, 256), 4);  // 4 * 256 = 1024 >= 1000
260/// assert_eq!(grid_size_for(256, 256), 1);
261/// assert_eq!(grid_size_for(257, 256), 2);
262/// assert_eq!(grid_size_for(0, 256), 0);
263/// assert_eq!(grid_size_for(1, 256), 1);
264/// ```
265#[inline]
266pub fn grid_size_for(n: u32, block_size: u32) -> u32 {
267    n.div_ceil(block_size)
268}
269
270// ---------------------------------------------------------------------------
271// Tests
272// ---------------------------------------------------------------------------
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn dim3_new() {
280        let d = Dim3::new(4, 5, 6);
281        assert_eq!(d.x, 4);
282        assert_eq!(d.y, 5);
283        assert_eq!(d.z, 6);
284    }
285
286    #[test]
287    fn dim3_x() {
288        let d = Dim3::x(128);
289        assert_eq!(d, Dim3::new(128, 1, 1));
290    }
291
292    #[test]
293    fn dim3_xy() {
294        let d = Dim3::xy(16, 16);
295        assert_eq!(d, Dim3::new(16, 16, 1));
296    }
297
298    #[test]
299    fn dim3_total() {
300        assert_eq!(Dim3::x(256).total(), 256);
301        assert_eq!(Dim3::xy(16, 16).total(), 256);
302        assert_eq!(Dim3::new(8, 8, 4).total(), 256);
303        assert_eq!(Dim3::new(1, 1, 1).total(), 1);
304    }
305
306    #[test]
307    fn dim3_from_u32() {
308        let d: Dim3 = 512u32.into();
309        assert_eq!(d, Dim3::x(512));
310    }
311
312    #[test]
313    fn dim3_from_tuple2() {
314        let d: Dim3 = (32u32, 8u32).into();
315        assert_eq!(d, Dim3::xy(32, 8));
316    }
317
318    #[test]
319    fn dim3_from_tuple3() {
320        let d: Dim3 = (4u32, 4u32, 4u32).into();
321        assert_eq!(d, Dim3::new(4, 4, 4));
322    }
323
324    #[test]
325    fn dim3_display_1d() {
326        assert_eq!(format!("{}", Dim3::x(256)), "256");
327    }
328
329    #[test]
330    fn dim3_display_2d() {
331        assert_eq!(format!("{}", Dim3::xy(16, 16)), "(16, 16)");
332    }
333
334    #[test]
335    fn dim3_display_3d() {
336        assert_eq!(format!("{}", Dim3::new(8, 8, 4)), "(8, 8, 4)");
337    }
338
339    #[test]
340    fn dim3_eq_and_hash() {
341        use std::collections::HashSet;
342        let mut set = HashSet::new();
343        set.insert(Dim3::x(256));
344        assert!(set.contains(&Dim3::new(256, 1, 1)));
345        assert!(!set.contains(&Dim3::x(128)));
346    }
347
348    #[test]
349    fn grid_size_for_exact() {
350        assert_eq!(grid_size_for(256, 256), 1);
351        assert_eq!(grid_size_for(512, 256), 2);
352    }
353
354    #[test]
355    fn grid_size_for_remainder() {
356        assert_eq!(grid_size_for(257, 256), 2);
357        assert_eq!(grid_size_for(1000, 256), 4);
358        assert_eq!(grid_size_for(1, 256), 1);
359    }
360
361    #[test]
362    fn grid_size_for_zero_elements() {
363        assert_eq!(grid_size_for(0, 256), 0);
364    }
365
366    #[test]
367    fn grid_size_for_one_block() {
368        assert_eq!(grid_size_for(1, 1), 1);
369        assert_eq!(grid_size_for(100, 100), 1);
370    }
371
372    #[test]
373    fn nearest_power_of_two_le_values() {
374        assert_eq!(super::nearest_power_of_two_le(0), 1);
375        assert_eq!(super::nearest_power_of_two_le(1), 1);
376        assert_eq!(super::nearest_power_of_two_le(2), 2);
377        assert_eq!(super::nearest_power_of_two_le(3), 2);
378        assert_eq!(super::nearest_power_of_two_le(4), 4);
379        assert_eq!(super::nearest_power_of_two_le(5), 4);
380        assert_eq!(super::nearest_power_of_two_le(16), 16);
381        assert_eq!(super::nearest_power_of_two_le(17), 16);
382        assert_eq!(super::nearest_power_of_two_le(255), 128);
383        assert_eq!(super::nearest_power_of_two_le(256), 256);
384    }
385
386    #[test]
387    fn auto_grid_for_signature_compiles() {
388        let _f: fn(
389            &oxicuda_driver::module::Function,
390            usize,
391        ) -> oxicuda_driver::error::CudaResult<(Dim3, Dim3)> = super::auto_grid_for;
392    }
393
394    #[test]
395    fn auto_grid_2d_signature_compiles() {
396        let _f: fn(
397            &oxicuda_driver::module::Function,
398            usize,
399            usize,
400        ) -> oxicuda_driver::error::CudaResult<(Dim3, Dim3)> = super::auto_grid_2d;
401    }
402
403    #[cfg(feature = "gpu-tests")]
404    #[test]
405    fn auto_grid_for_with_real_kernel() {
406        use std::sync::Arc;
407        oxicuda_driver::init().ok();
408        if let Ok(dev) = oxicuda_driver::device::Device::get(0) {
409            let _ctx = Arc::new(oxicuda_driver::context::Context::new(&dev).expect("ctx"));
410            let ptx = ".version 7.0\n.target sm_70\n.address_size 64\n.visible .entry test_kernel(.param .u32 n) { ret; }";
411            if let Ok(module) = oxicuda_driver::module::Module::from_ptx(ptx) {
412                let func = module.get_function("test_kernel").expect("func");
413                let (grid, block) = super::auto_grid_for(&func, 10000).expect("auto_grid");
414                assert!(grid.x > 0);
415                assert!(block.x > 0);
416            }
417        }
418    }
419}