vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use vyre::error::Error;
use vyre::ir::{
    inline_calls, inline_calls_with_resolver, BufferDecl, DataType, Expr, Node, Program,
};
use vyre::lower::wgsl::lower;

fn call(op_id: &str, args: Vec<Expr>) -> Expr {
    Expr::Call {
        op_id: op_id.to_string(),
        args,
    }
}

fn custom_resolver(op_id: &str) -> Option<Program> {
    match op_id {
        "test.add" => Some(binary_program(Expr::add)),
        "test.double" => Some(Program::new(
            vec![
                BufferDecl::read("a", 0, DataType::U32),
                BufferDecl::output("out", 1, DataType::U32),
            ],
            [1, 1, 1],
            vec![Node::store(
                "out",
                Expr::u32(0),
                call(
                    "test.add",
                    vec![Expr::load("a", Expr::u32(0)), Expr::load("a", Expr::u32(0))],
                ),
            )],
        )),
        "test.cycle" => Some(Program::new(
            vec![BufferDecl::output("out", 0, DataType::U32)],
            [1, 1, 1],
            vec![Node::store(
                "out",
                Expr::u32(0),
                call("test.cycle", Vec::new()),
            )],
        )),
        _ => None,
    }
}

fn binary_program(compute: fn(Expr, Expr) -> Expr) -> Program {
    Program::new(
        vec![
            BufferDecl::read("a", 0, DataType::U32),
            BufferDecl::read("b", 1, DataType::U32),
            BufferDecl::output("out", 2, DataType::U32),
        ],
        [1, 1, 1],
        vec![Node::store(
            "out",
            Expr::u32(0),
            compute(Expr::load("a", Expr::u32(0)), Expr::load("b", Expr::u32(0))),
        )],
    )
}

fn caller(value: Expr) -> Program {
    Program::new(
        vec![BufferDecl::read_write("out", 0, DataType::U32)],
        [1, 1, 1],
        vec![Node::store("out", Expr::u32(0), value)],
    )
}

#[test]
fn single_level_call_inlines_to_hand_written_composition() {
    let program = caller(call("test.add", vec![Expr::u32(7), Expr::u32(35)]));
    let inlined = inline_calls_with_resolver(&program, custom_resolver).expect("inline add");
    assert_no_calls(inlined.entry());
    assert_eq!(
        last_store_value(&inlined),
        Some(&Expr::var("__vyre_inline_0_result"))
    );
    assert!(contains_assign_value(
        inlined.entry(),
        &Expr::add(Expr::u32(7), Expr::u32(35))
    ));
}

#[test]
fn nested_call_inlines_recursively() {
    let program = caller(call("test.double", vec![Expr::u32(21)]));
    let inlined = inline_calls_with_resolver(&program, custom_resolver).expect("inline nested");
    assert_no_calls(inlined.entry());
    assert!(contains_assign_value(
        inlined.entry(),
        &Expr::add(Expr::u32(21), Expr::u32(21))
    ));
}

#[test]
fn vector_primitive_call_inlines_as_scalar_expression() {
    fn vector_add_resolver(op_id: &str) -> Option<Program> {
        (op_id == "test.vector_add").then(|| {
            Program::new(
                vec![
                    BufferDecl::read("a", 0, DataType::U32),
                    BufferDecl::read("b", 1, DataType::U32),
                    BufferDecl::output("out", 2, DataType::U32),
                ],
                [64, 1, 1],
                vec![
                    Node::let_bind("idx", Expr::gid_x()),
                    Node::if_then(
                        Expr::lt(Expr::var("idx"), Expr::buf_len("out")),
                        vec![Node::store(
                            "out",
                            Expr::var("idx"),
                            Expr::add(
                                Expr::load("a", Expr::var("idx")),
                                Expr::load("b", Expr::var("idx")),
                            ),
                        )],
                    ),
                ],
            )
        })
    }

    let program = caller(call("test.vector_add", vec![Expr::u32(7), Expr::u32(35)]));
    let inlined = inline_calls_with_resolver(&program, vector_add_resolver)
        .expect("vector primitive call must inline as a scalar expression");

    assert_no_calls(inlined.entry());
    assert!(contains_assign_value(
        inlined.entry(),
        &Expr::add(Expr::u32(7), Expr::u32(35))
    ));
}

