use vyre::ir::Program;
use crate::workgroup::{InvocationIds, MAX_WORKGROUP_BYTES};
#[derive(Debug, Clone, Copy)]
pub struct SequentialWorkgroup {
pub size: [u32; 3],
}
impl SequentialWorkgroup {
#[must_use]
pub fn for_program(program: &Program) -> Self {
Self {
size: program.workgroup_size(),
}
}
#[must_use]
pub fn invocation_count(&self) -> u32 {
self.size[0]
.saturating_mul(self.size[1])
.saturating_mul(self.size[2])
}
pub fn invocations(&self, workgroup_id: [u32; 3]) -> impl Iterator<Item = InvocationIds> {
let [sx, sy, sz] = self.size;
let wg = workgroup_id;
(0..sz).flat_map(move |lz| {
(0..sy).flat_map(move |ly| {
(0..sx).map(move |lx| InvocationIds {
global: [
wg[0].saturating_mul(sx).saturating_add(lx),
wg[1].saturating_mul(sy).saturating_add(ly),
wg[2].saturating_mul(sz).saturating_add(lz),
],
workgroup: wg,
local: [lx, ly, lz],
})
})
})
}
}
pub const MAX_SHARED_BYTES: usize = MAX_WORKGROUP_BYTES;
#[cfg(test)]
mod tests {
use super::*;
use vyre::ir::{BufferDecl, DataType, Node, Program};
fn trivial_program(size: [u32; 3]) -> Program {
Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32)],
size,
vec![Node::let_bind("idx", vyre::ir::Expr::gid_x())],
)
}
#[test]
fn invocation_count_is_product() {
let wg = SequentialWorkgroup::for_program(&trivial_program([4, 2, 1]));
assert_eq!(wg.invocation_count(), 8);
}
#[test]
fn invocation_order_is_canonical() {
let wg = SequentialWorkgroup { size: [2, 2, 1] };
let ids: Vec<_> = wg.invocations([0, 0, 0]).map(|i| i.local).collect();
assert_eq!(ids, vec![[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]);
}
#[test]
fn invocation_globals_offset_by_workgroup() {
let wg = SequentialWorkgroup { size: [2, 1, 1] };
let ids: Vec<_> = wg.invocations([3, 0, 0]).map(|i| i.global).collect();
assert_eq!(ids, vec![[6, 0, 0], [7, 0, 0]]);
}
}