use vyre::VyreBackend;
use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct OobReport {
pub violations: Vec<OobViolation>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct OobViolation {
pub test: OobTest,
pub message: String,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum OobTest {
Load,
Store,
AtomicAdd,
}
#[inline]
pub(crate) fn enforce_oob(backend: &dyn vyre::VyreBackend) -> OobReport {
let mut violations = Vec::new();
if let Err(msg) = check_oob_load(backend) {
violations.push(OobViolation {
test: OobTest::Load,
message: msg,
});
}
if let Err(msg) = check_oob_store(backend) {
violations.push(OobViolation {
test: OobTest::Store,
message: msg,
});
}
if let Err(msg) = check_oob_atomic(backend) {
violations.push(OobViolation {
test: OobTest::AtomicAdd,
message: msg,
});
}
OobReport { violations }
}
fn check_oob_load(backend: &dyn vyre::VyreBackend) -> Result<(), String> {
let program = program_oob_load();
let input = vec![0xFF; 16]; let output_size = 16;
let output = dispatch_exact(backend, &program, &[input], output_size)
.map_err(|e| format!("Fix: backend dispatch failed for OOB load test: {e}"))?;
let first_word = first_word(&output, "OOB load")?;
if first_word != 0 {
return Err(format!(
"Fix: OOB load returned 0x{first_word:X}, contract requires 0"
));
}
Ok(())
}
fn check_oob_store(backend: &dyn vyre::VyreBackend) -> Result<(), String> {
let program = program_oob_store();
let input = vec![0u8; 16];
let output_size = 16;
let output = dispatch_exact(backend, &program, &[input], output_size)
.map_err(|e| format!("Fix: backend dispatch failed for OOB store test: {e}"))?;
assert_exact_output_len(&output, output_size, "OOB store")?;
for (i, &b) in output.iter().enumerate() {
if b != POISON_BYTE {
return Err(format!(
"Fix: OOB store modified byte {i} to 0x{b:02X}, contract requires no modification of the poison sentinel"
));
}
}
Ok(())
}
fn check_oob_atomic(backend: &dyn vyre::VyreBackend) -> Result<(), String> {
let program = program_oob_atomic_add();
let input = vec![0u8; 16];
let output_size = 16;
let output = dispatch_exact(backend, &program, &[input], output_size)
.map_err(|e| format!("Fix: backend dispatch failed for OOB atomic test: {e}"))?;
assert_exact_output_len(&output, output_size, "OOB atomic")?;
let first_word = first_word(&output, "OOB atomic")?;
if first_word != 0 {
return Err(format!(
"Fix: OOB atomic returned 0x{first_word:X}, contract requires 0"
));
}
for (i, &b) in output.iter().enumerate().skip(4) {
if b != 0 {
return Err(format!(
"Fix: OOB atomic modified byte {i} to 0x{b:02X}, contract requires no modification"
));
}
}
Ok(())
}
const POISON_BYTE: u8 = 0xCD;
fn first_word(output: &[u8], test: &str) -> Result<u32, String> {
let bytes = output.get(0..4).ok_or_else(|| {
format!(
"Fix: backend returned {} bytes for {test}, expected at least 4",
output.len()
)
})?;
Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
fn assert_exact_output_len(output: &[u8], expected: usize, test: &str) -> Result<(), String> {
if output.len() == expected {
return Ok(());
}
Err(format!(
"Fix: backend returned {} bytes for {test}, expected exactly {expected}; truncated or padded OOB output cannot prove the contract.",
output.len()
))
}
fn dispatch_exact(
backend: &dyn vyre::VyreBackend,
program: &Program,
inputs: &[Vec<u8>],
output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
let program = program_with_output_size(program, output_size);
let mut outputs = backend.dispatch(&program, inputs, &vyre::DispatchConfig::default())?;
if outputs.is_empty() {
return Err(vyre::BackendError::new(
"backend returned zero output buffers. Fix: return the OOB probe output as outputs[0].",
));
}
let output = outputs.remove(0);
if output.len() != output_size {
return Err(vyre::BackendError::new(format!(
"backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the OOB probe output declaration.",
output.len()
)));
}
Ok(output)
}
fn program_with_output_size(program: &Program, output_size: usize) -> Program {
let mut buffers = program.buffers().to_vec();
for buffer in &mut buffers {
if buffer.access == vyre::ir::BufferAccess::ReadWrite {
buffer.is_output = true;
buffer.count = output_size.div_ceil(4).try_into().unwrap_or(u32::MAX);
break;
}
}
Program::new(buffers, program.workgroup_size(), program.entry().to_vec())
}
fn program_oob_load() -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::read_write("output", 1, DataType::U32),
],
[1, 1, 1],
vec![Node::store(
"output",
Expr::u32(0),
Expr::load("input", Expr::u32(1_000_000)),
)],
)
}
fn program_oob_store() -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::read_write("output", 1, DataType::U32),
],
[1, 1, 1],
vec![Node::store(
"output",
Expr::u32(1_000_000),
Expr::u32(0xDEAD_BEEF),
)],
)
}
fn program_oob_atomic_add() -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::read_write("output", 1, DataType::U32),
],
[1, 1, 1],
vec![
Node::let_bind(
"old",
Expr::atomic_add("output", Expr::u32(1_000_000), Expr::u32(1)),
),
Node::store("output", Expr::u32(0), Expr::var("old")),
],
)
}
pub struct OobAccessEnforcer;
impl crate::enforce::EnforceGate for OobAccessEnforcer {
fn id(&self) -> &'static str {
"oob_access"
}
fn name(&self) -> &'static str {
"oob_access"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let Some(backend) = ctx.backend else {
return vec![crate::enforce::aggregate_finding(
self.id(),
vec![
"oob_access: backend is required. Fix: provide a VyreBackend in EnforceCtx."
.to_string(),
],
)];
};
let report = enforce_oob(backend);
let messages = report
.violations
.into_iter()
.map(|violation| violation.message)
.collect::<Vec<_>>();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: OobAccessEnforcer = OobAccessEnforcer;
#[cfg(test)]
mod tests {
use super::*;
struct CorrectOobBackend;
impl vyre::VyreBackend for CorrectOobBackend {
fn id(&self) -> &'static str {
"correct-oob-mock"
}
fn dispatch(
&self,
program: &vyre::Program,
_inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let output = match program.entry.as_slice() {
[vyre::ir::Node::Store {
value: vyre::ir::Expr::LitU32(0xDEAD_BEEF),
..
}] => vec![POISON_BYTE; 16],
[vyre::ir::Node::Let { .. }, vyre::ir::Node::Store { .. }] => vec![0u8; 16],
_ => vec![0u8; 16],
};
Ok(vec![output])
}
}
#[test]
fn program_oob_load_returns_zero() {
let program = program_oob_load();
assert_eq!(program.buffers().len(), 2);
assert!(program.has_buffer("input"));
assert!(program.has_buffer("output"));
}
#[test]
fn enforce_oob_correct_backend_has_no_violations() {
let report = enforce_oob(&CorrectOobBackend);
assert!(
report.violations.is_empty(),
"expected no violations, got: {:?}",
report.violations
);
}
}