vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::AlgebraicLaw;
use crate::ops::{OpSpec, U32_INPUTS, U32_OUTPUTS};

// WGSL lowering marker for `match.scatter`.
//
// No special per-op lowering is needed. The normal IR lowerer handles the
// match-result scatter composition.

impl Scatter {
    /// Declarative operation specification.
    pub const SCATTER: OpSpec = OpSpec::composition(
        "match.scatter",
        U32_INPUTS,
        U32_OUTPUTS,
        LAWS,
        Self::program,
    );

    /// Declarative operation specification.
    pub const SPEC: OpSpec = Self::SCATTER;

    /// Build the canonical IR program.
    #[must_use]
    pub fn program() -> Program {
        let match_index = Expr::var("match_index");
        let match_count = Expr::var("match_count");
        let row_base = Expr::var("row_base");
        let pattern_id = Expr::var("pattern_id");
        let ptrn_len = Expr::var("ptrn_len");
        let map_start = Expr::var("map_start");
        let map_count = Expr::var("map_count");
        let max_strings = Expr::var("max_strings");
        let max_cached_positions = Expr::var("max_cached_positions");
        let offset = Expr::var("offset");
        let rule_id = Expr::var("rule_id");
        let string_id = Expr::var("string_id");
        let word_idx = Expr::var("word_idx");
        let bit_mask = Expr::var("bit_mask");
        let bm_idx = Expr::var("bm_idx");
        let count_slot = Expr::var("count_slot");
        let new_count = Expr::var("new_count");
        let position_slot = Expr::var("position_slot");
        let row_start = Expr::var("row_start");
        let row_end = Expr::var("row_end");
        let entropy_bucket = Expr::var("entropy_bucket");

        Program::new(
            vec![
                BufferDecl::read("matches_buf", 0, DataType::U32),
                BufferDecl::read("pattern_to_rules", 1, DataType::U32),
                BufferDecl::read("rule_list", 2, DataType::U32),
                BufferDecl::read("string_local_ids", 3, DataType::U32),
                BufferDecl::output("rule_bitmaps", 4, DataType::U32),
                BufferDecl::read_write("rule_counts", 5, DataType::U32),
                BufferDecl::read_write("rule_positions", 6, DataType::U32),
                BufferDecl::uniform("params", 7, DataType::U32),
                BufferDecl::read_write("rule_match_aux", 8, DataType::U32),
                BufferDecl::read("prefix_brace_depths", 9, DataType::U32),
            ],
            WORKGROUP_SIZE,
            vec![
                Node::let_bind("match_index", Expr::gid_x()),
                Node::let_bind("match_count", Expr::load("params", Expr::u32(0))),
                Node::if_then(
                    Expr::lt(match_index.clone(), match_count.clone()),
                    vec![
                        Node::let_bind(
                            "row_base",
                            Expr::mul(match_index, Expr::u32(4)),
                        ),
                        Node::let_bind(
                            "pattern_id",
                            Expr::load("matches_buf", row_base.clone()),
                        ),
                        Node::let_bind(
                            "row_start",
                            Expr::load(
                                "matches_buf",
                                Expr::add(row_base.clone(), Expr::u32(1)),
                            ),
                        ),
                        Node::let_bind(
                            "row_end",
                            Expr::load(
                                "matches_buf",
                                Expr::add(row_base.clone(), Expr::u32(2)),
                            ),
                        ),
                        Node::let_bind(
                            "entropy_bucket",
                            Expr::load(
                                "matches_buf",
                                Expr::add(row_base, Expr::u32(3)),
                            ),
                        ),
                        Node::let_bind("ptrn_len", Expr::buf_len("pattern_to_rules")),
                        Node::if_then(
                            Expr::lt(
                                Expr::add(
                                    Expr::mul(pattern_id.clone(), Expr::u32(2)),
                                    Expr::u32(1),
                                ),
                                ptrn_len.clone(),
                            ),
                            vec![
                                Node::let_bind(
                                    "map_start",
                                    Expr::load(
                                        "pattern_to_rules",
                                        Expr::mul(pattern_id.clone(), Expr::u32(2)),
                                    ),
                                ),
                                Node::let_bind(
                                    "map_count",
                                    Expr::load(
                                        "pattern_to_rules",
                                        Expr::add(
                                            Expr::mul(pattern_id, Expr::u32(2)),
                                            Expr::u32(1),
                                        ),
                                    ),
                                ),
                                Node::let_bind(
                                    "max_strings",
                                    Expr::load("params", Expr::u32(1)),
                                ),
                                Node::let_bind(
                                    "max_cached_positions",
                                    Expr::load("params", Expr::u32(2)),
                                ),
                                Node::loop_for(
                                    "i",
                                    Expr::u32(0),
                                    map_count.clone(),
                                    vec![
                                        Node::let_bind(
                                            "offset",
                                            Expr::add(map_start.clone(), Expr::var("i")),
                                        ),
                                        Node::if_then(
                                            Expr::and(
                                                Expr::lt(
                                                    offset.clone(),
                                                    Expr::buf_len("rule_list"),
                                                ),
                                                Expr::lt(
                                                    offset.clone(),
                                                    Expr::buf_len("string_local_ids"),
                                                ),
                                            ),
                                            vec![
                                                Node::let_bind(
                                                    "rule_id",
                                                    Expr::load("rule_list", offset.clone()),
                                                ),
                                                Node::let_bind(
                                                    "string_id",
                                                    Expr::load(
                                                        "string_local_ids",
                                                        offset,
                                                    ),
                                                ),
                                                // CRITICAL: string_id < 256 (bitmap is 8 words)
                                                Node::if_then(
                                                    Expr::lt(
                                                        string_id.clone(),
                                                        Expr::u32(256),
                                                    ),
                                                    vec![
                                                        Node::let_bind(
                                                            "word_idx",
                                                            Expr::shr(
                                                                string_id.clone(),
                                                                Expr::u32(5),
                                                            ),
                                                        ),
                                                        Node::let_bind(
                                                            "bit_idx",
                                                            Expr::bitand(
                                                                string_id.clone(),
                                                                Expr::u32(31),
                                                            ),
                                                        ),
                                                        Node::let_bind(
                                                            "bit_mask",
                                                            Expr::shl(
                                                                Expr::u32(1),
                                                                Expr::var("bit_idx"),
                                                            ),
                                                        ),
                                                        Node::let_bind(
                                                            "bm_idx",
                                                            Expr::add(
                                                                Expr::mul(
                                                                    rule_id.clone(),
                                                                    Expr::u32(8),
                                                                ),
                                                                word_idx.clone(),
                                                            ),
                                                        ),
                                                        Node::if_then(
                                                            Expr::lt(
                                                                bm_idx.clone(),
                                                                Expr::buf_len(
                                                                    "rule_bitmaps",
                                                                ),
                                                            ),
                                                            vec![Node::let_bind(
                                                                "_old_bitmap",
                                                                Expr::atomic_or(
                                                                    "rule_bitmaps",
                                                                    bm_idx,
                                                                    bit_mask,
                                                                ),
                                                            )],
                                                        ),
                                                        Node::let_bind(
                                                            "count_slot",
                                                            Expr::add(
                                                                Expr::mul(
                                                                    rule_id.clone(),
                                                                    max_strings.clone(),
                                                                ),
                                                                string_id.clone(),
                                                            ),
                                                        ),
                                                        Node::if_then(
                                                            Expr::lt(
                                                                count_slot.clone(),
                                                                Expr::buf_len(
                                                                    "rule_counts",
                                                                ),
                                                            ),
                                                            vec![
                                                                Node::let_bind(
                                                                    "new_count",
                                                                    Expr::atomic_add(
                                                                        "rule_counts",
                                                                        count_slot,
                                                                        Expr::u32(1),
                                                                    ),
                                                                ),
                                                                Node::if_then(
                                                                    Expr::lt(
                                                                        new_count.clone(),
                                                                        max_cached_positions
                                                                            .clone(),
                                                                    ),
                                                                    vec![
                                                                        Node::let_bind(
                                                                            "position_slot",
                                                                            Expr::add(
                                                                                Expr::mul(
                                                                                    Expr::add(
                                                                                        Expr::mul(
                                                                                            rule_id,
                                                                                            max_strings,
                                                                                        ),
                                                                                        string_id,
                                                                                    ),
                                                                                    max_cached_positions,
                                                                                ),
                                                                                new_count.clone(),
                                                                            ),
                                                                        ),
                                                                        Node::if_then(
                                                                            Expr::and(
                                                                                Expr::lt(
                                                                                    position_slot.clone(),
                                                                                    Expr::buf_len("rule_positions"),
                                                                                ),
                                                                                Expr::lt(
                                                                                    Expr::add(
                                                                                        Expr::mul(
                                                                                            position_slot.clone(),
                                                                                            Expr::u32(2),
                                                                                        ),
                                                                                        Expr::u32(1),
                                                                                    ),
                                                                                    Expr::buf_len("rule_match_aux"),
                                                                                ),
                                                                            ),
                                                                            vec![
                                                                                Node::store(
                                                                                    "rule_positions",
                                                                                    position_slot.clone(),
                                                                                    row_start.clone(),
                                                                                ),
                                                                                Node::let_bind(
                                                                                    "depth",
                                                                                    Expr::select(
                                                                                        Expr::lt(
                                                                                            row_start.clone(),
                                                                                            Expr::buf_len("prefix_brace_depths"),
                                                                                        ),
                                                                                        Expr::load(
                                                                                            "prefix_brace_depths",
                                                                                            row_start.clone(),
                                                                                        ),
                                                                                        Expr::u32(0),
                                                                                    ),
                                                                                ),
                                                                                Node::let_bind(
                                                                                    "packed",
                                                                                    Expr::bitor(
                                                                                        Expr::shl(
                                                                                            entropy_bucket.clone(),
                                                                                            Expr::u32(24),
                                                                                        ),
                                                                                        Expr::bitand(
                                                                                            Expr::var("depth"),
                                                                                            Expr::u32(0x00FFFFFF),
                                                                                        ),
                                                                                    ),
                                                                                ),
                                                                                Node::store(
                                                                                    "rule_match_aux",
                                                                                    Expr::mul(
                                                                                        position_slot.clone(),
                                                                                        Expr::u32(2),
                                                                                    ),
                                                                                    Expr::sub(
                                                                                        row_end.clone(),
                                                                                        row_start,
                                                                                    ),
                                                                                ),
                                                                                Node::store(
                                                                                    "rule_match_aux",
                                                                                    Expr::add(
                                                                                        Expr::mul(
                                                                                            position_slot,
                                                                                            Expr::u32(2),
                                                                                        ),
                                                                                        Expr::u32(1),
                                                                                    ),
                                                                                    Expr::var("packed"),
                                                                                ),
                                                                            ],
                                                                        ),
                                                                    ],
                                                                ),
                                                            ],
                                                        ),
                                                    ],
                                                ),
                                            ],
                                        ),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
    }
}

pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
    lo: 0,
    hi: u32::MAX,
}];

/// Match-to-rule scatter operation.
#[derive(Debug, Clone, Copy, Default)]
pub struct Scatter;

/// Operation specification re-export for the registry.
pub const SCATTER: OpSpec = Scatter::SCATTER;

/// Workgroup size for the scatter kernel.
pub const WORKGROUP_SIZE: [u32; 3] = [256, 1, 1];