use vyre::ir::{self, BinOp, BufferAccess, DataType, Expr, Node, Program, UnOp};
use vyre::ops::primitive::float::f32_cos::F32Cos;
use vyre::ops::primitive::float::f32_sqrt::F32Sqrt;
use vyre::ops::primitive::{add, and, compare::eq, mul, not, or, popcount, shl, sub, xor};
fn assert_binary_contract(program: &Program) {
assert_eq!(program.buffers.len(), 3);
assert_eq!(program.buffers[0].name, "a");
assert_eq!(program.buffers[0].binding, 0);
assert_eq!(program.buffers[0].access, BufferAccess::ReadOnly);
assert_eq!(program.buffers[1].name, "b");
assert_eq!(program.buffers[1].binding, 1);
assert_eq!(program.buffers[1].access, BufferAccess::ReadOnly);
assert_eq!(program.buffers[2].name, "out");
assert_eq!(program.buffers[2].binding, 2);
assert_eq!(program.buffers[2].access, BufferAccess::ReadWrite);
assert_entry_contract(program);
}
fn assert_unary_f32_contract(program: &Program) {
assert_eq!(program.buffers.len(), 2);
assert_eq!(program.buffers[0].name, "a");
assert_eq!(program.buffers[0].binding, 0);
assert_eq!(program.buffers[0].access, BufferAccess::ReadOnly);
assert_eq!(program.buffers[0].element, DataType::F32);
assert_eq!(program.buffers[1].name, "out");
assert_eq!(program.buffers[1].binding, 1);
assert_eq!(program.buffers[1].access, BufferAccess::ReadWrite);
assert_eq!(program.buffers[1].element, DataType::F32);
assert_entry_contract(program);
}
fn assert_unary_contract(program: &Program) {
assert_eq!(program.buffers.len(), 2);
assert_eq!(program.buffers[0].name, "a");
assert_eq!(program.buffers[0].binding, 0);
assert_eq!(program.buffers[0].access, BufferAccess::ReadOnly);
assert_eq!(program.buffers[1].name, "out");
assert_eq!(program.buffers[1].binding, 1);
assert_eq!(program.buffers[1].access, BufferAccess::ReadWrite);
assert_entry_contract(program);
}
fn assert_entry_contract(program: &Program) -> String {
assert_eq!(program.workgroup_size, [64, 1, 1]);
let errors = ir::validate(program);
assert!(errors.is_empty(), "{errors:?}");
let wgsl = vyre::lower::wgsl::lower(program).expect("primitive program must lower to WGSL");
assert_eq!(program.entry.len(), 2);
assert!(matches!(
&program.entry[0],
Node::Let {
name,
value: Expr::InvocationId { axis: 0 },
} if name == "idx"
));
assert!(matches!(
&program.entry[1],
Node::If {
cond: Expr::BinOp { op: BinOp::Lt, .. },
then,
otherwise,
} if then.len() == 1 && otherwise.is_empty()
));
wgsl
}
fn stored_value(program: &Program) -> &Expr {
let Node::If { then, .. } = &program.entry[1] else {
panic!("second entry node must be bounds check");
};
let Node::Store {
buffer,
index: Expr::Var(index),
value,
} = &then[0]
else {
panic!("bounds check must store one output value");
};
assert_eq!(buffer, "out");
assert_eq!(index, "idx");
value
}
#[test]
fn binary_primitive_programs_have_buffer_entry_and_operator_contracts() {
let cases: Vec<(fn() -> Program, BinOp)> = vec![
(xor::Xor::program, BinOp::BitXor),
(add::Add::program, BinOp::Add),
(shl::Shl::program, BinOp::Shl),
(or::Or::program, BinOp::BitOr),
(and::And::program, BinOp::BitAnd),
(sub::Sub::program, BinOp::Sub),
(mul::Mul::program, BinOp::Mul),
(eq::Eq::program, BinOp::Eq),
];
for (build, expected) in cases {
let program = build();
assert_binary_contract(&program);
assert!(matches!(
stored_value(&program),
Expr::BinOp { op, .. } if *op == expected
));
}
}
#[test]
fn unary_primitive_programs_have_buffer_entry_and_operator_contracts() {
let cases: Vec<(fn() -> Program, UnOp)> = vec![
(not::Not::program, UnOp::BitNot),
(popcount::Popcount::program, UnOp::Popcount),
];
for (build, expected) in cases {
let program = build();
assert_unary_contract(&program);
assert!(matches!(
stored_value(&program),
Expr::UnOp { op, .. } if *op == expected
));
}
}
#[test]
fn add_program_validates_and_lowers_with_plus() {
let program = add::Add::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains(" + "), "add WGSL must contain '+'");
}
#[test]
fn and_program_validates_and_lowers_with_ampersand() {
let program = and::And::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains(" & "), "and WGSL must contain '&'");
}
#[test]
fn eq_program_validates_and_lowers_with_eqeq() {
let program = eq::Eq::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains(" == "), "eq WGSL must contain '=='");
}
#[test]
fn mul_program_validates_and_lowers_with_star() {
let program = mul::Mul::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains(" * "), "mul WGSL must contain '*'");
}
#[test]
fn not_program_validates_and_lowers_with_tilde() {
let program = not::Not::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains("~"), "not WGSL must contain '~'");
}
#[test]
fn or_program_validates_and_lowers_with_pipe() {
let program = or::Or::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains(" | "), "or WGSL must contain '|'");
}
#[test]
fn popcount_program_validates_and_lowers_with_count_one_bits() {
let program = popcount::Popcount::program();
let wgsl = assert_entry_contract(&program);
assert!(
wgsl.contains("countOneBits"),
"popcount WGSL must contain 'countOneBits'"
);
}
#[test]
fn shl_program_validates_and_lowers_with_shl() {
let program = shl::Shl::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains(" << "), "shl WGSL must contain '<<'");
}
#[test]
fn sub_program_validates_and_lowers_with_minus() {
let program = sub::Sub::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains(" - "), "sub WGSL must contain '-'");
}
#[test]
fn xor_program_validates_and_lowers_with_caret() {
let program = xor::Xor::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains(" ^ "), "xor WGSL must contain '^'");
}
#[test]
fn f32_cos_program_has_unary_f32_contract() {
let program = F32Cos::program();
assert_unary_f32_contract(&program);
assert!(matches!(
stored_value(&program),
Expr::UnOp { op: UnOp::Cos, .. }
));
}
#[test]
fn f32_cos_program_validates_and_lowers_with_cos() {
let program = F32Cos::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains("cos("), "f32_cos WGSL must contain 'cos('");
}
#[test]
fn f32_sqrt_program_validates_and_lowers_with_sqrt() {
let program = F32Sqrt::program();
let wgsl = assert_entry_contract(&program);
assert!(wgsl.contains("sqrt("), "f32_sqrt WGSL must contain 'sqrt('");
}