vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Out-of-bounds access conformance enforcer.
//!
//! Verifies that a backend correctly implements the vyre OOB contract:
//! - OOB loads return zero.
//! - OOB stores are no-ops.
//! - OOB atomics return zero and are no-ops.

use vyre::VyreBackend;

use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};

/// Report produced by the OOB enforcer.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct OobReport {
    /// All detected OOB contract violations.
    pub violations: Vec<OobViolation>,
}

/// A single detected OOB contract violation.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct OobViolation {
    /// Which OOB test failed.
    pub test: OobTest,
    /// Actionable error message with "Fix: ..." prefix.
    pub message: String,
}

/// OOB test variant.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum OobTest {
    /// Out-of-bounds load test.
    Load,
    /// Out-of-bounds store test.
    Store,
    /// Out-of-bounds atomic add test.
    AtomicAdd,
}

/// Enforce the vyre out-of-bounds contract on `backend`.
///
/// Constructs minimal IR programs with deliberately out-of-bounds indices and
/// dispatches them through the backend. Any deviation from the contract is
/// recorded as an `OobViolation` with an actionable "Fix: ..." message.
#[inline]
pub(crate) fn enforce_oob(backend: &dyn vyre::VyreBackend) -> OobReport {
    let mut violations = Vec::new();

    if let Err(msg) = check_oob_load(backend) {
        violations.push(OobViolation {
            test: OobTest::Load,
            message: msg,
        });
    }

    if let Err(msg) = check_oob_store(backend) {
        violations.push(OobViolation {
            test: OobTest::Store,
            message: msg,
        });
    }

    if let Err(msg) = check_oob_atomic(backend) {
        violations.push(OobViolation {
            test: OobTest::AtomicAdd,
            message: msg,
        });
    }

    OobReport { violations }
}

fn check_oob_load(backend: &dyn vyre::VyreBackend) -> Result<(), String> {
    let program = program_oob_load();
    let input = vec![0xFF; 16]; // 4 u32s, all 0xFFFFFFFF
    let output_size = 16;

    let output = dispatch_exact(backend, &program, &[input], output_size)
        .map_err(|e| format!("Fix: backend dispatch failed for OOB load test: {e}"))?;

    let first_word = first_word(&output, "OOB load")?;
    if first_word != 0 {
        return Err(format!(
            "Fix: OOB load returned 0x{first_word:X}, contract requires 0"
        ));
    }

    Ok(())
}

fn check_oob_store(backend: &dyn vyre::VyreBackend) -> Result<(), String> {
    let program = program_oob_store();
    let input = vec![0u8; 16];
    let output_size = 16;

    let output = dispatch_exact(backend, &program, &[input], output_size)
        .map_err(|e| format!("Fix: backend dispatch failed for OOB store test: {e}"))?;
    assert_exact_output_len(&output, output_size, "OOB store")?;

    for (i, &b) in output.iter().enumerate() {
        if b != POISON_BYTE {
            return Err(format!(
                "Fix: OOB store modified byte {i} to 0x{b:02X}, contract requires no modification of the poison sentinel"
            ));
        }
    }

    Ok(())
}

fn check_oob_atomic(backend: &dyn vyre::VyreBackend) -> Result<(), String> {
    let program = program_oob_atomic_add();
    let input = vec![0u8; 16];
    let output_size = 16;

    let output = dispatch_exact(backend, &program, &[input], output_size)
        .map_err(|e| format!("Fix: backend dispatch failed for OOB atomic test: {e}"))?;
    assert_exact_output_len(&output, output_size, "OOB atomic")?;

    let first_word = first_word(&output, "OOB atomic")?;
    if first_word != 0 {
        return Err(format!(
            "Fix: OOB atomic returned 0x{first_word:X}, contract requires 0"
        ));
    }

    for (i, &b) in output.iter().enumerate().skip(4) {
        if b != 0 {
            return Err(format!(
                "Fix: OOB atomic modified byte {i} to 0x{b:02X}, contract requires no modification"
            ));
        }
    }

    Ok(())
}

const POISON_BYTE: u8 = 0xCD;

fn first_word(output: &[u8], test: &str) -> Result<u32, String> {
    let bytes = output.get(0..4).ok_or_else(|| {
        format!(
            "Fix: backend returned {} bytes for {test}, expected at least 4",
            output.len()
        )
    })?;
    Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}

fn assert_exact_output_len(output: &[u8], expected: usize, test: &str) -> Result<(), String> {
    if output.len() == expected {
        return Ok(());
    }
    Err(format!(
        "Fix: backend returned {} bytes for {test}, expected exactly {expected}; truncated or padded OOB output cannot prove the contract.",
        output.len()
    ))
}

