use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::stream::Stream;
use crate::grid::Dim3;
use crate::kernel::{Kernel, KernelArgs};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ClusterDim {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl ClusterDim {
#[inline]
pub fn new(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
#[inline]
pub fn x(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}
#[inline]
pub fn xy(x: u32, y: u32) -> Self {
Self { x, y, z: 1 }
}
#[inline]
pub fn total(&self) -> u32 {
self.x * self.y * self.z
}
fn validate(&self) -> CudaResult<()> {
if self.x == 0 || self.y == 0 || self.z == 0 {
return Err(CudaError::InvalidValue);
}
Ok(())
}
}
impl std::fmt::Display for ClusterDim {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ClusterDim({}x{}x{})", self.x, self.y, self.z)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ClusterLaunchParams {
pub grid: Dim3,
pub block: Dim3,
pub cluster: ClusterDim,
pub shared_mem_bytes: u32,
}
impl ClusterLaunchParams {
#[inline]
pub fn blocks_per_cluster(&self) -> u32 {
self.cluster.total()
}
pub fn cluster_count(&self) -> CudaResult<u32> {
self.validate()?;
let cx = self.grid.x / self.cluster.x;
let cy = self.grid.y / self.cluster.y;
let cz = self.grid.z / self.cluster.z;
Ok(cx * cy * cz)
}
pub fn validate(&self) -> CudaResult<()> {
self.cluster.validate()?;
if self.grid.x == 0 || self.grid.y == 0 || self.grid.z == 0 {
return Err(CudaError::InvalidValue);
}
if self.block.x == 0 || self.block.y == 0 || self.block.z == 0 {
return Err(CudaError::InvalidValue);
}
if self.grid.x % self.cluster.x != 0
|| self.grid.y % self.cluster.y != 0
|| self.grid.z % self.cluster.z != 0
{
return Err(CudaError::InvalidValue);
}
Ok(())
}
}
pub fn cluster_launch<A: KernelArgs>(
kernel: &Kernel,
params: &ClusterLaunchParams,
stream: &Stream,
args: &A,
) -> CudaResult<()> {
params.validate()?;
let launch_params = crate::params::LaunchParams {
grid: params.grid,
block: params.block,
shared_mem_bytes: params.shared_mem_bytes,
};
kernel.launch(&launch_params, stream, args)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cluster_dim_new() {
let c = ClusterDim::new(2, 2, 1);
assert_eq!(c.x, 2);
assert_eq!(c.y, 2);
assert_eq!(c.z, 1);
assert_eq!(c.total(), 4);
}
#[test]
fn cluster_dim_x() {
let c = ClusterDim::x(4);
assert_eq!(c.total(), 4);
assert_eq!(c.y, 1);
assert_eq!(c.z, 1);
}
#[test]
fn cluster_dim_xy() {
let c = ClusterDim::xy(2, 4);
assert_eq!(c.total(), 8);
}
#[test]
fn cluster_dim_display() {
let c = ClusterDim::new(2, 1, 1);
assert_eq!(format!("{c}"), "ClusterDim(2x1x1)");
}
#[test]
fn cluster_dim_validate_zero() {
let c = ClusterDim::new(0, 1, 1);
assert!(c.validate().is_err());
}
#[test]
fn cluster_launch_params_blocks_per_cluster() {
let p = ClusterLaunchParams {
grid: Dim3::x(16),
block: Dim3::x(256),
cluster: ClusterDim::new(2, 1, 1),
shared_mem_bytes: 0,
};
assert_eq!(p.blocks_per_cluster(), 2);
}
#[test]
fn cluster_count_valid() {
let p = ClusterLaunchParams {
grid: Dim3::new(8, 4, 2),
block: Dim3::x(256),
cluster: ClusterDim::new(2, 2, 1),
shared_mem_bytes: 0,
};
let count = p.cluster_count();
assert!(count.is_ok());
assert_eq!(count.ok(), Some(4 * 2 * 2));
}
#[test]
fn cluster_count_not_divisible() {
let p = ClusterLaunchParams {
grid: Dim3::x(7),
block: Dim3::x(256),
cluster: ClusterDim::x(2),
shared_mem_bytes: 0,
};
assert!(p.cluster_count().is_err());
}
#[test]
fn validate_rejects_zero_block() {
let p = ClusterLaunchParams {
grid: Dim3::x(4),
block: Dim3::new(0, 1, 1),
cluster: ClusterDim::x(2),
shared_mem_bytes: 0,
};
assert!(p.validate().is_err());
}
#[test]
fn cluster_launch_signature_compiles() {
let _: fn(&Kernel, &ClusterLaunchParams, &Stream, &(u32,)) -> CudaResult<()> =
cluster_launch;
}
#[test]
fn cluster_dim_1x1x1_valid() {
let c = ClusterDim::new(1, 1, 1);
assert_eq!(c.x, 1);
assert_eq!(c.y, 1);
assert_eq!(c.z, 1);
assert_eq!(c.total(), 1);
assert!(c.validate().is_ok());
}
#[test]
fn cluster_dim_2x2x2_valid() {
let c = ClusterDim::new(2, 2, 2);
assert_eq!(c.total(), 8);
assert!(c.validate().is_ok());
}
#[test]
fn cluster_dim_8x1x1_valid() {
let c = ClusterDim::new(8, 1, 1);
assert_eq!(c.total(), 8);
assert!(c.validate().is_ok());
}
#[test]
fn cluster_dim_zero_rejected() {
let c = ClusterDim::new(0, 1, 1);
assert!(
c.validate().is_err(),
"ClusterDim with zero x must be rejected by validate()"
);
let c_y = ClusterDim::new(1, 0, 1);
assert!(c_y.validate().is_err(), "ClusterDim with zero y must fail");
let c_z = ClusterDim::new(1, 1, 0);
assert!(c_z.validate().is_err(), "ClusterDim with zero z must fail");
}
#[test]
fn cluster_total_blocks_product() {
let c = ClusterDim::new(3, 2, 4);
assert_eq!(c.total(), 3 * 2 * 4);
let c2 = ClusterDim::new(1, 7, 2);
assert_eq!(c2.total(), 7 * 2);
}
#[test]
fn cluster_launch_params_contains_cluster_dim() {
let cluster = ClusterDim::new(2, 1, 1);
let p = ClusterLaunchParams {
grid: Dim3::x(16),
block: Dim3::x(256),
cluster,
shared_mem_bytes: 0,
};
assert_eq!(p.cluster.x, 2);
assert_eq!(p.cluster.y, 1);
assert_eq!(p.cluster.z, 1);
assert_eq!(p.cluster.total(), 2);
}
}