use vyre::error::Error;
use vyre::ir::{
inline_calls, inline_calls_with_resolver, BufferDecl, DataType, Expr, Node, Program,
};
use vyre::lower::wgsl::lower;
fn call(op_id: &str, args: Vec<Expr>) -> Expr {
Expr::Call {
op_id: op_id.to_string(),
args,
}
}
fn custom_resolver(op_id: &str) -> Option<Program> {
match op_id {
"test.add" => Some(binary_program(Expr::add)),
"test.double" => Some(Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::output("out", 1, DataType::U32),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
call(
"test.add",
vec![Expr::load("a", Expr::u32(0)), Expr::load("a", Expr::u32(0))],
),
)],
)),
"test.cycle" => Some(Program::new(
vec![BufferDecl::output("out", 0, DataType::U32)],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
call("test.cycle", Vec::new()),
)],
)),
_ => None,
}
}
fn binary_program(compute: fn(Expr, Expr) -> Expr) -> Program {
Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read("b", 1, DataType::U32),
BufferDecl::output("out", 2, DataType::U32),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
compute(Expr::load("a", Expr::u32(0)), Expr::load("b", Expr::u32(0))),
)],
)
}
fn caller(value: Expr) -> Program {
Program::new(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), value)],
)
}
#[test]
fn single_level_call_inlines_to_hand_written_composition() {
let program = caller(call("test.add", vec![Expr::u32(7), Expr::u32(35)]));
let inlined = inline_calls_with_resolver(&program, custom_resolver).expect("inline add");
assert_no_calls(inlined.entry());
assert_eq!(
last_store_value(&inlined),
Some(&Expr::var("__vyre_inline_0_result"))
);
assert!(contains_assign_value(
inlined.entry(),
&Expr::add(Expr::u32(7), Expr::u32(35))
));
}
#[test]
fn nested_call_inlines_recursively() {
let program = caller(call("test.double", vec![Expr::u32(21)]));
let inlined = inline_calls_with_resolver(&program, custom_resolver).expect("inline nested");
assert_no_calls(inlined.entry());
assert!(contains_assign_value(
inlined.entry(),
&Expr::add(Expr::u32(21), Expr::u32(21))
));
}
#[test]
fn vector_primitive_call_inlines_as_scalar_expression() {
fn vector_add_resolver(op_id: &str) -> Option<Program> {
(op_id == "test.vector_add").then(|| {
Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read("b", 1, DataType::U32),
BufferDecl::output("out", 2, DataType::U32),
],
[64, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(Expr::var("idx"), Expr::buf_len("out")),
vec![Node::store(
"out",
Expr::var("idx"),
Expr::add(
Expr::load("a", Expr::var("idx")),
Expr::load("b", Expr::var("idx")),
),
)],
),
],
)
})
}
let program = caller(call("test.vector_add", vec![Expr::u32(7), Expr::u32(35)]));
let inlined = inline_calls_with_resolver(&program, vector_add_resolver)
.expect("vector primitive call must inline as a scalar expression");
assert_no_calls(inlined.entry());
assert!(contains_assign_value(
inlined.entry(),
&Expr::add(Expr::u32(7), Expr::u32(35))
));
}
#[test]
fn cycle_returns_inline_cycle() {
let program = caller(call("test.cycle", Vec::new()));
let err = inline_calls_with_resolver(&program, custom_resolver).expect_err("cycle must fail");
assert!(matches!(err, Error::InlineCycle { op_id } if op_id == "test.cycle"));
}
#[test]
fn unknown_op_returns_inline_unknown_op() {
let program = caller(call("test.missing", Vec::new()));
let err = inline_calls_with_resolver(&program, custom_resolver).expect_err("unknown must fail");
assert!(matches!(err, Error::InlineUnknownOp { ref op_id } if op_id == "test.missing"));
assert!(err.to_string().contains("Fix: register"));
}
#[test]
fn lowered_hash_fnv1a_composition_has_no_operation_call_symbol() {
let program = Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::read_write("out", 1, DataType::U32),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
call("hash.fnv1a32", Vec::new()),
)],
);
let inlined = inline_calls(&program).expect("hash op must inline");
assert_no_calls(inlined.entry());
let wgsl = lower(&program).expect("inlined hash call must lower");
assert!(!wgsl.contains("hash_fnv1a32("));
assert!(!wgsl.contains("hash.fnv1a32"));
assert!(!wgsl.contains("primitive_"));
}
fn last_store_value(program: &Program) -> Option<&Expr> {
program.entry().iter().rev().find_map(|node| match node {
Node::Store { value, .. } => Some(value),
_ => None,
})
}
fn contains_assign_value(nodes: &[Node], expected: &Expr) -> bool {
nodes.iter().any(|node| match node {
Node::Assign { value, .. } if value == expected => true,
Node::If {
then, otherwise, ..
} => contains_assign_value(then, expected) || contains_assign_value(otherwise, expected),
Node::Loop { body, .. } | Node::Block(body) => contains_assign_value(body, expected),
_ => false,
})
}
fn assert_no_calls(nodes: &[Node]) {
for node in nodes {
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => assert_expr_no_call(value),
Node::Store { index, value, .. } => {
assert_expr_no_call(index);
assert_expr_no_call(value);
}
Node::If {
cond,
then,
otherwise,
} => {
assert_expr_no_call(cond);
assert_no_calls(then);
assert_no_calls(otherwise);
}
Node::Loop { from, to, body, .. } => {
assert_expr_no_call(from);
assert_expr_no_call(to);
assert_no_calls(body);
}
Node::Block(body) => assert_no_calls(body),
Node::Return | Node::Barrier => {}
}
}
}
fn assert_expr_no_call(expr: &Expr) {
match expr {
Expr::Call { op_id, .. } => panic!("residual call {op_id}"),
Expr::Load { index, .. } | Expr::UnOp { operand: index, .. } => assert_expr_no_call(index),
Expr::BinOp { left, right, .. } => {
assert_expr_no_call(left);
assert_expr_no_call(right);
}
Expr::Select {
cond,
true_val,
false_val,
} => {
assert_expr_no_call(cond);
assert_expr_no_call(true_val);
assert_expr_no_call(false_val);
}
Expr::Cast { value, .. } => assert_expr_no_call(value),
Expr::Atomic {
index,
expected,
value,
..
} => {
assert_expr_no_call(index);
if let Some(expected) = expected {
assert_expr_no_call(expected);
}
assert_expr_no_call(value);
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. } => {}
unexpected => panic!("unexpected expression variant after inlining: {unexpected:?}"),
}
}