vyre 0.1.0

GPU bytecode condition engine
Documentation
/// Generate the pass-2 scatter shader.
pub fn build_scatter_shader(max_cached_positions: u32) -> String {
    format!(
        r#"
struct MatchRow {{
    pattern_id: u32,
    start: u32,
    end: u32,
    _pad: u32,
}};

@group(0) @binding(0) var<storage, read> matches_buf: array<MatchRow>;
@group(0) @binding(1) var<storage, read> pattern_to_rules: array<vec2<u32>>;
@group(0) @binding(2) var<storage, read> rule_list: array<u32>;
@group(0) @binding(3) var<storage, read> string_local_ids: array<u32>;
@group(0) @binding(4) var<storage, read_write> rule_bitmaps: array<atomic<u32>>;
@group(0) @binding(5) var<storage, read_write> rule_counts: array<atomic<u32>>;
@group(0) @binding(6) var<storage, read_write> rule_positions: array<atomic<u32>>;
@group(0) @binding(7) var<uniform> params: vec4<u32>;
@group(0) @binding(8) var<storage, read_write> rule_lengths: array<atomic<u32>>;

fn bitmap_index(rule_id: u32, word_idx: u32) -> u32 {{
    return rule_id * 8u + word_idx;
}}

fn count_index(rule_id: u32, string_id: u32, max_strings: u32) -> u32 {{
    return rule_id * max_strings + string_id;
}}

fn pos_index(rule_id: u32, string_id: u32, slot: u32, max_strings: u32) -> u32 {{
    return ((rule_id * max_strings + string_id) * {max_cached_positions}u) + slot;
}}

@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
    let match_count = params.x;
    for (var match_index = 0u; match_index < match_count; match_index = match_index + 1u) {{
        let row = matches_buf[match_index];
        if (row.pattern_id >= arrayLength(&pattern_to_rules)) {{
            continue;
        }}
        let mapping = pattern_to_rules[row.pattern_id];
        let map_start = mapping.x;
        let map_count = mapping.y;
        let max_strings = params.y;
        for (var i = 0u; i < map_count; i = i + 1u) {{
            let offset = map_start + i;
            if (offset >= arrayLength(&rule_list) || offset >= arrayLength(&string_local_ids)) {{
                continue;
            }}
            let rule_id = rule_list[offset];
            let string_id = string_local_ids[offset];

            let word_idx = string_id / 32u;
            let bit_idx = string_id % 32u;
            atomicOr(&rule_bitmaps[bitmap_index(rule_id, word_idx)], 1u << bit_idx);

            let count_slot = count_index(rule_id, string_id, max_strings);
            let previous = atomicAdd(&rule_counts[count_slot], 1u);
            if (previous < {max_cached_positions}u) {{
                let position_slot = pos_index(rule_id, string_id, previous, max_strings);
                atomicStore(&rule_positions[position_slot], row.start);
                atomicStore(&rule_lengths[position_slot], row.end - row.start);
            }}
        }}
    }}
}}
"#
    )
}