vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Base64 decode operation as an IR composition.

/// CPU-independent IR kernel for contiguous RFC 4648 base64 runs.
pub mod kernel {
    use super::super::shared::{and, in_range, BYTE_BOUNDED_LAWS};
    use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
    use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};

    /// GPU region decoder for contiguous RFC 4648 base64 runs.
    #[derive(Debug, Clone, Copy, Default)]
    pub struct Base64Decode;

    impl Base64Decode {
        /// Declarative operation specification.
        pub const SPEC: OpSpec = OpSpec::composition(
            "decode.base64",
            BYTES_TO_BYTES_INPUTS,
            BYTES_TO_BYTES_OUTPUTS,
            BYTE_BOUNDED_LAWS,
            Self::program,
        );

        /// Build the canonical IR program.
        #[must_use]
        pub fn program() -> Program {
            let idx = Expr::var("idx");
            let out_idx = Expr::mul(idx.clone(), Expr::u32(3));
            let a = Expr::load("input", Expr::mul(idx.clone(), Expr::u32(4)));
            let b = Expr::load(
                "input",
                Expr::add(Expr::mul(idx.clone(), Expr::u32(4)), Expr::u32(1)),
            );
            let c = Expr::load(
                "input",
                Expr::add(Expr::mul(idx.clone(), Expr::u32(4)), Expr::u32(2)),
            );
            let d = Expr::load(
                "input",
                Expr::add(Expr::mul(idx.clone(), Expr::u32(4)), Expr::u32(3)),
            );
            Program::new(
                vec![
                    BufferDecl::read("input", 0, DataType::Bytes),
                    BufferDecl::output("out", 1, DataType::Bytes),
                ],
                [64, 1, 1],
                vec![
                    Node::let_bind("idx", Expr::gid_x()),
                    Node::if_then(
                        and(
                            Expr::lt(
                                Expr::add(Expr::mul(idx.clone(), Expr::u32(4)), Expr::u32(3)),
                                Expr::buf_len("input"),
                            ),
                            Expr::lt(
                                Expr::add(out_idx.clone(), Expr::u32(2)),
                                Expr::buf_len("out"),
                            ),
                        ),
                        vec![
                            Node::let_bind("a6", decode_value(a)),
                            Node::let_bind("b6", decode_value(b)),
                            Node::let_bind("c6", decode_value(c)),
                            Node::let_bind("d6", decode_value(d)),
                            Node::store("out", out_idx.clone(), byte0()),
                            Node::store("out", Expr::add(out_idx.clone(), Expr::u32(1)), byte1()),
                            Node::store("out", Expr::add(out_idx, Expr::u32(2)), byte2()),
                        ],
                    ),
                ],
            )
        }
    }

    /// Decode one base64 alphabet byte to a 6-bit value.
    #[must_use]
    pub fn decode_value(byte: Expr) -> Expr {
        let upper = in_range(byte.clone(), b'A', b'Z');
        let lower = in_range(byte.clone(), b'a', b'z');
        let digit = in_range(byte.clone(), b'0', b'9');
        let plus = Expr::eq(byte.clone(), Expr::u32(u32::from(b'+')));
        let slash = Expr::eq(byte.clone(), Expr::u32(u32::from(b'/')));
        Expr::select(
            upper,
            Expr::sub(byte.clone(), Expr::u32(u32::from(b'A'))),
            Expr::select(
                lower,
                Expr::add(
                    Expr::sub(byte.clone(), Expr::u32(u32::from(b'a'))),
                    Expr::u32(26),
                ),
                Expr::select(
                    digit,
                    Expr::add(Expr::sub(byte, Expr::u32(u32::from(b'0'))), Expr::u32(52)),
                    Expr::select(
                        plus,
                        Expr::u32(62),
                        Expr::select(slash, Expr::u32(63), Expr::u32(0)),
                    ),
                ),
            ),
        )
    }

    #[must_use]
    pub fn byte0() -> Expr {
        Expr::bitor(
            Expr::shl(Expr::var("a6"), Expr::u32(2)),
            Expr::shr(Expr::var("b6"), Expr::u32(4)),
        )
    }

    #[must_use]
    pub fn byte1() -> Expr {
        Expr::bitor(
            Expr::shl(Expr::bitand(Expr::var("b6"), Expr::u32(0x0f)), Expr::u32(4)),
            Expr::shr(Expr::var("c6"), Expr::u32(2)),
        )
    }

    #[must_use]
    pub fn byte2() -> Expr {
        Expr::bitor(
            Expr::shl(Expr::bitand(Expr::var("c6"), Expr::u32(0x03)), Expr::u32(6)),
            Expr::var("d6"),
        )
    }
}

pub use kernel::Base64Decode;