vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use vyre::ir::{self, BinOp, BufferAccess, DataType, Expr, Node, Program, UnOp};
use vyre::ops::primitive::float::f32_cos::F32Cos;
use vyre::ops::primitive::float::f32_sqrt::F32Sqrt;
use vyre::ops::primitive::{add, and, compare::eq, mul, not, or, popcount, shl, sub, xor};

fn assert_binary_contract(program: &Program) {
    assert_eq!(program.buffers.len(), 3);
    assert_eq!(program.buffers[0].name, "a");
    assert_eq!(program.buffers[0].binding, 0);
    assert_eq!(program.buffers[0].access, BufferAccess::ReadOnly);
    assert_eq!(program.buffers[1].name, "b");
    assert_eq!(program.buffers[1].binding, 1);
    assert_eq!(program.buffers[1].access, BufferAccess::ReadOnly);
    assert_eq!(program.buffers[2].name, "out");
    assert_eq!(program.buffers[2].binding, 2);
    assert_eq!(program.buffers[2].access, BufferAccess::ReadWrite);
    assert_entry_contract(program);
}

fn assert_unary_f32_contract(program: &Program) {
    assert_eq!(program.buffers.len(), 2);
    assert_eq!(program.buffers[0].name, "a");
    assert_eq!(program.buffers[0].binding, 0);
    assert_eq!(program.buffers[0].access, BufferAccess::ReadOnly);
    assert_eq!(program.buffers[0].element, DataType::F32);
    assert_eq!(program.buffers[1].name, "out");
    assert_eq!(program.buffers[1].binding, 1);
    assert_eq!(program.buffers[1].access, BufferAccess::ReadWrite);
    assert_eq!(program.buffers[1].element, DataType::F32);
    assert_entry_contract(program);
}

fn assert_unary_contract(program: &Program) {
    assert_eq!(program.buffers.len(), 2);
    assert_eq!(program.buffers[0].name, "a");
    assert_eq!(program.buffers[0].binding, 0);
    assert_eq!(program.buffers[0].access, BufferAccess::ReadOnly);
    assert_eq!(program.buffers[1].name, "out");
    assert_eq!(program.buffers[1].binding, 1);
    assert_eq!(program.buffers[1].access, BufferAccess::ReadWrite);
    assert_entry_contract(program);
}

fn assert_entry_contract(program: &Program) -> String {
    assert_eq!(program.workgroup_size, [64, 1, 1]);
    let errors = ir::validate(program);
    assert!(errors.is_empty(), "{errors:?}");
    let wgsl = vyre::lower::wgsl::lower(program).expect("primitive program must lower to WGSL");

    assert_eq!(program.entry.len(), 2);
    assert!(matches!(
        &program.entry[0],
        Node::Let {
            name,
            value: Expr::InvocationId { axis: 0 },
        } if name == "idx"
    ));
    assert!(matches!(
        &program.entry[1],
        Node::If {
            cond: Expr::BinOp { op: BinOp::Lt, .. },
            then,
            otherwise,
        } if then.len() == 1 && otherwise.is_empty()
    ));
    wgsl
}

fn stored_value(program: &Program) -> &Expr {
    let Node::If { then, .. } = &program.entry[1] else {
        panic!("second entry node must be bounds check");
    };
    let Node::Store {
        buffer,
        index: Expr::Var(index),
        value,
    } = &then[0]
    else {
        panic!("bounds check must store one output value");
    };
    assert_eq!(buffer, "out");
    assert_eq!(index, "idx");
    value
}

#[test]
fn binary_primitive_programs_have_buffer_entry_and_operator_contracts() {
    let cases: Vec<(fn() -> Program, BinOp)> = vec![
        (xor::Xor::program, BinOp::BitXor),
        (add::Add::program, BinOp::Add),
        (shl::Shl::program, BinOp::Shl),
        (or::Or::program, BinOp::BitOr),
        (and::And::program, BinOp::BitAnd),
        (sub::Sub::program, BinOp::Sub),
        (mul::Mul::program, BinOp::Mul),
        (eq::Eq::program, BinOp::Eq),
    ];

    for (build, expected) in cases {
        let program = build();
        assert_binary_contract(&program);
        assert!(matches!(
            stored_value(&program),
            Expr::BinOp { op, .. } if *op == expected
        ));
    }
}

#[test]
fn unary_primitive_programs_have_buffer_entry_and_operator_contracts() {
    let cases: Vec<(fn() -> Program, UnOp)> = vec![
        (not::Not::program, UnOp::BitNot),
        (popcount::Popcount::program, UnOp::Popcount),
    ];

    for (build, expected) in cases {
        let program = build();
        assert_unary_contract(&program);
        assert!(matches!(
            stored_value(&program),
            Expr::UnOp { op, .. } if *op == expected
        ));
    }
}

#[test]
fn add_program_validates_and_lowers_with_plus() {
    let program = add::Add::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains(" + "), "add WGSL must contain '+'");
}

#[test]
fn and_program_validates_and_lowers_with_ampersand() {
    let program = and::And::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains(" & "), "and WGSL must contain '&'");
}

#[test]
fn eq_program_validates_and_lowers_with_eqeq() {
    let program = eq::Eq::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains(" == "), "eq WGSL must contain '=='");
}

#[test]
fn mul_program_validates_and_lowers_with_star() {
    let program = mul::Mul::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains(" * "), "mul WGSL must contain '*'");
}

#[test]
fn not_program_validates_and_lowers_with_tilde() {
    let program = not::Not::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains("~"), "not WGSL must contain '~'");
}

#[test]
fn or_program_validates_and_lowers_with_pipe() {
    let program = or::Or::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains(" | "), "or WGSL must contain '|'");
}

#[test]
fn popcount_program_validates_and_lowers_with_count_one_bits() {
    let program = popcount::Popcount::program();
    let wgsl = assert_entry_contract(&program);
    assert!(
        wgsl.contains("countOneBits"),
        "popcount WGSL must contain 'countOneBits'"
    );
}

#[test]
fn shl_program_validates_and_lowers_with_shl() {
    let program = shl::Shl::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains(" << "), "shl WGSL must contain '<<'");
}

#[test]
fn sub_program_validates_and_lowers_with_minus() {
    let program = sub::Sub::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains(" - "), "sub WGSL must contain '-'");
}

#[test]
fn xor_program_validates_and_lowers_with_caret() {
    let program = xor::Xor::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains(" ^ "), "xor WGSL must contain '^'");
}

#[test]
fn f32_cos_program_has_unary_f32_contract() {
    let program = F32Cos::program();
    assert_unary_f32_contract(&program);
    assert!(matches!(
        stored_value(&program),
        Expr::UnOp { op: UnOp::Cos, .. }
    ));
}

#[test]
fn f32_cos_program_validates_and_lowers_with_cos() {
    let program = F32Cos::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains("cos("), "f32_cos WGSL must contain 'cos('");
}

#[test]
fn f32_sqrt_program_validates_and_lowers_with_sqrt() {
    let program = F32Sqrt::program();
    let wgsl = assert_entry_contract(&program);
    assert!(wgsl.contains("sqrt("), "f32_sqrt WGSL must contain 'sqrt('");
}