vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// Zstandard raw and RLE block decompression as a Category A IR composition.

use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec};

pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32];

pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];

pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
    lo: 0,
    hi: u32::MAX,
}];

/// Maximum declared decompressed-output expansion accepted before GPU dispatch.
pub const MAX_OUTPUT_RATIO: u32 = 1024;

/// `WORKGROUP_SIZE` constant.
pub const WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];

pub const LOOP_LIMIT: u32 = 131_072;

pub const BLOCK_RAW: u32 = 0;

pub const BLOCK_RLE: u32 = 1;

pub const BLOCK_COMPRESSED: u32 = 2;

pub const STATUS_INPUT_TRUNCATED: u32 = 1;

pub const STATUS_OUTPUT_OVERFLOW: u32 = 2;

pub const STATUS_COMPRESSED_BLOCK_UNSUPPORTED: u32 = 3;

pub const STATUS_RESERVED_BLOCK_TYPE: u32 = 4;

pub const STATUS_LOOP_LIMIT_EXCEEDED: u32 = 5;

pub const OP_ID: &str = "compression.zstd_decompress";

/// Operation id surfaced when zstd compressed blocks require FSE/Huffman support.
///
/// Raw and RLE zstd blocks are implemented in this composition. Compressed
/// blocks require FSE and Huffman decoders, so this program records a hard
/// failure status for `compression.zstd.compressed_block` instead of silently
/// skipping or falling back to CPU execution.
pub const COMPRESSED_BLOCK_OP: &str = "compression.zstd.compressed_block";

/// Zstandard raw/RLE block decompression operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct ZstdDecompress;

impl ZstdDecompress {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(OP_ID, INPUTS, OUTPUTS, LAWS, Self::program);

    /// Build the canonical IR program.
    #[must_use]
    pub fn program() -> Program {
        zstd_decompress_program()
    }
}

/// Build the canonical zstd raw/RLE block decompression IR program.
#[must_use]
pub fn zstd_decompress_program() -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::U32),
            BufferDecl::output("out", 1, DataType::U32),
            BufferDecl::read("params", 2, DataType::U32),
            BufferDecl::read_write("status", 3, DataType::U32),
        ],
        WORKGROUP_SIZE,
        vec![
            Node::let_bind("ip", Expr::u32(0)),
            Node::let_bind("op", Expr::u32(0)),
            Node::let_bind("is_active", Expr::u32(1)),
            Node::store("status", Expr::u32(0), Expr::u32(0)),
            Node::Loop {
                var: "block".to_string(),
                from: Expr::u32(0),
                to: Expr::u32(LOOP_LIMIT),
                body: vec![Node::if_then(active(), zstd_block_body())],
            },
            Node::if_then(
                and(active(), Expr::lt(Expr::var("ip"), input_len())),
                vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
            ),
            Node::store("status", Expr::u32(1), Expr::var("op")),
        ],
    )
    .with_entry_op_id(OP_ID)
}

pub fn zstd_block_body() -> Vec<Node> {
    vec![
        Node::let_bind("z_b0", Expr::u32(0)),
        Node::let_bind("z_b1", Expr::u32(0)),
        Node::let_bind("z_b2", Expr::u32(0)),
        Node::let_bind("z_header", Expr::u32(0)),
        Node::let_bind("z_last", Expr::u32(0)),
        Node::let_bind("z_block_type", Expr::u32(0)),
        Node::let_bind("z_block_size", Expr::u32(0)),
        read_header(),
        decode_raw_block(),
        decode_rle_block(),
        reject_compressed_block(),
        reject_reserved_block(),
        Node::if_then(
            and(active(), Expr::ne(Expr::var("z_last"), Expr::u32(0))),
            vec![Node::assign("is_active", Expr::u32(0))],
        ),
    ]
}

pub fn read_header() -> Node {
    Node::if_then_else(
        Expr::lt(Expr::add(Expr::var("ip"), Expr::u32(2)), input_len()),
        vec![
            Node::assign("z_b0", read_byte("input", Expr::var("ip"))),
            Node::assign(
                "z_b1",
                read_byte("input", Expr::add(Expr::var("ip"), Expr::u32(1))),
            ),
            Node::assign(
                "z_b2",
                read_byte("input", Expr::add(Expr::var("ip"), Expr::u32(2))),
            ),
            Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(3))),
            Node::assign(
                "z_header",
                Expr::bitor(
                    Expr::bitor(
                        Expr::var("z_b0"),
                        Expr::shl(Expr::var("z_b1"), Expr::u32(8)),
                    ),
                    Expr::shl(Expr::var("z_b2"), Expr::u32(16)),
                ),
            ),
            Node::assign("z_last", Expr::bitand(Expr::var("z_header"), Expr::u32(1))),
            Node::assign(
                "z_block_type",
                Expr::bitand(Expr::shr(Expr::var("z_header"), Expr::u32(1)), Expr::u32(3)),
            ),
            Node::assign(
                "z_block_size",
                Expr::shr(Expr::var("z_header"), Expr::u32(3)),
            ),
        ],
        vec![fail(STATUS_INPUT_TRUNCATED)],
    )
}

