// Portable GPU kernel for string_matching.substring_find_first.
//
// Bindings:
// 0 params: len_a, len_b, param_c, param_d
// 1 haystack_words: little-endian packed bytes
// 2 needle_words: little-endian packed bytes
// 3 result: atomic u32 first offset, initialized to 0xffffffff
struct Params {
haystack_len: u32,
needle_len: u32,
reserved0: u32,
reserved1: u32,
}
const NOT_FOUND: u32 = 0xffffffffu;
@group(0) @binding(0) var<uniform> params: Params;
@group(0) @binding(1) var<storage, read> haystack_words: array<u32>;
@group(0) @binding(2) var<storage, read> needle_words: array<u32>;
@group(0) @binding(3) var<storage, read_write> result: atomic<u32>;
fn packed_byte(words: ptr<storage, array<u32>, read>, index: u32) -> u32 {
let word = (*words)[index >> 2u];
let shift = (index & 3u) << 3u;
return (word >> shift) & 0xffu;
}
fn haystack_byte(index: u32) -> u32 {
return packed_byte(&haystack_words, index);
}
fn needle_byte(index: u32) -> u32 {
return packed_byte(&needle_words, index);
}
fn candidate_matches(start: u32) -> bool {
var offset = 0u;
loop {
if (offset >= params.len_b) {
break;
}
if (haystack_byte(start + offset) != needle_byte(offset)) {
return false;
}
offset = offset + 1u;
}
return true;
}
@compute @workgroup_size(256, 1, 1)
fn string_matching_substring_find_first(@builtin(global_invocation_id) id: vec3<u32>) {
if (params.len_b == 0u) {
atomicMin(&result, 0u);
return;
}
if (params.len_b > params.len_a) {
return;
}
let start = id.x;
if (start > params.len_a - params.len_b) {
return;
}
if (start >= atomicLoad(&result)) {
return;
}
if (candidate_matches(start)) {
atomicMin(&result, start);
}
}