pub const MAX_WORKGROUP_DIM: u32 = 256;
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct WorkgroupSize {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl WorkgroupSize {
#[allow(dead_code)]
#[must_use]
pub const fn linear(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}
#[allow(dead_code)]
#[must_use]
pub const fn planar(x: u32, y: u32) -> Self {
Self { x, y, z: 1 }
}
#[allow(dead_code)]
#[must_use]
pub const fn new(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
#[allow(dead_code)]
#[must_use]
pub const fn thread_count(self) -> u32 {
self.x * self.y * self.z
}
#[allow(dead_code)]
#[must_use]
pub fn is_valid(self, max_threads: u32) -> bool {
self.x >= 1 && self.y >= 1 && self.z >= 1 && self.thread_count() <= max_threads
}
}
impl Default for WorkgroupSize {
fn default() -> Self {
Self::linear(64)
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DispatchGrid {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl DispatchGrid {
#[allow(dead_code)]
#[must_use]
pub const fn new(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
#[allow(dead_code)]
#[must_use]
pub const fn total_workgroups(self) -> u64 {
self.x as u64 * self.y as u64 * self.z as u64
}
#[allow(dead_code)]
#[must_use]
pub const fn total_threads(self, wg: WorkgroupSize) -> u64 {
self.total_workgroups() * wg.thread_count() as u64
}
}
#[allow(dead_code)]
#[must_use]
pub fn dispatch_1d(count: u32, wg_size: u32) -> DispatchGrid {
assert!(wg_size > 0, "wg_size must be > 0");
let x = count.div_ceil(wg_size);
DispatchGrid::new(x, 1, 1)
}
#[allow(dead_code)]
#[must_use]
pub fn dispatch_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> DispatchGrid {
assert!(wg_x > 0 && wg_y > 0, "workgroup dims must be > 0");
let x = width.div_ceil(wg_x);
let y = height.div_ceil(wg_y);
DispatchGrid::new(x, y, 1)
}
#[allow(dead_code)]
#[must_use]
pub fn dispatch_3d(
width: u32,
height: u32,
depth: u32,
wg_x: u32,
wg_y: u32,
wg_z: u32,
) -> DispatchGrid {
assert!(
wg_x > 0 && wg_y > 0 && wg_z > 0,
"workgroup dims must be > 0"
);
DispatchGrid::new(
width.div_ceil(wg_x),
height.div_ceil(wg_y),
depth.div_ceil(wg_z),
)
}
#[allow(dead_code)]
#[must_use]
pub fn recommend_2d_workgroup(max_threads: u32) -> WorkgroupSize {
let mut side = 1u32;
while side * side * 4 <= max_threads {
side *= 2;
}
while side * side > max_threads {
side /= 2;
}
WorkgroupSize::planar(side.max(1), side.max(1))
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BarrierKind {
MemoryReadAfterWrite,
ExecutionOnly,
Full,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct BarrierRecord {
pub index: u32,
pub kind: BarrierKind,
pub label: Option<String>,
}
#[allow(dead_code)]
#[derive(Debug, Default)]
pub struct BarrierTracker {
records: Vec<BarrierRecord>,
next_index: u32,
}
impl BarrierTracker {
#[allow(dead_code)]
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[allow(dead_code)]
pub fn push(&mut self, kind: BarrierKind, label: Option<&str>) {
self.records.push(BarrierRecord {
index: self.next_index,
kind,
label: label.map(String::from),
});
self.next_index += 1;
}
#[allow(dead_code)]
#[must_use]
pub fn len(&self) -> usize {
self.records.len()
}
#[allow(dead_code)]
#[must_use]
pub fn is_empty(&self) -> bool {
self.records.is_empty()
}
#[allow(dead_code)]
#[must_use]
pub fn records(&self) -> &[BarrierRecord] {
&self.records
}
#[allow(dead_code)]
#[must_use]
pub fn count_of_kind(&self, kind: BarrierKind) -> usize {
self.records.iter().filter(|r| r.kind == kind).count()
}
#[allow(dead_code)]
pub fn reset(&mut self) {
self.records.clear();
self.next_index = 0;
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct DispatchRecord {
pub index: u32,
pub pipeline_id: String,
pub grid: DispatchGrid,
pub workgroup_size: WorkgroupSize,
}
#[allow(dead_code)]
#[derive(Debug, Default)]
pub struct DispatchTracker {
records: Vec<DispatchRecord>,
next_index: u32,
}
impl DispatchTracker {
#[allow(dead_code)]
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[allow(dead_code)]
pub fn push(
&mut self,
pipeline_id: impl Into<String>,
grid: DispatchGrid,
workgroup_size: WorkgroupSize,
) {
self.records.push(DispatchRecord {
index: self.next_index,
pipeline_id: pipeline_id.into(),
grid,
workgroup_size,
});
self.next_index += 1;
}
#[allow(dead_code)]
#[must_use]
pub fn len(&self) -> usize {
self.records.len()
}
#[allow(dead_code)]
#[must_use]
pub fn is_empty(&self) -> bool {
self.records.is_empty()
}
#[allow(dead_code)]
#[must_use]
pub fn total_threads(&self) -> u64 {
self.records
.iter()
.map(|r| r.grid.total_threads(r.workgroup_size))
.sum()
}
#[allow(dead_code)]
#[must_use]
pub fn records(&self) -> &[DispatchRecord] {
&self.records
}
#[allow(dead_code)]
pub fn reset(&mut self) {
self.records.clear();
self.next_index = 0;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DataDispatchStrategy {
Linear1D,
Square2D,
FixedRowCount {
rows: u32,
},
}
pub struct DataDrivenDispatch {
wg_x: u32,
wg_y: u32,
strategy: DataDispatchStrategy,
grid: Option<DispatchGrid>,
last_element_count: u64,
}
impl DataDrivenDispatch {
#[must_use]
pub fn new(wg_x: u32, wg_y: u32, strategy: DataDispatchStrategy) -> Self {
let wg_x = wg_x.max(1);
let wg_y = wg_y.max(1);
Self {
wg_x,
wg_y,
strategy,
grid: None,
last_element_count: 0,
}
}
#[must_use]
pub fn linear(wg_size: u32) -> Self {
Self::new(wg_size, 1, DataDispatchStrategy::Linear1D)
}
#[must_use]
pub fn square(wg_x: u32, wg_y: u32) -> Self {
Self::new(wg_x, wg_y, DataDispatchStrategy::Square2D)
}
pub fn prepare(&mut self, element_count: u64) -> DispatchGrid {
self.last_element_count = element_count;
let n = element_count as u32;
let grid = match self.strategy {
DataDispatchStrategy::Linear1D => {
let x = n.div_ceil(self.wg_x);
DispatchGrid::new(x.max(1), 1, 1)
}
DataDispatchStrategy::Square2D => {
let threads_per_wg = self.wg_x * self.wg_y;
let total_wgs = n.div_ceil(threads_per_wg).max(1);
let side = (total_wgs as f64).sqrt().ceil() as u32;
let side = side.max(1);
DispatchGrid::new(side, side, 1)
}
DataDispatchStrategy::FixedRowCount { rows } => {
let rows = rows.max(1);
let total_wgs = n.div_ceil(self.wg_x * self.wg_y).max(1);
let cols = total_wgs.div_ceil(rows);
DispatchGrid::new(cols, rows, 1)
}
};
self.grid = Some(grid);
grid
}
#[must_use]
pub fn grid(&self) -> Option<DispatchGrid> {
self.grid
}
#[must_use]
pub fn last_element_count(&self) -> u64 {
self.last_element_count
}
#[must_use]
pub fn covered_elements(&self) -> u64 {
match self.grid {
None => 0,
Some(g) => {
u64::from(g.total_workgroups()) * u64::from(self.wg_x) * u64::from(self.wg_y)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workgroup_thread_count() {
let wg = WorkgroupSize::new(8, 8, 1);
assert_eq!(wg.thread_count(), 64);
}
#[test]
fn test_workgroup_is_valid() {
assert!(WorkgroupSize::linear(64).is_valid(1024));
assert!(!WorkgroupSize::new(33, 33, 1).is_valid(1024));
}
#[test]
fn test_dispatch_1d_exact() {
let g = dispatch_1d(256, 64);
assert_eq!(g.x, 4);
assert_eq!(g.y, 1);
assert_eq!(g.z, 1);
}
#[test]
fn test_dispatch_1d_rounds_up() {
let g = dispatch_1d(257, 64);
assert_eq!(g.x, 5);
}
#[test]
fn test_dispatch_2d() {
let g = dispatch_2d(1920, 1080, 16, 16);
assert_eq!(g.x, 120); assert_eq!(g.y, 68); }
#[test]
fn test_dispatch_3d() {
let g = dispatch_3d(8, 8, 8, 4, 4, 4);
assert_eq!(g.x, 2);
assert_eq!(g.y, 2);
assert_eq!(g.z, 2);
}
#[test]
fn test_total_workgroups() {
let g = DispatchGrid::new(4, 4, 1);
assert_eq!(g.total_workgroups(), 16);
}
#[test]
fn test_total_threads() {
let g = DispatchGrid::new(2, 2, 1);
let wg = WorkgroupSize::planar(8, 8);
assert_eq!(g.total_threads(wg), 256);
}
#[test]
fn test_recommend_2d_workgroup_within_limit() {
let wg = recommend_2d_workgroup(256);
assert!(wg.thread_count() <= 256);
}
#[test]
fn test_recommend_2d_workgroup_square() {
let wg = recommend_2d_workgroup(1024);
assert_eq!(wg.x, wg.y);
}
#[test]
fn test_barrier_tracker_push_and_count() {
let mut bt = BarrierTracker::new();
bt.push(BarrierKind::MemoryReadAfterWrite, Some("pre-blur"));
bt.push(BarrierKind::Full, None);
assert_eq!(bt.len(), 2);
assert_eq!(bt.count_of_kind(BarrierKind::Full), 1);
}
#[test]
fn test_barrier_tracker_reset() {
let mut bt = BarrierTracker::new();
bt.push(BarrierKind::ExecutionOnly, None);
bt.reset();
assert!(bt.is_empty());
}
#[test]
fn test_dispatch_tracker_total_threads() {
let mut dt = DispatchTracker::new();
dt.push(
"blur",
DispatchGrid::new(10, 10, 1),
WorkgroupSize::planar(8, 8),
);
assert_eq!(dt.total_threads(), 6400);
}
#[test]
fn test_dispatch_tracker_records_sequential_indices() {
let mut dt = DispatchTracker::new();
dt.push("a", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
dt.push("b", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
assert_eq!(dt.records()[0].index, 0);
assert_eq!(dt.records()[1].index, 1);
}
#[test]
fn test_dispatch_tracker_reset() {
let mut dt = DispatchTracker::new();
dt.push("x", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(32));
dt.reset();
assert!(dt.is_empty());
assert_eq!(dt.total_threads(), 0);
}
#[test]
fn test_data_driven_linear_exact() {
let mut dd = DataDrivenDispatch::linear(64);
let g = dd.prepare(128);
assert_eq!(g.x, 2);
assert_eq!(g.y, 1);
assert_eq!(g.z, 1);
}
#[test]
fn test_data_driven_linear_rounds_up() {
let mut dd = DataDrivenDispatch::linear(64);
let g = dd.prepare(65);
assert_eq!(g.x, 2);
}
#[test]
fn test_data_driven_linear_zero_elements() {
let mut dd = DataDrivenDispatch::linear(64);
let g = dd.prepare(0);
assert_eq!(g.x, 1);
}
#[test]
fn test_data_driven_square_covers_all_elements() {
let mut dd = DataDrivenDispatch::square(8, 8);
dd.prepare(500);
assert!(dd.covered_elements() >= 500);
}
#[test]
fn test_data_driven_square_grid_is_square() {
let mut dd = DataDrivenDispatch::square(8, 8);
let g = dd.prepare(1024);
assert_eq!(g.x, g.y);
}
#[test]
fn test_data_driven_fixed_row_count() {
let mut dd = DataDrivenDispatch::new(8, 1, DataDispatchStrategy::FixedRowCount { rows: 4 });
let g = dd.prepare(256);
assert_eq!(g.y, 4);
assert_eq!(g.x, 8);
}
#[test]
fn test_data_driven_grid_none_before_prepare() {
let dd = DataDrivenDispatch::linear(32);
assert!(dd.grid().is_none());
assert_eq!(dd.covered_elements(), 0);
}
#[test]
fn test_data_driven_last_element_count_stored() {
let mut dd = DataDrivenDispatch::linear(16);
dd.prepare(999);
assert_eq!(dd.last_element_count(), 999);
}
#[test]
fn test_data_driven_covered_elements_gte_last_count() {
let mut dd = DataDrivenDispatch::square(4, 4);
let count = 137_u64;
dd.prepare(count);
assert!(dd.covered_elements() >= count);
}
}