Skip to main content

vyre_reference/
eval_expr.rs

1//! Expression evaluator that gives the parity engine a pure-Rust ground truth
2//! for every `Expr` variant.
3//!
4//! If a backend lowers `Expr::BinOp`, `Expr::Load`, or `Expr::Atomic` differently
5//! than this evaluator, the conform gate reports the exact divergence. This module
6//! exists so IR semantics are defined by Rust code, not by whatever a WGSL driver
7//! happens to emit.
8
9use vyre::ir::{AtomicOp, BinOp, BufferAccess, BufferDecl, DataType, Expr, Program, UnOp};
10
11use vyre::Error;
12
13use crate::{atomics, oob, value::Value, workgroup::Invocation, workgroup::Memory};
14
15/// Re-export the OOB-guarded buffer type used by storage operations.
16pub use crate::oob::Buffer;
17
18const MAX_CALL_INPUT_BYTES: usize = 64 * 1024 * 1024;
19
20/// Evaluate an expression for one invocation.
21///
22/// # Errors
23///
24/// Returns [`Error::Interp`] on operand type errors, malformed atomic or call
25/// expressions, unimplemented variants, or float operands.
26pub fn eval(
27    expr: &Expr,
28    invocation: &mut Invocation<'_>,
29    memory: &mut Memory,
30    program: &Program,
31) -> Result<Value, vyre::Error> {
32    match expr {
33        Expr::LitU32(value) => eval_lit_u32(*value),
34        Expr::LitI32(value) => eval_lit_i32(*value),
35        Expr::LitBool(value) => eval_lit_bool(*value),
36        Expr::Var(name) => eval_var(name, invocation),
37        Expr::Load { buffer, index } => eval_load(buffer, index, invocation, memory, program),
38        Expr::BufLen { buffer } => eval_buf_len(buffer, memory, program),
39        Expr::InvocationId { axis } => eval_invocation_id(*axis, invocation),
40        Expr::WorkgroupId { axis } => eval_workgroup_id(*axis, invocation),
41        Expr::LocalId { axis } => eval_local_id(*axis, invocation),
42        Expr::BinOp { op, left, right } => {
43            eval_binop(op.clone(), left, right, invocation, memory, program)
44        }
45        Expr::UnOp { op, operand } => eval_unop(op.clone(), operand, invocation, memory, program),
46        Expr::Call { op_id, args } => eval_call(op_id, args, invocation, memory, program),
47        Expr::Select {
48            cond,
49            true_val,
50            false_val,
51        } => eval_select(cond, true_val, false_val, invocation, memory, program),
52        Expr::Cast { target, value } => {
53            eval_cast(target.clone(), value, invocation, memory, program)
54        }
55        Expr::Atomic {
56            op,
57            buffer,
58            index,
59            expected,
60            value,
61        } => eval_atomic(
62            op.clone(),
63            buffer,
64            index,
65            expected.as_deref(),
66            value,
67            invocation,
68            memory,
69            program,
70        ),
71        _ => Err(Error::interp(format!(
72            "unsupported IR `unknown Expr variant: {expr:?}`. Fix: update vyre-reference for the new vyre::ir variant."
73        ))),
74    }
75}
76
77/// Return a mutable buffer only when the program declares it writable.
78///
79/// # Errors
80///
81/// Returns [`Error::Interp`] if the buffer is read-only, uniform,
82/// or does not exist in the program declaration.
83pub fn buffer_mut<'a>(
84    memory: &'a mut Memory,
85    program: &Program,
86    name: &str,
87) -> Result<&'a mut Buffer, vyre::Error> {
88    let decl = buffer_decl(program, name)?;
89    match decl.access() {
90        BufferAccess::ReadWrite | BufferAccess::Workgroup => resolve_buffer_mut(memory, decl),
91        BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
92            "store target `{name}` is not writable. Fix: declare it ReadWrite or Workgroup."
93        ))),
94        _ => Err(Error::interp(format!(
95            "store target `{name}` uses an unsupported access mode. Fix: upgrade vyre-conform or use a supported BufferAccess."
96        ))),
97    }
98}
99
100fn eval_lit_u32(value: u32) -> Result<Value, vyre::Error> {
101    Ok(Value::U32(value))
102}
103
104fn eval_lit_i32(value: i32) -> Result<Value, vyre::Error> {
105    Ok(Value::I32(value))
106}
107
108fn eval_lit_bool(value: bool) -> Result<Value, vyre::Error> {
109    Ok(Value::Bool(value))
110}
111
112fn eval_var(name: &str, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
113    invocation.locals.get(name).cloned().ok_or_else(|| {
114        Error::interp(format!(
115            "reference to undeclared variable `{name}`. Fix: add a Let before this use."
116        ))
117    })
118}
119
120fn eval_load(
121    buffer: &str,
122    index: &Expr,
123    invocation: &mut Invocation<'_>,
124    memory: &mut Memory,
125    program: &Program,
126) -> Result<Value, vyre::Error> {
127    let idx = eval_to_index(index, "load index", invocation, memory, program)?;
128    Ok(oob::load(resolve_buffer(memory, program, buffer)?, idx))
129}
130
131fn eval_buf_len(buffer: &str, memory: &Memory, program: &Program) -> Result<Value, vyre::Error> {
132    Ok(Value::U32(resolve_buffer(memory, program, buffer)?.len()))
133}
134
135fn eval_invocation_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
136    axis_value(invocation.ids.global, axis)
137}
138
139fn eval_workgroup_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
140    axis_value(invocation.ids.workgroup, axis)
141}
142
143fn eval_local_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
144    axis_value(invocation.ids.local, axis)
145}
146
147fn eval_binop(
148    op: BinOp,
149    left: &Expr,
150    right: &Expr,
151    invocation: &mut Invocation<'_>,
152    memory: &mut Memory,
153    program: &Program,
154) -> Result<Value, vyre::Error> {
155    let left = eval(left, invocation, memory, program)?;
156    let right = eval(right, invocation, memory, program)?;
157    super::typed_ops::eval_binop(op, left, right)
158}
159
160fn eval_unop(
161    op: UnOp,
162    operand: &Expr,
163    invocation: &mut Invocation<'_>,
164    memory: &mut Memory,
165    program: &Program,
166) -> Result<Value, vyre::Error> {
167    let operand = eval(operand, invocation, memory, program)?;
168    super::typed_ops::eval_unop(op, operand)
169}
170
171fn eval_call(
172    op_id: &str,
173    args: &[Expr],
174    invocation: &mut Invocation<'_>,
175    memory: &mut Memory,
176    program: &Program,
177) -> Result<Value, vyre::Error> {
178    let spec = vyre::ops::registry::lookup(op_id).ok_or_else(|| Error::interp(format!(
179            "unsupported call `{op_id}`. Fix: register the op in core::ops::registry or inline the callee as IR."
180    )))?;
181    let expected = spec.inputs().len();
182    if args.len() != expected {
183        return Err(Error::interp(format!(
184            "call `{op_id}` received {} arguments but the primitive signature requires {expected}. Fix: pass exactly {expected} arguments.",
185            args.len()
186        )));
187    }
188    let mut input = Vec::new();
189    for (arg, declared_type) in args.iter().zip(spec.inputs()) {
190        let declared_width = declared_type.min_bytes();
191        let bytes = eval(arg, invocation, memory, program)?.to_bytes_width(declared_width);
192        let next_len = input
193            .len()
194            .checked_add(bytes.len())
195            .ok_or_else(|| Error::interp(format!(
196                    "call `{op_id}` input byte size overflows usize. Fix: reduce the argument count or byte payload size."
197            )))?;
198        if next_len > MAX_CALL_INPUT_BYTES {
199            return Err(Error::interp(format!(
200                "call `{op_id}` requires {next_len} input bytes, exceeding the {MAX_CALL_INPUT_BYTES}-byte reference budget. Fix: reduce call input size."
201            )));
202        }
203        input.extend_from_slice(&bytes);
204    }
205    let mut output = Vec::new();
206    match spec.compose() {
207        vyre::ops::Compose::Composition(build) => {
208            crate::flat_cpu::run_flat(&build().with_entry_op_id(spec.id()), &input, &mut output)?;
209        }
210        vyre::ops::Compose::Intrinsic(intrinsic) => {
211            intrinsic.cpu_fn()(&input, &mut output);
212        }
213        other => {
214            return Err(Error::interp(format!(
215                "Fix: vyre-reference does not yet implement compose-kind `{other:?}` for op `{}`. Either implement the CPU path for this compose variant in vyre-reference/src/eval_expr.rs, or route the caller through a different op.",
216                spec.id()
217            )));
218        }
219    }
220    Ok(spec_output_value(
221        spec.outputs().first().cloned().unwrap_or(DataType::Bytes),
222        &output,
223    ))
224}
225
226fn eval_select(
227    cond: &Expr,
228    true_val: &Expr,
229    false_val: &Expr,
230    invocation: &mut Invocation<'_>,
231    memory: &mut Memory,
232    program: &Program,
233) -> Result<Value, vyre::Error> {
234    let cond = eval(cond, invocation, memory, program)?.truthy();
235    let true_val = eval(true_val, invocation, memory, program)?;
236    let false_val = eval(false_val, invocation, memory, program)?;
237    Ok(if cond { true_val } else { false_val })
238}
239
240fn eval_cast(
241    target: DataType,
242    value: &Expr,
243    invocation: &mut Invocation<'_>,
244    memory: &mut Memory,
245    program: &Program,
246) -> Result<Value, vyre::Error> {
247    let value = eval(value, invocation, memory, program)?;
248    cast_value(target, &value)
249}
250
251fn eval_atomic(
252    op: AtomicOp,
253    buffer: &str,
254    index: &Expr,
255    expected: Option<&Expr>,
256    value: &Expr,
257    invocation: &mut Invocation<'_>,
258    memory: &mut Memory,
259    program: &Program,
260) -> Result<Value, vyre::Error> {
261    match (op.clone(), expected) {
262        (AtomicOp::CompareExchange, None) => {
263            return Err(Error::interp(
264                "compare-exchange atomic is missing expected value. Fix: set Expr::Atomic.expected for AtomicOp::CompareExchange.",
265            ));
266        }
267        (AtomicOp::CompareExchange, Some(_)) => {}
268        (_, Some(_)) => {
269            return Err(Error::interp(
270                "non-compare-exchange atomic includes an expected value. Fix: use Expr::Atomic.expected only with AtomicOp::CompareExchange.",
271            ));
272        }
273        (_, None) => {}
274    }
275    let idx = eval_to_index(index, "atomic index", invocation, memory, program)?;
276    let expected = expected
277        .map(|expr| {
278            eval(expr, invocation, memory, program)?.try_as_u32().ok_or_else(|| {
279                Error::interp(format!(
280                        "atomic expected value {expr:?} cannot be represented as u32. Fix: use a scalar u32-compatible argument."
281                ))
282            })
283        })
284        .transpose()?;
285    let value = eval(value, invocation, memory, program)?;
286    let value = value.try_as_u32().ok_or_else(|| {
287        Error::interp(
288            "atomic value cannot be represented as u32. Fix: use a scalar u32-compatible argument.",
289        )
290    })?;
291    let target = atomic_buffer_mut(memory, program, buffer)?;
292    let Some(old) = oob::atomic_load(target, idx) else {
293        return Ok(Value::U32(0));
294    };
295    let (old, new) = atomics::apply(op, old, expected, value)?;
296    oob::atomic_store(target, idx, new);
297    Ok(Value::U32(old))
298}
299
300fn eval_to_index(
301    index: &Expr,
302    context: &'static str,
303    invocation: &mut Invocation<'_>,
304    memory: &mut Memory,
305    program: &Program,
306) -> Result<u32, vyre::Error> {
307    let value = eval(index, invocation, memory, program)?;
308    value
309        .try_as_u32()
310        .ok_or_else(|| Error::interp(format!(
311                "{context} {value:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32.",
312        )))
313}
314
315fn resolve_buffer<'a>(
316    memory: &'a Memory,
317    program: &Program,
318    name: &str,
319) -> Result<&'a oob::Buffer, vyre::Error> {
320    let decl = buffer_decl(program, name)?;
321    if decl.access() == BufferAccess::Workgroup {
322        memory.workgroup.get(name)
323    } else {
324        memory.storage.get(name)
325    }
326    .ok_or_else(|| {
327        Error::interp(format!(
328            "missing buffer `{name}`. Fix: initialize all declared buffers."
329        ))
330    })
331}
332
333fn resolve_buffer_mut<'a>(
334    memory: &'a mut Memory,
335    decl: &BufferDecl,
336) -> Result<&'a mut oob::Buffer, vyre::Error> {
337    let name = decl.name();
338    if decl.access() == BufferAccess::Workgroup {
339        memory.workgroup.get_mut(name)
340    } else {
341        memory.storage.get_mut(name)
342    }
343    .ok_or_else(|| {
344        Error::interp(format!(
345            "missing buffer `{name}`. Fix: initialize all declared buffers."
346        ))
347    })
348}
349
350fn atomic_buffer_mut<'a>(
351    memory: &'a mut Memory,
352    program: &Program,
353    name: &str,
354) -> Result<&'a mut oob::Buffer, vyre::Error> {
355    let decl = buffer_decl(program, name)?;
356    match decl.access() {
357        BufferAccess::ReadWrite => resolve_buffer_mut(memory, decl),
358        BufferAccess::Workgroup => Err(Error::interp(format!(
359            "atomic target `{name}` is workgroup memory. Fix: atomics only support ReadWrite storage buffers."
360        ))),
361        BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
362            "atomic target `{name}` is not writable. Fix: atomics only support ReadWrite storage buffers."
363        ))),
364        _ => Err(Error::interp(format!(
365            "atomic target `{name}` uses an unsupported access mode. Fix: upgrade vyre-conform or use a supported BufferAccess."
366        ))),
367    }
368}
369
370fn buffer_decl<'a>(program: &'a Program, name: &str) -> Result<&'a BufferDecl, vyre::Error> {
371    program.buffer(name).ok_or_else(|| {
372        Error::interp(format!(
373            "unknown buffer `{name}`. Fix: declare it in Program::buffers."
374        ))
375    })
376}
377
378fn axis_value(values: [u32; 3], axis: u8) -> Result<Value, vyre::Error> {
379    values
380        .get(axis as usize)
381        .copied()
382        .map(Value::U32)
383        .ok_or_else(|| {
384            Error::interp(format!(
385                "invocation/workgroup ID axis {axis} out of range. Fix: use 0, 1, or 2."
386            ))
387        })
388}
389
390fn spec_output_value(ty: DataType, bytes: &[u8]) -> Value {
391    match ty {
392        DataType::U32 => Value::U32(read_u32_prefix(bytes)),
393        DataType::I32 => Value::I32(read_u32_prefix(bytes) as i32),
394        DataType::Bool => Value::Bool(read_u32_prefix(bytes) != 0),
395        DataType::U64 => Value::U64(read_u64_prefix(bytes)),
396        DataType::F32 => Value::Float(f32::from_bits(read_u32_prefix(bytes)) as f64),
397        DataType::Vec2U32 => Value::Bytes(read_fixed_prefix(bytes, 8)),
398        DataType::Vec4U32 => Value::Bytes(read_fixed_prefix(bytes, 16)),
399        DataType::Bytes => Value::Bytes(bytes.to_vec()),
400        _ => Value::Bytes(bytes.to_vec()),
401    }
402}
403
404fn read_fixed_prefix(bytes: &[u8], width: usize) -> Vec<u8> {
405    let mut fixed = vec![0u8; width];
406    let len = bytes.len().min(width);
407    fixed[..len].copy_from_slice(&bytes[..len]);
408    fixed
409}
410
411fn cast_value(target: DataType, value: &Value) -> Result<Value, vyre::Error> {
412    match target {
413        // GPU parity: casting I32 -> U32 is a two's-complement bit
414        // reinterpretation (WGSL `u32(x)` where `x: i32`), not a lossless
415        // numeric conversion. The validator rejecting I32(-1) would force
416        // every IR author to pre-mask; that diverges from the WGSL backend
417        // and breaks ops like `neg` that hand back negative intermediaries.
418        DataType::U32 => match value {
419            Value::I32(v) => Ok(Value::U32(*v as u32)),
420            _ => value
421                .try_as_u32()
422                .map(Value::U32)
423                .ok_or_else(|| invalid_cast(target, value)),
424        },
425        DataType::I32 => match value {
426            Value::I32(value) => Ok(Value::I32(*value)),
427            _ => value
428                .try_as_u32()
429                .map(|value| Value::I32(value as i32))
430                .ok_or_else(|| invalid_cast(target, value)),
431        },
432        // Kimi audit #3: preserve the full u64 payload, not just the
433        // low word. Previously `u64::from(value.as_u32())` silently
434        // discarded the upper 32 bits of any U64 source value.
435        DataType::U64 => value
436            .try_as_u64()
437            .map(Value::U64)
438            .ok_or_else(|| invalid_cast(target, value)),
439        DataType::Bool => Ok(Value::Bool(value.truthy())),
440        DataType::Bytes => Ok(Value::Bytes(value.to_bytes())),
441        // Kimi audit #4 & #5: Vec2U32/Vec4U32 casts must preserve every
442        // component of the source, not just the first word. The prior
443        // `vec_bytes(value.as_u32(), N)` pipeline discarded the upper
444        // words of any multi-component source.
445        DataType::Vec2U32 => Ok(Value::Bytes(widen_to_words(value, 2))),
446        DataType::Vec4U32 => Ok(Value::Bytes(widen_to_words(value, 4))),
447        _ => Ok(Value::Bytes(value.to_bytes())),
448    }
449}
450
451fn invalid_cast(target: DataType, value: &Value) -> Error {
452    Error::interp(format!(
453        "cast to {target:?} cannot represent {value:?} losslessly. Fix: cast from an in-range scalar value."
454    ))
455}
456
457/// Widen `value` into `words` little-endian u32 slots, preserving
458/// every source byte up to `words * 4` and zero-filling the remainder.
459/// Scalar sources (U32/I32/Bool) sit in the first slot with the rest
460/// zeroed; multi-word sources (U64/Bytes/Vec*) copy all their bytes up
461/// to the target width.
462fn widen_to_words(value: &Value, words: usize) -> Vec<u8> {
463    let target_bytes = words * 4;
464    let mut bytes = value.to_bytes();
465    if bytes.len() > target_bytes {
466        bytes.truncate(target_bytes);
467    } else if bytes.len() < target_bytes {
468        bytes.resize(target_bytes, 0);
469    }
470    bytes
471}
472
473use super::ops::{read_u32_prefix, read_u64_prefix};