vyre-reference 0.1.0

Pure-Rust CPU reference interpreter for vyre IR — byte-identical oracle for backend conformance and small-data fallback
Documentation
//! Statement executor that gives the parity engine a pure-Rust ground truth
//! for every `Node` variant.
//!
//! This module simulates the exact control-flow, memory, and barrier behavior
//! that a correct GPU backend must produce. Any divergence in `If`, `Loop`,
//! `Barrier`, or `Store` semantics is caught by the conform gate as a concrete
//! counterexample.

use vyre::ir::{Expr, Node, Program};

use crate::{
    eval_expr, oob,
    workgroup::{Frame, Invocation, Memory},
};
use vyre::Error;

/// Execute one scheduling step for an invocation.
///
/// # Errors
///
/// Returns [`Error::Interp`] for uniform-control-flow violations,
/// out-of-bounds stores, malformed loops, or expression evaluation failures.
pub fn step<'a>(
    invocation: &mut Invocation<'a>,
    memory: &mut Memory,
    program: &'a Program,
) -> Result<(), vyre::Error> {
    if invocation.done() || invocation.waiting_at_barrier {
        return Ok(());
    }

    loop {
        let Some(frame) = invocation.frames_mut().pop() else {
            return Ok(());
        };
        match frame {
            Frame::Nodes {
                nodes,
                index,
                scoped,
            } => {
                if step_nodes_frame(invocation, memory, program, nodes, index, scoped)? {
                    return Ok(());
                }
            }
            Frame::Loop {
                var,
                next,
                to,
                body,
            } => step_loop_frame(invocation, var, next, to, body)?,
        }
    }
}

fn step_nodes_frame<'a>(
    invocation: &mut Invocation<'a>,
    memory: &mut Memory,
    program: &'a Program,
    nodes: &'a [Node],
    index: usize,
    scoped: bool,
) -> Result<bool, vyre::Error> {
    if index >= nodes.len() {
        if scoped {
            invocation.pop_scope();
        }
        return Ok(false);
    }

    invocation.frames_mut().push(Frame::Nodes {
        nodes,
        index: index + 1,
        scoped,
    });
    execute_node(&nodes[index], invocation, memory, program)?;
    Ok(true)
}

fn step_loop_frame<'a>(
    invocation: &mut Invocation<'a>,
    var: &'a str,
    next: u32,
    to: u32,
    body: &'a [Node],
) -> Result<(), vyre::Error> {
    if next >= to {
        return Ok(());
    }
    invocation.frames_mut().push(Frame::Loop {
        var,
        next: next.wrapping_add(1),
        to,
        body,
    });
    invocation.push_scope();
    invocation.bind_loop_var(var, crate::value::Value::U32(next))?;
    invocation.frames_mut().push(Frame::Nodes {
        nodes: body,
        index: 0,
        scoped: true,
    });
    Ok(())
}

fn execute_node<'a>(
    node: &'a Node,
    invocation: &mut Invocation<'a>,
    memory: &mut Memory,
    program: &'a Program,
) -> Result<(), vyre::Error> {
    match node {
        Node::Let { name, value } => eval_let(name, value, invocation, memory, program),
        Node::Assign { name, value } => eval_assign(name, value, invocation, memory, program),
        Node::Store {
            buffer,
            index,
            value,
        } => eval_store(buffer, index, value, invocation, memory, program),
        Node::If {
            cond,
            then,
            otherwise,
        } => eval_if(cond, then, otherwise, node, invocation, memory, program),
        Node::Loop {
            var,
            from,
            to,
            body,
        } => eval_loop(var, from, to, body, invocation, memory, program),
        Node::Return => eval_return(invocation),
        Node::Block(nodes) => eval_block(nodes, invocation),
        Node::Barrier => eval_barrier(invocation),
    }
}

fn eval_let(
    name: &str,
    value: &Expr,
    invocation: &mut Invocation<'_>,
    memory: &mut Memory,
    program: &Program,
) -> Result<(), vyre::Error> {
    let value = eval_expr::eval(value, invocation, memory, program)?;
    invocation.bind(name, value)
}

