oxicuda_launch/cluster.rs
1//! Thread block cluster configuration for Hopper+ GPUs (SM 9.0+).
2//!
3//! Thread block clusters are a new level of the CUDA execution hierarchy
4//! introduced with the NVIDIA Hopper architecture (compute capability 9.0).
5//! A cluster groups multiple thread blocks that can cooperate more
6//! efficiently via distributed shared memory and hardware-accelerated
7//! synchronisation.
8//!
9//! # Requirements
10//!
11//! - NVIDIA Hopper (H100) or later GPU (compute capability 9.0+).
12//! - CUDA driver version 12.0 or later.
13//! - The kernel must be compiled with cluster support.
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! # use oxicuda_launch::cluster::{ClusterDim, ClusterLaunchParams};
19//! # use oxicuda_launch::Dim3;
20//! let cluster_params = ClusterLaunchParams {
21//! grid: Dim3::x(16),
22//! block: Dim3::x(256),
23//! cluster: ClusterDim::new(2, 1, 1),
24//! shared_mem_bytes: 0,
25//! };
26//! assert_eq!(cluster_params.blocks_per_cluster(), 2);
27//! ```
28
29use oxicuda_driver::error::{CudaError, CudaResult};
30use oxicuda_driver::stream::Stream;
31
32use crate::grid::Dim3;
33use crate::kernel::{Kernel, KernelArgs};
34
35// ---------------------------------------------------------------------------
36// ClusterDim
37// ---------------------------------------------------------------------------
38
39/// Cluster dimensions specifying how many thread blocks form one cluster.
40///
41/// Each dimension specifies the number of blocks in that direction of
42/// the cluster. The total number of blocks in a cluster is
43/// `x * y * z`.
44///
45/// # Constraints
46///
47/// - All dimensions must be non-zero.
48/// - The total blocks per cluster must not exceed the hardware limit
49/// (typically 8 or 16 for Hopper GPUs).
50/// - The grid dimensions must be evenly divisible by the cluster
51/// dimensions.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53pub struct ClusterDim {
54 /// Number of blocks in the X dimension of the cluster.
55 pub x: u32,
56 /// Number of blocks in the Y dimension of the cluster.
57 pub y: u32,
58 /// Number of blocks in the Z dimension of the cluster.
59 pub z: u32,
60}
61
62impl ClusterDim {
63 /// Creates a new cluster dimension.
64 #[inline]
65 pub fn new(x: u32, y: u32, z: u32) -> Self {
66 Self { x, y, z }
67 }
68
69 /// Creates a 1D cluster (only X dimension used).
70 #[inline]
71 pub fn x(x: u32) -> Self {
72 Self { x, y: 1, z: 1 }
73 }
74
75 /// Creates a 2D cluster.
76 #[inline]
77 pub fn xy(x: u32, y: u32) -> Self {
78 Self { x, y, z: 1 }
79 }
80
81 /// Returns the total number of blocks per cluster.
82 #[inline]
83 pub fn total(&self) -> u32 {
84 self.x * self.y * self.z
85 }
86
87 /// Validates that all dimensions are non-zero.
88 fn validate(&self) -> CudaResult<()> {
89 if self.x == 0 || self.y == 0 || self.z == 0 {
90 return Err(CudaError::InvalidValue);
91 }
92 Ok(())
93 }
94}
95
96impl std::fmt::Display for ClusterDim {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 write!(f, "ClusterDim({}x{}x{})", self.x, self.y, self.z)
99 }
100}
101
102// ---------------------------------------------------------------------------
103// ClusterLaunchParams
104// ---------------------------------------------------------------------------
105
106/// Launch parameters including thread block cluster configuration.
107///
108/// Extends the standard grid/block configuration with a cluster
109/// dimension. The grid dimensions must be evenly divisible by the
110/// cluster dimensions.
111#[derive(Debug, Clone, Copy)]
112pub struct ClusterLaunchParams {
113 /// Grid dimensions (number of thread blocks total).
114 pub grid: Dim3,
115 /// Block dimensions (threads per block).
116 pub block: Dim3,
117 /// Cluster dimensions (blocks per cluster).
118 pub cluster: ClusterDim,
119 /// Dynamic shared memory per block in bytes.
120 pub shared_mem_bytes: u32,
121}
122
123impl ClusterLaunchParams {
124 /// Returns the total number of blocks per cluster.
125 #[inline]
126 pub fn blocks_per_cluster(&self) -> u32 {
127 self.cluster.total()
128 }
129
130 /// Returns the total number of clusters in the grid.
131 ///
132 /// This requires that the grid dimensions be evenly divisible by
133 /// the cluster dimensions.
134 ///
135 /// # Errors
136 ///
137 /// Returns [`CudaError::InvalidValue`] if the grid is not evenly
138 /// divisible by the cluster dimensions, or if any dimension is zero.
139 pub fn cluster_count(&self) -> CudaResult<u32> {
140 self.validate()?;
141 let cx = self.grid.x / self.cluster.x;
142 let cy = self.grid.y / self.cluster.y;
143 let cz = self.grid.z / self.cluster.z;
144 Ok(cx * cy * cz)
145 }
146
147 /// Validates the cluster launch parameters.
148 ///
149 /// Checks that:
150 /// - All grid, block, and cluster dimensions are non-zero.
151 /// - The grid dimensions are evenly divisible by the cluster dimensions.
152 ///
153 /// # Errors
154 ///
155 /// Returns [`CudaError::InvalidValue`] on any violation.
156 pub fn validate(&self) -> CudaResult<()> {
157 // Validate cluster dims
158 self.cluster.validate()?;
159
160 // Validate grid dims are non-zero
161 if self.grid.x == 0 || self.grid.y == 0 || self.grid.z == 0 {
162 return Err(CudaError::InvalidValue);
163 }
164 if self.block.x == 0 || self.block.y == 0 || self.block.z == 0 {
165 return Err(CudaError::InvalidValue);
166 }
167
168 // Grid must be divisible by cluster
169 if self.grid.x % self.cluster.x != 0
170 || self.grid.y % self.cluster.y != 0
171 || self.grid.z % self.cluster.z != 0
172 {
173 return Err(CudaError::InvalidValue);
174 }
175
176 Ok(())
177 }
178}
179
180// ---------------------------------------------------------------------------
181// cluster_launch
182// ---------------------------------------------------------------------------
183
184/// Launches a kernel with thread block cluster configuration.
185///
186/// On Hopper+ GPUs (compute capability 9.0+), this groups thread blocks
187/// into clusters for enhanced cooperation via distributed shared memory.
188///
189/// This function validates the cluster parameters and delegates to the
190/// standard kernel launch. On hardware that supports clusters natively,
191/// the CUDA driver would use `cuLaunchKernelEx` with cluster attributes.
192///
193/// # Parameters
194///
195/// * `kernel` — the kernel to launch.
196/// * `params` — cluster-aware launch parameters.
197/// * `stream` — the stream to launch on.
198/// * `args` — kernel arguments.
199///
200/// # Errors
201///
202/// Returns [`CudaError::InvalidValue`] if the parameters are invalid
203/// (zero dimensions, grid not divisible by cluster, etc.), or another
204/// error from the underlying kernel launch.
205pub fn cluster_launch<A: KernelArgs>(
206 kernel: &Kernel,
207 params: &ClusterLaunchParams,
208 stream: &Stream,
209 args: &A,
210) -> CudaResult<()> {
211 params.validate()?;
212
213 // Convert to standard launch params. The cluster dimension is
214 // a hint to the driver; the actual grid/block stays the same.
215 let launch_params = crate::params::LaunchParams {
216 grid: params.grid,
217 block: params.block,
218 shared_mem_bytes: params.shared_mem_bytes,
219 };
220
221 kernel.launch(&launch_params, stream, args)
222}
223
224// ---------------------------------------------------------------------------
225// Tests
226// ---------------------------------------------------------------------------
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn cluster_dim_new() {
234 let c = ClusterDim::new(2, 2, 1);
235 assert_eq!(c.x, 2);
236 assert_eq!(c.y, 2);
237 assert_eq!(c.z, 1);
238 assert_eq!(c.total(), 4);
239 }
240
241 #[test]
242 fn cluster_dim_x() {
243 let c = ClusterDim::x(4);
244 assert_eq!(c.total(), 4);
245 assert_eq!(c.y, 1);
246 assert_eq!(c.z, 1);
247 }
248
249 #[test]
250 fn cluster_dim_xy() {
251 let c = ClusterDim::xy(2, 4);
252 assert_eq!(c.total(), 8);
253 }
254
255 #[test]
256 fn cluster_dim_display() {
257 let c = ClusterDim::new(2, 1, 1);
258 assert_eq!(format!("{c}"), "ClusterDim(2x1x1)");
259 }
260
261 #[test]
262 fn cluster_dim_validate_zero() {
263 let c = ClusterDim::new(0, 1, 1);
264 assert!(c.validate().is_err());
265 }
266
267 #[test]
268 fn cluster_launch_params_blocks_per_cluster() {
269 let p = ClusterLaunchParams {
270 grid: Dim3::x(16),
271 block: Dim3::x(256),
272 cluster: ClusterDim::new(2, 1, 1),
273 shared_mem_bytes: 0,
274 };
275 assert_eq!(p.blocks_per_cluster(), 2);
276 }
277
278 #[test]
279 fn cluster_count_valid() {
280 let p = ClusterLaunchParams {
281 grid: Dim3::new(8, 4, 2),
282 block: Dim3::x(256),
283 cluster: ClusterDim::new(2, 2, 1),
284 shared_mem_bytes: 0,
285 };
286 let count = p.cluster_count();
287 assert!(count.is_ok());
288 assert_eq!(count.ok(), Some(4 * 2 * 2));
289 }
290
291 #[test]
292 fn cluster_count_not_divisible() {
293 let p = ClusterLaunchParams {
294 grid: Dim3::x(7),
295 block: Dim3::x(256),
296 cluster: ClusterDim::x(2),
297 shared_mem_bytes: 0,
298 };
299 assert!(p.cluster_count().is_err());
300 }
301
302 #[test]
303 fn validate_rejects_zero_block() {
304 let p = ClusterLaunchParams {
305 grid: Dim3::x(4),
306 block: Dim3::new(0, 1, 1),
307 cluster: ClusterDim::x(2),
308 shared_mem_bytes: 0,
309 };
310 assert!(p.validate().is_err());
311 }
312
313 #[test]
314 fn cluster_launch_signature_compiles() {
315 let _: fn(&Kernel, &ClusterLaunchParams, &Stream, &(u32,)) -> CudaResult<()> =
316 cluster_launch;
317 }
318
319 // ---------------------------------------------------------------------------
320 // Quality gate tests (CPU-only)
321 // ---------------------------------------------------------------------------
322
323 #[test]
324 fn cluster_dim_1x1x1_valid() {
325 let c = ClusterDim::new(1, 1, 1);
326 assert_eq!(c.x, 1);
327 assert_eq!(c.y, 1);
328 assert_eq!(c.z, 1);
329 assert_eq!(c.total(), 1);
330 // validate() must succeed for a 1x1x1 cluster
331 assert!(c.validate().is_ok());
332 }
333
334 #[test]
335 fn cluster_dim_2x2x2_valid() {
336 let c = ClusterDim::new(2, 2, 2);
337 assert_eq!(c.total(), 8);
338 assert!(c.validate().is_ok());
339 }
340
341 #[test]
342 fn cluster_dim_8x1x1_valid() {
343 // Maximum 8 blocks per axis is well within hardware limits
344 let c = ClusterDim::new(8, 1, 1);
345 assert_eq!(c.total(), 8);
346 assert!(c.validate().is_ok());
347 }
348
349 #[test]
350 fn cluster_dim_zero_rejected() {
351 // ClusterDim::new(0, 1, 1) is constructable but validate() must return Err
352 let c = ClusterDim::new(0, 1, 1);
353 assert!(
354 c.validate().is_err(),
355 "ClusterDim with zero x must be rejected by validate()"
356 );
357 // Also test zero in y and z
358 let c_y = ClusterDim::new(1, 0, 1);
359 assert!(c_y.validate().is_err(), "ClusterDim with zero y must fail");
360 let c_z = ClusterDim::new(1, 1, 0);
361 assert!(c_z.validate().is_err(), "ClusterDim with zero z must fail");
362 }
363
364 #[test]
365 fn cluster_total_blocks_product() {
366 // total() == x * y * z for arbitrary values
367 let c = ClusterDim::new(3, 2, 4);
368 assert_eq!(c.total(), 3 * 2 * 4);
369
370 let c2 = ClusterDim::new(1, 7, 2);
371 assert_eq!(c2.total(), 7 * 2);
372 }
373
374 #[test]
375 fn cluster_launch_params_contains_cluster_dim() {
376 let cluster = ClusterDim::new(2, 1, 1);
377 let p = ClusterLaunchParams {
378 grid: Dim3::x(16),
379 block: Dim3::x(256),
380 cluster,
381 shared_mem_bytes: 0,
382 };
383 // The ClusterLaunchParams must expose a .cluster field with the right dims
384 assert_eq!(p.cluster.x, 2);
385 assert_eq!(p.cluster.y, 1);
386 assert_eq!(p.cluster.z, 1);
387 assert_eq!(p.cluster.total(), 2);
388 }
389}