vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// Shared IR builders for byte-scale reduction operations.

use crate::ir::{BufferDecl, DataType, Expr, Node, Program};


/// `WORKGROUP_SIZE` constant.
pub const WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];

pub(crate) fn reduce_u32_program(initial: u32, step: fn(Expr, Expr) -> Expr) -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::U32),
            BufferDecl::output("out", 1, DataType::U32),
        ],
        WORKGROUP_SIZE,
        vec![
            Node::let_bind("acc", Expr::u32(initial)),
            Node::loop_for(
                "i",
                Expr::u32(0),
                Expr::buf_len("input"),
                vec![Node::assign(
                    "acc",
                    step(Expr::var("acc"), Expr::load("input", Expr::var("i"))),
                )],
            ),
            Node::if_then(
                Expr::lt(Expr::u32(0), Expr::buf_len("out")),
                vec![Node::store("out", Expr::u32(0), Expr::var("acc"))],
            ),
        ],
    )
}

pub(crate) fn reduce_bool_program(initial: bool, step: fn(Expr, Expr) -> Expr) -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::Bool),
            BufferDecl::output("out", 1, DataType::Bool),
        ],
        WORKGROUP_SIZE,
        vec![
            Node::let_bind("acc", Expr::bool(initial)),
            Node::loop_for(
                "i",
                Expr::u32(0),
                Expr::buf_len("input"),
                vec![Node::assign(
                    "acc",
                    step(Expr::var("acc"), Expr::load("input", Expr::var("i"))),
                )],
            ),
            Node::if_then(
                Expr::lt(Expr::u32(0), Expr::buf_len("out")),
                vec![Node::store("out", Expr::u32(0), Expr::var("acc"))],
            ),
        ],
    )
}

pub(crate) fn arg_u32_program(is_better: fn(Expr, Expr) -> Expr) -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::U32),
            BufferDecl::output("out", 1, DataType::U32),
        ],
        WORKGROUP_SIZE,
        vec![
            Node::let_bind("best_index", Expr::u32(u32::MAX)),
            Node::let_bind("best_value", Expr::u32(0)),
            Node::loop_for(
                "i",
                Expr::u32(0),
                Expr::buf_len("input"),
                vec![
                    Node::let_bind("value", Expr::load("input", Expr::var("i"))),
                    Node::if_then(
                        Expr::or(
                            Expr::eq(Expr::var("best_index"), Expr::u32(u32::MAX)),
                            is_better(Expr::var("value"), Expr::var("best_value")),
                        ),
                        vec![
                            Node::assign("best_value", Expr::var("value")),
                            Node::assign("best_index", Expr::var("i")),
                        ],
                    ),
                ],
            ),
            Node::if_then(
                Expr::lt(Expr::u32(0), Expr::buf_len("out")),
                vec![Node::store("out", Expr::u32(0), Expr::var("best_index"))],
            ),
        ],
    )
}