use std::collections::HashMap;
use vyre::ir::{BufferAccess, BufferDecl, Program};
use vyre::Error;
use crate::{
eval_node,
oob::Buffer,
value::Value,
workgroup::{self, Invocation, Memory},
};
pub fn run(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
let validation_errors = vyre::ir::validate(program);
if !validation_errors.is_empty() {
let messages = validation_errors
.into_iter()
.map(|error| error.message().to_string())
.collect::<Vec<_>>()
.join("; ");
return Err(Error::interp(format!(
"program failed IR validation: {messages}. Fix: repair the Program before invoking the reference interpreter."
)));
}
let Prepared {
storage,
output_names,
max_elements,
} = prepare_storage(program, inputs)?;
execute_dispatch(program, storage, output_names, max_elements)
}
struct Prepared {
storage: HashMap<String, Buffer>,
output_names: Vec<String>,
max_elements: u32,
}
fn prepare_storage(program: &Program, inputs: &[Value]) -> Result<Prepared, vyre::Error> {
let mut storage = HashMap::new();
let mut input_index = 0usize;
let mut output_names = Vec::new();
let mut max_elements = 1u32;
for decl in program.buffers() {
if decl.access() == BufferAccess::Workgroup {
continue;
}
let value = inputs
.get(input_index)
.ok_or_else(|| Error::interp(format!(
"missing input for buffer `{}`. Fix: pass one Value for each non-workgroup buffer in Program::buffers order.",
decl.name()
)))?;
input_index += 1;
let bytes = value.to_bytes();
max_elements = max_elements.max(element_count(decl, bytes.len())?);
if decl.access() == BufferAccess::ReadWrite {
output_names.push(decl.name().to_string());
}
storage.insert(
decl.name().to_string(),
Buffer {
bytes,
element: decl.element(),
},
);
}
if input_index != inputs.len() {
return Err(Error::interp(
"unused input values supplied. Fix: pass exactly one Value per non-workgroup buffer declaration.",
));
}
Ok(Prepared {
storage,
output_names,
max_elements,
})
}
fn execute_dispatch(
program: &Program,
mut storage: HashMap<String, Buffer>,
output_names: Vec<String>,
max_elements: u32,
) -> Result<Vec<Value>, vyre::Error> {
validate_workgroup_size(program)?;
let invocations_per_workgroup = invocations_per_workgroup(program);
let workgroup_count_x = max_elements.div_ceil(invocations_per_workgroup).max(1);
for wg_x in 0..workgroup_count_x {
run_workgroup(program, &mut storage, [wg_x, 0, 0])?;
}
output_names
.into_iter()
.map(|name| {
storage
.remove(&name)
.map(|buffer| Value::Bytes(buffer.bytes))
.ok_or_else(|| Error::interp(format!(
"missing output buffer `{name}` after dispatch. Fix: keep buffer declarations unique."
)))
})
.collect()
}
fn validate_workgroup_size(program: &Program) -> Result<(), vyre::Error> {
if program.workgroup_size().contains(&0) {
return Err(Error::interp(
"workgroup size contains zero. Fix: all dimensions must be >= 1.",
));
}
Ok(())
}
fn invocations_per_workgroup(program: &Program) -> u32 {
program
.workgroup_size()
.iter()
.copied()
.fold(1u32, u32::saturating_mul)
.max(1)
}
fn run_workgroup(
program: &Program,
storage: &mut HashMap<String, Buffer>,
workgroup_id: [u32; 3],
) -> Result<(), vyre::Error> {
let mut memory = Memory {
storage: std::mem::take(storage),
workgroup: workgroup::workgroup_memory(program)?,
};
let mut invocations = workgroup::create_invocations(program, workgroup_id)?;
run_invocations(program, &mut memory, &mut invocations)?;
*storage = memory.storage;
Ok(())
}
fn run_invocations<'a>(
program: &'a Program,
memory: &mut Memory,
invocations: &mut [Invocation<'a>],
) -> Result<(), vyre::Error> {
while invocations.iter().any(|invocation| !invocation.done()) {
let made_progress = step_round_robin(program, memory, invocations)?;
verify_uniform_control_flow(invocations)?;
if release_barrier_if_ready(invocations) {
continue;
}
if !made_progress && live_waiting_count(invocations) > 0 {
return Err(Error::interp(
"program violates uniform-control-flow rule: not every live invocation reached the same barrier. Fix: move Barrier to uniform control flow.",
));
}
}
Ok(())
}
fn step_round_robin<'a>(
program: &'a Program,
memory: &mut Memory,
invocations: &mut [Invocation<'a>],
) -> Result<bool, vyre::Error> {
let mut made_progress = false;
for invocation in invocations {
if invocation.done() || invocation.waiting_at_barrier {
continue;
}
eval_node::step(invocation, memory, program)?;
made_progress = true;
}
Ok(made_progress)
}
fn release_barrier_if_ready(invocations: &mut [Invocation<'_>]) -> bool {
let active = invocations
.iter()
.filter(|invocation| !invocation.done())
.count();
let waiting = live_waiting_count(invocations);
if active > 0 && active == waiting {
for invocation in invocations {
invocation.waiting_at_barrier = false;
}
true
} else {
false
}
}
fn live_waiting_count(invocations: &[Invocation<'_>]) -> usize {
invocations
.iter()
.filter(|invocation| !invocation.done() && invocation.waiting_at_barrier)
.count()
}
fn verify_uniform_control_flow(invocations: &[Invocation<'_>]) -> Result<(), vyre::Error> {
let mut observed: HashMap<usize, bool> = HashMap::new();
for invocation in invocations.iter().filter(|invocation| !invocation.done()) {
for (id, value) in &invocation.uniform_checks {
if let Some(previous) = observed.insert(*id, *value) {
if previous != *value {
return Err(Error::interp(
"program violates uniform-control-flow rule: Barrier appears inside an If whose condition differs across the workgroup. Fix: make the condition uniform or move Barrier outside the branch.",
));
}
}
}
}
Ok(())
}
fn element_count(decl: &BufferDecl, byte_len: usize) -> Result<u32, vyre::Error> {
let stride = decl.element().min_bytes();
if stride == 0 {
return u32::try_from(byte_len).map_err(|_| Error::interp(format!(
"buffer `{}` has {} bytes and cannot be indexed within u32 address space. Fix: shrink or split the invocation."
, decl.name(),
byte_len,
)));
}
let elements = byte_len / stride;
u32::try_from(elements).map_err(|_| Error::interp(format!(
"buffer `{}` has {} bytes for stride {} and overflows u32 elements. Fix: shrink declaration footprint or split work.",
decl.name(),
byte_len,
stride,
)))
}