vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// LZ4 block decompression as a Category A IR composition.

use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec};

pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32];

pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];

pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
    lo: 0,
    hi: u32::MAX,
}];

/// Maximum declared decompressed-output expansion accepted before GPU dispatch.
pub const MAX_OUTPUT_RATIO: u32 = 1024;

/// `WORKGROUP_SIZE` constant.
pub const WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];

pub const LOOP_LIMIT: u32 = 65_536;

pub const STATUS_INPUT_TRUNCATED: u32 = 1;

pub const STATUS_OUTPUT_OVERFLOW: u32 = 2;

pub const STATUS_INVALID_OFFSET: u32 = 3;

pub const STATUS_LOOP_LIMIT_EXCEEDED: u32 = 4;

pub const OP_ID: &str = "compression.lz4_decompress";

/// LZ4 block decompression operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Lz4Decompress;

impl Lz4Decompress {
    /// Declarative operation specification.
    pub const SPEC: OpSpec = OpSpec::composition(OP_ID, INPUTS, OUTPUTS, LAWS, Self::program);

    /// Build the canonical IR program.
    #[must_use]
    pub fn program() -> Program {
        lz4_decompress_program()
    }
}

/// Build the canonical LZ4 block decompression IR program.
#[must_use]
pub fn lz4_decompress_program() -> Program {
    Program::new(
        vec![
            BufferDecl::read("input", 0, DataType::U32),
            BufferDecl::output("out", 1, DataType::U32),
            BufferDecl::read("params", 2, DataType::U32),
            BufferDecl::read_write("status", 3, DataType::U32),
        ],
        WORKGROUP_SIZE,
        vec![
            Node::let_bind("ip", Expr::u32(0)),
            Node::let_bind("op", Expr::u32(0)),
            Node::let_bind("is_active", Expr::u32(1)),
            Node::store("status", Expr::u32(0), Expr::u32(0)),
            Node::Loop {
                var: "sequence".to_string(),
                from: Expr::u32(0),
                to: Expr::u32(LOOP_LIMIT),
                body: vec![Node::if_then(active(), lz4_sequence_body())],
            },
            Node::if_then(
                and(active(), Expr::lt(Expr::var("ip"), input_len())),
                vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
            ),
            Node::store("status", Expr::u32(1), Expr::var("op")),
        ],
    )
    .with_entry_op_id(OP_ID)
}

pub fn lz4_sequence_body() -> Vec<Node> {
    vec![Node::if_then_else(
        Expr::lt(Expr::var("ip"), input_len()),
        vec![
            Node::let_bind("token", read_byte("input", Expr::var("ip"))),
            Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
            Node::let_bind("lit_len", Expr::shr(Expr::var("token"), Expr::u32(4))),
            Node::let_bind("lit_ext", Expr::eq(Expr::var("lit_len"), Expr::u32(15))),
            extend_length("lit", "lit_len", "lit_ext"),
            Node::if_then(
                and(active(), Expr::var("lit_ext")),
                vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
            ),
            copy_literals(),
            Node::if_then(
                and(
                    active(),
                    Expr::gt(Expr::var("lit_len"), Expr::u32(LOOP_LIMIT)),
                ),
                vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
            ),
            Node::let_bind("offset", Expr::u32(0)),
            read_match_offset(),
            Node::if_then(
                active(),
                vec![
                    Node::let_bind(
                        "match_len",
                        Expr::add(
                            Expr::bitand(Expr::var("token"), Expr::u32(0x0f)),
                            Expr::u32(4),
                        ),
                    ),
                    Node::let_bind(
                        "match_ext",
                        Expr::eq(
                            Expr::bitand(Expr::var("token"), Expr::u32(0x0f)),
                            Expr::u32(15),
                        ),
                    ),
                    extend_length("match", "match_len", "match_ext"),
                    Node::if_then(
                        and(active(), Expr::var("match_ext")),
                        vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
                    ),
                    validate_match_offset(),
                    copy_match(),
                    Node::if_then(
                        and(
                            active(),
                            Expr::gt(Expr::var("match_len"), Expr::u32(LOOP_LIMIT)),
                        ),
                        vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
                    ),
                ],
            ),
        ],
        vec![Node::assign("is_active", Expr::u32(0))],
    )]
}

pub fn extend_length(prefix: &'static str, len_var: &'static str, flag_var: &'static str) -> Node {
    let ext_byte = format!("{prefix}_ext_byte");
    Node::Loop {
        var: format!("{prefix}_ext_i"),
        from: Expr::u32(0),
        to: Expr::u32(LOOP_LIMIT),
        body: vec![Node::if_then(
            and(active(), Expr::var(flag_var)),
            vec![Node::if_then_else(
                Expr::lt(Expr::var("ip"), input_len()),
                vec![
                    Node::let_bind(&ext_byte, read_byte("input", Expr::var("ip"))),
                    Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
                    Node::assign(len_var, Expr::add(Expr::var(len_var), Expr::var(&ext_byte))),
                    Node::assign(flag_var, Expr::eq(Expr::var(&ext_byte), Expr::u32(255))),
                ],
                vec![fail(STATUS_INPUT_TRUNCATED)],
            )],
        )],
    }
}

