use vyre::ir::{BinOp, BufferDecl, DataType, Expr, Node, Program};
use vyre_conform::{reference::interp, spec::value::Value, specs::primitive};
const CASES: usize = 10_000;
#[test]
fn reference_agrees_with_every_primitive_cpu_fn_for_witnessed_u32_inputs() {
for spec in primitive::specs() {
let arity = spec.signature.inputs.len();
assert!(
arity == 1 || arity == 2,
"primitive agreement harness supports unary/binary u32 specs; {} has arity {arity}",
spec.id
);
let input_words = witnessed_words(spec.id, arity, CASES);
let input_bytes = words_to_bytes(&input_words);
let output_init = vec![0u8; CASES * 4];
let program = primitive_program(spec.id, arity);
let outputs = interp::run(
&program,
&[Value::Bytes(input_bytes.clone()), Value::Bytes(output_init)],
)
.unwrap_or_else(|error| panic!("reference interpreter failed for {}: {error}", spec.id));
let Value::Bytes(actual) = &outputs[0] else {
panic!("reference output for {} was not bytes", spec.id);
};
for case in 0..CASES {
let start = case * arity * 4;
let end = start + arity * 4;
let expected = (spec.cpu_fn)(&input_bytes[start..end]);
let actual = &actual[case * 4..case * 4 + 4];
assert_eq!(
actual,
expected.as_slice(),
"reference interpreter disagrees with cpu_fn for {}\ncase: {case}\ninput_words: {:?}\ninput_bytes: {:02x?}\nexpected: {:02x?}\nactual: {:02x?}",
spec.id,
&input_words[case * arity..case * arity + arity],
&input_bytes[start..end],
expected,
actual
);
}
}
}
fn primitive_program(op_id: &'static str, arity: usize) -> Program {
let idx = Expr::gid_x();
let base = if arity == 1 {
idx.clone()
} else {
Expr::BinOp {
op: BinOp::Mul,
left: Box::new(idx.clone()),
right: Box::new(Expr::u32(arity as u32)),
}
};
let args = (0..arity)
.map(|offset| {
let index = if offset == 0 {
base.clone()
} else {
Expr::BinOp {
op: BinOp::Add,
left: Box::new(base.clone()),
right: Box::new(Expr::u32(offset as u32)),
}
};
Expr::load("input", index)
})
.collect();
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::read_write("out", 1, DataType::U32),
],
[64, 1, 1],
vec![Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![Node::store(
"out",
idx,
Expr::Call {
op_id: op_id.to_string(),
args,
},
)],
)],
)
}
fn witnessed_words(op_id: &str, arity: usize, cases: usize) -> Vec<u32> {
let mut seed = fnv1a32(op_id.as_bytes()) ^ 0xA53C_9E17;
let mut words = Vec::with_capacity(arity * cases);
let boundaries = [0, 1, 2, 31, 32, 33, 0x7FFF_FFFF, 0x8000_0000, u32::MAX];
for case in 0..cases {
for arg in 0..arity {
let boundary = boundaries[(case + arg) % boundaries.len()];
seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
let mixed = seed.rotate_left(((case + arg) & 31) as u32);
words.push(if case % 5 == 0 {
boundary
} else {
mixed ^ boundary
});
}
}
words
}
fn words_to_bytes(words: &[u32]) -> Vec<u8> {
words.iter().flat_map(|word| word.to_le_bytes()).collect()
}
fn fnv1a32(bytes: &[u8]) -> u32 {
let mut hash = 0x811C_9DC5u32;
for byte in bytes {
hash ^= u32::from(*byte);
hash = hash.wrapping_mul(0x0100_0193);
}
hash
}