fn dispatch_exact(
    backend: &dyn vyre::VyreBackend,
    program: &Program,
    inputs: &[Vec<u8>],
    output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
    let program = program_with_output_size(program, output_size);
    let mut outputs = backend.dispatch(&program, inputs, &vyre::DispatchConfig::default())?;
    if outputs.is_empty() {
        return Err(vyre::BackendError::new(
            "backend returned zero output buffers. Fix: return the OOB probe output as outputs[0].",
        ));
    }
    let output = outputs.remove(0);
    if output.len() != output_size {
        return Err(vyre::BackendError::new(format!(
            "backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the OOB probe output declaration.",
            output.len()
        )));
    }
    Ok(output)
}

fn program_with_output_size(program: &Program, output_size: usize) -> Program {
    let mut buffers = program.buffers().to_vec();
    for buffer in &mut buffers {
        if buffer.access == vyre::ir::BufferAccess::ReadWrite {
            buffer.is_output = true;
            buffer.count = output_size.div_ceil(4).try_into().unwrap_or(u32::MAX);
            break;
        }
    }
    Program::new(buffers, program.workgroup_size(), program.entry().to_vec())
}

fn program_oob_load() -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::U32),
            BufferDecl::read_write("output", 1, DataType::U32),
        ],
        [1, 1, 1],
        vec![Node::store(
            "output",
            Expr::u32(0),
            Expr::load("input", Expr::u32(1_000_000)),
        )],
    )
}

fn program_oob_store() -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::U32),
            BufferDecl::read_write("output", 1, DataType::U32),
        ],
        [1, 1, 1],
        vec![Node::store(
            "output",
            Expr::u32(1_000_000),
            Expr::u32(0xDEAD_BEEF),
        )],
    )
}

fn program_oob_atomic_add() -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::U32),
            BufferDecl::read_write("output", 1, DataType::U32),
        ],
        [1, 1, 1],
        vec![
            Node::let_bind(
                "old",
                Expr::atomic_add("output", Expr::u32(1_000_000), Expr::u32(1)),
            ),
            Node::store("output", Expr::u32(0), Expr::var("old")),
        ],
    )
}

/// Registry entry for `oob_access` enforcement.
pub struct OobAccessEnforcer;

impl crate::enforce::EnforceGate for OobAccessEnforcer {
    fn id(&self) -> &'static str {
        "oob_access"
    }

    fn name(&self) -> &'static str {
        "oob_access"
    }

    fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
        let Some(backend) = ctx.backend else {
            return vec![crate::enforce::aggregate_finding(
                self.id(),
                vec![
                    "oob_access: backend is required. Fix: provide a VyreBackend in EnforceCtx."
                        .to_string(),
                ],
            )];
        };
        let report = enforce_oob(backend);
        let messages = report
            .violations
            .into_iter()
            .map(|violation| violation.message)
            .collect::<Vec<_>>();
        crate::enforce::finding_result(self.id(), messages)
    }
}

/// Auto-registered `oob_access` enforcer.
pub const REGISTERED: OobAccessEnforcer = OobAccessEnforcer;

#[cfg(test)]
mod tests {
    use super::*;

    struct CorrectOobBackend;

    impl vyre::VyreBackend for CorrectOobBackend {
        fn id(&self) -> &'static str {
            "correct-oob-mock"
        }

        fn dispatch(
            &self,
            program: &vyre::Program,
            _inputs: &[Vec<u8>],
            _config: &vyre::DispatchConfig,
        ) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
            let output = match program.entry.as_slice() {
                // OOB store: store(output, 1000000, 0xDEAD_BEEF)
                [vyre::ir::Node::Store {
                    value: vyre::ir::Expr::LitU32(0xDEAD_BEEF),
                    ..
                }] => vec![POISON_BYTE; 16],
                // OOB atomic: let old = atomic_add(...); store(output, 0, old)
                [vyre::ir::Node::Let { .. }, vyre::ir::Node::Store { .. }] => vec![0u8; 16],
                // OOB load: store(output, 0, load(input, 1000000))
                _ => vec![0u8; 16],
            };
            Ok(vec![output])
        }
    }

    #[test]
    fn program_oob_load_returns_zero() {
        let program = program_oob_load();
        assert_eq!(program.buffers().len(), 2);
        assert!(program.has_buffer("input"));
        assert!(program.has_buffer("output"));
    }

    #[test]
    fn enforce_oob_correct_backend_has_no_violations() {
        let report = enforce_oob(&CorrectOobBackend);
        assert!(
            report.violations.is_empty(),
            "expected no violations, got: {:?}",
            report.violations
        );
    }
}