#![allow(dead_code)]
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct IndirectDispatchArgs {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl IndirectDispatchArgs {
pub fn new(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
pub fn one_d(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}
pub fn two_d(x: u32, y: u32) -> Self {
Self { x, y, z: 1 }
}
pub fn total_workgroups(&self) -> u64 {
u64::from(self.x) * u64::from(self.y) * u64::from(self.z)
}
pub fn to_bytes(&self) -> [u8; 12] {
let mut buf = [0u8; 12];
buf[0..4].copy_from_slice(&self.x.to_le_bytes());
buf[4..8].copy_from_slice(&self.y.to_le_bytes());
buf[8..12].copy_from_slice(&self.z.to_le_bytes());
buf
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 12 {
return None;
}
let x = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
let y = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
let z = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
Some(Self { x, y, z })
}
pub fn is_valid(&self) -> bool {
self.x > 0 && self.y > 0 && self.z > 0
}
}
impl Default for IndirectDispatchArgs {
fn default() -> Self {
Self { x: 1, y: 1, z: 1 }
}
}
impl fmt::Display for IndirectDispatchArgs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Dispatch({}x{}x{})", self.x, self.y, self.z)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DispatchStrategy {
Linear,
Tiled2D {
tile_w: u32,
tile_h: u32,
},
Volumetric {
vol_w: u32,
vol_h: u32,
vol_d: u32,
},
}
#[allow(clippy::cast_precision_loss)]
pub fn compute_dispatch(
element_count: u32,
workgroup_size: u32,
strategy: DispatchStrategy,
) -> IndirectDispatchArgs {
match strategy {
DispatchStrategy::Linear => {
let groups = (element_count + workgroup_size - 1) / workgroup_size;
IndirectDispatchArgs::one_d(groups)
}
DispatchStrategy::Tiled2D { tile_w, tile_h } => {
let gx = (tile_w + workgroup_size - 1) / workgroup_size;
let gy = (tile_h + workgroup_size - 1) / workgroup_size;
IndirectDispatchArgs::two_d(gx, gy)
}
DispatchStrategy::Volumetric {
vol_w,
vol_h,
vol_d,
} => {
let gx = (vol_w + workgroup_size - 1) / workgroup_size;
let gy = (vol_h + workgroup_size - 1) / workgroup_size;
let gz = (vol_d + workgroup_size - 1) / workgroup_size;
IndirectDispatchArgs::new(gx, gy, gz)
}
}
}
pub struct IndirectBuffer {
args: IndirectDispatchArgs,
label: String,
generation: u64,
}
impl IndirectBuffer {
pub fn new(label: &str) -> Self {
Self {
args: IndirectDispatchArgs::default(),
label: label.to_string(),
generation: 0,
}
}
pub fn with_args(label: &str, args: IndirectDispatchArgs) -> Self {
Self {
args,
label: label.to_string(),
generation: 0,
}
}
pub fn update(&mut self, args: IndirectDispatchArgs) {
self.args = args;
self.generation += 1;
}
pub fn args(&self) -> IndirectDispatchArgs {
self.args
}
pub fn label(&self) -> &str {
&self.label
}
pub fn generation(&self) -> u64 {
self.generation
}
pub fn size_bytes(&self) -> usize {
12
}
pub fn to_bytes(&self) -> [u8; 12] {
self.args.to_bytes()
}
}
pub fn validate_dispatch_limits(
args: &IndirectDispatchArgs,
max_per_dimension: u32,
) -> Result<(), String> {
if args.x > max_per_dimension {
return Err(format!(
"X workgroup count {} exceeds limit {}",
args.x, max_per_dimension
));
}
if args.y > max_per_dimension {
return Err(format!(
"Y workgroup count {} exceeds limit {}",
args.y, max_per_dimension
));
}
if args.z > max_per_dimension {
return Err(format!(
"Z workgroup count {} exceeds limit {}",
args.z, max_per_dimension
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dispatch_args_new() {
let args = IndirectDispatchArgs::new(4, 8, 2);
assert_eq!(args.x, 4);
assert_eq!(args.y, 8);
assert_eq!(args.z, 2);
}
#[test]
fn test_dispatch_args_one_d() {
let args = IndirectDispatchArgs::one_d(16);
assert_eq!(args.x, 16);
assert_eq!(args.y, 1);
assert_eq!(args.z, 1);
}
#[test]
fn test_total_workgroups() {
let args = IndirectDispatchArgs::new(4, 8, 2);
assert_eq!(args.total_workgroups(), 64);
}
#[test]
fn test_to_from_bytes_roundtrip() {
let original = IndirectDispatchArgs::new(123, 456, 789);
let bytes = original.to_bytes();
let restored = IndirectDispatchArgs::from_bytes(&bytes)
.expect("deserialization from bytes should succeed");
assert_eq!(original, restored);
}
#[test]
fn test_from_bytes_too_short() {
assert!(IndirectDispatchArgs::from_bytes(&[0u8; 8]).is_none());
}
#[test]
fn test_is_valid() {
assert!(IndirectDispatchArgs::new(1, 1, 1).is_valid());
assert!(!IndirectDispatchArgs::new(0, 1, 1).is_valid());
assert!(!IndirectDispatchArgs::new(1, 0, 1).is_valid());
assert!(!IndirectDispatchArgs::new(1, 1, 0).is_valid());
}
#[test]
fn test_display() {
let args = IndirectDispatchArgs::new(4, 8, 2);
assert_eq!(format!("{args}"), "Dispatch(4x8x2)");
}
#[test]
fn test_compute_dispatch_linear() {
let args = compute_dispatch(1000, 64, DispatchStrategy::Linear);
assert_eq!(args.x, 16);
assert_eq!(args.y, 1);
assert_eq!(args.z, 1);
}
#[test]
fn test_compute_dispatch_tiled() {
let args = compute_dispatch(
0,
16,
DispatchStrategy::Tiled2D {
tile_w: 1920,
tile_h: 1080,
},
);
assert_eq!(args.x, 120); assert_eq!(args.y, 68); assert_eq!(args.z, 1);
}
#[test]
fn test_compute_dispatch_volumetric() {
let args = compute_dispatch(
0,
8,
DispatchStrategy::Volumetric {
vol_w: 64,
vol_h: 64,
vol_d: 32,
},
);
assert_eq!(args.x, 8);
assert_eq!(args.y, 8);
assert_eq!(args.z, 4);
}
#[test]
fn test_indirect_buffer_new() {
let buf = IndirectBuffer::new("test_buf");
assert_eq!(buf.label(), "test_buf");
assert_eq!(buf.args(), IndirectDispatchArgs::default());
assert_eq!(buf.generation(), 0);
assert_eq!(buf.size_bytes(), 12);
}
#[test]
fn test_indirect_buffer_update() {
let mut buf = IndirectBuffer::new("buf");
buf.update(IndirectDispatchArgs::new(10, 20, 30));
assert_eq!(buf.args().x, 10);
assert_eq!(buf.generation(), 1);
buf.update(IndirectDispatchArgs::one_d(5));
assert_eq!(buf.generation(), 2);
}
#[test]
fn test_validate_dispatch_limits_ok() {
let args = IndirectDispatchArgs::new(100, 100, 100);
assert!(validate_dispatch_limits(&args, 65535).is_ok());
}
#[test]
fn test_validate_dispatch_limits_exceeded() {
let args = IndirectDispatchArgs::new(70000, 1, 1);
assert!(validate_dispatch_limits(&args, 65535).is_err());
}
#[test]
fn test_default_dispatch_args() {
let args = IndirectDispatchArgs::default();
assert_eq!(args.x, 1);
assert_eq!(args.y, 1);
assert_eq!(args.z, 1);
assert!(args.is_valid());
}
}