use vyre::ir::{AtomicOp, BinOp, BufferDecl, DataType, Expr, Node, Program};
use vyre_conform::{reference::interp, spec::value::Value, specs::primitive::add};
fn bytes(words: &[u32]) -> Vec<u8> {
words.iter().flat_map(|word| word.to_le_bytes()).collect()
}
fn words(bytes: &[u8]) -> Vec<u32> {
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
fn output_bytes(value: &Value) -> &[u8] {
let Value::Bytes(bytes) = value else {
panic!("reference output must be raw bytes");
};
bytes
}
#[test]
fn reference_smoke_add_matches_cpu_reference() {
let program = Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read("b", 1, DataType::U32),
BufferDecl::read_write("out", 2, DataType::U32),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::BinOp {
op: vyre::ir::BinOp::Add,
left: Box::new(Expr::load("a", Expr::u32(0))),
right: Box::new(Expr::load("b", Expr::u32(0))),
},
)],
);
let input = bytes(&[0xFFFF_FFFF, 2]);
let expected = (add::spec().cpu_fn)(&input);
let outputs = interp::run(
&program,
&[
Value::Bytes(bytes(&[0xFFFF_FFFF])),
Value::Bytes(bytes(&[2])),
Value::Bytes(bytes(&[0])),
],
)
.expect("reference interpreter should run add smoke program");
assert_eq!(outputs, vec![Value::Bytes(expected)]);
}
#[test]
fn overflowing_workgroup_size_is_structured_error() {
let program = Program::new(vec![], [65_536, 65_536, 1], vec![]);
let error = interp::run(&program, &[]).expect_err("overflowing workgroup must be rejected");
let message = error.to_string();
assert!(message.contains("overflows u32 invocation count"));
assert!(message.contains("Fix:"));
}
#[test]
fn oob_atomic_returns_declared_element_zero() {
let program = Program::new(
vec![
BufferDecl::read_write("empty", 0, DataType::U64),
BufferDecl::read_write("out", 1, DataType::Bytes),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::Cast {
target: DataType::Bytes,
value: Box::new(Expr::Atomic {
op: AtomicOp::Exchange,
buffer: "empty".to_string(),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::u32(1)),
}),
},
)],
);
let outputs = interp::run(
&program,
&[Value::Bytes(vec![]), Value::Bytes(vec![0xAA; 8])],
)
.expect("OOB u64 atomic should return an eight-byte zero payload");
assert_eq!(output_bytes(&outputs[0]), &[0; 8]);
}
#[test]
fn logical_ops_short_circuit_side_effecting_rhs() {
let program = Program::new(
vec![
BufferDecl::read_write("counter", 0, DataType::U32),
BufferDecl::read_write("out", 1, DataType::U32),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::BinOp {
op: BinOp::And,
left: Box::new(Expr::u32(0)),
right: Box::new(Expr::Atomic {
op: AtomicOp::Exchange,
buffer: "counter".to_string(),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::u32(99)),
}),
},
)],
);
let outputs = interp::run(
&program,
&[Value::Bytes(bytes(&[7])), Value::Bytes(bytes(&[0]))],
)
.expect("short-circuited RHS must not mutate counter");
assert_eq!(words(output_bytes(&outputs[0])), vec![7]);
assert_eq!(words(output_bytes(&outputs[1])), vec![0]);
}
#[test]
fn bytes_load_with_zero_stride_reads_buffer_contents() {
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::read_write("out", 1, DataType::Bytes),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::load("input", Expr::u32(0)),
)],
);
let outputs = interp::run(
&program,
&[Value::Bytes(vec![1, 2, 3]), Value::Bytes(vec![0xAA; 3])],
)
.expect("stride-zero byte load should be accessible");
assert_eq!(output_bytes(&outputs[0]), &[1, 2, 3]);
}
#[test]
fn workgroup_bytes_allocate_declared_count() {
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::read_write("out", 1, DataType::Bytes),
BufferDecl::workgroup("scratch", 3, DataType::Bytes),
],
[1, 1, 1],
vec![
Node::store("scratch", Expr::u32(0), Expr::load("input", Expr::u32(0))),
Node::Barrier,
Node::store("out", Expr::u32(0), Expr::load("scratch", Expr::u32(0))),
],
);
let outputs = interp::run(
&program,
&[Value::Bytes(vec![4, 5, 6]), Value::Bytes(vec![0xAA; 3])],
)
.expect("workgroup Bytes storage should allocate the declared count");
assert_eq!(output_bytes(&outputs[0]), &[4, 5, 6]);
}
#[test]
fn oob_atomic_short_circuits_nested_value_operand() {
let program = Program::new(
vec![
BufferDecl::read_write("empty", 0, DataType::U32),
BufferDecl::read_write("other", 1, DataType::U32),
],
[1, 1, 1],
vec![Node::let_bind(
"ignored",
Expr::Atomic {
op: AtomicOp::Add,
buffer: "empty".to_string(),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::Atomic {
op: AtomicOp::Add,
buffer: "other".to_string(),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::u32(1)),
}),
},
)],
);
let outputs = interp::run(&program, &[Value::Bytes(vec![]), Value::Bytes(bytes(&[7]))])
.expect("OOB outer atomic should not evaluate nested value operand");
assert_eq!(output_bytes(&outputs[0]), &[] as &[u8]);
assert_eq!(words(output_bytes(&outputs[1])), vec![7]);
}