use vyre::ir::{BufferDecl, DataType as IrDataType, Expr, Node, Program};
use vyre_conform::{
reference::interp,
registry,
spec::value::Value,
types::{DataType as SpecDataType, OpSpec},
};
const WITNESSED_CASES: usize = 10_000;
#[test]
fn reference_run_bit_matches_cpu_fn_for_registered_specs() {
let specs = registry::all_specs();
assert!(
!specs.is_empty(),
"Fix: registry::all_specs must expose the registered OpSpec set"
);
for spec in &specs {
run_exhaustive_u8_when_tractable(spec);
run_witnessed_u32_cases(spec);
}
}
fn run_exhaustive_u8_when_tractable(spec: &OpSpec) {
if !supports_exhaustive_u8(spec) {
return;
}
let arity = spec.signature.inputs.len();
let mut cases = Vec::new();
if arity == 1 {
for a in u8::MIN..=u8::MAX {
cases.push(vec![u32::from(a)]);
}
} else {
for a in u8::MIN..=u8::MAX {
for b in u8::MIN..=u8::MAX {
cases.push(vec![u32::from(a), u32::from(b)]);
}
}
}
assert_fixed_width_cases(spec, &cases, "exhaustive_u8");
}
fn run_witnessed_u32_cases(spec: &OpSpec) {
if is_fixed_width_signature(spec) {
let arity = spec.signature.inputs.len();
let mut seed = seed_from_op_id(spec.id);
let cases = (0..WITNESSED_CASES)
.map(|case| {
(0..arity)
.map(|arg| witnessed_word(&mut seed, case, arg))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
assert_fixed_width_cases(spec, &cases, "witnessed_u32");
return;
}
let mut seed = seed_from_op_id(spec.id);
for case in 0..WITNESSED_CASES {
let input = witnessed_variable_input(spec, &mut seed, case);
assert_single_case(spec, &input, "witnessed_u32", case);
}
}
fn supports_exhaustive_u8(spec: &OpSpec) -> bool {
let arity = spec.signature.inputs.len();
(arity == 1 || arity == 2)
&& spec
.signature
.inputs
.iter()
.all(|ty| matches!(ty, SpecDataType::U32 | SpecDataType::I32))
&& fixed_output_size(&spec.signature.output).is_some()
}
fn is_fixed_width_signature(spec: &OpSpec) -> bool {
spec.signature
.inputs
.iter()
.all(|ty| fixed_input_type(ty).is_some())
&& fixed_output_size(&spec.signature.output).is_some()
}
fn assert_fixed_width_cases(spec: &OpSpec, cases: &[Vec<u32>], label: &str) {
let program = fixed_width_program(spec, cases.len());
let mut inputs = Vec::new();
for (arg_index, ty) in spec.signature.inputs.iter().enumerate() {
let mut bytes = Vec::new();
for case in cases {
append_word_as_type(&mut bytes, case[arg_index], ty);
}
inputs.push(Value::Bytes(bytes));
}
let output_size = fixed_output_size(&spec.signature.output).expect("checked by caller");
inputs.push(Value::Bytes(vec![0; cases.len() * output_size]));
let outputs = interp::run(&program, &inputs)
.unwrap_or_else(|error| panic!("Fix: interp::run failed for {} {label}: {error}", spec.id));
let Value::Bytes(actual) = &outputs[0] else {
panic!("Fix: interp::run returned non-byte output for {}", spec.id);
};
let expected = cases
.iter()
.flat_map(|case| {
let input = case
.iter()
.zip(&spec.signature.inputs)
.flat_map(|(word, ty)| {
let mut bytes = Vec::new();
append_word_as_type(&mut bytes, *word, ty);
bytes
})
.collect::<Vec<_>>();
(spec.cpu_fn)(&input)
})
.collect::<Vec<_>>();
assert_eq!(
actual, &expected,
"Fix: reference interpreter disagrees with cpu_fn for {} on {label}",
spec.id
);
}
fn fixed_width_program(spec: &OpSpec, cases: usize) -> Program {
let gid = Expr::gid_x();
let args = spec
.signature
.inputs
.iter()
.enumerate()
.map(|(index, _)| Expr::load(&format!("arg{index}"), gid.clone()))
.collect::<Vec<_>>();
let mut buffers = spec
.signature
.inputs
.iter()
.enumerate()
.map(|(index, ty)| {
BufferDecl::read(
&format!("arg{index}"),
index as u32,
fixed_input_type(ty).expect("checked by caller"),
)
})
.collect::<Vec<_>>();
buffers.push(BufferDecl::read_write(
"out",
buffers.len() as u32,
fixed_input_type(&spec.signature.output).expect("checked by caller"),
));
Program::new(
buffers,
[64, 1, 1],
vec![Node::if_then(
Expr::lt(gid.clone(), Expr::u32(cases as u32)),
vec![Node::store(
"out",
gid,
Expr::Call {
op_id: spec.id.to_string(),
args,
},
)],
)],
)
}
fn assert_single_case(spec: &OpSpec, input: &[Vec<u8>], label: &str, case: usize) {
let expected = (spec.cpu_fn)(&input.concat());
let program = variable_width_program(spec, expected.len());
let mut values = input.iter().cloned().map(Value::Bytes).collect::<Vec<_>>();
values.push(Value::Bytes(vec![0; expected.len()]));
let outputs = interp::run(&program, &values).unwrap_or_else(|error| {
panic!(
"Fix: interp::run failed for {} {label} case {case}: {error}",
spec.id
)
});
let Value::Bytes(actual) = &outputs[0] else {
panic!("Fix: interp::run returned non-byte output for {}", spec.id);
};
assert_eq!(
actual, &expected,
"Fix: reference interpreter disagrees with cpu_fn for {} on {label} case {case}",
spec.id
);
}
fn variable_width_program(spec: &OpSpec, output_len: usize) -> Program {
let args = spec
.signature
.inputs
.iter()
.enumerate()
.map(|(index, _)| Expr::load(&format!("arg{index}"), Expr::u32(0)))
.collect::<Vec<_>>();
let mut buffers = spec
.signature
.inputs
.iter()
.enumerate()
.map(|(index, ty)| {
BufferDecl::read(
&format!("arg{index}"),
index as u32,
input_type_or_bytes(ty),
)
})
.collect::<Vec<_>>();
buffers.push(BufferDecl::read_write(
"out",
buffers.len() as u32,
output_type_or_bytes(&spec.signature.output, output_len),
));
Program::new(
buffers,
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::Call {
op_id: spec.id.to_string(),
args,
},
)],
)
}
fn witnessed_variable_input(spec: &OpSpec, seed: &mut u32, case: usize) -> Vec<Vec<u8>> {
spec.signature
.inputs
.iter()
.enumerate()
.map(|(arg, ty)| match ty {
SpecDataType::Bytes | SpecDataType::Array { .. } | SpecDataType::Tensor => {
let len = (witnessed_word(seed, case, arg) as usize) % 65;
(0..len)
.map(|byte_index| witnessed_word(seed, case, arg + byte_index) as u8)
.collect()
}
_ => {
let mut bytes = Vec::new();
append_word_as_type(&mut bytes, witnessed_word(seed, case, arg), ty);
bytes
}
})
.collect()
}
fn append_word_as_type(bytes: &mut Vec<u8>, word: u32, ty: &SpecDataType) {
match ty {
SpecDataType::U32 | SpecDataType::I32 | SpecDataType::F32 => {
bytes.extend_from_slice(&word.to_le_bytes());
}
SpecDataType::U64 | SpecDataType::Vec2U32 => {
bytes.extend_from_slice(&word.to_le_bytes());
bytes.extend_from_slice(&word.rotate_left(17).to_le_bytes());
}
SpecDataType::Vec4U32 => {
for lane in 0..4 {
bytes.extend_from_slice(&word.rotate_left(lane * 7).to_le_bytes());
}
}
SpecDataType::F16 | SpecDataType::BF16 => {
bytes.extend_from_slice(&(word as u16).to_le_bytes());
}
SpecDataType::Bytes | SpecDataType::Array { .. } | SpecDataType::Tensor => {}
unsupported => {
panic!("Fix: add reference parity input encoding for unsupported type {unsupported:?}")
}
}
}
fn fixed_input_type(ty: &SpecDataType) -> Option<IrDataType> {
match ty {
SpecDataType::U32 | SpecDataType::F32 => Some(IrDataType::U32),
SpecDataType::I32 => Some(IrDataType::I32),
SpecDataType::U64 => Some(IrDataType::U64),
SpecDataType::Vec2U32 => Some(IrDataType::Vec2U32),
SpecDataType::Vec4U32 => Some(IrDataType::Vec4U32),
SpecDataType::F16
| SpecDataType::BF16
| SpecDataType::Bytes
| SpecDataType::Array { .. }
| SpecDataType::Tensor => None,
unsupported => panic!(
"Fix: add reference parity fixed-width mapping for unsupported type {unsupported:?}"
),
}
}
fn fixed_output_size(ty: &SpecDataType) -> Option<usize> {
fixed_input_type(ty).map(|ir| ir.min_bytes())
}
fn input_type_or_bytes(ty: &SpecDataType) -> IrDataType {
fixed_input_type(ty).unwrap_or(IrDataType::Bytes)
}
fn output_type_or_bytes(ty: &SpecDataType, output_len: usize) -> IrDataType {
if output_len == 0 {
IrDataType::Bytes
} else {
fixed_input_type(ty).unwrap_or(IrDataType::Bytes)
}
}
fn witnessed_word(seed: &mut u32, case: usize, arg: usize) -> u32 {
const BOUNDARIES: [u32; 11] = [
0,
1,
2,
7,
31,
32,
33,
0x7FFF_FFFF,
0x8000_0000,
0xFFFF_FFFE,
u32::MAX,
];
*seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
let boundary = BOUNDARIES[(case + arg) % BOUNDARIES.len()];
if case % 4 == 0 {
boundary
} else {
seed.rotate_left(((case + arg) & 31) as u32) ^ boundary
}
}
fn seed_from_op_id(op_id: &str) -> u32 {
let mut hash = 0x811C_9DC5u32;
for byte in op_id.bytes() {
hash ^= u32::from(byte);
hash = hash.wrapping_mul(0x0100_0193);
}
hash
}