use vyre::ir::{BufferDecl, DataType, Expr, Node, Program};
use vyre::validate;
fn main() {
let unknown_buffer_program = Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read_write("out", 1, DataType::U32),
],
[64, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(Expr::var("idx"), Expr::buf_len("out")),
vec![Node::store(
"out",
Expr::var("idx"),
Expr::load("ghost", Expr::var("idx")),
)],
),
],
);
let errors = validate(&unknown_buffer_program);
println!("Errors for unknown buffer load:");
for error in &errors {
println!(" - {}", error.message());
}
assert!(
!errors.is_empty(),
"expected at least one validation error for an unknown buffer"
);
let first = errors[0].message();
assert!(
first.contains("load from unknown buffer `ghost`"),
"expected 'load from unknown buffer' error, got: {first}"
);
assert!(
first.contains("Fix: declare it in Program::buffers."),
"expected a 'Fix:' suggestion, got: {first}"
);
let type_mismatch_program = Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read("b", 1, DataType::Bool),
BufferDecl::read_write("out", 2, DataType::U32),
],
[64, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(Expr::var("idx"), Expr::buf_len("out")),
vec![Node::store(
"out",
Expr::var("idx"),
Expr::bitxor(
Expr::load("a", Expr::var("idx")),
Expr::load("b", Expr::var("idx")),
),
)],
),
],
);
let errors = validate(&type_mismatch_program);
println!("\nErrors for type mismatch:");
for error in &errors {
println!(" - {}", error.message());
}
assert!(
errors.iter().any(|e| {
let m = e.message();
m.contains("binary operation") && m.contains("got `bool`")
}),
"expected a type mismatch error about `bool`, got: {errors:?}"
);
let lower_result = vyre::lower::wgsl::lower(&type_mismatch_program);
match lower_result {
Ok(wgsl) => {
println!(
"\nFINDING-EXAMPLES-003: lower() succeeded on invalid input ({} bytes). \
The lowerer should return Err for invalid programs but currently does not.",
wgsl.len()
);
}
Err(error) => {
println!("\nLower refused invalid program: {error}");
}
}
println!("\nAll error-message assertions passed.");
}