vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::AlgebraicLaw;
use crate::ops::{OpSpec, BYTES_TO_U32_OUTPUTS, U32_INPUTS};

// WGSL lowering marker for `match.dfa_scan`.
//
// No special per-op lowering is needed. The normal IR lowerer handles the
// DFA scan composition.

/// DFA scan operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct DfaScan;

impl DfaScan {
    /// Declarative operation specification.
    pub const SPEC: OpSpec =
        OpSpec::composition(OP_ID, U32_INPUTS, BYTES_TO_U32_OUTPUTS, LAWS, Self::program);

    /// Build the canonical IR program.
    #[must_use]
    pub fn program() -> Program {
        let pos = Expr::var("pos");
        let state = Expr::var("state");
        let pattern_id = Expr::var("pattern_id");
        Program::new(
            vec![
                BufferDecl::read("input_bytes", 0, DataType::U32),
                BufferDecl::read("transitions", 1, DataType::U32),
                BufferDecl::read("accept_map", 2, DataType::U32),
                BufferDecl::output("matches", 3, DataType::U32),
                BufferDecl::read_write("match_count", 4, DataType::U32),
                BufferDecl::read("params", 5, DataType::U32),
                BufferDecl::read("pattern_lengths", 6, DataType::U32),
            ],
            WORKGROUP_SIZE,
            vec![
                Node::let_bind("worker", Expr::gid_x()),
                Node::if_then(
                    Expr::eq(Expr::var("worker"), Expr::u32(0)),
                    vec![
                        Node::let_bind("state", Expr::u32(0)),
                        Node::let_bind("pos", Expr::u32(0)),
                        Node::Loop {
                            var: "step".to_string(),
                            from: Expr::u32(0),
                            to: Expr::load("params", Expr::u32(0)),
                            body: vec![
                                Node::let_bind("byte_val", read_byte(pos.clone())),
                                Node::assign(
                                    "state",
                                    transition(state.clone(), Expr::var("byte_val")),
                                ),
                                Node::let_bind(
                                    "pattern_id",
                                    Expr::load("accept_map", state.clone()),
                                ),
                                Node::if_then(
                                    Expr::ne(pattern_id.clone(), Expr::u32(NO_MATCH)),
                                    vec![report_match(
                                        pattern_id,
                                        Expr::add(pos.clone(), Expr::u32(1)),
                                    )],
                                ),
                                Node::assign("pos", Expr::add(pos, Expr::u32(1))),
                            ],
                        },
                    ],
                ),
            ],
        )
        .with_entry_op_id(OP_ID)
    }
}

pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
    lo: 0,
    hi: u32::MAX,
}];

/// Sentinel used for absent pattern/output-link entries.
pub const NO_MATCH: u32 = u32::MAX;

pub const OP_ID: &str = "match.dfa_scan";

pub fn read_byte(pos: Expr) -> Expr {
    let word = Expr::load("input_bytes", Expr::shr(pos.clone(), Expr::u32(2)));
    let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
    Expr::bitand(Expr::shr(word, shift), Expr::u32(0xff))
}

pub fn report_match(pattern_id: Expr, end_pos: Expr) -> Node {
    Node::Block(vec![
        Node::let_bind(
            "pattern_len",
            Expr::load("pattern_lengths", pattern_id.clone()),
        ),
        Node::let_bind(
            "match_idx",
            Expr::atomic_add("match_count", Expr::u32(0), Expr::u32(1)),
        ),
        Node::if_then(
            Expr::lt(Expr::var("match_idx"), Expr::load("params", Expr::u32(2))),
            vec![
                Node::store(
                    "matches",
                    Expr::mul(Expr::var("match_idx"), Expr::u32(3)),
                    pattern_id,
                ),
                Node::store(
                    "matches",
                    Expr::add(
                        Expr::mul(Expr::var("match_idx"), Expr::u32(3)),
                        Expr::u32(1),
                    ),
                    Expr::sub(end_pos.clone(), Expr::var("pattern_len")),
                ),
                Node::store(
                    "matches",
                    Expr::add(
                        Expr::mul(Expr::var("match_idx"), Expr::u32(3)),
                        Expr::u32(2),
                    ),
                    end_pos,
                ),
            ],
        ),
    ])
}

pub fn transition(state: Expr, byte: Expr) -> Expr {
    Expr::load(
        "transitions",
        Expr::add(Expr::mul(state, Expr::u32(256)), byte),
    )
}

/// Workgroup size used by the DFA scanner.
pub const WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];