use vyre::ir::{AtomicOp, BinOp, BufferDecl, DataType, Expr, Node, Program};
use vyre_conform::{reference::interp, spec::value::Value};
fn bytes(words: &[u32]) -> Vec<u8> {
words.iter().flat_map(|word| word.to_le_bytes()).collect()
}
fn words(bytes: &[u8]) -> Vec<u32> {
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
fn output_bytes(value: &Value) -> &[u8] {
let Value::Bytes(bytes) = value else {
panic!("reference output must be raw bytes");
};
bytes
}
#[test]
fn binop_i32_negative_adds_without_u32_coercion() {
let program = Program::new(
vec![BufferDecl::read_write("out", 0, DataType::I32)],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::add(Expr::i32(-1), Expr::i32(1)),
)],
);
let outputs = interp::run(&program, &[Value::Bytes(vec![0; 4])])
.expect("i32 addition must not coerce through u32");
assert_eq!(output_bytes(&outputs[0]), 0i32.to_le_bytes());
}
#[test]
fn binop_u64_preserves_upper_word() {
let mut input = 0x1_0000_0000u64.to_le_bytes().to_vec();
input.extend_from_slice(&1u64.to_le_bytes());
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::U64),
BufferDecl::read_write("out", 1, DataType::U64),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::add(
Expr::load("input", Expr::u32(0)),
Expr::load("input", Expr::u32(1)),
),
)],
);
let outputs = interp::run(&program, &[Value::Bytes(input), Value::Bytes(vec![0; 8])])
.expect("u64 addition must preserve all 64 bits");
assert_eq!(output_bytes(&outputs[0]), 0x1_0000_0001u64.to_le_bytes());
}
#[test]
fn unop_negate_i32_min_wraps_without_panic() {
let program = Program::new(
vec![BufferDecl::read_write("out", 0, DataType::I32)],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::negate(Expr::i32(i32::MIN)),
)],
);
let outputs = interp::run(&program, &[Value::Bytes(vec![0; 4])])
.expect("i32::MIN negation must use wrapping semantics");
assert_eq!(output_bytes(&outputs[0]), i32::MIN.to_le_bytes());
}
#[test]
fn executes_vector_add_with_oob_invocations_ignored_by_program_guard() {
let idx = Expr::var("idx");
let program = Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read("b", 1, DataType::U32),
BufferDecl::read_write("out", 2, DataType::U32),
],
[4, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![Node::store(
"out",
idx.clone(),
Expr::add(Expr::load("a", idx.clone()), Expr::load("b", idx)),
)],
),
],
);
let outputs = interp::run(
&program,
&[
Value::Bytes(bytes(&[1, 2, 3, 4, 5, 6])),
Value::Bytes(bytes(&[10, 20, 30, 40, 50, 60])),
Value::Bytes(bytes(&[0; 6])),
],
)
.expect("interpreter should execute vector add");
assert_eq!(
words(output_bytes(&outputs[0])),
vec![11, 22, 33, 44, 55, 66]
);
}
#[test]
fn enforces_oob_load_store_and_atomic_contract() {
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::read_write("out", 1, DataType::U32),
],
[1, 1, 1],
vec![
Node::store("out", Expr::u32(0), Expr::load("input", Expr::u32(99))),
Node::store("out", Expr::u32(99), Expr::u32(123)),
Node::store(
"out",
Expr::u32(1),
Expr::Atomic {
op: AtomicOp::Add,
buffer: "out".to_string(),
index: Box::new(Expr::u32(99)),
expected: None,
value: Box::new(Expr::u32(1)),
},
),
],
);
let outputs = interp::run(
&program,
&[Value::Bytes(bytes(&[7])), Value::Bytes(bytes(&[55, 66]))],
)
.expect("OOB operations are defined");
assert_eq!(words(output_bytes(&outputs[0])), vec![0, 0]);
}
#[test]
fn atomics_are_sequentially_consistent_in_round_robin_order() {
let gid = Expr::gid_x();
let program = Program::new(
vec![
BufferDecl::read_write("counter", 0, DataType::U32),
BufferDecl::read_write("out", 1, DataType::U32),
],
[4, 1, 1],
vec![Node::store(
"out",
gid,
Expr::atomic_add("counter", Expr::u32(0), Expr::u32(1)),
)],
);
let outputs = interp::run(
&program,
&[Value::Bytes(bytes(&[0])), Value::Bytes(bytes(&[0; 4]))],
)
.expect("atomic add should execute");
assert_eq!(words(output_bytes(&outputs[0])), vec![4]);
assert_eq!(words(output_bytes(&outputs[1])), vec![0, 1, 2, 3]);
}
#[test]
fn barrier_synchronizes_workgroup_memory_per_workgroup() {
let local = Expr::LocalId { axis: 0 };
let gid = Expr::gid_x();
let peer = Expr::BinOp {
op: BinOp::BitXor,
left: Box::new(local.clone()),
right: Box::new(Expr::u32(1)),
};
let program = Program::new(
vec![
BufferDecl::read_write("out", 0, DataType::U32),
BufferDecl::workgroup("scratch", 2, DataType::U32),
],
[2, 1, 1],
vec![
Node::store("scratch", local, gid.clone()),
Node::Barrier,
Node::store("out", gid, Expr::load("scratch", peer)),
],
);
let outputs = interp::run(&program, &[Value::Bytes(bytes(&[0; 4]))])
.expect("barrier should synchronize local workgroup memory");
assert_eq!(words(output_bytes(&outputs[0])), vec![1, 0, 3, 2]);
}
#[test]
fn detects_non_uniform_barrier_condition() {
let program = Program::new(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[2, 1, 1],
vec![Node::if_then(
Expr::eq(Expr::LocalId { axis: 0 }, Expr::u32(0)),
vec![
Node::Barrier,
Node::store("out", Expr::u32(0), Expr::u32(1)),
],
)],
);
let error = interp::run(&program, &[Value::Bytes(bytes(&[0; 2]))])
.expect_err("divergent barrier must fail");
let message = error.to_string();
assert!(message.contains("program violates uniform-control-flow rule"));
assert!(message.contains("Fix:"));
}
#[test]
fn completed_invocations_do_not_poison_later_barrier_uniformity() {
let program = Program::new(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[2, 1, 1],
vec![Node::if_then(
Expr::eq(Expr::LocalId { axis: 0 }, Expr::u32(1)),
vec![
Node::Barrier,
Node::store("out", Expr::LocalId { axis: 0 }, Expr::u32(9)),
],
)],
);
let outputs = interp::run(&program, &[Value::Bytes(bytes(&[0, 0]))])
.expect("finished invocation checks must not poison the live barrier");
assert_eq!(words(output_bytes(&outputs[0])), vec![0, 9]);
}
#[test]
fn loop_iteration_uniform_checks_are_cleared_at_barriers() {
let program = Program::new(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(2),
vec![Node::if_then_else(
Expr::eq(Expr::var("i"), Expr::u32(0)),
vec![Node::Barrier],
vec![Node::Barrier],
)],
),
Node::store("out", Expr::u32(0), Expr::u32(7)),
],
);
let outputs = interp::run(&program, &[Value::Bytes(bytes(&[0]))])
.expect("uniform loop barriers must not retain prior iteration checks");
assert_eq!(words(output_bytes(&outputs[0])), vec![7]);
}
#[test]
fn cast_to_u64_preserves_upper_word() {
let input = 0x1234_5678_9ABC_DEF0u64.to_le_bytes().to_vec();
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::U64),
BufferDecl::read_write("out", 1, DataType::U64),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::Cast {
target: DataType::U64,
value: Box::new(Expr::load("input", Expr::u32(0))),
},
)],
);
let outputs = interp::run(
&program,
&[Value::Bytes(input.clone()), Value::Bytes(vec![0; 8])],
)
.expect("u64 cast should preserve the full payload");
assert_eq!(output_bytes(&outputs[0]), input.as_slice());
}
#[test]
fn cast_to_vec2_preserves_all_lanes() {
let input = bytes(&[0x1111_2222, 0x3333_4444]);
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::Vec2U32),
BufferDecl::read_write("out", 1, DataType::Vec2U32),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::Cast {
target: DataType::Vec2U32,
value: Box::new(Expr::load("input", Expr::u32(0))),
},
)],
);
let outputs = interp::run(
&program,
&[Value::Bytes(input.clone()), Value::Bytes(vec![0; 8])],
)
.expect("vec2 cast should preserve both lanes");
assert_eq!(output_bytes(&outputs[0]), input.as_slice());
}
#[test]
fn cast_to_vec4_preserves_all_lanes() {
let input = bytes(&[1, 2, 3, 4]);
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::Vec4U32),
BufferDecl::read_write("out", 1, DataType::Vec4U32),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::Cast {
target: DataType::Vec4U32,
value: Box::new(Expr::load("input", Expr::u32(0))),
},
)],
);
let outputs = interp::run(
&program,
&[Value::Bytes(input.clone()), Value::Bytes(vec![0; 16])],
)
.expect("vec4 cast should preserve all lanes");
assert_eq!(output_bytes(&outputs[0]), input.as_slice());
}
#[test]
fn eval_call_passes_full_bytes_to_cpu_reference() {
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::read_write("out", 1, DataType::Bytes),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::Call {
op_id: "primitive.encoding.hex".to_string(),
args: vec![Expr::load("input", Expr::u32(0))],
},
)],
);
let outputs = interp::run(
&program,
&[
Value::Bytes(b"414243".to_vec()),
Value::Bytes(vec![0xAA; 3]),
],
)
.expect("primitive.encoding.hex call should receive the full byte buffer");
assert_eq!(output_bytes(&outputs[0]), b"ABC");
}
#[test]
fn atomic_index_uses_declared_element_stride() {
let mut input = 0xAAAA_BBBB_CCCC_DDDDu64.to_le_bytes().to_vec();
input.extend_from_slice(&5u64.to_le_bytes());
let mut expected = 0xAAAA_BBBB_CCCC_DDDDu64.to_le_bytes().to_vec();
expected.extend_from_slice(&6u64.to_le_bytes());
let program = Program::new(
vec![BufferDecl::read_write("wide", 0, DataType::U64)],
[1, 1, 1],
vec![Node::let_bind(
"old",
Expr::Atomic {
op: AtomicOp::Add,
buffer: "wide".to_string(),
index: Box::new(Expr::u32(1)),
expected: None,
value: Box::new(Expr::u32(1)),
},
)],
);
let outputs = interp::run(&program, &[Value::Bytes(input)])
.expect("atomic add on second u64 element should execute");
assert_eq!(output_bytes(&outputs[0]), expected.as_slice());
}
#[test]
fn buflen_counts_trailing_partial_element() {
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::read_write("out", 1, DataType::U32),
],
[1, 1, 1],
vec![
Node::store("out", Expr::u32(0), Expr::buf_len("input")),
Node::store("out", Expr::u32(1), Expr::load("input", Expr::u32(1))),
],
);
let outputs = interp::run(
&program,
&[
Value::Bytes(vec![1, 2, 3, 4, 5]),
Value::Bytes(bytes(&[0, 0])),
],
)
.expect("partial trailing element should count toward BufLen");
assert_eq!(words(output_bytes(&outputs[0])), vec![2, 0]);
}