use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::stream::Stream;
use crate::kernel::{Kernel, KernelArgs};
use crate::params::LaunchParams;
#[derive(Debug)]
pub struct CooperativeLaunch;
impl CooperativeLaunch {
pub fn launch<A: KernelArgs>(
kernel: &Kernel,
params: &LaunchParams,
stream: &Stream,
args: &A,
) -> CudaResult<()> {
if params.grid.x == 0
|| params.grid.y == 0
|| params.grid.z == 0
|| params.block.x == 0
|| params.block.y == 0
|| params.block.z == 0
{
return Err(CudaError::InvalidValue);
}
let total_blocks = params.grid.total() as u64;
let block_threads = params.block.total();
let max = Self::max_active_blocks_inner(
kernel,
block_threads as i32,
params.shared_mem_bytes as usize,
)?;
if total_blocks > max as u64 {
return Err(CudaError::CooperativeLaunchTooLarge);
}
kernel.launch(params, stream, args)
}
pub fn max_active_blocks(
kernel: &Kernel,
block_size: u32,
dynamic_smem: usize,
) -> CudaResult<u32> {
let result = Self::max_active_blocks_inner(kernel, block_size as i32, dynamic_smem)?;
Ok(result as u32)
}
fn max_active_blocks_inner(
kernel: &Kernel,
block_size: i32,
dynamic_smem: usize,
) -> CudaResult<i32> {
kernel
.function()
.max_active_blocks_per_sm(block_size, dynamic_smem)
}
pub fn optimal_block_size(kernel: &Kernel, dynamic_smem: usize) -> CudaResult<(i32, i32)> {
kernel.function().optimal_block_size(dynamic_smem)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grid::Dim3;
#[test]
fn cooperative_launch_struct_exists() {
let _: fn(&Kernel, &LaunchParams, &Stream, &(u64, u32)) -> CudaResult<()> =
CooperativeLaunch::launch;
}
#[test]
fn max_active_blocks_signature_compiles() {
let _: fn(&Kernel, u32, usize) -> CudaResult<u32> = CooperativeLaunch::max_active_blocks;
}
#[test]
fn optimal_block_size_signature_compiles() {
let _: fn(&Kernel, usize) -> CudaResult<(i32, i32)> = CooperativeLaunch::optimal_block_size;
}
#[test]
fn launch_rejects_zero_grid_x() {
let params = LaunchParams {
grid: Dim3::new(0, 1, 1),
block: Dim3::x(256),
shared_mem_bytes: 0,
};
assert_eq!(params.grid.x, 0);
}
#[test]
fn launch_rejects_zero_block_y() {
let params = LaunchParams {
grid: Dim3::x(4),
block: Dim3::new(256, 0, 1),
shared_mem_bytes: 0,
};
assert_eq!(params.block.y, 0);
}
#[test]
fn cooperative_launch_is_send() {
fn assert_send<T: Send>() {}
assert_send::<CooperativeLaunch>();
}
#[test]
fn cooperative_launch_is_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<CooperativeLaunch>();
}
#[test]
fn cooperative_dim3_total_nonzero() {
let d = Dim3::new(4, 2, 1);
assert_eq!(d.total(), 8);
assert!(d.total() > 0);
}
#[test]
fn cooperative_config_valid_fields() {
let params = LaunchParams {
grid: Dim3::new(4, 1, 1),
block: Dim3::new(256, 1, 1),
shared_mem_bytes: 1024,
};
assert_eq!(params.grid.x, 4);
assert_eq!(params.block.x, 256);
assert_eq!(params.shared_mem_bytes, 1024);
}
#[test]
fn cooperative_max_blocks_constraint_signature() {
let _: fn(&Kernel, &LaunchParams, &Stream, &(u64, u32)) -> CudaResult<()> =
CooperativeLaunch::launch;
}
#[test]
fn cooperative_debug_display() {
let coop = CooperativeLaunch;
let dbg = format!("{coop:?}");
assert!(
dbg.contains("CooperativeLaunch"),
"Debug output must contain type name, got: {dbg}"
);
}
}