use vyre::ir::{BinOp, BufferDecl, DataType, Expr, Node, Program};
const WIRE_FORMAT_EQ_CASES_PER_OP: usize = 128;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WireFormatEquivViolation {
pub op_name: String,
pub input_bytes: Vec<u8>,
pub original_output: Option<String>,
pub roundtripped_output: Option<String>,
pub message: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WireFormatEquivReport {
pub violations: Vec<WireFormatEquivViolation>,
pub tested: usize,
pub passed: usize,
}
struct WireCase {
op_name: &'static str,
binop: BinOp,
}
fn wire_cases() -> [WireCase; 10] {
[
WireCase {
op_name: "Add",
binop: BinOp::Add,
},
WireCase {
op_name: "Sub",
binop: BinOp::Sub,
},
WireCase {
op_name: "Mul",
binop: BinOp::Mul,
},
WireCase {
op_name: "Div",
binop: BinOp::Div,
},
WireCase {
op_name: "BitAnd",
binop: BinOp::BitAnd,
},
WireCase {
op_name: "BitOr",
binop: BinOp::BitOr,
},
WireCase {
op_name: "BitXor",
binop: BinOp::BitXor,
},
WireCase {
op_name: "Eq",
binop: BinOp::Eq,
},
WireCase {
op_name: "Ne",
binop: BinOp::Ne,
},
WireCase {
op_name: "Shl",
binop: BinOp::Shl,
},
]
}
#[inline]
pub(crate) fn enforce_wire_format_equivalence(
backend: &dyn vyre::VyreBackend,
) -> WireFormatEquivReport {
let mut violations = Vec::new();
let mut tested = 0usize;
let mut passed = 0usize;
for (case_index, case) in wire_cases().into_iter().enumerate() {
for witness in 0..WIRE_FORMAT_EQ_CASES_PER_OP {
tested += 1;
let seed = ((case_index as u64) << 32) | witness as u64;
let (a, b) = u32_pair_from_seed(seed);
let input_bytes = u32s_to_bytes([a, b]);
let program = match build_ir_program(case.binop.clone(), a, b) {
Ok(program) => program,
Err(err) => {
violations.push(codec_violation(
&case,
&input_bytes,
format!("failed to build canonical {} program: {err}", case.op_name),
));
continue;
}
};
match check_case(backend, &case, &program, &input_bytes, witness) {
None => passed += 1,
Some(violation) => violations.push(violation),
}
}
}
WireFormatEquivReport {
violations,
tested,
passed,
}
}
fn check_case(
backend: &dyn vyre::VyreBackend,
case: &WireCase,
program: &Program,
input_bytes: &[u8],
witness: usize,
) -> Option<WireFormatEquivViolation> {
let original_bytes = match program.to_wire() {
Ok(bytes) => bytes,
Err(err) => {
return Some(codec_violation(
case,
input_bytes,
format!("to_wire failed for witness {witness}: {err}. Fix: serialize a valid IR Program."),
));
}
};
let roundtripped = match Program::from_wire(&original_bytes) {
Ok(program) => program,
Err(err) => {
return Some(codec_violation(
case,
input_bytes,
format!("from_wire failed for witness {witness}: {err}. Fix: keep Program::to_wire and Program::from_wire compatible."),
));
}
};
if &roundtripped != program {
return Some(codec_violation(
case,
input_bytes,
format!(
"wire round-trip changed IR structure for witness {witness}. Fix: preserve every Program field in the VIR0 codec."
),
));
}
if let Err(err) = roundtripped.to_wire() {
return Some(codec_violation(
case,
input_bytes,
format!("re-encoding round-tripped IR failed for witness {witness}: {err}. Fix: decode only encodable Program values."),
));
}
let original_output = dispatch_exact(backend, program, &[input_bytes.to_vec()], 4);
let roundtripped_output = dispatch_exact(backend, &roundtripped, &[input_bytes.to_vec()], 4);
match (&original_output, &roundtripped_output) {
(Ok(original), Ok(roundtripped)) if original == roundtripped => None,
_ => Some(WireFormatEquivViolation {
op_name: case.op_name.to_string(),
input_bytes: input_bytes.to_vec(),
original_output: original_output.as_ref().ok().map(|bytes| hex_dump(bytes)),
roundtripped_output: roundtripped_output.as_ref().ok().map(|bytes| hex_dump(bytes)),
message: format!(
"wire-format semantic divergence for {} witness {}: original_path={}, roundtripped_path={}. Fix: make Program::from_wire(to_wire(p)) preserve backend-visible semantics.",
case.op_name,
witness,
result_label(&original_output),
result_label(&roundtripped_output),
),
}),
}
}
fn dispatch_exact(
backend: &dyn vyre::VyreBackend,
program: &Program,
inputs: &[Vec<u8>],
output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
let program = program_with_output_size(program, output_size);
let mut outputs = backend.dispatch(&program, inputs, &vyre::DispatchConfig::default())?;
if outputs.is_empty() {
return Err(vyre::BackendError::new(
"backend returned zero output buffers. Fix: return the wire-format probe output as outputs[0].",
));
}
let output = outputs.remove(0);
if output.len() != output_size {
return Err(vyre::BackendError::new(format!(
"backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the wire-format probe output declaration.",
output.len()
)));
}
Ok(output)
}
fn program_with_output_size(program: &Program, output_size: usize) -> Program {
let mut buffers = program.buffers().to_vec();
for buffer in &mut buffers {
if buffer.access == vyre::ir::BufferAccess::ReadWrite {
buffer.is_output = true;
buffer.count = output_size.div_ceil(4).try_into().unwrap_or(u32::MAX);
break;
}
}
Program::new(buffers, program.workgroup_size(), program.entry().to_vec())
}
fn codec_violation(
case: &WireCase,
input_bytes: &[u8],
message: String,
) -> WireFormatEquivViolation {
WireFormatEquivViolation {
op_name: case.op_name.to_string(),
input_bytes: input_bytes.to_vec(),
original_output: None,
roundtripped_output: None,
message,
}
}
fn result_label(result: &Result<Vec<u8>, vyre::BackendError>) -> String {
match result {
Ok(bytes) => format!("ok({})", hex_dump(bytes)),
Err(err) => err.to_string(),
}
}
fn build_ir_program(binop: BinOp, a: u32, b: u32) -> Result<Program, String> {
Ok(Program::new(
vec![BufferDecl::read_write("results", 0, DataType::U32)],
[1, 1, 1],
vec![
Node::store("results", Expr::u32(0), binary_expr(binop, a, b)?),
Node::Return,
],
))
}
fn binary_expr(binop: BinOp, a: u32, b: u32) -> Result<Expr, String> {
let left = Expr::u32(a);
let right = Expr::u32(b);
let expr = match binop {
BinOp::Add => Expr::add(left, right),
BinOp::Sub => Expr::sub(left, right),
BinOp::Mul => Expr::mul(left, right),
BinOp::Div => Expr::div(left, right),
BinOp::BitAnd => Expr::bitand(left, right),
BinOp::BitOr => Expr::bitor(left, right),
BinOp::BitXor => Expr::bitxor(left, right),
BinOp::Eq => Expr::eq(left, right),
BinOp::Ne => Expr::ne(left, right),
BinOp::Shl => Expr::shl(left, right),
_ => {
return Err(format!(
"unsupported canonical binary operation {binop:?}. Fix: keep the wire_format_eq catalog limited to explicitly handled BinOp variants."
));
}
};
Ok(expr)
}
fn u32_pair_from_seed(seed: u64) -> (u32, u32) {
let a = split_mix_u32(seed);
let b = split_mix_u32(seed.wrapping_add(0x9e37_79b9_7f4a_7c15));
(((a & 0x7fff) + 1), ((b & 0x7fff) + 1))
}
fn split_mix_u32(seed: u64) -> u32 {
let mut z = seed.wrapping_add(0x9e37_79b9_7f4a_7c15);
z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
z = z ^ (z >> 31);
((z >> 32) as u32).wrapping_add(z as u32)
}
fn u32s_to_bytes(values: [u32; 2]) -> Vec<u8> {
values
.into_iter()
.flat_map(u32::to_le_bytes)
.collect::<Vec<_>>()
}
fn hex_dump(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for byte in bytes {
use core::fmt::Write;
let _ = write!(&mut out, "{byte:02x}");
}
out
}
pub struct WireFormatEquivalenceEnforcer;
impl crate::enforce::EnforceGate for WireFormatEquivalenceEnforcer {
fn id(&self) -> &'static str {
"wire_format_equivalence"
}
fn name(&self) -> &'static str {
"wire_format_equivalence"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let Some(backend) = ctx.backend else {
return vec![crate::enforce::aggregate_finding(self.id(), vec!["wire_format_equivalence: backend is required. Fix: provide a VyreBackend in EnforceCtx.".to_string()])];
};
let report = enforce_wire_format_equivalence(backend);
let messages = report
.violations
.into_iter()
.map(|violation| violation.message)
.collect::<Vec<_>>();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: WireFormatEquivalenceEnforcer = WireFormatEquivalenceEnforcer;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn canonical_catalog_has_ten_entries() {
assert_eq!(wire_cases().len(), 10);
}
#[test]
fn prng_is_deterministic_and_bounded() {
let first = u32_pair_from_seed(42);
let second = u32_pair_from_seed(42);
assert_eq!(first, second);
assert!((1..=32_768).contains(&first.0));
assert!((1..=32_768).contains(&first.1));
}
#[test]
fn canonical_program_round_trips_structurally() {
let program =
build_ir_program(BinOp::Add, 7, 9).expect("canonical test program should build");
let bytes = program
.to_wire()
.expect("canonical test program should encode");
let decoded = Program::from_wire(&bytes).expect("canonical test program should decode");
assert_eq!(decoded, program);
}
}