use oxicuda_driver::device::Device;
use crate::error::LaunchError;
use crate::grid::Dim3;
#[derive(Debug, Clone, Copy)]
pub struct LaunchParams {
pub grid: Dim3,
pub block: Dim3,
pub shared_mem_bytes: u32,
}
impl LaunchParams {
#[inline]
pub fn new(grid: impl Into<Dim3>, block: impl Into<Dim3>) -> Self {
Self {
grid: grid.into(),
block: block.into(),
shared_mem_bytes: 0,
}
}
#[inline]
pub fn with_shared_mem(mut self, bytes: u32) -> Self {
self.shared_mem_bytes = bytes;
self
}
#[inline]
pub fn builder() -> LaunchParamsBuilder {
LaunchParamsBuilder::default()
}
#[inline]
pub fn total_threads(&self) -> u64 {
self.grid.total() as u64 * self.block.total() as u64
}
pub fn validate(&self, device: &Device) -> Result<(), Box<dyn std::error::Error>> {
self.validate_inner(device)
}
fn validate_inner(&self, device: &Device) -> Result<(), Box<dyn std::error::Error>> {
if self.block.x == 0 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "block.x",
value: 0,
}));
}
if self.block.y == 0 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "block.y",
value: 0,
}));
}
if self.block.z == 0 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "block.z",
value: 0,
}));
}
if self.grid.x == 0 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "grid.x",
value: 0,
}));
}
if self.grid.y == 0 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "grid.y",
value: 0,
}));
}
if self.grid.z == 0 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "grid.z",
value: 0,
}));
}
let max_threads = device.max_threads_per_block()? as u32;
let block_total = self.block.total();
if block_total > max_threads {
return Err(Box::new(LaunchError::BlockSizeExceedsLimit {
requested: block_total,
max: max_threads,
}));
}
let (max_bx, max_by, max_bz) = device.max_block_dim()?;
if self.block.x > max_bx as u32 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "block.x",
value: self.block.x,
}));
}
if self.block.y > max_by as u32 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "block.y",
value: self.block.y,
}));
}
if self.block.z > max_bz as u32 {
return Err(Box::new(LaunchError::InvalidDimension {
dim: "block.z",
value: self.block.z,
}));
}
let (max_gx, max_gy, max_gz) = device.max_grid_dim()?;
if self.grid.x > max_gx as u32 {
return Err(Box::new(LaunchError::GridSizeExceedsLimit {
requested: self.grid.x,
max: max_gx as u32,
}));
}
if self.grid.y > max_gy as u32 {
return Err(Box::new(LaunchError::GridSizeExceedsLimit {
requested: self.grid.y,
max: max_gy as u32,
}));
}
if self.grid.z > max_gz as u32 {
return Err(Box::new(LaunchError::GridSizeExceedsLimit {
requested: self.grid.z,
max: max_gz as u32,
}));
}
let max_smem = device.max_shared_memory_per_block()? as u32;
if self.shared_mem_bytes > max_smem {
return Err(Box::new(LaunchError::SharedMemoryExceedsLimit {
requested: self.shared_mem_bytes,
max: max_smem,
}));
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct LaunchParamsBuilder {
grid: Option<Dim3>,
block: Option<Dim3>,
shared_mem_bytes: u32,
}
impl LaunchParamsBuilder {
#[inline]
pub fn grid(mut self, dim: impl Into<Dim3>) -> Self {
self.grid = Some(dim.into());
self
}
#[inline]
pub fn block(mut self, dim: impl Into<Dim3>) -> Self {
self.block = Some(dim.into());
self
}
#[inline]
pub fn shared_mem(mut self, bytes: u32) -> Self {
self.shared_mem_bytes = bytes;
self
}
#[inline]
pub fn build(self) -> LaunchParams {
LaunchParams {
grid: self.grid.unwrap_or(Dim3::x(1)),
block: self.block.unwrap_or(Dim3::x(1)),
shared_mem_bytes: self.shared_mem_bytes,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn launch_params_new_basic() {
let p = LaunchParams::new(4u32, 256u32);
assert_eq!(p.grid, Dim3::x(4));
assert_eq!(p.block, Dim3::x(256));
assert_eq!(p.shared_mem_bytes, 0);
}
#[test]
fn launch_params_new_with_dim3() {
let p = LaunchParams::new(Dim3::xy(4, 4), Dim3::xy(16, 16));
assert_eq!(p.grid.total(), 16);
assert_eq!(p.block.total(), 256);
}
#[test]
fn launch_params_new_with_tuples() {
let p = LaunchParams::new((4u32, 4u32), (16u32, 16u32));
assert_eq!(p.grid, Dim3::xy(4, 4));
assert_eq!(p.block, Dim3::xy(16, 16));
}
#[test]
fn launch_params_with_shared_mem() {
let p = LaunchParams::new(1u32, 256u32).with_shared_mem(8192);
assert_eq!(p.shared_mem_bytes, 8192);
}
#[test]
fn launch_params_total_threads() {
let p = LaunchParams::new(4u32, 256u32);
assert_eq!(p.total_threads(), 1024);
let p = LaunchParams::new(Dim3::xy(4, 4), Dim3::xy(16, 16));
assert_eq!(p.total_threads(), 16 * 256);
}
#[test]
fn launch_params_total_threads_large() {
let p = LaunchParams::new(Dim3::xy(65535, 65535), Dim3::x(1024));
let expected = 65535u64 * 65535u64 * 1024u64;
assert_eq!(p.total_threads(), expected);
}
#[test]
fn builder_defaults() {
let p = LaunchParams::builder().build();
assert_eq!(p.grid, Dim3::x(1));
assert_eq!(p.block, Dim3::x(1));
assert_eq!(p.shared_mem_bytes, 0);
}
#[test]
fn builder_full() {
let p = LaunchParams::builder()
.grid(128u32)
.block(256u32)
.shared_mem(4096)
.build();
assert_eq!(p.grid, Dim3::x(128));
assert_eq!(p.block, Dim3::x(256));
assert_eq!(p.shared_mem_bytes, 4096);
}
#[test]
fn builder_partial_grid_only() {
let p = LaunchParams::builder().grid(64u32).build();
assert_eq!(p.grid, Dim3::x(64));
assert_eq!(p.block, Dim3::x(1));
}
#[test]
fn builder_partial_block_only() {
let p = LaunchParams::builder().block(512u32).build();
assert_eq!(p.grid, Dim3::x(1));
assert_eq!(p.block, Dim3::x(512));
}
#[test]
fn builder_with_tuple_dims() {
let p = LaunchParams::builder()
.grid((8u32, 8u32))
.block((16u32, 16u32, 1u32))
.build();
assert_eq!(p.grid, Dim3::xy(8, 8));
assert_eq!(p.block, Dim3::new(16, 16, 1));
}
type ValidateFn = fn(&LaunchParams, &Device) -> Result<(), Box<dyn std::error::Error>>;
#[test]
fn validate_zero_block_x() {
let p = LaunchParams {
grid: Dim3::x(1),
block: Dim3::new(0, 1, 1),
shared_mem_bytes: 0,
};
let _validate_fn: ValidateFn = LaunchParams::validate;
assert_eq!(p.block.x, 0);
}
#[test]
fn validate_zero_grid_z() {
let p = LaunchParams {
grid: Dim3::new(1, 1, 0),
block: Dim3::x(256),
shared_mem_bytes: 0,
};
assert_eq!(p.grid.z, 0);
}
#[test]
fn validate_signature_compiles() {
let _: ValidateFn = LaunchParams::validate;
}
#[cfg(feature = "gpu-tests")]
#[test]
fn validate_with_real_device() {
oxicuda_driver::init().ok();
if let Ok(dev) = Device::get(0) {
let p = LaunchParams::new(4u32, 256u32);
assert!(p.validate(&dev).is_ok());
let p2 = LaunchParams::new(1u32, Dim3::new(1024, 1024, 1));
assert!(p2.validate(&dev).is_err());
}
}
}