use vyre::ir::{Expr, Node, Program};
use crate::{
eval_expr, oob,
workgroup::{Frame, Invocation, Memory},
};
use vyre::Error;
pub fn step<'a>(
invocation: &mut Invocation<'a>,
memory: &mut Memory,
program: &'a Program,
) -> Result<(), vyre::Error> {
if invocation.done() || invocation.waiting_at_barrier {
return Ok(());
}
loop {
let Some(frame) = invocation.frames_mut().pop() else {
return Ok(());
};
match frame {
Frame::Nodes {
nodes,
index,
scoped,
} => {
if step_nodes_frame(invocation, memory, program, nodes, index, scoped)? {
return Ok(());
}
}
Frame::Loop {
var,
next,
to,
body,
} => step_loop_frame(invocation, var, next, to, body)?,
}
}
}
fn step_nodes_frame<'a>(
invocation: &mut Invocation<'a>,
memory: &mut Memory,
program: &'a Program,
nodes: &'a [Node],
index: usize,
scoped: bool,
) -> Result<bool, vyre::Error> {
if index >= nodes.len() {
if scoped {
invocation.pop_scope();
}
return Ok(false);
}
invocation.frames_mut().push(Frame::Nodes {
nodes,
index: index + 1,
scoped,
});
execute_node(&nodes[index], invocation, memory, program)?;
Ok(true)
}
fn step_loop_frame<'a>(
invocation: &mut Invocation<'a>,
var: &'a str,
next: u32,
to: u32,
body: &'a [Node],
) -> Result<(), vyre::Error> {
if next >= to {
return Ok(());
}
invocation.frames_mut().push(Frame::Loop {
var,
next: next.wrapping_add(1),
to,
body,
});
invocation.push_scope();
invocation.bind_loop_var(var, crate::value::Value::U32(next))?;
invocation.frames_mut().push(Frame::Nodes {
nodes: body,
index: 0,
scoped: true,
});
Ok(())
}
fn execute_node<'a>(
node: &'a Node,
invocation: &mut Invocation<'a>,
memory: &mut Memory,
program: &'a Program,
) -> Result<(), vyre::Error> {
match node {
Node::Let { name, value } => eval_let(name, value, invocation, memory, program),
Node::Assign { name, value } => eval_assign(name, value, invocation, memory, program),
Node::Store {
buffer,
index,
value,
} => eval_store(buffer, index, value, invocation, memory, program),
Node::If {
cond,
then,
otherwise,
} => eval_if(cond, then, otherwise, node, invocation, memory, program),
Node::Loop {
var,
from,
to,
body,
} => eval_loop(var, from, to, body, invocation, memory, program),
Node::Return => eval_return(invocation),
Node::Block(nodes) => eval_block(nodes, invocation),
Node::Barrier => eval_barrier(invocation),
}
}
fn eval_let(
name: &str,
value: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<(), vyre::Error> {
let value = eval_expr::eval(value, invocation, memory, program)?;
invocation.bind(name, value)
}
fn eval_assign(
name: &str,
value: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<(), vyre::Error> {
let value = eval_expr::eval(value, invocation, memory, program)?;
invocation.assign(name, value)
}
fn eval_store(
buffer: &str,
index: &Expr,
value: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<(), vyre::Error> {
let index = eval_expr::eval(index, invocation, memory, program)?;
let index = index
.try_as_u32()
.ok_or_else(|| Error::interp(format!(
"store index {index:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
)))?;
let value = eval_expr::eval(value, invocation, memory, program)?;
let target = eval_expr::buffer_mut(memory, program, buffer)?;
oob::store(target, index, &value);
Ok(())
}
fn eval_if<'a>(
cond: &Expr,
then: &'a [Node],
otherwise: &'a [Node],
node: &Node,
invocation: &mut Invocation<'a>,
memory: &mut Memory,
program: &Program,
) -> Result<(), vyre::Error> {
let cond_value = eval_expr::eval(cond, invocation, memory, program)?.truthy();
if contains_barrier(then) || contains_barrier(otherwise) {
invocation.uniform_checks.push((node_id(node), cond_value));
}
let branch = if cond_value { then } else { otherwise };
invocation.push_scope();
invocation.frames_mut().push(Frame::Nodes {
nodes: branch,
index: 0,
scoped: true,
});
Ok(())
}
fn eval_loop<'a>(
var: &'a str,
from: &Expr,
to: &Expr,
body: &'a [Node],
invocation: &mut Invocation<'a>,
memory: &mut Memory,
program: &Program,
) -> Result<(), vyre::Error> {
let from_value = eval_expr::eval(from, invocation, memory, program)?;
let to_value = eval_expr::eval(to, invocation, memory, program)?;
let from = from_value.try_as_u32().ok_or_else(|| {
Error::interp(format!(
"loop lower bound {from_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
))
})?;
let to = to_value.try_as_u32().ok_or_else(|| Error::interp(format!(
"loop upper bound {to_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
)))?;
invocation.frames_mut().push(Frame::Loop {
var,
next: from,
to,
body,
});
Ok(())
}
fn eval_return(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
invocation.frames_mut().clear();
invocation.returned = true;
Ok(())
}
fn eval_block<'a>(nodes: &'a [Node], invocation: &mut Invocation<'a>) -> Result<(), vyre::Error> {
invocation.push_scope();
invocation.frames_mut().push(Frame::Nodes {
nodes,
index: 0,
scoped: true,
});
Ok(())
}
fn eval_barrier(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
invocation.waiting_at_barrier = true;
Ok(())
}
fn contains_barrier(nodes: &[Node]) -> bool {
nodes.iter().any(node_contains_barrier)
}
fn node_contains_barrier(node: &Node) -> bool {
match node {
Node::Barrier => true,
Node::Let { .. } | Node::Assign { .. } | Node::Store { .. } | Node::Return => false,
Node::If {
then, otherwise, ..
} => contains_barrier(then) || contains_barrier(otherwise),
Node::Loop { body, .. } => contains_barrier(body),
Node::Block(body) => contains_barrier(body),
}
}
fn node_id(node: &Node) -> usize {
std::ptr::from_ref(node).addr()
}