fn eval_assign(
    name: &str,
    value: &Expr,
    invocation: &mut Invocation<'_>,
    memory: &mut Memory,
    program: &Program,
) -> Result<(), vyre::Error> {
    let value = eval_expr::eval(value, invocation, memory, program)?;
    invocation.assign(name, value)
}

fn eval_store(
    buffer: &str,
    index: &Expr,
    value: &Expr,
    invocation: &mut Invocation<'_>,
    memory: &mut Memory,
    program: &Program,
) -> Result<(), vyre::Error> {
    let index = eval_expr::eval(index, invocation, memory, program)?;
    let index = index
        .try_as_u32()
        .ok_or_else(|| Error::interp(format!(
                "store index {index:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
        )))?;
    let value = eval_expr::eval(value, invocation, memory, program)?;
    let target = eval_expr::buffer_mut(memory, program, buffer)?;
    oob::store(target, index, &value);
    Ok(())
}

fn eval_if<'a>(
    cond: &Expr,
    then: &'a [Node],
    otherwise: &'a [Node],
    node: &Node,
    invocation: &mut Invocation<'a>,
    memory: &mut Memory,
    program: &Program,
) -> Result<(), vyre::Error> {
    let cond_value = eval_expr::eval(cond, invocation, memory, program)?.truthy();
    if contains_barrier(then) || contains_barrier(otherwise) {
        invocation.uniform_checks.push((node_id(node), cond_value));
    }
    let branch = if cond_value { then } else { otherwise };
    invocation.push_scope();
    invocation.frames_mut().push(Frame::Nodes {
        nodes: branch,
        index: 0,
        scoped: true,
    });
    Ok(())
}

fn eval_loop<'a>(
    var: &'a str,
    from: &Expr,
    to: &Expr,
    body: &'a [Node],
    invocation: &mut Invocation<'a>,
    memory: &mut Memory,
    program: &Program,
) -> Result<(), vyre::Error> {
    let from_value = eval_expr::eval(from, invocation, memory, program)?;
    let to_value = eval_expr::eval(to, invocation, memory, program)?;
    let from = from_value.try_as_u32().ok_or_else(|| {
        Error::interp(format!(
                "loop lower bound {from_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
        ))
    })?;
    let to = to_value.try_as_u32().ok_or_else(|| Error::interp(format!(
            "loop upper bound {to_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
    )))?;
    invocation.frames_mut().push(Frame::Loop {
        var,
        next: from,
        to,
        body,
    });
    Ok(())
}

fn eval_return(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
    invocation.frames_mut().clear();
    invocation.returned = true;
    Ok(())
}

fn eval_block<'a>(nodes: &'a [Node], invocation: &mut Invocation<'a>) -> Result<(), vyre::Error> {
    invocation.push_scope();
    invocation.frames_mut().push(Frame::Nodes {
        nodes,
        index: 0,
        scoped: true,
    });
    Ok(())
}

fn eval_barrier(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
    invocation.waiting_at_barrier = true;
    Ok(())
}

/// Whether any statement in `nodes` may reach a [`Node::Barrier`], scanning
/// child statement lists recursively with an exhaustive [`Node`] match.
fn contains_barrier(nodes: &[Node]) -> bool {
    nodes.iter().any(node_contains_barrier)
}

fn node_contains_barrier(node: &Node) -> bool {
    match node {
        Node::Barrier => true,
        Node::Let { .. } | Node::Assign { .. } | Node::Store { .. } | Node::Return => false,
        Node::If {
            then, otherwise, ..
        } => contains_barrier(then) || contains_barrier(otherwise),
        Node::Loop { body, .. } => contains_barrier(body),
        Node::Block(body) => contains_barrier(body),
    }
}

fn node_id(node: &Node) -> usize {
    std::ptr::from_ref(node).addr()
}