vyre-reference 0.1.0

Pure-Rust CPU reference interpreter for vyre IR — byte-identical oracle for backend conformance and small-data fallback
Documentation
//! Workgroup simulation — the parity engine's model of invocation coordination.
//!
//! GPU backends must reproduce the exact barrier synchronization, shared-memory
//! layout, and invocation-ID arithmetic that this module defines. The conform gate
//! compares GPU dispatch output against this deterministic CPU simulation; any
//! divergence in control flow uniformity or workgroup memory semantics is a bug.

use std::collections::{HashMap, HashSet};

use vyre::ir::{BufferAccess, Node, Program};

use vyre::Error;

use crate::{oob::Buffer, value::Value};

/// Maximum per-workgroup shared memory the reference interpreter will allocate.
pub const MAX_WORKGROUP_BYTES: usize = 64 * 1024 * 1024;

/// Identity of one compute invocation.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct InvocationIds {
    /// Global invocation id.
    pub global: [u32; 3],
    /// Workgroup id.
    pub workgroup: [u32; 3],
    /// Local invocation id.
    pub local: [u32; 3],
}

impl InvocationIds {
    /// Zero-valued invocation ids for examples and unit tests.
    pub const ZERO: Self = Self {
        global: [0, 0, 0],
        workgroup: [0, 0, 0],
        local: [0, 0, 0],
    };
}

/// Shared execution memory for storage and current workgroup buffers.
#[derive(Debug)]
pub struct Memory {
    pub(crate) storage: HashMap<String, Buffer>,
    pub(crate) workgroup: HashMap<String, Buffer>,
}

/// One paused or running invocation.
pub struct Invocation<'a> {
    /// Builtin ids for this invocation.
    pub ids: InvocationIds,
    pub(crate) locals: HashMap<String, Value>,
    immutable: HashSet<String>,
    scopes: Vec<Vec<String>>,
    frames: Vec<Frame<'a>>,
    /// True after `return`.
    pub returned: bool,
    /// True when paused at a barrier.
    pub waiting_at_barrier: bool,
    /// Uniform-if observations for branches that contain a barrier.
    pub uniform_checks: Vec<(usize, bool)>,
}

/// Interpreter continuation stack.
#[non_exhaustive]
pub enum Frame<'a> {
    /// Sequence of nodes.
    Nodes {
        /// Nodes being executed.
        nodes: &'a [Node],
        /// Next node index.
        index: usize,
        /// Whether completion pops a lexical scope.
        scoped: bool,
    },
    /// Bounded `u32` loop.
    Loop {
        /// Loop variable name.
        var: &'a str,
        /// Next induction value.
        next: u32,
        /// Exclusive upper bound.
        to: u32,
        /// Loop body.
        body: &'a [Node],
    },
}

impl<'a> Invocation<'a> {
    /// Create an invocation at the start of the entry point.
    pub fn new(ids: InvocationIds, entry: &'a [Node]) -> Self {
        Self {
            ids,
            locals: HashMap::new(),
            immutable: HashSet::new(),
            scopes: vec![Vec::new()],
            frames: vec![Frame::Nodes {
                nodes: entry,
                index: 0,
                scoped: false,
            }],
            returned: false,
            waiting_at_barrier: false,
            uniform_checks: Vec::new(),
        }
    }

    /// Return true when no further execution can occur.
    pub fn done(&self) -> bool {
        self.returned || self.frames.is_empty()
    }

    /// Push a lexical scope.
    ///
    ///
    /// ```rust,no_run
    /// use vyre_reference::workgroup::{Invocation, InvocationIds};
    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
    /// invocation.push_scope();
    /// ```
    pub fn push_scope(&mut self) {
        self.scopes.push(Vec::new());
    }

    /// Pop a lexical scope and remove bindings declared in it.
    ///
    ///
    /// ```rust,no_run
    /// use vyre_reference::workgroup::{Invocation, InvocationIds};
    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
    /// invocation.pop_scope();
    /// ```
    pub fn pop_scope(&mut self) {
        if let Some(names) = self.scopes.pop() {
            for name in names {
                self.locals.remove(&name);
                self.immutable.remove(&name);
            }
        }
    }

