use proptest::prelude::*;
use vyre::ir::*;
pub fn arb_data_type() -> impl Strategy<Value = DataType> {
prop_oneof![
Just(DataType::U32),
Just(DataType::I32),
Just(DataType::Bool),
]
}
pub fn arb_buffer_access() -> impl Strategy<Value = BufferAccess> {
prop_oneof![Just(BufferAccess::ReadOnly), Just(BufferAccess::ReadWrite),]
}
pub fn arb_buffer_decl(binding: u32) -> impl Strategy<Value = BufferDecl> {
(arb_buffer_access(), arb_data_type()).prop_map(move |(access, element)| {
let name = format!("buf_{binding}");
BufferDecl::storage(&name, binding, access, element)
})
}
pub fn arb_literal() -> impl Strategy<Value = Expr> {
prop_oneof![
(0u32..=65535).prop_map(Expr::LitU32),
any::<i32>().prop_map(Expr::LitI32),
any::<bool>().prop_map(Expr::LitBool),
]
}
pub fn arb_binop() -> impl Strategy<Value = BinOp> {
prop_oneof![
Just(BinOp::Add),
Just(BinOp::Sub),
Just(BinOp::Mul),
Just(BinOp::BitAnd),
Just(BinOp::BitOr),
Just(BinOp::BitXor),
Just(BinOp::Shl),
Just(BinOp::Shr),
Just(BinOp::Eq),
Just(BinOp::Lt),
]
}
pub fn arb_unop() -> impl Strategy<Value = UnOp> {
prop_oneof![
Just(UnOp::BitNot),
Just(UnOp::Popcount),
Just(UnOp::Clz),
Just(UnOp::Ctz),
Just(UnOp::ReverseBits),
]
}
pub fn arb_simple_expr() -> impl Strategy<Value = Expr> {
prop_oneof![
arb_literal(),
Just(Expr::gid_x()),
Just(Expr::InvocationId { axis: 0 }),
]
}
pub fn arb_expr(depth: u32) -> BoxedStrategy<Expr> {
if depth == 0 {
arb_simple_expr().boxed()
} else {
prop_oneof![
arb_simple_expr(),
(arb_binop(), arb_expr(depth - 1), arb_expr(depth - 1)).prop_map(
|(op, left, right)| Expr::BinOp {
op,
left: Box::new(left),
right: Box::new(right),
}
),
(arb_unop(), arb_expr(depth - 1)).prop_map(|(op, operand)| Expr::UnOp {
op,
operand: Box::new(operand),
}),
]
.boxed()
}
}
pub fn arb_node(has_rw_buffer: bool) -> BoxedStrategy<Node> {
let mut options: Vec<BoxedStrategy<Node>> = vec![
(0u32..8, arb_expr(2))
.prop_map(|(i, value)| Node::let_bind(format!("v{i}"), value))
.boxed(),
];
if has_rw_buffer {
options.push(
arb_expr(1)
.prop_map(|value| Node::store("buf_0", Expr::u32(0), value))
.boxed(),
);
}
options.push(Just(Node::Return).boxed());
proptest::strategy::Union::new(options).boxed()
}
pub fn arb_program() -> impl Strategy<Value = Program> {
let bufs = (1u32..4).prop_flat_map(|count| {
let mut strats: Vec<BoxedStrategy<BufferDecl>> = Vec::new();
strats.push(
arb_data_type()
.prop_map(|element| {
BufferDecl::storage("buf_0", 0, BufferAccess::ReadWrite, element)
})
.boxed(),
);
for i in 1..count {
strats.push(arb_buffer_decl(i).boxed());
}
strats.into_iter().collect::<Vec<_>>().into_iter().fold(
Just(Vec::new()).boxed(),
|acc: BoxedStrategy<Vec<BufferDecl>>, strat| {
(acc, strat)
.prop_map(|(mut v, b)| {
v.push(b);
v
})
.boxed()
},
)
});
let wg = prop_oneof![Just([1u32, 1, 1]), Just([64, 1, 1]), Just([256, 1, 1]),];
let body_len = 1u32..8;
(bufs, wg, body_len).prop_flat_map(|(buffers, workgroup_size, len)| {
let nodes = proptest::collection::vec(arb_node(true), len as usize..=len as usize);
(Just(buffers), Just(workgroup_size), nodes).prop_map(|(buffers, workgroup_size, entry)| {
Program::new(buffers, workgroup_size, entry)
})
})
}