vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Shared test support — proptest strategies for vyre IR.

use proptest::prelude::*;
use vyre::ir::*;

pub fn arb_data_type() -> impl Strategy<Value = DataType> {
    prop_oneof![
        Just(DataType::U32),
        Just(DataType::I32),
        Just(DataType::Bool),
    ]
}

pub fn arb_buffer_access() -> impl Strategy<Value = BufferAccess> {
    prop_oneof![Just(BufferAccess::ReadOnly), Just(BufferAccess::ReadWrite),]
}

pub fn arb_buffer_decl(binding: u32) -> impl Strategy<Value = BufferDecl> {
    (arb_buffer_access(), arb_data_type()).prop_map(move |(access, element)| {
        let name = format!("buf_{binding}");
        BufferDecl::storage(&name, binding, access, element)
    })
}

pub fn arb_literal() -> impl Strategy<Value = Expr> {
    prop_oneof![
        // Keep u32 literals small enough that constant-folded arithmetic
        // (add / mul) cannot overflow u32, avoiding spurious naga compile errors.
        (0u32..=65535).prop_map(Expr::LitU32),
        any::<i32>().prop_map(Expr::LitI32),
        any::<bool>().prop_map(Expr::LitBool),
    ]
}

pub fn arb_binop() -> impl Strategy<Value = BinOp> {
    prop_oneof![
        Just(BinOp::Add),
        Just(BinOp::Sub),
        Just(BinOp::Mul),
        Just(BinOp::BitAnd),
        Just(BinOp::BitOr),
        Just(BinOp::BitXor),
        Just(BinOp::Shl),
        Just(BinOp::Shr),
        Just(BinOp::Eq),
        Just(BinOp::Lt),
    ]
}

pub fn arb_unop() -> impl Strategy<Value = UnOp> {
    prop_oneof![
        Just(UnOp::BitNot),
        Just(UnOp::Popcount),
        Just(UnOp::Clz),
        Just(UnOp::Ctz),
        Just(UnOp::ReverseBits),
    ]
}

/// Generate a simple expression (no recursion, no buffer refs).
pub fn arb_simple_expr() -> impl Strategy<Value = Expr> {
    prop_oneof![
        arb_literal(),
        Just(Expr::gid_x()),
        Just(Expr::InvocationId { axis: 0 }),
    ]
}

/// Generate an expression up to a given depth.
pub fn arb_expr(depth: u32) -> BoxedStrategy<Expr> {
    if depth == 0 {
        arb_simple_expr().boxed()
    } else {
        prop_oneof![
            arb_simple_expr(),
            (arb_binop(), arb_expr(depth - 1), arb_expr(depth - 1)).prop_map(
                |(op, left, right)| Expr::BinOp {
                    op,
                    left: Box::new(left),
                    right: Box::new(right),
                }
            ),
            (arb_unop(), arb_expr(depth - 1)).prop_map(|(op, operand)| Expr::UnOp {
                op,
                operand: Box::new(operand),
            }),
        ]
        .boxed()
    }
}

/// Generate a simple statement node.
pub fn arb_node(has_rw_buffer: bool) -> BoxedStrategy<Node> {
    let mut options: Vec<BoxedStrategy<Node>> = vec![
        // Let binding with random name + expr
        (0u32..8, arb_expr(2))
            .prop_map(|(i, value)| Node::let_bind(format!("v{i}"), value))
            .boxed(),
    ];

    if has_rw_buffer {
        // Store to buffer
        options.push(
            arb_expr(1)
                .prop_map(|value| Node::store("buf_0", Expr::u32(0), value))
                .boxed(),
        );
    }

    options.push(Just(Node::Return).boxed());

    proptest::strategy::Union::new(options).boxed()
}

/// Generate a complete valid Program.
pub fn arb_program() -> impl Strategy<Value = Program> {
    // Always have at least one read-write buffer for stores
    let bufs = (1u32..4).prop_flat_map(|count| {
        let mut strats: Vec<BoxedStrategy<BufferDecl>> = Vec::new();
        // First buffer is always read-write
        strats.push(
            arb_data_type()
                .prop_map(|element| {
                    BufferDecl::storage("buf_0", 0, BufferAccess::ReadWrite, element)
                })
                .boxed(),
        );
        for i in 1..count {
            strats.push(arb_buffer_decl(i).boxed());
        }
        strats.into_iter().collect::<Vec<_>>().into_iter().fold(
            Just(Vec::new()).boxed(),
            |acc: BoxedStrategy<Vec<BufferDecl>>, strat| {
                (acc, strat)
                    .prop_map(|(mut v, b)| {
                        v.push(b);
                        v
                    })
                    .boxed()
            },
        )
    });

    let wg = prop_oneof![Just([1u32, 1, 1]), Just([64, 1, 1]), Just([256, 1, 1]),];

    let body_len = 1u32..8;

    (bufs, wg, body_len).prop_flat_map(|(buffers, workgroup_size, len)| {
        let nodes = proptest::collection::vec(arb_node(true), len as usize..=len as usize);
        (Just(buffers), Just(workgroup_size), nodes).prop_map(|(buffers, workgroup_size, entry)| {
            Program::new(buffers, workgroup_size, entry)
        })
    })
}