    /// Bind a mutable local.
    ///
    ///
    /// ```rust,no_run
    /// use vyre_reference::{value::Value, workgroup::{Invocation, InvocationIds}};
    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
    /// invocation.bind("example", Value::U32(1)).unwrap();
    /// ```
    pub fn bind(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
        if self.locals.contains_key(name) {
            return Err(Error::interp(format!(
                "duplicate local binding `{name}`. Fix: choose a unique local name; shadowing is not allowed."
            )));
        }
        self.locals.insert(name.to_string(), value);
        if let Some(scope) = self.scopes.last_mut() {
            scope.push(name.to_string());
        }
        Ok(())
    }

    /// Bind an immutable loop variable.
    ///
    ///
    /// ```rust,no_run
    /// use vyre_reference::{value::Value, workgroup::{Invocation, InvocationIds}};
    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
    /// invocation.bind_loop_var("example", Value::U32(1)).unwrap();
    /// ```
    pub fn bind_loop_var(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
        self.bind(name, value)?;
        self.immutable.insert(name.to_string());
        Ok(())
    }

    /// Assign an existing mutable local.
    pub fn assign(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
        if self.immutable.contains(name) {
            return Err(Error::interp(format!(
                "assignment to loop variable `{name}`. Fix: loop variables are immutable."
            )));
        }
        let Some(slot) = self.locals.get_mut(name) else {
            return Err(Error::interp(format!(
                "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
            )));
        };
        *slot = value;
        Ok(())
    }

    pub(crate) fn frames_mut(&mut self) -> &mut Vec<Frame<'a>> {
        &mut self.frames
    }
}

pub(crate) fn create_invocations(
    program: &Program,
    workgroup: [u32; 3],
) -> Result<Vec<Invocation<'_>>, vyre::Error> {
    let global_dim = |wgid: u32, size: u32, local: u32| {
        wgid
            .checked_mul(size)
            .and_then(|base| base.checked_add(local))
            .ok_or_else(|| Error::interp(
                "workgroup * dispatch dimensions overflow u32 global id. Fix: reduce workgroup id or workgroup size so each global_invocation_id component fits in u32.",
            ))
    };
    let [sx, sy, sz] = program.workgroup_size();
    let mut invocations = Vec::with_capacity((sx * sy * sz) as usize);
    for z in 0..sz {
        for y in 0..sy {
            for x in 0..sx {
                let local = [x, y, z];
                let global = [
                    global_dim(workgroup[0], sx, x)?,
                    global_dim(workgroup[1], sy, y)?,
                    global_dim(workgroup[2], sz, z)?,
                ];
                invocations.push(Invocation::new(
                    InvocationIds {
                        global,
                        workgroup,
                        local,
                    },
                    program.entry(),
                ));
            }
        }
    }
    Ok(invocations)
}

pub(crate) fn workgroup_memory(program: &Program) -> Result<HashMap<String, Buffer>, vyre::Error> {
    let mut workgroup = HashMap::new();
    let mut allocated = 0usize;
    for decl in program
        .buffers()
        .iter()
        .filter(|decl| decl.access() == BufferAccess::Workgroup)
    {
        let element_size = decl.element().min_bytes();
        let len = (decl.count() as usize)
            .checked_mul(element_size)
            .ok_or_else(|| Error::interp(format!(
                    "workgroup buffer `{}` byte size overflows usize. Fix: reduce count or element size.",
                    decl.name()
            )))?;
        allocated = allocated
            .checked_add(len)
            .ok_or_else(|| Error::interp(
                "total workgroup memory byte size overflows usize. Fix: reduce workgroup buffer declarations.",
            ))?;
        if allocated > MAX_WORKGROUP_BYTES {
            return Err(Error::interp(format!(
                "workgroup memory requires {allocated} bytes, exceeding the {MAX_WORKGROUP_BYTES}-byte reference budget. Fix: reduce workgroup buffer counts."
            )));
        }
        workgroup.insert(
            decl.name().to_string(),
            Buffer {
                bytes: vec![0; len],
                element: decl.element(),
            },
        );
    }
    Ok(workgroup)
}