pub fn build_scatter_shader(max_cached_positions: u32) -> String {
format!(
r#"
struct MatchRow {{
pattern_id: u32,
start: u32,
end: u32,
_pad: u32,
}};
@group(0) @binding(0) var<storage, read> matches_buf: array<MatchRow>;
@group(0) @binding(1) var<storage, read> pattern_to_rules: array<vec2<u32>>;
@group(0) @binding(2) var<storage, read> rule_list: array<u32>;
@group(0) @binding(3) var<storage, read> string_local_ids: array<u32>;
@group(0) @binding(4) var<storage, read_write> rule_bitmaps: array<atomic<u32>>;
@group(0) @binding(5) var<storage, read_write> rule_counts: array<atomic<u32>>;
@group(0) @binding(6) var<storage, read_write> rule_positions: array<atomic<u32>>;
@group(0) @binding(7) var<uniform> params: vec4<u32>;
@group(0) @binding(8) var<storage, read_write> rule_lengths: array<atomic<u32>>;
fn bitmap_index(rule_id: u32, word_idx: u32) -> u32 {{
return rule_id * 8u + word_idx;
}}
fn count_index(rule_id: u32, string_id: u32, max_strings: u32) -> u32 {{
return rule_id * max_strings + string_id;
}}
fn pos_index(rule_id: u32, string_id: u32, slot: u32, max_strings: u32) -> u32 {{
return ((rule_id * max_strings + string_id) * {max_cached_positions}u) + slot;
}}
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let match_count = params.x;
for (var match_index = 0u; match_index < match_count; match_index = match_index + 1u) {{
let row = matches_buf[match_index];
if (row.pattern_id >= arrayLength(&pattern_to_rules)) {{
continue;
}}
let mapping = pattern_to_rules[row.pattern_id];
let map_start = mapping.x;
let map_count = mapping.y;
let max_strings = params.y;
for (var i = 0u; i < map_count; i = i + 1u) {{
let offset = map_start + i;
if (offset >= arrayLength(&rule_list) || offset >= arrayLength(&string_local_ids)) {{
continue;
}}
let rule_id = rule_list[offset];
let string_id = string_local_ids[offset];
let word_idx = string_id / 32u;
let bit_idx = string_id % 32u;
atomicOr(&rule_bitmaps[bitmap_index(rule_id, word_idx)], 1u << bit_idx);
let count_slot = count_index(rule_id, string_id, max_strings);
let previous = atomicAdd(&rule_counts[count_slot], 1u);
if (previous < {max_cached_positions}u) {{
let position_slot = pos_index(rule_id, string_id, previous, max_strings);
atomicStore(&rule_positions[position_slot], row.start);
atomicStore(&rule_lengths[position_slot], row.end - row.start);
}}
}}
}}
}}
"#
)
}