vyre-reference 0.1.0

Pure-Rust CPU reference interpreter for vyre IR — byte-identical oracle for backend conformance and small-data fallback
Documentation
//! Top-level interpreter dispatch — the parity engine's executable specification.
//!
//! This module exists so every vyre IR program has one deterministic result
//! defined by pure Rust code. The conform gate treats the output of `run` as
//! the golden expected value and diffs it against the GPU backend's actual
//! dispatch output. Any byte-level divergence is a certified bug in the backend.

use std::collections::HashMap;

use vyre::ir::{BufferAccess, BufferDecl, Program};

use vyre::Error;

use crate::{
    eval_node,
    oob::Buffer,
    value::Value,
    workgroup::{self, Invocation, Memory},
};

/// Execute a vyre IR program on the pure Rust reference interpreter.
///
/// `inputs` are matched to every non-workgroup buffer declaration in
/// `Program::buffers` order. `ReadWrite` buffers consume an input value as
/// their initial contents and are returned as raw `Value::Bytes` in declaration
/// order after dispatch.
///
/// # Errors
///
/// Returns [`Error::Interp`] if IR validation fails, inputs are missing or
/// excess, workgroup size is zero, names are unresolved, operand types do not
/// match, unsupported IR is encountered, or float operations are requested
/// before the integer-only interpreter grows full float support.
pub fn run(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
    let validation_errors = vyre::ir::validate(program);
    if !validation_errors.is_empty() {
        let messages = validation_errors
            .into_iter()
            .map(|error| error.message().to_string())
            .collect::<Vec<_>>()
            .join("; ");
        return Err(Error::interp(format!(
            "program failed IR validation: {messages}. Fix: repair the Program before invoking the reference interpreter."
        )));
    }

    let Prepared {
        storage,
        output_names,
        max_elements,
    } = prepare_storage(program, inputs)?;
    execute_dispatch(program, storage, output_names, max_elements)
}

struct Prepared {
    storage: HashMap<String, Buffer>,
    output_names: Vec<String>,
    max_elements: u32,
}

fn prepare_storage(program: &Program, inputs: &[Value]) -> Result<Prepared, vyre::Error> {
    let mut storage = HashMap::new();
    let mut input_index = 0usize;
    let mut output_names = Vec::new();
    let mut max_elements = 1u32;

    for decl in program.buffers() {
        if decl.access() == BufferAccess::Workgroup {
            continue;
        }
        let value = inputs
            .get(input_index)
            .ok_or_else(|| Error::interp(format!(
                    "missing input for buffer `{}`. Fix: pass one Value for each non-workgroup buffer in Program::buffers order.",
                    decl.name()
            )))?;
        input_index += 1;

        let bytes = value.to_bytes();
        max_elements = max_elements.max(element_count(decl, bytes.len())?);
        if decl.access() == BufferAccess::ReadWrite {
            output_names.push(decl.name().to_string());
        }
        storage.insert(
            decl.name().to_string(),
            Buffer {
                bytes,
                element: decl.element(),
            },
        );
    }

    if input_index != inputs.len() {
        return Err(Error::interp(
            "unused input values supplied. Fix: pass exactly one Value per non-workgroup buffer declaration.",
        ));
    }

    Ok(Prepared {
        storage,
        output_names,
        max_elements,
    })
}

fn execute_dispatch(
    program: &Program,
    mut storage: HashMap<String, Buffer>,
    output_names: Vec<String>,
    max_elements: u32,
) -> Result<Vec<Value>, vyre::Error> {
    validate_workgroup_size(program)?;
    let invocations_per_workgroup = invocations_per_workgroup(program);
    let workgroup_count_x = max_elements.div_ceil(invocations_per_workgroup).max(1);

    for wg_x in 0..workgroup_count_x {
        run_workgroup(program, &mut storage, [wg_x, 0, 0])?;
    }

    output_names
        .into_iter()
        .map(|name| {
            storage
                .remove(&name)
                .map(|buffer| Value::Bytes(buffer.bytes))
                .ok_or_else(|| Error::interp(format!(
                        "missing output buffer `{name}` after dispatch. Fix: keep buffer declarations unique."
                )))
        })
        .collect()
}

fn validate_workgroup_size(program: &Program) -> Result<(), vyre::Error> {
    if program.workgroup_size().contains(&0) {
        return Err(Error::interp(
            "workgroup size contains zero. Fix: all dimensions must be >= 1.",
        ));
    }
    Ok(())
}

