Skip to main content

oxicuda_launch/
cluster.rs

1//! Thread block cluster configuration for Hopper+ GPUs (SM 9.0+).
2//!
3//! Thread block clusters are a new level of the CUDA execution hierarchy
4//! introduced with the NVIDIA Hopper architecture (compute capability 9.0).
5//! A cluster groups multiple thread blocks that can cooperate more
6//! efficiently via distributed shared memory and hardware-accelerated
7//! synchronisation.
8//!
9//! # Requirements
10//!
11//! - NVIDIA Hopper (H100) or later GPU (compute capability 9.0+).
12//! - CUDA driver version 12.0 or later.
13//! - The kernel must be compiled with cluster support.
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! # use oxicuda_launch::cluster::{ClusterDim, ClusterLaunchParams};
19//! # use oxicuda_launch::Dim3;
20//! let cluster_params = ClusterLaunchParams {
21//!     grid: Dim3::x(16),
22//!     block: Dim3::x(256),
23//!     cluster: ClusterDim::new(2, 1, 1),
24//!     shared_mem_bytes: 0,
25//! };
26//! assert_eq!(cluster_params.blocks_per_cluster(), 2);
27//! ```
28
29use oxicuda_driver::error::{CudaError, CudaResult};
30use oxicuda_driver::stream::Stream;
31
32use crate::grid::Dim3;
33use crate::kernel::{Kernel, KernelArgs};
34
35// ---------------------------------------------------------------------------
36// ClusterDim
37// ---------------------------------------------------------------------------
38
39/// Cluster dimensions specifying how many thread blocks form one cluster.
40///
41/// Each dimension specifies the number of blocks in that direction of
42/// the cluster. The total number of blocks in a cluster is
43/// `x * y * z`.
44///
45/// # Constraints
46///
47/// - All dimensions must be non-zero.
48/// - The total blocks per cluster must not exceed the hardware limit
49///   (typically 8 or 16 for Hopper GPUs).
50/// - The grid dimensions must be evenly divisible by the cluster
51///   dimensions.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53pub struct ClusterDim {
54    /// Number of blocks in the X dimension of the cluster.
55    pub x: u32,
56    /// Number of blocks in the Y dimension of the cluster.
57    pub y: u32,
58    /// Number of blocks in the Z dimension of the cluster.
59    pub z: u32,
60}
61
62impl ClusterDim {
63    /// Creates a new cluster dimension.
64    #[inline]
65    pub fn new(x: u32, y: u32, z: u32) -> Self {
66        Self { x, y, z }
67    }
68
69    /// Creates a 1D cluster (only X dimension used).
70    #[inline]
71    pub fn x(x: u32) -> Self {
72        Self { x, y: 1, z: 1 }
73    }
74
75    /// Creates a 2D cluster.
76    #[inline]
77    pub fn xy(x: u32, y: u32) -> Self {
78        Self { x, y, z: 1 }
79    }
80
81    /// Returns the total number of blocks per cluster.
82    #[inline]
83    pub fn total(&self) -> u32 {
84        self.x * self.y * self.z
85    }
86
87    /// Validates that all dimensions are non-zero.
88    fn validate(&self) -> CudaResult<()> {
89        if self.x == 0 || self.y == 0 || self.z == 0 {
90            return Err(CudaError::InvalidValue);
91        }
92        Ok(())
93    }
94}
95
96impl std::fmt::Display for ClusterDim {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        write!(f, "ClusterDim({}x{}x{})", self.x, self.y, self.z)
99    }
100}
101
102// ---------------------------------------------------------------------------
103// ClusterLaunchParams
104// ---------------------------------------------------------------------------
105
106/// Launch parameters including thread block cluster configuration.
107///
108/// Extends the standard grid/block configuration with a cluster
109/// dimension. The grid dimensions must be evenly divisible by the
110/// cluster dimensions.
111#[derive(Debug, Clone, Copy)]
112pub struct ClusterLaunchParams {
113    /// Grid dimensions (number of thread blocks total).
114    pub grid: Dim3,
115    /// Block dimensions (threads per block).
116    pub block: Dim3,
117    /// Cluster dimensions (blocks per cluster).
118    pub cluster: ClusterDim,
119    /// Dynamic shared memory per block in bytes.
120    pub shared_mem_bytes: u32,
121}
122
123impl ClusterLaunchParams {
124    /// Returns the total number of blocks per cluster.
125    #[inline]
126    pub fn blocks_per_cluster(&self) -> u32 {
127        self.cluster.total()
128    }
129
130    /// Returns the total number of clusters in the grid.
131    ///
132    /// This requires that the grid dimensions be evenly divisible by
133    /// the cluster dimensions.
134    ///
135    /// # Errors
136    ///
137    /// Returns [`CudaError::InvalidValue`] if the grid is not evenly
138    /// divisible by the cluster dimensions, or if any dimension is zero.
139    pub fn cluster_count(&self) -> CudaResult<u32> {
140        self.validate()?;
141        let cx = self.grid.x / self.cluster.x;
142        let cy = self.grid.y / self.cluster.y;
143        let cz = self.grid.z / self.cluster.z;
144        Ok(cx * cy * cz)
145    }
146
147    /// Validates the cluster launch parameters.
148    ///
149    /// Checks that:
150    /// - All grid, block, and cluster dimensions are non-zero.
151    /// - The grid dimensions are evenly divisible by the cluster dimensions.
152    ///
153    /// # Errors
154    ///
155    /// Returns [`CudaError::InvalidValue`] on any violation.
156    pub fn validate(&self) -> CudaResult<()> {
157        // Validate cluster dims
158        self.cluster.validate()?;
159
160        // Validate grid dims are non-zero
161        if self.grid.x == 0 || self.grid.y == 0 || self.grid.z == 0 {
162            return Err(CudaError::InvalidValue);
163        }
164        if self.block.x == 0 || self.block.y == 0 || self.block.z == 0 {
165            return Err(CudaError::InvalidValue);
166        }
167
168        // Grid must be divisible by cluster
169        if self.grid.x % self.cluster.x != 0
170            || self.grid.y % self.cluster.y != 0
171            || self.grid.z % self.cluster.z != 0
172        {
173            return Err(CudaError::InvalidValue);
174        }
175
176        Ok(())
177    }
178}
179
180// ---------------------------------------------------------------------------
181// cluster_launch
182// ---------------------------------------------------------------------------
183
184/// Launches a kernel with thread block cluster configuration.
185///
186/// On Hopper+ GPUs (compute capability 9.0+), this groups thread blocks
187/// into clusters for enhanced cooperation via distributed shared memory.
188///
189/// This function validates the cluster parameters and delegates to the
190/// standard kernel launch. On hardware that supports clusters natively,
191/// the CUDA driver would use `cuLaunchKernelEx` with cluster attributes.
192///
193/// # Parameters
194///
195/// * `kernel` — the kernel to launch.
196/// * `params` — cluster-aware launch parameters.
197/// * `stream` — the stream to launch on.
198/// * `args` — kernel arguments.
199///
200/// # Errors
201///
202/// Returns [`CudaError::InvalidValue`] if the parameters are invalid
203/// (zero dimensions, grid not divisible by cluster, etc.), or another
204/// error from the underlying kernel launch.
205pub fn cluster_launch<A: KernelArgs>(
206    kernel: &Kernel,
207    params: &ClusterLaunchParams,
208    stream: &Stream,
209    args: &A,
210) -> CudaResult<()> {
211    params.validate()?;
212
213    // Convert to standard launch params. The cluster dimension is
214    // a hint to the driver; the actual grid/block stays the same.
215    let launch_params = crate::params::LaunchParams {
216        grid: params.grid,
217        block: params.block,
218        shared_mem_bytes: params.shared_mem_bytes,
219    };
220
221    kernel.launch(&launch_params, stream, args)
222}
223
224// ---------------------------------------------------------------------------
225// Tests
226// ---------------------------------------------------------------------------
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn cluster_dim_new() {
234        let c = ClusterDim::new(2, 2, 1);
235        assert_eq!(c.x, 2);
236        assert_eq!(c.y, 2);
237        assert_eq!(c.z, 1);
238        assert_eq!(c.total(), 4);
239    }
240
241    #[test]
242    fn cluster_dim_x() {
243        let c = ClusterDim::x(4);
244        assert_eq!(c.total(), 4);
245        assert_eq!(c.y, 1);
246        assert_eq!(c.z, 1);
247    }
248
249    #[test]
250    fn cluster_dim_xy() {
251        let c = ClusterDim::xy(2, 4);
252        assert_eq!(c.total(), 8);
253    }
254
255    #[test]
256    fn cluster_dim_display() {
257        let c = ClusterDim::new(2, 1, 1);
258        assert_eq!(format!("{c}"), "ClusterDim(2x1x1)");
259    }
260
261    #[test]
262    fn cluster_dim_validate_zero() {
263        let c = ClusterDim::new(0, 1, 1);
264        assert!(c.validate().is_err());
265    }
266
267    #[test]
268    fn cluster_launch_params_blocks_per_cluster() {
269        let p = ClusterLaunchParams {
270            grid: Dim3::x(16),
271            block: Dim3::x(256),
272            cluster: ClusterDim::new(2, 1, 1),
273            shared_mem_bytes: 0,
274        };
275        assert_eq!(p.blocks_per_cluster(), 2);
276    }
277
278    #[test]
279    fn cluster_count_valid() {
280        let p = ClusterLaunchParams {
281            grid: Dim3::new(8, 4, 2),
282            block: Dim3::x(256),
283            cluster: ClusterDim::new(2, 2, 1),
284            shared_mem_bytes: 0,
285        };
286        let count = p.cluster_count();
287        assert!(count.is_ok());
288        assert_eq!(count.ok(), Some(4 * 2 * 2));
289    }
290
291    #[test]
292    fn cluster_count_not_divisible() {
293        let p = ClusterLaunchParams {
294            grid: Dim3::x(7),
295            block: Dim3::x(256),
296            cluster: ClusterDim::x(2),
297            shared_mem_bytes: 0,
298        };
299        assert!(p.cluster_count().is_err());
300    }
301
302    #[test]
303    fn validate_rejects_zero_block() {
304        let p = ClusterLaunchParams {
305            grid: Dim3::x(4),
306            block: Dim3::new(0, 1, 1),
307            cluster: ClusterDim::x(2),
308            shared_mem_bytes: 0,
309        };
310        assert!(p.validate().is_err());
311    }
312
313    #[test]
314    fn cluster_launch_signature_compiles() {
315        let _: fn(&Kernel, &ClusterLaunchParams, &Stream, &(u32,)) -> CudaResult<()> =
316            cluster_launch;
317    }
318
319    // ---------------------------------------------------------------------------
320    // Quality gate tests (CPU-only)
321    // ---------------------------------------------------------------------------
322
323    #[test]
324    fn cluster_dim_1x1x1_valid() {
325        let c = ClusterDim::new(1, 1, 1);
326        assert_eq!(c.x, 1);
327        assert_eq!(c.y, 1);
328        assert_eq!(c.z, 1);
329        assert_eq!(c.total(), 1);
330        // validate() must succeed for a 1x1x1 cluster
331        assert!(c.validate().is_ok());
332    }
333
334    #[test]
335    fn cluster_dim_2x2x2_valid() {
336        let c = ClusterDim::new(2, 2, 2);
337        assert_eq!(c.total(), 8);
338        assert!(c.validate().is_ok());
339    }
340
341    #[test]
342    fn cluster_dim_8x1x1_valid() {
343        // Maximum 8 blocks per axis is well within hardware limits
344        let c = ClusterDim::new(8, 1, 1);
345        assert_eq!(c.total(), 8);
346        assert!(c.validate().is_ok());
347    }
348
349    #[test]
350    fn cluster_dim_zero_rejected() {
351        // ClusterDim::new(0, 1, 1) is constructable but validate() must return Err
352        let c = ClusterDim::new(0, 1, 1);
353        assert!(
354            c.validate().is_err(),
355            "ClusterDim with zero x must be rejected by validate()"
356        );
357        // Also test zero in y and z
358        let c_y = ClusterDim::new(1, 0, 1);
359        assert!(c_y.validate().is_err(), "ClusterDim with zero y must fail");
360        let c_z = ClusterDim::new(1, 1, 0);
361        assert!(c_z.validate().is_err(), "ClusterDim with zero z must fail");
362    }
363
364    #[test]
365    fn cluster_total_blocks_product() {
366        // total() == x * y * z for arbitrary values
367        let c = ClusterDim::new(3, 2, 4);
368        assert_eq!(c.total(), 3 * 2 * 4);
369
370        let c2 = ClusterDim::new(1, 7, 2);
371        assert_eq!(c2.total(), 7 * 2);
372    }
373
374    #[test]
375    fn cluster_launch_params_contains_cluster_dim() {
376        let cluster = ClusterDim::new(2, 1, 1);
377        let p = ClusterLaunchParams {
378            grid: Dim3::x(16),
379            block: Dim3::x(256),
380            cluster,
381            shared_mem_bytes: 0,
382        };
383        // The ClusterLaunchParams must expose a .cluster field with the right dims
384        assert_eq!(p.cluster.x, 2);
385        assert_eq!(p.cluster.y, 1);
386        assert_eq!(p.cluster.z, 1);
387        assert_eq!(p.cluster.total(), 2);
388    }
389}