Skip to main content

vyre_reference/execution/
node.rs

1//! Statement executor that gives the parity engine a pure-Rust ground truth
2//! for every `Node` variant.
3//!
4//! This module simulates the exact control-flow, memory, and barrier behavior
5//! that a correct GPU backend must produce. Any divergence in `If`, `Loop`,
6//! `Barrier`, or `Store` semantics is caught by the conform gate as a concrete
7//! counterexample.
8
9use vyre::ir::{Expr, Node, Program};
10
11use crate::{
12    execution::expr as eval_expr,
13    execution::node_tree::{contains_barrier, node_id},
14    oob,
15    workgroup::{AsyncTransfer, Frame, Invocation, Memory},
16};
17use vyre::Error;
18
19/// Execute one scheduling step for an invocation.
20///
21/// # Errors
22///
23/// Returns [`Error::Interp`] for uniform-control-flow violations,
24/// out-of-bounds stores, malformed loops, or expression evaluation failures.
25pub fn step<'a>(
26    invocation: &mut Invocation<'a>,
27    memory: &mut Memory,
28    program: &'a Program,
29) -> Result<(), vyre::Error> {
30    if invocation.done() || invocation.waiting_at_barrier {
31        return Ok(());
32    }
33
34    loop {
35        let Some(frame) = invocation.frames_mut().pop() else {
36            return Ok(());
37        };
38        match frame {
39            Frame::Nodes {
40                nodes,
41                index,
42                scoped,
43            } => {
44                if step_nodes_frame(invocation, memory, program, nodes, index, scoped)? {
45                    return Ok(());
46                }
47            }
48            Frame::Loop {
49                var,
50                next,
51                to,
52                body,
53            } => step_loop_frame(invocation, var, next, to, body)?,
54        }
55    }
56}
57
58fn step_nodes_frame<'a>(
59    invocation: &mut Invocation<'a>,
60    memory: &mut Memory,
61    program: &'a Program,
62    nodes: &'a [Node],
63    index: usize,
64    scoped: bool,
65) -> Result<bool, vyre::Error> {
66    if index >= nodes.len() {
67        if scoped {
68            invocation.pop_scope();
69        }
70        return Ok(false);
71    }
72
73    invocation.frames_mut().push(Frame::Nodes {
74        nodes,
75        index: index + 1,
76        scoped,
77    });
78    execute_node(&nodes[index], invocation, memory, program)?;
79    Ok(true)
80}
81
82fn step_loop_frame<'a>(
83    invocation: &mut Invocation<'a>,
84    var: &'a str,
85    next: u32,
86    to: u32,
87    body: &'a [Node],
88) -> Result<(), vyre::Error> {
89    if next >= to {
90        return Ok(());
91    }
92    invocation.frames_mut().push(Frame::Loop {
93        var,
94        next: next.wrapping_add(1),
95        to,
96        body,
97    });
98    invocation.push_scope();
99    invocation.bind_loop_var(var, crate::value::Value::U32(next))?;
100    invocation.frames_mut().push(Frame::Nodes {
101        nodes: body,
102        index: 0,
103        scoped: true,
104    });
105    Ok(())
106}
107
108fn execute_node<'a>(
109    node: &'a Node,
110    invocation: &mut Invocation<'a>,
111    memory: &mut Memory,
112    program: &'a Program,
113) -> Result<(), vyre::Error> {
114    match node {
115        Node::Let { name, value } => eval_let(name, value, invocation, memory, program),
116        Node::Assign { name, value } => eval_assign(name, value, invocation, memory, program),
117        Node::Store {
118            buffer,
119            index,
120            value,
121        } => eval_store(buffer, index, value, invocation, memory, program),
122        Node::If {
123            cond,
124            then,
125            otherwise,
126        } => eval_if(cond, then, otherwise, node, invocation, memory, program),
127        Node::Loop {
128            var,
129            from,
130            to,
131            body,
132        } => eval_loop(var, from, to, body, invocation, memory, program),
133        Node::Return => eval_return(invocation),
134        Node::Block(nodes) => eval_block(nodes, invocation),
135        Node::Barrier { .. } => eval_barrier(invocation),
136        Node::IndirectDispatch {
137            count_buffer,
138            count_offset,
139        } => eval_indirect_dispatch(count_buffer, *count_offset, memory, program),
140        Node::AsyncLoad {
141            source,
142            destination,
143            offset,
144            size,
145            tag,
146        } => eval_async_load(
147            AsyncLoadEval {
148                source,
149                destination,
150                offset,
151                size,
152                tag,
153            },
154            invocation,
155            memory,
156            program,
157        ),
158        Node::AsyncStore {
159            source,
160            destination,
161            offset,
162            size,
163            tag,
164        } => eval_async_store(
165            AsyncStoreEval {
166                source,
167                destination,
168                offset,
169                size,
170                tag,
171            },
172            invocation,
173            memory,
174            program,
175        ),
176        Node::AsyncWait { tag } => eval_async_wait(tag, invocation, memory, program),
177        Node::Trap { address, tag } => {
178            let address = eval_expr::eval(address, invocation, memory, program)?
179                .try_as_u32()
180                .ok_or_else(|| {
181                    Error::interp(format!(
182                        "reference trap `{tag}` address is not a u32. Fix: pass a scalar u32 trap address."
183                    ))
184                })?;
185            Err(vyre::Error::interp(format!(
186                "reference dispatch trapped: address={address}, tag=`{tag}`. Fix: handle the trap condition or route this Program through a backend/runtime with replay support."
187            )))
188        }
189        Node::Resume { tag } => Err(vyre::Error::interp(format!(
190            "reference dispatch reached Resume `{tag}` without a replay runtime. Fix: lower Resume through a runtime-owned replay path before reference execution."
191        ))),
192        Node::AllReduce { buffer, group, .. } => Err(vyre::Error::interp(format!(
193            "reference dispatch reached AllReduce on buffer `{buffer}` for group {}. Fix: run this Program on a distributed backend with collective support or lower the single-rank collective before reference execution.",
194            group.as_u32()
195        ))),
196        Node::AllGather {
197            input,
198            output,
199            group,
200        } => Err(vyre::Error::interp(format!(
201            "reference dispatch reached AllGather `{input}` -> `{output}` for group {}. Fix: run this Program on a distributed backend with collective support or lower the single-rank collective before reference execution.",
202            group.as_u32()
203        ))),
204        Node::ReduceScatter {
205            input,
206            output,
207            group,
208            ..
209        } => Err(vyre::Error::interp(format!(
210            "reference dispatch reached ReduceScatter `{input}` -> `{output}` for group {}. Fix: run this Program on a distributed backend with collective support or lower the single-rank collective before reference execution.",
211            group.as_u32()
212        ))),
213        Node::Broadcast {
214            buffer,
215            root,
216            group,
217        } => Err(vyre::Error::interp(format!(
218            "reference dispatch reached Broadcast on buffer `{buffer}` from root {root} for group {}. Fix: run this Program on a distributed backend with collective support or lower the single-rank collective before reference execution.",
219            group.as_u32()
220        ))),
221        Node::Region { body, .. } => eval_block(body, invocation),
222        Node::Opaque(extension) => Err(vyre::Error::interp(format!(
223            "reference interpreter does not support opaque node extension `{}`/`{}`. Fix: provide a reference evaluator for this NodeExtension or lower it to core Node variants before evaluation.",
224            extension.extension_kind(),
225            extension.debug_identity()
226        ))),
227        _ => Err(vyre::Error::interp(
228            "reference interpreter encountered an unknown Node variant. Fix: update vyre-reference before executing this IR.",
229        )),
230    }
231}
232
233fn eval_let(
234    name: &str,
235    value: &Expr,
236    invocation: &mut Invocation<'_>,
237    memory: &mut Memory,
238    program: &Program,
239) -> Result<(), vyre::Error> {
240    let value = eval_expr::eval(value, invocation, memory, program)?;
241    invocation.bind(name, value)
242}
243
244fn eval_assign(
245    name: &str,
246    value: &Expr,
247    invocation: &mut Invocation<'_>,
248    memory: &mut Memory,
249    program: &Program,
250) -> Result<(), vyre::Error> {
251    let value = eval_expr::eval(value, invocation, memory, program)?;
252    invocation.assign(name, value)
253}
254
255fn eval_store(
256    buffer: &str,
257    index: &Expr,
258    value: &Expr,
259    invocation: &mut Invocation<'_>,
260    memory: &mut Memory,
261    program: &Program,
262) -> Result<(), vyre::Error> {
263    let index = eval_expr::eval(index, invocation, memory, program)?;
264    let index = index
265        .try_as_u32()
266        .ok_or_else(|| Error::interp(format!(
267                "store index {index:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
268        )))?;
269    let value = eval_expr::eval(value, invocation, memory, program)?;
270    let target = eval_expr::buffer_mut(memory, program, buffer)?;
271    oob::store(target, index, &value);
272    Ok(())
273}
274
275fn eval_indirect_dispatch(
276    count_buffer: &str,
277    count_offset: u64,
278    memory: &Memory,
279    program: &Program,
280) -> Result<(), vyre::Error> {
281    if count_offset % 4 != 0 {
282        return Err(Error::interp(format!(
283            "indirect dispatch offset {count_offset} is not 4-byte aligned. Fix: use a u32-aligned dispatch tuple."
284        )));
285    }
286    let decl = program.buffer(count_buffer).ok_or_else(|| {
287        Error::interp(format!(
288            "indirect dispatch references unknown buffer `{count_buffer}`. Fix: declare the count buffer before execution."
289        ))
290    })?;
291    let buffer = if decl.access() == vyre::ir::BufferAccess::Workgroup {
292        memory.workgroup.get(count_buffer)
293    } else {
294        memory.storage.get(count_buffer)
295    }
296    .ok_or_else(|| {
297        Error::interp(format!(
298            "indirect dispatch buffer `{count_buffer}` is missing. Fix: initialize the count buffer before execution."
299        ))
300    })?;
301    let required_end = count_offset.checked_add(12).ok_or_else(|| {
302        Error::interp(
303            "indirect dispatch byte range overflowed u64. Fix: shrink the count offset."
304                .to_string(),
305        )
306    })?;
307    let byte_len = buffer
308        .bytes
309        .read()
310        .map_err(|_| {
311            Error::interp(format!(
312                "indirect dispatch buffer `{count_buffer}` lock is poisoned. Fix: rebuild the interpreter memory state before execution."
313            ))
314        })?
315        .len();
316    if u64::try_from(byte_len).unwrap_or(u64::MAX) < required_end {
317        return Err(Error::interp(format!(
318            "indirect dispatch buffer `{count_buffer}` is too short for a 3-word dispatch tuple at byte offset {count_offset}. Fix: provide 12 readable bytes starting at that offset."
319        )));
320    }
321    Err(Error::interp(format!(
322        "Node::IndirectDispatch cannot execute in the sequential reference interpreter because dynamic indirect dispatch requires runtime queue scheduling. Fix: run this program on a backend/runtime that supports indirect dispatch or lower `{count_buffer}` at byte offset {count_offset} to a static workgroup grid before reference execution."
323    )))
324}
325
326struct AsyncLoadEval<'a> {
327    source: &'a str,
328    destination: &'a str,
329    offset: &'a Expr,
330    size: &'a Expr,
331    tag: &'a str,
332}
333
334struct AsyncStoreEval<'a> {
335    source: &'a str,
336    destination: &'a str,
337    offset: &'a Expr,
338    size: &'a Expr,
339    tag: &'a str,
340}
341
342fn eval_async_load(
343    request: AsyncLoadEval<'_>,
344    invocation: &mut Invocation<'_>,
345    memory: &mut Memory,
346    program: &Program,
347) -> Result<(), vyre::Error> {
348    let start = eval_byte_count(
349        request.offset,
350        "async load source offset",
351        invocation,
352        memory,
353        program,
354    )?;
355    let byte_count = eval_byte_count(request.size, "async load size", invocation, memory, program)?;
356    let payload = read_bytes(memory, program, request.source, start, byte_count)?;
357    ensure_writable_buffer(memory, program, request.destination)?;
358    invocation.begin_async(
359        request.tag,
360        AsyncTransfer::Copy {
361            destination: request.destination.into(),
362            start: 0,
363            payload,
364        },
365    )
366}
367
368fn eval_async_store(
369    request: AsyncStoreEval<'_>,
370    invocation: &mut Invocation<'_>,
371    memory: &mut Memory,
372    program: &Program,
373) -> Result<(), vyre::Error> {
374    let start = eval_byte_count(
375        request.offset,
376        "async store destination offset",
377        invocation,
378        memory,
379        program,
380    )?;
381    let byte_count = eval_byte_count(
382        request.size,
383        "async store size",
384        invocation,
385        memory,
386        program,
387    )?;
388    let payload = read_bytes(memory, program, request.source, 0, byte_count)?;
389    ensure_writable_buffer(memory, program, request.destination)?;
390    invocation.begin_async(
391        request.tag,
392        AsyncTransfer::Copy {
393            destination: request.destination.into(),
394            start,
395            payload,
396        },
397    )
398}
399
400fn eval_async_wait(
401    tag: &str,
402    invocation: &mut Invocation<'_>,
403    memory: &mut Memory,
404    program: &Program,
405) -> Result<(), vyre::Error> {
406    apply_async_transfer(invocation.finish_async(tag)?, memory, program)
407}
408
409fn eval_byte_count(
410    expr: &Expr,
411    label: &str,
412    invocation: &mut Invocation<'_>,
413    memory: &mut Memory,
414    program: &Program,
415) -> Result<usize, Error> {
416    let value = eval_expr::eval(expr, invocation, memory, program)?;
417    usize::try_from(value.try_as_u64().ok_or_else(|| {
418        Error::interp(format!(
419            "{label} cannot be represented as u64. Fix: use an in-range non-negative byte count."
420        ))
421    })?)
422    .map_err(|_| {
423        Error::interp(format!(
424            "{label} exceeds host usize. Fix: reduce the async transfer span."
425        ))
426    })
427}
428
429fn read_bytes(
430    memory: &Memory,
431    program: &Program,
432    source: &str,
433    start: usize,
434    byte_count: usize,
435) -> Result<Vec<u8>, Error> {
436    let buffer = resolve_buffer(memory, program, source)?;
437    let bytes = buffer
438        .bytes
439        .read()
440        .unwrap_or_else(|error| error.into_inner());
441    let mut payload = vec![0; byte_count];
442    if start < bytes.len() {
443        let available = (bytes.len() - start).min(byte_count);
444        payload[..available].copy_from_slice(&bytes[start..start + available]);
445    }
446    Ok(payload)
447}
448
449fn ensure_writable_buffer(memory: &mut Memory, program: &Program, name: &str) -> Result<(), Error> {
450    eval_expr::buffer_mut(memory, program, name).map(|_| ())
451}
452
453fn apply_async_transfer(
454    transfer: AsyncTransfer,
455    memory: &mut Memory,
456    program: &Program,
457) -> Result<(), Error> {
458    match transfer {
459        AsyncTransfer::Copy {
460            destination,
461            start,
462            payload,
463        } => {
464            let buffer = eval_expr::buffer_mut(memory, program, &destination)?;
465            let mut bytes = buffer
466                .bytes
467                .write()
468                .unwrap_or_else(|error| error.into_inner());
469            if start >= bytes.len() {
470                return Ok(());
471            }
472            let write_len = payload.len().min(bytes.len() - start);
473            bytes[start..start + write_len].copy_from_slice(&payload[..write_len]);
474            Ok(())
475        }
476    }
477}
478
479fn resolve_buffer<'a>(
480    memory: &'a Memory,
481    program: &Program,
482    name: &str,
483) -> Result<&'a oob::Buffer, Error> {
484    let decl = program.buffer(name).ok_or_else(|| {
485        Error::interp(format!(
486            "missing buffer declaration `{name}`. Fix: declare every async transfer buffer."
487        ))
488    })?;
489    if decl.access() == vyre::ir::BufferAccess::Workgroup {
490        memory.workgroup.get(name)
491    } else {
492        memory.storage.get(name)
493    }
494    .ok_or_else(|| {
495        Error::interp(format!(
496            "missing buffer `{name}`. Fix: initialize every declared async transfer buffer."
497        ))
498    })
499}
500
501fn eval_if<'a>(
502    cond: &Expr,
503    then: &'a [Node],
504    otherwise: &'a [Node],
505    node: &Node,
506    invocation: &mut Invocation<'a>,
507    memory: &mut Memory,
508    program: &Program,
509) -> Result<(), vyre::Error> {
510    let cond_value = eval_expr::eval(cond, invocation, memory, program)?.truthy();
511    if contains_barrier(then) || contains_barrier(otherwise) {
512        invocation.uniform_checks.push((node_id(node), cond_value));
513    }
514    let branch = if cond_value { then } else { otherwise };
515    invocation.push_scope();
516    invocation.frames_mut().push(Frame::Nodes {
517        nodes: branch,
518        index: 0,
519        scoped: true,
520    });
521    Ok(())
522}
523
524fn eval_loop<'a>(
525    var: &'a str,
526    from: &Expr,
527    to: &Expr,
528    body: &'a [Node],
529    invocation: &mut Invocation<'a>,
530    memory: &mut Memory,
531    program: &Program,
532) -> Result<(), vyre::Error> {
533    let from_value = eval_expr::eval(from, invocation, memory, program)?;
534    let to_value = eval_expr::eval(to, invocation, memory, program)?;
535    let from = from_value.try_as_u32().ok_or_else(|| {
536        Error::interp(format!(
537                "loop lower bound {from_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
538        ))
539    })?;
540    let to = to_value.try_as_u32().ok_or_else(|| Error::interp(format!(
541            "loop upper bound {to_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
542    )))?;
543    invocation.frames_mut().push(Frame::Loop {
544        var,
545        next: from,
546        to,
547        body,
548    });
549    Ok(())
550}
551
552fn eval_return(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
553    invocation.frames_mut().clear();
554    invocation.returned = true;
555    Ok(())
556}
557
558fn eval_block<'a>(nodes: &'a [Node], invocation: &mut Invocation<'a>) -> Result<(), vyre::Error> {
559    invocation.push_scope();
560    invocation.frames_mut().push(Frame::Nodes {
561        nodes,
562        index: 0,
563        scoped: true,
564    });
565    Ok(())
566}
567
568fn eval_barrier(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
569    invocation.waiting_at_barrier = true;
570    Ok(())
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use crate::oob::Buffer;
577    use crate::workgroup::InvocationIds;
578    use vyre::ir::{BufferDecl, DataType};
579
580    fn run_program(program: &Program, memory: &mut Memory) -> Result<(), vyre::Error> {
581        let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
582        while !invocation.done() {
583            step(&mut invocation, memory, program)?;
584        }
585        Ok(())
586    }
587
588    fn bytes(memory: &Memory, name: &str) -> Vec<u8> {
589        memory
590            .storage
591            .get(name)
592            .expect("Fix: test buffer exists")
593            .bytes
594            .read()
595            .unwrap_or_else(|error| error.into_inner())
596            .clone()
597    }
598
599    #[test]
600    fn async_load_wait_copies_payload_into_destination() {
601        let program = Program::wrapped(
602            vec![
603                BufferDecl::read("src", 0, DataType::Bytes).with_count(8),
604                BufferDecl::output("dst", 1, DataType::Bytes).with_count(8),
605            ],
606            [1, 1, 1],
607            vec![
608                Node::async_load_ext("src", "dst", Expr::u32(2), Expr::u32(4), "copy"),
609                Node::AsyncWait { tag: "copy".into() },
610            ],
611        );
612        let mut memory = Memory::empty()
613            .with_storage(
614                "src",
615                Buffer::new(vec![10, 11, 12, 13, 14, 15, 16, 17], DataType::Bytes),
616            )
617            .with_storage("dst", Buffer::new(vec![0; 8], DataType::Bytes));
618
619        run_program(&program, &mut memory).unwrap();
620
621        assert_eq!(bytes(&memory, "dst"), vec![12, 13, 14, 15, 0, 0, 0, 0]);
622    }
623
624    #[test]
625    fn async_store_wait_copies_payload_at_destination_offset() {
626        let program = Program::wrapped(
627            vec![
628                BufferDecl::read("src", 0, DataType::Bytes).with_count(4),
629                BufferDecl::output("dst", 1, DataType::Bytes).with_count(8),
630            ],
631            [1, 1, 1],
632            vec![
633                Node::async_store("src", "dst", Expr::u32(3), Expr::u32(4), "store"),
634                Node::AsyncWait {
635                    tag: "store".into(),
636                },
637            ],
638        );
639        let mut memory = Memory::empty()
640            .with_storage("src", Buffer::new(vec![21, 22, 23, 24], DataType::Bytes))
641            .with_storage("dst", Buffer::new(vec![0; 8], DataType::Bytes));
642
643        run_program(&program, &mut memory).unwrap();
644
645        assert_eq!(bytes(&memory, "dst"), vec![0, 0, 0, 21, 22, 23, 24, 0]);
646    }
647}