use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
pub const WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];
pub(crate) fn reduce_u32_program(initial: u32, step: fn(Expr, Expr) -> Expr) -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::output("out", 1, DataType::U32),
],
WORKGROUP_SIZE,
vec![
Node::let_bind("acc", Expr::u32(initial)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::buf_len("input"),
vec![Node::assign(
"acc",
step(Expr::var("acc"), Expr::load("input", Expr::var("i"))),
)],
),
Node::if_then(
Expr::lt(Expr::u32(0), Expr::buf_len("out")),
vec![Node::store("out", Expr::u32(0), Expr::var("acc"))],
),
],
)
}
pub(crate) fn reduce_bool_program(initial: bool, step: fn(Expr, Expr) -> Expr) -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bool),
BufferDecl::output("out", 1, DataType::Bool),
],
WORKGROUP_SIZE,
vec![
Node::let_bind("acc", Expr::bool(initial)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::buf_len("input"),
vec![Node::assign(
"acc",
step(Expr::var("acc"), Expr::load("input", Expr::var("i"))),
)],
),
Node::if_then(
Expr::lt(Expr::u32(0), Expr::buf_len("out")),
vec![Node::store("out", Expr::u32(0), Expr::var("acc"))],
),
],
)
}
pub(crate) fn arg_u32_program(is_better: fn(Expr, Expr) -> Expr) -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::output("out", 1, DataType::U32),
],
WORKGROUP_SIZE,
vec![
Node::let_bind("best_index", Expr::u32(u32::MAX)),
Node::let_bind("best_value", Expr::u32(0)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::buf_len("input"),
vec![
Node::let_bind("value", Expr::load("input", Expr::var("i"))),
Node::if_then(
Expr::or(
Expr::eq(Expr::var("best_index"), Expr::u32(u32::MAX)),
is_better(Expr::var("value"), Expr::var("best_value")),
),
vec![
Node::assign("best_value", Expr::var("value")),
Node::assign("best_index", Expr::var("i")),
],
),
],
),
Node::if_then(
Expr::lt(Expr::u32(0), Expr::buf_len("out")),
vec![Node::store("out", Expr::u32(0), Expr::var("best_index"))],
),
],
)
}