vyre 0.2.0

GPU bytecode condition engine
Documentation
/// Generate the pass-2 scatter shader.
pub fn build_scatter_shader(max_cached_positions: u32) -> String {
    format!(
        r#"
struct MatchRow {{
    pattern_id: u32,
    start: u32,
    end: u32,
    _pad: u32,
}};

@group(0) @binding(0) var<storage, read> matches_buf: array<MatchRow>;
@group(0) @binding(1) var<storage, read> pattern_to_rules: array<vec2<u32>>;
@group(0) @binding(2) var<storage, read> rule_list: array<u32>;
@group(0) @binding(3) var<storage, read> string_local_ids: array<u32>;
@group(0) @binding(4) var<storage, read_write> rule_bitmaps: array<atomic<u32>>;
@group(0) @binding(5) var<storage, read_write> rule_counts: array<atomic<u32>>;
@group(0) @binding(6) var<storage, read_write> rule_positions: array<atomic<u32>>;
@group(0) @binding(7) var<uniform> params: vec4<u32>;
@group(0) @binding(8) var<storage, read_write> rule_lengths: array<atomic<u32>>;

fn bitmap_index(rule_id: u32, word_idx: u32) -> u32 {{
    return rule_id * 8u + word_idx;
}}

fn count_index(rule_id: u32, string_id: u32, max_strings: u32) -> u32 {{
    return rule_id * max_strings + string_id;
}}

fn pos_index(rule_id: u32, string_id: u32, slot: u32, max_strings: u32) -> u32 {{
    return ((rule_id * max_strings + string_id) * {max_cached_positions}u) + slot;
}}

@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
    let match_count = params.x;
    let match_index = gid.x;
    if (match_index >= match_count) {{
        return;
    }}
    let row = matches_buf[match_index];
    if (row.pattern_id >= arrayLength(&pattern_to_rules)) {{
        return;
    }}
    let mapping = pattern_to_rules[row.pattern_id];
    let map_start = mapping.x;
    let map_count = mapping.y;
    let max_strings = params.y;
    for (var i = 0u; i < map_count; i = i + 1u) {{
        let offset = map_start + i;
        if (offset >= arrayLength(&rule_list) || offset >= arrayLength(&string_local_ids)) {{
            continue;
        }}
        let rule_id = rule_list[offset];
        let string_id = string_local_ids[offset];

        let word_idx = string_id / 32u;
        let bit_idx = string_id % 32u;
        atomicOr(&rule_bitmaps[bitmap_index(rule_id, word_idx)], 1u << bit_idx);

        let count_slot = count_index(rule_id, string_id, max_strings);
        let previous = atomicAdd(&rule_counts[count_slot], 1u);
        if (previous < {max_cached_positions}u) {{
            let position_slot = pos_index(rule_id, string_id, previous, max_strings);
            atomicStore(&rule_positions[position_slot], row.start);
            atomicStore(&rule_lengths[position_slot], row.end - row.start);
        }}
    }}
}}
"#
    )
}