use std::collections::{HashMap, HashSet};
use vyre::ir::{BufferAccess, Node, Program};
use vyre::Error;
use crate::{oob::Buffer, value::Value};
pub const MAX_WORKGROUP_BYTES: usize = 64 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct InvocationIds {
pub global: [u32; 3],
pub workgroup: [u32; 3],
pub local: [u32; 3],
}
impl InvocationIds {
pub const ZERO: Self = Self {
global: [0, 0, 0],
workgroup: [0, 0, 0],
local: [0, 0, 0],
};
}
#[derive(Debug)]
pub struct Memory {
pub(crate) storage: HashMap<String, Buffer>,
pub(crate) workgroup: HashMap<String, Buffer>,
}
pub struct Invocation<'a> {
pub ids: InvocationIds,
pub(crate) locals: HashMap<String, Value>,
immutable: HashSet<String>,
scopes: Vec<Vec<String>>,
frames: Vec<Frame<'a>>,
pub returned: bool,
pub waiting_at_barrier: bool,
pub uniform_checks: Vec<(usize, bool)>,
}
#[non_exhaustive]
pub enum Frame<'a> {
Nodes {
nodes: &'a [Node],
index: usize,
scoped: bool,
},
Loop {
var: &'a str,
next: u32,
to: u32,
body: &'a [Node],
},
}
impl<'a> Invocation<'a> {
pub fn new(ids: InvocationIds, entry: &'a [Node]) -> Self {
Self {
ids,
locals: HashMap::new(),
immutable: HashSet::new(),
scopes: vec![Vec::new()],
frames: vec![Frame::Nodes {
nodes: entry,
index: 0,
scoped: false,
}],
returned: false,
waiting_at_barrier: false,
uniform_checks: Vec::new(),
}
}
pub fn done(&self) -> bool {
self.returned || self.frames.is_empty()
}
pub fn push_scope(&mut self) {
self.scopes.push(Vec::new());
}
pub fn pop_scope(&mut self) {
if let Some(names) = self.scopes.pop() {
for name in names {
self.locals.remove(&name);
self.immutable.remove(&name);
}
}
}
pub fn bind(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
if self.locals.contains_key(name) {
return Err(Error::interp(format!(
"duplicate local binding `{name}`. Fix: choose a unique local name; shadowing is not allowed."
)));
}
self.locals.insert(name.to_string(), value);
if let Some(scope) = self.scopes.last_mut() {
scope.push(name.to_string());
}
Ok(())
}
pub fn bind_loop_var(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
self.bind(name, value)?;
self.immutable.insert(name.to_string());
Ok(())
}
pub fn assign(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
if self.immutable.contains(name) {
return Err(Error::interp(format!(
"assignment to loop variable `{name}`. Fix: loop variables are immutable."
)));
}
let Some(slot) = self.locals.get_mut(name) else {
return Err(Error::interp(format!(
"assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
)));
};
*slot = value;
Ok(())
}
pub(crate) fn frames_mut(&mut self) -> &mut Vec<Frame<'a>> {
&mut self.frames
}
}
pub(crate) fn create_invocations(
program: &Program,
workgroup: [u32; 3],
) -> Result<Vec<Invocation<'_>>, vyre::Error> {
let global_dim = |wgid: u32, size: u32, local: u32| {
wgid
.checked_mul(size)
.and_then(|base| base.checked_add(local))
.ok_or_else(|| Error::interp(
"workgroup * dispatch dimensions overflow u32 global id. Fix: reduce workgroup id or workgroup size so each global_invocation_id component fits in u32.",
))
};
let [sx, sy, sz] = program.workgroup_size();
let mut invocations = Vec::with_capacity((sx * sy * sz) as usize);
for z in 0..sz {
for y in 0..sy {
for x in 0..sx {
let local = [x, y, z];
let global = [
global_dim(workgroup[0], sx, x)?,
global_dim(workgroup[1], sy, y)?,
global_dim(workgroup[2], sz, z)?,
];
invocations.push(Invocation::new(
InvocationIds {
global,
workgroup,
local,
},
program.entry(),
));
}
}
}
Ok(invocations)
}
pub(crate) fn workgroup_memory(program: &Program) -> Result<HashMap<String, Buffer>, vyre::Error> {
let mut workgroup = HashMap::new();
let mut allocated = 0usize;
for decl in program
.buffers()
.iter()
.filter(|decl| decl.access() == BufferAccess::Workgroup)
{
let element_size = decl.element().min_bytes();
let len = (decl.count() as usize)
.checked_mul(element_size)
.ok_or_else(|| Error::interp(format!(
"workgroup buffer `{}` byte size overflows usize. Fix: reduce count or element size.",
decl.name()
)))?;
allocated = allocated
.checked_add(len)
.ok_or_else(|| Error::interp(
"total workgroup memory byte size overflows usize. Fix: reduce workgroup buffer declarations.",
))?;
if allocated > MAX_WORKGROUP_BYTES {
return Err(Error::interp(format!(
"workgroup memory requires {allocated} bytes, exceeding the {MAX_WORKGROUP_BYTES}-byte reference budget. Fix: reduce workgroup buffer counts."
)));
}
workgroup.insert(
decl.name().to_string(),
Buffer {
bytes: vec![0; len],
element: decl.element(),
},
);
}
Ok(workgroup)
}