#[test]
fn cycle_returns_inline_cycle() {
    let program = caller(call("test.cycle", Vec::new()));
    let err = inline_calls_with_resolver(&program, custom_resolver).expect_err("cycle must fail");
    assert!(matches!(err, Error::InlineCycle { op_id } if op_id == "test.cycle"));
}

#[test]
fn unknown_op_returns_inline_unknown_op() {
    let program = caller(call("test.missing", Vec::new()));
    let err = inline_calls_with_resolver(&program, custom_resolver).expect_err("unknown must fail");
    assert!(matches!(err, Error::InlineUnknownOp { ref op_id } if op_id == "test.missing"));
    assert!(err.to_string().contains("Fix: register"));
}

#[test]
fn lowered_hash_fnv1a_composition_has_no_operation_call_symbol() {
    let program = Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::Bytes),
            BufferDecl::read_write("out", 1, DataType::U32),
        ],
        [1, 1, 1],
        vec![Node::store(
            "out",
            Expr::u32(0),
            call("hash.fnv1a32", Vec::new()),
        )],
    );

    let inlined = inline_calls(&program).expect("hash op must inline");
    assert_no_calls(inlined.entry());
    let wgsl = lower(&program).expect("inlined hash call must lower");
    assert!(!wgsl.contains("hash_fnv1a32("));
    assert!(!wgsl.contains("hash.fnv1a32"));
    assert!(!wgsl.contains("primitive_"));
}

fn last_store_value(program: &Program) -> Option<&Expr> {
    program.entry().iter().rev().find_map(|node| match node {
        Node::Store { value, .. } => Some(value),
        _ => None,
    })
}

fn contains_assign_value(nodes: &[Node], expected: &Expr) -> bool {
    nodes.iter().any(|node| match node {
        Node::Assign { value, .. } if value == expected => true,
        Node::If {
            then, otherwise, ..
        } => contains_assign_value(then, expected) || contains_assign_value(otherwise, expected),
        Node::Loop { body, .. } | Node::Block(body) => contains_assign_value(body, expected),
        _ => false,
    })
}

fn assert_no_calls(nodes: &[Node]) {
    for node in nodes {
        match node {
            Node::Let { value, .. } | Node::Assign { value, .. } => assert_expr_no_call(value),
            Node::Store { index, value, .. } => {
                assert_expr_no_call(index);
                assert_expr_no_call(value);
            }
            Node::If {
                cond,
                then,
                otherwise,
            } => {
                assert_expr_no_call(cond);
                assert_no_calls(then);
                assert_no_calls(otherwise);
            }
            Node::Loop { from, to, body, .. } => {
                assert_expr_no_call(from);
                assert_expr_no_call(to);
                assert_no_calls(body);
            }
            Node::Block(body) => assert_no_calls(body),
            Node::Return | Node::Barrier => {}
        }
    }
}

fn assert_expr_no_call(expr: &Expr) {
    match expr {
        Expr::Call { op_id, .. } => panic!("residual call {op_id}"),
        Expr::Load { index, .. } | Expr::UnOp { operand: index, .. } => assert_expr_no_call(index),
        Expr::BinOp { left, right, .. } => {
            assert_expr_no_call(left);
            assert_expr_no_call(right);
        }
        Expr::Select {
            cond,
            true_val,
            false_val,
        } => {
            assert_expr_no_call(cond);
            assert_expr_no_call(true_val);
            assert_expr_no_call(false_val);
        }
        Expr::Cast { value, .. } => assert_expr_no_call(value),
        Expr::Atomic {
            index,
            expected,
            value,
            ..
        } => {
            assert_expr_no_call(index);
            if let Some(expected) = expected {
                assert_expr_no_call(expected);
            }
            assert_expr_no_call(value);
        }
        Expr::LitU32(_)
        | Expr::LitI32(_)
        | Expr::LitBool(_)
        | Expr::Var(_)
        | Expr::BufLen { .. }
        | Expr::InvocationId { .. }
        | Expr::WorkgroupId { .. }
        | Expr::LocalId { .. } => {}
        unexpected => panic!("unexpected expression variant after inlining: {unexpected:?}"),
    }
}