pub fn decode_raw_block() -> Node {
    Node::Loop {
        var: "raw_i".to_string(),
        from: Expr::u32(0),
        to: Expr::u32(LOOP_LIMIT),
        body: vec![Node::if_then(
            and(
                and(
                    active(),
                    Expr::eq(Expr::var("z_block_type"), Expr::u32(BLOCK_RAW)),
                ),
                Expr::lt(Expr::var("raw_i"), Expr::var("z_block_size")),
            ),
            vec![Node::if_then_else(
                and(
                    Expr::lt(Expr::var("ip"), input_len()),
                    Expr::lt(Expr::var("op"), output_len()),
                ),
                vec![
                    store_byte("out", Expr::var("op"), read_byte("input", Expr::var("ip"))),
                    Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
                    Node::assign("op", Expr::add(Expr::var("op"), Expr::u32(1))),
                ],
                vec![Node::if_then_else(
                    Expr::lt(Expr::var("ip"), input_len()),
                    vec![fail(STATUS_OUTPUT_OVERFLOW)],
                    vec![fail(STATUS_INPUT_TRUNCATED)],
                )],
            )],
        )],
    }
}

pub fn decode_rle_block() -> Node {
    Node::if_then(
        and(
            active(),
            Expr::eq(Expr::var("z_block_type"), Expr::u32(BLOCK_RLE)),
        ),
        vec![Node::if_then_else(
            Expr::lt(Expr::var("ip"), input_len()),
            vec![
                Node::let_bind("rle_byte", read_byte("input", Expr::var("ip"))),
                Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
                rle_fill_loop(),
            ],
            vec![fail(STATUS_INPUT_TRUNCATED)],
        )],
    )
}

pub fn rle_fill_loop() -> Node {
    Node::Loop {
        var: "rle_i".to_string(),
        from: Expr::u32(0),
        to: Expr::u32(LOOP_LIMIT),
        body: vec![Node::if_then(
            and(
                active(),
                Expr::lt(Expr::var("rle_i"), Expr::var("z_block_size")),
            ),
            vec![Node::if_then_else(
                Expr::lt(Expr::var("op"), output_len()),
                vec![
                    store_byte("out", Expr::var("op"), Expr::var("rle_byte")),
                    Node::assign("op", Expr::add(Expr::var("op"), Expr::u32(1))),
                ],
                vec![fail(STATUS_OUTPUT_OVERFLOW)],
            )],
        )],
    }
}

pub fn reject_compressed_block() -> Node {
    Node::if_then(
        and(
            active(),
            Expr::eq(Expr::var("z_block_type"), Expr::u32(BLOCK_COMPRESSED)),
        ),
        vec![fail(STATUS_COMPRESSED_BLOCK_UNSUPPORTED)],
    )
}

pub fn reject_reserved_block() -> Node {
    Node::if_then(
        and(
            active(),
            Expr::gt(Expr::var("z_block_type"), Expr::u32(BLOCK_COMPRESSED)),
        ),
        vec![fail(STATUS_RESERVED_BLOCK_TYPE)],
    )
}

pub fn read_byte(buffer: &'static str, pos: Expr) -> Expr {
    let word = Expr::load(buffer, Expr::shr(pos.clone(), Expr::u32(2)));
    let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
    Expr::bitand(Expr::shr(word, shift), Expr::u32(0xff))
}

pub fn store_byte(buffer: &'static str, pos: Expr, byte: Expr) -> Node {
    let word_index = Expr::shr(pos.clone(), Expr::u32(2));
    let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
    let mask = Expr::shl(Expr::u32(0xff), shift.clone());
    let cleared = Expr::bitand(
        Expr::load(buffer, word_index.clone()),
        Expr::bitnot(mask.clone()),
    );
    let inserted = Expr::shl(Expr::bitand(byte, Expr::u32(0xff)), shift);
    Node::store(
        buffer,
        word_index,
        Expr::bitor(cleared, Expr::bitand(inserted, mask)),
    )
}

pub fn fail(code: u32) -> Node {
    Node::Block(vec![
        Node::store("status", Expr::u32(0), Expr::u32(code)),
        Node::assign("is_active", Expr::u32(0)),
    ])
}

pub fn input_len() -> Expr {
    Expr::load("params", Expr::u32(0))
}

pub fn output_len() -> Expr {
    Expr::load("params", Expr::u32(1))
}

pub fn active() -> Expr {
    Expr::ne(Expr::var("is_active"), Expr::u32(0))
}

pub fn and(left: Expr, right: Expr) -> Expr {
    Expr::bitand(left, right)
}