use vyre::ir::{AtomicOp, BinOp, BufferAccess, BufferDecl, DataType, Expr, Program, UnOp};
use vyre::Error;
use crate::{atomics, oob, value::Value, workgroup::Invocation, workgroup::Memory};
pub use crate::oob::Buffer;
const MAX_CALL_INPUT_BYTES: usize = 64 * 1024 * 1024;
pub fn eval(
expr: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<Value, vyre::Error> {
match expr {
Expr::LitU32(value) => eval_lit_u32(*value),
Expr::LitI32(value) => eval_lit_i32(*value),
Expr::LitBool(value) => eval_lit_bool(*value),
Expr::Var(name) => eval_var(name, invocation),
Expr::Load { buffer, index } => eval_load(buffer, index, invocation, memory, program),
Expr::BufLen { buffer } => eval_buf_len(buffer, memory, program),
Expr::InvocationId { axis } => eval_invocation_id(*axis, invocation),
Expr::WorkgroupId { axis } => eval_workgroup_id(*axis, invocation),
Expr::LocalId { axis } => eval_local_id(*axis, invocation),
Expr::BinOp { op, left, right } => {
eval_binop(op.clone(), left, right, invocation, memory, program)
}
Expr::UnOp { op, operand } => eval_unop(op.clone(), operand, invocation, memory, program),
Expr::Call { op_id, args } => eval_call(op_id, args, invocation, memory, program),
Expr::Select {
cond,
true_val,
false_val,
} => eval_select(cond, true_val, false_val, invocation, memory, program),
Expr::Cast { target, value } => {
eval_cast(target.clone(), value, invocation, memory, program)
}
Expr::Atomic {
op,
buffer,
index,
expected,
value,
} => eval_atomic(
op.clone(),
buffer,
index,
expected.as_deref(),
value,
invocation,
memory,
program,
),
_ => Err(Error::interp(format!(
"unsupported IR `unknown Expr variant: {expr:?}`. Fix: update vyre-reference for the new vyre::ir variant."
))),
}
}
pub fn buffer_mut<'a>(
memory: &'a mut Memory,
program: &Program,
name: &str,
) -> Result<&'a mut Buffer, vyre::Error> {
let decl = buffer_decl(program, name)?;
match decl.access() {
BufferAccess::ReadWrite | BufferAccess::Workgroup => resolve_buffer_mut(memory, decl),
BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
"store target `{name}` is not writable. Fix: declare it ReadWrite or Workgroup."
))),
_ => Err(Error::interp(format!(
"store target `{name}` uses an unsupported access mode. Fix: upgrade vyre-conform or use a supported BufferAccess."
))),
}
}
fn eval_lit_u32(value: u32) -> Result<Value, vyre::Error> {
Ok(Value::U32(value))
}
fn eval_lit_i32(value: i32) -> Result<Value, vyre::Error> {
Ok(Value::I32(value))
}
fn eval_lit_bool(value: bool) -> Result<Value, vyre::Error> {
Ok(Value::Bool(value))
}
fn eval_var(name: &str, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
invocation.locals.get(name).cloned().ok_or_else(|| {
Error::interp(format!(
"reference to undeclared variable `{name}`. Fix: add a Let before this use."
))
})
}
fn eval_load(
buffer: &str,
index: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<Value, vyre::Error> {
let idx = eval_to_index(index, "load index", invocation, memory, program)?;
Ok(oob::load(resolve_buffer(memory, program, buffer)?, idx))
}
fn eval_buf_len(buffer: &str, memory: &Memory, program: &Program) -> Result<Value, vyre::Error> {
Ok(Value::U32(resolve_buffer(memory, program, buffer)?.len()))
}
fn eval_invocation_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
axis_value(invocation.ids.global, axis)
}
fn eval_workgroup_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
axis_value(invocation.ids.workgroup, axis)
}
fn eval_local_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
axis_value(invocation.ids.local, axis)
}
fn eval_binop(
op: BinOp,
left: &Expr,
right: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<Value, vyre::Error> {
let left = eval(left, invocation, memory, program)?;
let right = eval(right, invocation, memory, program)?;
super::typed_ops::eval_binop(op, left, right)
}
fn eval_unop(
op: UnOp,
operand: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<Value, vyre::Error> {
let operand = eval(operand, invocation, memory, program)?;
super::typed_ops::eval_unop(op, operand)
}
fn eval_call(
op_id: &str,
args: &[Expr],
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<Value, vyre::Error> {
let spec = vyre::ops::registry::lookup(op_id).ok_or_else(|| Error::interp(format!(
"unsupported call `{op_id}`. Fix: register the op in core::ops::registry or inline the callee as IR."
)))?;
let expected = spec.inputs().len();
if args.len() != expected {
return Err(Error::interp(format!(
"call `{op_id}` received {} arguments but the primitive signature requires {expected}. Fix: pass exactly {expected} arguments.",
args.len()
)));
}
let mut input = Vec::new();
for (arg, declared_type) in args.iter().zip(spec.inputs()) {
let declared_width = declared_type.min_bytes();
let bytes = eval(arg, invocation, memory, program)?.to_bytes_width(declared_width);
let next_len = input
.len()
.checked_add(bytes.len())
.ok_or_else(|| Error::interp(format!(
"call `{op_id}` input byte size overflows usize. Fix: reduce the argument count or byte payload size."
)))?;
if next_len > MAX_CALL_INPUT_BYTES {
return Err(Error::interp(format!(
"call `{op_id}` requires {next_len} input bytes, exceeding the {MAX_CALL_INPUT_BYTES}-byte reference budget. Fix: reduce call input size."
)));
}
input.extend_from_slice(&bytes);
}
let mut output = Vec::new();
match spec.compose() {
vyre::ops::Compose::Composition(build) => {
crate::flat_cpu::run_flat(&build().with_entry_op_id(spec.id()), &input, &mut output)?;
}
vyre::ops::Compose::Intrinsic(intrinsic) => {
intrinsic.cpu_fn()(&input, &mut output);
}
other => {
return Err(Error::interp(format!(
"Fix: vyre-reference does not yet implement compose-kind `{other:?}` for op `{}`. Either implement the CPU path for this compose variant in vyre-reference/src/eval_expr.rs, or route the caller through a different op.",
spec.id()
)));
}
}
Ok(spec_output_value(
spec.outputs().first().cloned().unwrap_or(DataType::Bytes),
&output,
))
}
fn eval_select(
cond: &Expr,
true_val: &Expr,
false_val: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<Value, vyre::Error> {
let cond = eval(cond, invocation, memory, program)?.truthy();
let true_val = eval(true_val, invocation, memory, program)?;
let false_val = eval(false_val, invocation, memory, program)?;
Ok(if cond { true_val } else { false_val })
}
fn eval_cast(
target: DataType,
value: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<Value, vyre::Error> {
let value = eval(value, invocation, memory, program)?;
cast_value(target, &value)
}
fn eval_atomic(
op: AtomicOp,
buffer: &str,
index: &Expr,
expected: Option<&Expr>,
value: &Expr,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<Value, vyre::Error> {
match (op.clone(), expected) {
(AtomicOp::CompareExchange, None) => {
return Err(Error::interp(
"compare-exchange atomic is missing expected value. Fix: set Expr::Atomic.expected for AtomicOp::CompareExchange.",
));
}
(AtomicOp::CompareExchange, Some(_)) => {}
(_, Some(_)) => {
return Err(Error::interp(
"non-compare-exchange atomic includes an expected value. Fix: use Expr::Atomic.expected only with AtomicOp::CompareExchange.",
));
}
(_, None) => {}
}
let idx = eval_to_index(index, "atomic index", invocation, memory, program)?;
let expected = expected
.map(|expr| {
eval(expr, invocation, memory, program)?.try_as_u32().ok_or_else(|| {
Error::interp(format!(
"atomic expected value {expr:?} cannot be represented as u32. Fix: use a scalar u32-compatible argument."
))
})
})
.transpose()?;
let value = eval(value, invocation, memory, program)?;
let value = value.try_as_u32().ok_or_else(|| {
Error::interp(
"atomic value cannot be represented as u32. Fix: use a scalar u32-compatible argument.",
)
})?;
let target = atomic_buffer_mut(memory, program, buffer)?;
let Some(old) = oob::atomic_load(target, idx) else {
return Ok(Value::U32(0));
};
let (old, new) = atomics::apply(op, old, expected, value)?;
oob::atomic_store(target, idx, new);
Ok(Value::U32(old))
}
fn eval_to_index(
index: &Expr,
context: &'static str,
invocation: &mut Invocation<'_>,
memory: &mut Memory,
program: &Program,
) -> Result<u32, vyre::Error> {
let value = eval(index, invocation, memory, program)?;
value
.try_as_u32()
.ok_or_else(|| Error::interp(format!(
"{context} {value:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32.",
)))
}
fn resolve_buffer<'a>(
memory: &'a Memory,
program: &Program,
name: &str,
) -> Result<&'a oob::Buffer, vyre::Error> {
let decl = buffer_decl(program, name)?;
if decl.access() == BufferAccess::Workgroup {
memory.workgroup.get(name)
} else {
memory.storage.get(name)
}
.ok_or_else(|| {
Error::interp(format!(
"missing buffer `{name}`. Fix: initialize all declared buffers."
))
})
}
fn resolve_buffer_mut<'a>(
memory: &'a mut Memory,
decl: &BufferDecl,
) -> Result<&'a mut oob::Buffer, vyre::Error> {
let name = decl.name();
if decl.access() == BufferAccess::Workgroup {
memory.workgroup.get_mut(name)
} else {
memory.storage.get_mut(name)
}
.ok_or_else(|| {
Error::interp(format!(
"missing buffer `{name}`. Fix: initialize all declared buffers."
))
})
}
fn atomic_buffer_mut<'a>(
memory: &'a mut Memory,
program: &Program,
name: &str,
) -> Result<&'a mut oob::Buffer, vyre::Error> {
let decl = buffer_decl(program, name)?;
match decl.access() {
BufferAccess::ReadWrite => resolve_buffer_mut(memory, decl),
BufferAccess::Workgroup => Err(Error::interp(format!(
"atomic target `{name}` is workgroup memory. Fix: atomics only support ReadWrite storage buffers."
))),
BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
"atomic target `{name}` is not writable. Fix: atomics only support ReadWrite storage buffers."
))),
_ => Err(Error::interp(format!(
"atomic target `{name}` uses an unsupported access mode. Fix: upgrade vyre-conform or use a supported BufferAccess."
))),
}
}
fn buffer_decl<'a>(program: &'a Program, name: &str) -> Result<&'a BufferDecl, vyre::Error> {
program.buffer(name).ok_or_else(|| {
Error::interp(format!(
"unknown buffer `{name}`. Fix: declare it in Program::buffers."
))
})
}
fn axis_value(values: [u32; 3], axis: u8) -> Result<Value, vyre::Error> {
values
.get(axis as usize)
.copied()
.map(Value::U32)
.ok_or_else(|| {
Error::interp(format!(
"invocation/workgroup ID axis {axis} out of range. Fix: use 0, 1, or 2."
))
})
}
fn spec_output_value(ty: DataType, bytes: &[u8]) -> Value {
match ty {
DataType::U32 => Value::U32(read_u32_prefix(bytes)),
DataType::I32 => Value::I32(read_u32_prefix(bytes) as i32),
DataType::Bool => Value::Bool(read_u32_prefix(bytes) != 0),
DataType::U64 => Value::U64(read_u64_prefix(bytes)),
DataType::F32 => Value::Float(f32::from_bits(read_u32_prefix(bytes)) as f64),
DataType::Vec2U32 => Value::Bytes(read_fixed_prefix(bytes, 8)),
DataType::Vec4U32 => Value::Bytes(read_fixed_prefix(bytes, 16)),
DataType::Bytes => Value::Bytes(bytes.to_vec()),
_ => Value::Bytes(bytes.to_vec()),
}
}
fn read_fixed_prefix(bytes: &[u8], width: usize) -> Vec<u8> {
let mut fixed = vec![0u8; width];
let len = bytes.len().min(width);
fixed[..len].copy_from_slice(&bytes[..len]);
fixed
}
fn cast_value(target: DataType, value: &Value) -> Result<Value, vyre::Error> {
match target {
DataType::U32 => match value {
Value::I32(v) => Ok(Value::U32(*v as u32)),
_ => value
.try_as_u32()
.map(Value::U32)
.ok_or_else(|| invalid_cast(target, value)),
},
DataType::I32 => match value {
Value::I32(value) => Ok(Value::I32(*value)),
_ => value
.try_as_u32()
.map(|value| Value::I32(value as i32))
.ok_or_else(|| invalid_cast(target, value)),
},
DataType::U64 => value
.try_as_u64()
.map(Value::U64)
.ok_or_else(|| invalid_cast(target, value)),
DataType::Bool => Ok(Value::Bool(value.truthy())),
DataType::Bytes => Ok(Value::Bytes(value.to_bytes())),
DataType::Vec2U32 => Ok(Value::Bytes(widen_to_words(value, 2))),
DataType::Vec4U32 => Ok(Value::Bytes(widen_to_words(value, 4))),
_ => Ok(Value::Bytes(value.to_bytes())),
}
}
fn invalid_cast(target: DataType, value: &Value) -> Error {
Error::interp(format!(
"cast to {target:?} cannot represent {value:?} losslessly. Fix: cast from an in-range scalar value."
))
}
fn widen_to_words(value: &Value, words: usize) -> Vec<u8> {
let target_bytes = words * 4;
let mut bytes = value.to_bytes();
if bytes.len() > target_bytes {
bytes.truncate(target_bytes);
} else if bytes.len() < target_bytes {
bytes.resize(target_bytes, 0);
}
bytes
}
use super::ops::{read_u32_prefix, read_u64_prefix};