pub fn copy_literals() -> Node {
    Node::Loop {
        var: "lit_i".to_string(),
        from: Expr::u32(0),
        to: Expr::u32(LOOP_LIMIT),
        body: vec![Node::if_then(
            and(active(), Expr::lt(Expr::var("lit_i"), Expr::var("lit_len"))),
            vec![Node::if_then_else(
                and(
                    Expr::lt(Expr::var("ip"), input_len()),
                    Expr::lt(Expr::var("op"), output_len()),
                ),
                vec![
                    store_byte("out", Expr::var("op"), read_byte("input", Expr::var("ip"))),
                    Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
                    Node::assign("op", Expr::add(Expr::var("op"), Expr::u32(1))),
                ],
                vec![Node::if_then_else(
                    Expr::lt(Expr::var("ip"), input_len()),
                    vec![fail(STATUS_OUTPUT_OVERFLOW)],
                    vec![fail(STATUS_INPUT_TRUNCATED)],
                )],
            )],
        )],
    }
}

pub fn read_match_offset() -> Node {
    Node::if_then(
        active(),
        vec![Node::if_then_else(
            Expr::eq(Expr::var("ip"), input_len()),
            vec![Node::assign("is_active", Expr::u32(0))],
            vec![Node::if_then_else(
                Expr::lt(Expr::add(Expr::var("ip"), Expr::u32(1)), input_len()),
                vec![
                    Node::let_bind("offset_lo", read_byte("input", Expr::var("ip"))),
                    Node::let_bind(
                        "offset_hi",
                        read_byte("input", Expr::add(Expr::var("ip"), Expr::u32(1))),
                    ),
                    Node::assign(
                        "offset",
                        Expr::bitor(
                            Expr::var("offset_lo"),
                            Expr::shl(Expr::var("offset_hi"), Expr::u32(8)),
                        ),
                    ),
                    Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(2))),
                ],
                vec![fail(STATUS_INPUT_TRUNCATED)],
            )],
        )],
    )
}

pub fn validate_match_offset() -> Node {
    Node::if_then(
        active(),
        vec![Node::if_then(
            or(
                Expr::eq(Expr::var("offset"), Expr::u32(0)),
                Expr::gt(Expr::var("offset"), Expr::var("op")),
            ),
            vec![fail(STATUS_INVALID_OFFSET)],
        )],
    )
}

pub fn copy_match() -> Node {
    Node::Loop {
        var: "match_i".to_string(),
        from: Expr::u32(0),
        to: Expr::u32(LOOP_LIMIT),
        body: vec![Node::if_then(
            and(
                active(),
                Expr::lt(Expr::var("match_i"), Expr::var("match_len")),
            ),
            vec![Node::if_then_else(
                Expr::lt(Expr::var("op"), output_len()),
                vec![
                    Node::let_bind(
                        "match_byte",
                        read_byte("out", Expr::sub(Expr::var("op"), Expr::var("offset"))),
                    ),
                    store_byte("out", Expr::var("op"), Expr::var("match_byte")),
                    Node::assign("op", Expr::add(Expr::var("op"), Expr::u32(1))),
                ],
                vec![fail(STATUS_OUTPUT_OVERFLOW)],
            )],
        )],
    }
}

pub fn read_byte(buffer: &'static str, pos: Expr) -> Expr {
    let word = Expr::load(buffer, Expr::shr(pos.clone(), Expr::u32(2)));
    let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
    Expr::bitand(Expr::shr(word, shift), Expr::u32(0xff))
}

pub fn store_byte(buffer: &'static str, pos: Expr, byte: Expr) -> Node {
    let word_index = Expr::shr(pos.clone(), Expr::u32(2));
    let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
    let mask = Expr::shl(Expr::u32(0xff), shift.clone());
    let cleared = Expr::bitand(
        Expr::load(buffer, word_index.clone()),
        Expr::bitnot(mask.clone()),
    );
    let inserted = Expr::shl(Expr::bitand(byte, Expr::u32(0xff)), shift);
    Node::store(
        buffer,
        word_index,
        Expr::bitor(cleared, Expr::bitand(inserted, mask)),
    )
}

pub fn fail(code: u32) -> Node {
    Node::Block(vec![
        Node::store("status", Expr::u32(0), Expr::u32(code)),
        Node::assign("is_active", Expr::u32(0)),
    ])
}

pub fn input_len() -> Expr {
    Expr::load("params", Expr::u32(0))
}

pub fn output_len() -> Expr {
    Expr::load("params", Expr::u32(1))
}

pub fn active() -> Expr {
    Expr::ne(Expr::var("is_active"), Expr::u32(0))
}

pub fn and(left: Expr, right: Expr) -> Expr {
    Expr::bitand(left, right)
}

pub fn or(left: Expr, right: Expr) -> Expr {
    Expr::bitor(left, right)
}