fn invocations_per_workgroup(program: &Program) -> u32 {
    program
        .workgroup_size()
        .iter()
        .copied()
        .fold(1u32, u32::saturating_mul)
        .max(1)
}

fn run_workgroup(
    program: &Program,
    storage: &mut HashMap<String, Buffer>,
    workgroup_id: [u32; 3],
) -> Result<(), vyre::Error> {
    let mut memory = Memory {
        storage: std::mem::take(storage),
        workgroup: workgroup::workgroup_memory(program)?,
    };
    let mut invocations = workgroup::create_invocations(program, workgroup_id)?;
    run_invocations(program, &mut memory, &mut invocations)?;
    *storage = memory.storage;
    Ok(())
}

fn run_invocations<'a>(
    program: &'a Program,
    memory: &mut Memory,
    invocations: &mut [Invocation<'a>],
) -> Result<(), vyre::Error> {
    while invocations.iter().any(|invocation| !invocation.done()) {
        let made_progress = step_round_robin(program, memory, invocations)?;
        verify_uniform_control_flow(invocations)?;
        if release_barrier_if_ready(invocations) {
            continue;
        }
        if !made_progress && live_waiting_count(invocations) > 0 {
            return Err(Error::interp(
                "program violates uniform-control-flow rule: not every live invocation reached the same barrier. Fix: move Barrier to uniform control flow.",
            ));
        }
    }
    Ok(())
}

fn step_round_robin<'a>(
    program: &'a Program,
    memory: &mut Memory,
    invocations: &mut [Invocation<'a>],
) -> Result<bool, vyre::Error> {
    let mut made_progress = false;
    for invocation in invocations {
        if invocation.done() || invocation.waiting_at_barrier {
            continue;
        }
        eval_node::step(invocation, memory, program)?;
        made_progress = true;
    }
    Ok(made_progress)
}

fn release_barrier_if_ready(invocations: &mut [Invocation<'_>]) -> bool {
    let active = invocations
        .iter()
        .filter(|invocation| !invocation.done())
        .count();
    let waiting = live_waiting_count(invocations);
    if active > 0 && active == waiting {
        for invocation in invocations {
            invocation.waiting_at_barrier = false;
        }
        true
    } else {
        false
    }
}

fn live_waiting_count(invocations: &[Invocation<'_>]) -> usize {
    invocations
        .iter()
        .filter(|invocation| !invocation.done() && invocation.waiting_at_barrier)
        .count()
}

fn verify_uniform_control_flow(invocations: &[Invocation<'_>]) -> Result<(), vyre::Error> {
    // Kimi audit finding #1: filter on `!done()` instead of
    // `!returned`. A finished invocation that exited normally
    // (`frames.is_empty()` but `returned == false`) still carries
    // stale `uniform_checks` entries from its own past branches.
    // Including them in the cross-invocation comparison produces
    // false uniform-control-flow violations when a second invocation
    // legitimately visits the same barrier with a different branch
    // condition.
    let mut observed: HashMap<usize, bool> = HashMap::new();
    for invocation in invocations.iter().filter(|invocation| !invocation.done()) {
        for (id, value) in &invocation.uniform_checks {
            if let Some(previous) = observed.insert(*id, *value) {
                if previous != *value {
                    return Err(Error::interp(
                        "program violates uniform-control-flow rule: Barrier appears inside an If whose condition differs across the workgroup. Fix: make the condition uniform or move Barrier outside the branch.",
                    ));
                }
            }
        }
    }
    Ok(())
}

fn element_count(decl: &BufferDecl, byte_len: usize) -> Result<u32, vyre::Error> {
    let stride = decl.element().min_bytes();
    if stride == 0 {
        return u32::try_from(byte_len).map_err(|_| Error::interp(format!(
                "buffer `{}` has {} bytes and cannot be indexed within u32 address space. Fix: shrink or split the invocation."
                , decl.name(),
                byte_len,
        )));
    }
    let elements = byte_len / stride;
    u32::try_from(elements).map_err(|_| Error::interp(format!(
            "buffer `{}` has {} bytes for stride {} and overflows u32 elements. Fix: shrink declaration footprint or split work.",
            decl.name(),
            byte_len,
            stride,
    )))
}