vyre-reference 0.4.1

Pure-Rust CPU reference interpreter for vyre IR — byte-identical oracle for backend conformance and small-data fallback
Documentation
use crate::{dual_impls::common, workgroup::Memory};
use vyre_primitives::PatternMatchDfa;

const HEADER_LEN: usize = 16;
const MAGIC: &[u8; 4] = b"VDFA";

impl common::ReferenceEvaluator for PatternMatchDfa {
    fn evaluate(&self, inputs: &[Memory]) -> Result<Memory, common::EvalError> {
        let haystack = common::one_input(inputs, "scan_dfa")?;
        let dfa = ParsedDfa::parse(&self.dfa)?;
        let mut state = dfa.start;
        let mut offsets = Vec::new();
        for (offset, byte) in haystack.iter().copied().enumerate() {
            state = dfa.step(state, byte)?;
            if dfa.accepts[state] {
                offsets.push(u32::try_from(offset).map_err(|_| {
                    common::EvalError::new(
                        "primitive `scan_dfa` offset exceeds u32. Fix: split haystacks before 4 GiB.",
                    )
                })?);
            }
        }
        Ok(common::write_u32s(offsets))
    }
}

struct ParsedDfa {
    start: usize,
    accepts: Vec<bool>,
    transitions: Vec<u32>,
}

impl ParsedDfa {
    fn parse(bytes: &[u8]) -> Result<Self, common::EvalError> {
        if bytes.len() < HEADER_LEN || &bytes[..4] != MAGIC {
            return Err(common::EvalError::new(
                "primitive `scan_dfa` expected VDFA header. Fix: encode magic, state_count, start, and accept_count.",
            ));
        }
        let state_count = read_u32_at(bytes, 4)? as usize;
        let start = read_u32_at(bytes, 8)? as usize;
        let accept_count = read_u32_at(bytes, 12)? as usize;
        if state_count == 0 || start >= state_count {
            return Err(common::EvalError::new(
                "primitive `scan_dfa` has invalid state count/start. Fix: provide at least one state and a valid start state.",
            ));
        }
        let accept_bytes = accept_count.checked_mul(4).ok_or_else(|| {
            common::EvalError::new(
                "primitive `scan_dfa` accept table size overflow. Fix: bound DFA state metadata.",
            )
        })?;
        let transition_start = HEADER_LEN + accept_bytes;
        let transition_words = state_count.checked_mul(256).ok_or_else(|| {
            common::EvalError::new(
                "primitive `scan_dfa` transition table size overflow. Fix: bound DFA state count.",
            )
        })?;
        let transition_bytes = transition_words.checked_mul(4).ok_or_else(|| {
            common::EvalError::new(
                "primitive `scan_dfa` transition byte size overflow. Fix: bound DFA state count.",
            )
        })?;
        if bytes.len() != transition_start + transition_bytes {
            return Err(common::EvalError::new(format!(
                "primitive `scan_dfa` byte length mismatch: got {}, expected {}. Fix: encode accept states followed by state_count*256 u32 transitions.",
                bytes.len(),
                transition_start + transition_bytes
            )));
        }
        let mut accepts = vec![false; state_count];
        for accept in 0..accept_count {
            let state = read_u32_at(bytes, HEADER_LEN + accept * 4)? as usize;
            if state >= state_count {
                return Err(common::EvalError::new(
                    "primitive `scan_dfa` accept state is out of range. Fix: keep accept states below state_count.",
                ));
            }
            accepts[state] = true;
        }
        let transitions = common::u32_words(&bytes[transition_start..], "scan_dfa")?;
        Ok(Self {
            start,
            accepts,
            transitions,
        })
    }

    fn step(&self, state: usize, byte: u8) -> Result<usize, common::EvalError> {
        let offset = state * 256 + usize::from(byte);
        let next = self.transitions[offset] as usize;
        if next >= self.accepts.len() {
            Err(common::EvalError::new(
                "primitive `scan_dfa` transition targets an out-of-range state. Fix: validate every transition target.",
            ))
        } else {
            Ok(next)
        }
    }
}

fn read_u32_at(bytes: &[u8], offset: usize) -> Result<u32, common::EvalError> {
    if offset + 4 > bytes.len() {
        return Err(common::EvalError::new(
            "primitive `scan_dfa` truncated u32 field. Fix: encode all header fields.",
        ));
    }
    Ok(u32::from_le_bytes([
        bytes[offset],
        bytes[offset + 1],
        bytes[offset + 2],
        bytes[offset + 3],
    ]))
}