vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use super::{input_byte, probability_bucket, HISTOGRAM_BUCKETS};
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};

/// Build the canonical fixed-point Shannon entropy approximation program.
#[must_use]
pub fn entropy_program() -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::Bytes),
            BufferDecl::output("out", 1, DataType::U32),
            BufferDecl::read_write("histogram", 2, DataType::U32),
            BufferDecl::read("neg_log2_scaled", 3, DataType::U32),
        ],
        [HISTOGRAM_BUCKETS, 1, 1],
        vec![
            Node::let_bind("lane", Expr::LocalId { axis: 0 }),
            Node::if_then(
                Expr::lt(Expr::var("lane"), Expr::u32(HISTOGRAM_BUCKETS)),
                vec![Node::store("histogram", Expr::var("lane"), Expr::u32(0))],
            ),
            Node::if_then(
                Expr::eq(Expr::var("lane"), Expr::u32(0)),
                vec![Node::store("out", Expr::u32(0), Expr::u32(0))],
            ),
            Node::Barrier,
            Node::if_then(
                Expr::eq(Expr::var("lane"), Expr::u32(0)),
                vec![Node::loop_for(
                    "pos",
                    Expr::u32(0),
                    Expr::buf_len("input"),
                    vec![Node::let_bind(
                        "prior_count",
                        Expr::atomic_add("histogram", input_byte(), Expr::u32(1)),
                    )],
                )],
            ),
            Node::Barrier,
            Node::if_then(
                Expr::lt(Expr::var("lane"), Expr::u32(HISTOGRAM_BUCKETS)),
                vec![
                    Node::let_bind("count", Expr::load("histogram", Expr::var("lane"))),
                    Node::if_then(
                        Expr::bitand(
                            Expr::ne(Expr::var("count"), Expr::u32(0)),
                            Expr::ne(Expr::buf_len("input"), Expr::u32(0)),
                        ),
                        vec![
                            Node::let_bind("probability_bucket", probability_bucket()),
                            Node::let_bind(
                                "entropy_part",
                                Expr::div(
                                    Expr::mul(
                                        Expr::var("count"),
                                        Expr::load(
                                            "neg_log2_scaled",
                                            Expr::var("probability_bucket"),
                                        ),
                                    ),
                                    Expr::buf_len("input"),
                                ),
                            ),
                            Node::let_bind(
                                "prior_entropy",
                                Expr::atomic_add("out", Expr::u32(0), Expr::var("entropy_part")),
                            ),
                        ],
                    ),
                ],
            ),
        ],
    )
}