oxicuda_launch/grid.rs
1//! Grid and block dimension types for kernel launch configuration.
2//!
3//! CUDA kernels are launched with a grid of thread blocks.
4//! Each block contains threads organized in up to 3 dimensions.
5//!
6//! # Dimension model
7//!
8//! The CUDA execution model uses a two-level hierarchy:
9//!
10//! - **Grid**: A collection of thread blocks, specified as up to 3D dimensions.
11//! - **Block**: A collection of threads within a block, also up to 3D.
12//!
13//! Both are described by [`Dim3`], which defaults unused dimensions to 1.
14//!
15//! # Helper function
16//!
17//! The [`grid_size_for`] function computes the minimum grid size needed
18//! to cover a given number of elements with a given block size (ceiling
19//! division).
20
21use std::fmt;
22
23use oxicuda_driver::error::CudaResult;
24use oxicuda_driver::module::Function;
25
26/// 3-dimensional size specification for grids and blocks.
27///
28/// Used to specify the number of thread blocks in a grid
29/// and the number of threads in a block. Dimensions default
30/// to 1 when not explicitly provided.
31///
32/// # Examples
33///
34/// ```
35/// use oxicuda_launch::Dim3;
36///
37/// // 1D: 256 threads
38/// let block = Dim3::x(256);
39/// assert_eq!(block.x, 256);
40/// assert_eq!(block.y, 1);
41/// assert_eq!(block.z, 1);
42///
43/// // 2D: 16x16 threads
44/// let block = Dim3::xy(16, 16);
45/// assert_eq!(block.total(), 256);
46///
47/// // 3D
48/// let block = Dim3::new(8, 8, 4);
49/// assert_eq!(block.total(), 256);
50///
51/// // From conversions
52/// let block: Dim3 = 256u32.into();
53/// assert_eq!(block, Dim3::x(256));
54///
55/// let block: Dim3 = (16u32, 16u32).into();
56/// assert_eq!(block, Dim3::xy(16, 16));
57///
58/// let block: Dim3 = (8u32, 8u32, 4u32).into();
59/// assert_eq!(block, Dim3::new(8, 8, 4));
60/// ```
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
62pub struct Dim3 {
63 /// X dimension.
64 pub x: u32,
65 /// Y dimension (default 1).
66 pub y: u32,
67 /// Z dimension (default 1).
68 pub z: u32,
69}
70
71impl Dim3 {
72 /// Creates a new `Dim3` with explicit values for all three dimensions.
73 #[inline]
74 pub fn new(x: u32, y: u32, z: u32) -> Self {
75 Self { x, y, z }
76 }
77
78 /// Creates a 1-dimensional `Dim3` with the given X value.
79 ///
80 /// Y and Z are set to 1.
81 #[inline]
82 pub fn x(x: u32) -> Self {
83 Self::new(x, 1, 1)
84 }
85
86 /// Creates a 2-dimensional `Dim3` with the given X and Y values.
87 ///
88 /// Z is set to 1.
89 #[inline]
90 pub fn xy(x: u32, y: u32) -> Self {
91 Self::new(x, y, 1)
92 }
93
94 /// Total number of elements (`x * y * z`).
95 ///
96 /// For a grid dimension, this is the total number of thread blocks.
97 /// For a block dimension, this is the total number of threads per block.
98 #[inline]
99 pub fn total(&self) -> u32 {
100 self.x * self.y * self.z
101 }
102}
103
104impl From<u32> for Dim3 {
105 /// Converts a single `u32` into a 1D `Dim3`.
106 #[inline]
107 fn from(x: u32) -> Self {
108 Self::x(x)
109 }
110}
111
112impl From<(u32, u32)> for Dim3 {
113 /// Converts a `(u32, u32)` tuple into a 2D `Dim3`.
114 #[inline]
115 fn from((x, y): (u32, u32)) -> Self {
116 Self::xy(x, y)
117 }
118}
119
120impl From<(u32, u32, u32)> for Dim3 {
121 /// Converts a `(u32, u32, u32)` tuple into a 3D `Dim3`.
122 #[inline]
123 fn from((x, y, z): (u32, u32, u32)) -> Self {
124 Self::new(x, y, z)
125 }
126}
127
128impl fmt::Display for Dim3 {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 if self.z != 1 {
131 write!(f, "({}, {}, {})", self.x, self.y, self.z)
132 } else if self.y != 1 {
133 write!(f, "({}, {})", self.x, self.y)
134 } else {
135 write!(f, "{}", self.x)
136 }
137 }
138}
139
140// ---------------------------------------------------------------------------
141// Occupancy-based auto grid sizing
142// ---------------------------------------------------------------------------
143
144/// Computes optimal grid and block dimensions for a 1D problem of `n` elements.
145///
146/// Queries the CUDA occupancy API to determine the block size that
147/// maximises multiprocessor occupancy for the given kernel function,
148/// then calculates the grid size needed to cover `n` work items.
149///
150/// Returns `(grid_dim, block_dim)` suitable for use with [`LaunchParams`](crate::LaunchParams).
151///
152/// # Errors
153///
154/// Returns a [`CudaError`](oxicuda_driver::CudaError) if the occupancy
155/// query fails (e.g., invalid function handle, driver not loaded).
156///
157/// # Examples
158///
159/// ```rust,no_run
160/// use oxicuda_launch::grid::auto_grid_for;
161/// # use oxicuda_driver::module::Module;
162///
163/// # let module: Module = todo!();
164/// let func = module.get_function("my_kernel")?;
165/// let (grid, block) = auto_grid_for(&func, 100_000)?;
166/// # Ok::<(), oxicuda_driver::error::CudaError>(())
167/// ```
168pub fn auto_grid_for(func: &Function, n: usize) -> CudaResult<(Dim3, Dim3)> {
169 let (_min_grid, optimal_block) = func.optimal_block_size(0)?;
170 let block_size = optimal_block as u32;
171 let grid_x = if n == 0 {
172 0
173 } else {
174 (n as u32).div_ceil(block_size)
175 };
176 Ok((Dim3::x(grid_x), Dim3::x(block_size)))
177}
178
179/// Computes optimal grid and block dimensions for a 2D problem.
180///
181/// Given a kernel function and problem dimensions `(width, height)`,
182/// this function determines a 2D block size and the corresponding
183/// grid dimensions. The block is sized as a square (or near-square)
184/// tile whose total thread count respects the occupancy-optimal value.
185///
186/// Returns `(grid_dim, block_dim)` as 2D [`Dim3`] values.
187///
188/// # Errors
189///
190/// Returns a [`CudaError`](oxicuda_driver::CudaError) if the occupancy
191/// query fails.
192///
193/// # Examples
194///
195/// ```rust,no_run
196/// use oxicuda_launch::grid::auto_grid_2d;
197/// # use oxicuda_driver::module::Module;
198///
199/// # let module: Module = todo!();
200/// let func = module.get_function("my_kernel_2d")?;
201/// let (grid, block) = auto_grid_2d(&func, 1920, 1080)?;
202/// # Ok::<(), oxicuda_driver::error::CudaError>(())
203/// ```
204pub fn auto_grid_2d(func: &Function, width: usize, height: usize) -> CudaResult<(Dim3, Dim3)> {
205 let (_min_grid, optimal_block) = func.optimal_block_size(0)?;
206 let total = optimal_block as u32;
207
208 // Find a near-square block tile. Start from sqrt and round down to
209 // powers-of-two-friendly values.
210 let sqrt_approx = (total as f64).sqrt() as u32;
211 let block_x = nearest_power_of_two_le(sqrt_approx).max(1);
212 let block_y = (total / block_x).max(1);
213
214 let grid_x = if width == 0 {
215 0
216 } else {
217 (width as u32).div_ceil(block_x)
218 };
219 let grid_y = if height == 0 {
220 0
221 } else {
222 (height as u32).div_ceil(block_y)
223 };
224
225 Ok((Dim3::xy(grid_x, grid_y), Dim3::xy(block_x, block_y)))
226}
227
228/// Returns the largest power of two that is less than or equal to `n`.
229///
230/// Returns 1 if `n` is 0.
231fn nearest_power_of_two_le(n: u32) -> u32 {
232 if n == 0 {
233 return 1;
234 }
235 // Highest bit position gives the largest power-of-2 <= n.
236 1u32 << (31 - n.leading_zeros())
237}
238
239// ---------------------------------------------------------------------------
240// Simple grid sizing helper
241// ---------------------------------------------------------------------------
242
243/// Calculate the grid size needed to cover `n` elements with `block_size` threads.
244///
245/// Returns `(n + block_size - 1) / block_size`, i.e., ceiling division.
246/// This is the standard formula for determining how many thread blocks
247/// are needed to process `n` work items when each block handles
248/// `block_size` items.
249///
250/// # Panics
251///
252/// Panics if `block_size` is zero.
253///
254/// # Examples
255///
256/// ```
257/// use oxicuda_launch::grid_size_for;
258///
259/// assert_eq!(grid_size_for(1000, 256), 4); // 4 * 256 = 1024 >= 1000
260/// assert_eq!(grid_size_for(256, 256), 1);
261/// assert_eq!(grid_size_for(257, 256), 2);
262/// assert_eq!(grid_size_for(0, 256), 0);
263/// assert_eq!(grid_size_for(1, 256), 1);
264/// ```
265#[inline]
266pub fn grid_size_for(n: u32, block_size: u32) -> u32 {
267 n.div_ceil(block_size)
268}
269
270// ---------------------------------------------------------------------------
271// Tests
272// ---------------------------------------------------------------------------
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn dim3_new() {
280 let d = Dim3::new(4, 5, 6);
281 assert_eq!(d.x, 4);
282 assert_eq!(d.y, 5);
283 assert_eq!(d.z, 6);
284 }
285
286 #[test]
287 fn dim3_x() {
288 let d = Dim3::x(128);
289 assert_eq!(d, Dim3::new(128, 1, 1));
290 }
291
292 #[test]
293 fn dim3_xy() {
294 let d = Dim3::xy(16, 16);
295 assert_eq!(d, Dim3::new(16, 16, 1));
296 }
297
298 #[test]
299 fn dim3_total() {
300 assert_eq!(Dim3::x(256).total(), 256);
301 assert_eq!(Dim3::xy(16, 16).total(), 256);
302 assert_eq!(Dim3::new(8, 8, 4).total(), 256);
303 assert_eq!(Dim3::new(1, 1, 1).total(), 1);
304 }
305
306 #[test]
307 fn dim3_from_u32() {
308 let d: Dim3 = 512u32.into();
309 assert_eq!(d, Dim3::x(512));
310 }
311
312 #[test]
313 fn dim3_from_tuple2() {
314 let d: Dim3 = (32u32, 8u32).into();
315 assert_eq!(d, Dim3::xy(32, 8));
316 }
317
318 #[test]
319 fn dim3_from_tuple3() {
320 let d: Dim3 = (4u32, 4u32, 4u32).into();
321 assert_eq!(d, Dim3::new(4, 4, 4));
322 }
323
324 #[test]
325 fn dim3_display_1d() {
326 assert_eq!(format!("{}", Dim3::x(256)), "256");
327 }
328
329 #[test]
330 fn dim3_display_2d() {
331 assert_eq!(format!("{}", Dim3::xy(16, 16)), "(16, 16)");
332 }
333
334 #[test]
335 fn dim3_display_3d() {
336 assert_eq!(format!("{}", Dim3::new(8, 8, 4)), "(8, 8, 4)");
337 }
338
339 #[test]
340 fn dim3_eq_and_hash() {
341 use std::collections::HashSet;
342 let mut set = HashSet::new();
343 set.insert(Dim3::x(256));
344 assert!(set.contains(&Dim3::new(256, 1, 1)));
345 assert!(!set.contains(&Dim3::x(128)));
346 }
347
348 #[test]
349 fn grid_size_for_exact() {
350 assert_eq!(grid_size_for(256, 256), 1);
351 assert_eq!(grid_size_for(512, 256), 2);
352 }
353
354 #[test]
355 fn grid_size_for_remainder() {
356 assert_eq!(grid_size_for(257, 256), 2);
357 assert_eq!(grid_size_for(1000, 256), 4);
358 assert_eq!(grid_size_for(1, 256), 1);
359 }
360
361 #[test]
362 fn grid_size_for_zero_elements() {
363 assert_eq!(grid_size_for(0, 256), 0);
364 }
365
366 #[test]
367 fn grid_size_for_one_block() {
368 assert_eq!(grid_size_for(1, 1), 1);
369 assert_eq!(grid_size_for(100, 100), 1);
370 }
371
372 #[test]
373 fn nearest_power_of_two_le_values() {
374 assert_eq!(super::nearest_power_of_two_le(0), 1);
375 assert_eq!(super::nearest_power_of_two_le(1), 1);
376 assert_eq!(super::nearest_power_of_two_le(2), 2);
377 assert_eq!(super::nearest_power_of_two_le(3), 2);
378 assert_eq!(super::nearest_power_of_two_le(4), 4);
379 assert_eq!(super::nearest_power_of_two_le(5), 4);
380 assert_eq!(super::nearest_power_of_two_le(16), 16);
381 assert_eq!(super::nearest_power_of_two_le(17), 16);
382 assert_eq!(super::nearest_power_of_two_le(255), 128);
383 assert_eq!(super::nearest_power_of_two_le(256), 256);
384 }
385
386 #[test]
387 fn auto_grid_for_signature_compiles() {
388 let _f: fn(
389 &oxicuda_driver::module::Function,
390 usize,
391 ) -> oxicuda_driver::error::CudaResult<(Dim3, Dim3)> = super::auto_grid_for;
392 }
393
394 #[test]
395 fn auto_grid_2d_signature_compiles() {
396 let _f: fn(
397 &oxicuda_driver::module::Function,
398 usize,
399 usize,
400 ) -> oxicuda_driver::error::CudaResult<(Dim3, Dim3)> = super::auto_grid_2d;
401 }
402
403 #[cfg(feature = "gpu-tests")]
404 #[test]
405 fn auto_grid_for_with_real_kernel() {
406 use std::sync::Arc;
407 oxicuda_driver::init().ok();
408 if let Ok(dev) = oxicuda_driver::device::Device::get(0) {
409 let _ctx = Arc::new(oxicuda_driver::context::Context::new(&dev).expect("ctx"));
410 let ptx = ".version 7.0\n.target sm_70\n.address_size 64\n.visible .entry test_kernel(.param .u32 n) { ret; }";
411 if let Ok(module) = oxicuda_driver::module::Module::from_ptx(ptx) {
412 let func = module.get_function("test_kernel").expect("func");
413 let (grid, block) = super::auto_grid_for(&func, 10000).expect("auto_grid");
414 assert!(grid.x > 0);
415 assert!(block.x > 0);
416 }
417 }
418 }
419}