// Portable GPU kernel for string_similarity.ngram_histogram.
//
// Bindings:
// 0 params: len_a, len_b, param_c, param_d
// 1 input_words: little-endian packed bytes
// 2 output_grams: each unique n-gram copied into output_stride_words packed u32 words
// 3 output_counts: u32 count per emitted n-gram
// 4 record_count: u32 number of records emitted
@group(0) @binding(1) var<storage, read> input_words: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_grams: array<u32>;
@group(0) @binding(3) var<storage, read_write> output_counts: array<u32>;
@group(0) @binding(4) var<storage, read_write> record_count: array<u32>;
fn byte_at(index: u32) -> u32 {
let word = input_words[index >> 2u];
let shift = (index & 3u) << 3u;
return (word >> shift) & 0xffu;
}
fn compare_grams(left: u32, right: u32) -> i32 {
var offset = 0u;
loop {
if (offset >= params.len_b) {
break;
}
let a = byte_at(left + offset);
let b = byte_at(right + offset);
if (a < b) {
return -1i;
}
if (a > b) {
return 1i;
}
offset = offset + 1u;
}
return 0i;
}
fn store_byte(base_word: u32, offset: u32, byte: u32) {
let word_index = base_word + (offset >> 2u);
let shift = (offset & 3u) << 3u;
let mask = 0xffu << shift;
output_grams[word_index] = (output_grams[word_index] & ~mask) | ((byte & 0xffu) << shift);
}
fn store_gram(record: u32, pos: u32) {
let base = record * params.param_d;
var offset = 0u;
loop {
if (offset >= params.len_b) {
break;
}
store_byte(base, offset, byte_at(pos + offset));
offset = offset + 1u;
}
}
fn count_equal(pos: u32, gram_count: u32) -> u32 {
var count = 0u;
var cursor = 0u;
loop {
if (cursor >= gram_count) {
break;
}
if (compare_grams(pos, cursor) == 0i) {
count = count + 1u;
}
cursor = cursor + 1u;
}
return count;
}
@compute @workgroup_size(1, 1, 1)
fn string_similarity_ngram_histogram(@builtin(global_invocation_id) id: vec3<u32>) {
if (id.x != 0u) {
return;
}
record_count[0] = 0u;
if (params.len_b == 0u || params.len_b > params.len_a || params.param_d == 0u) {
return;
}
let gram_count = params.len_a - params.len_b + 1u;
var previous = 0xffffffffu;
var emitted = 0u;
loop {
if (emitted >= params.param_c) {
break;
}
var best = 0xffffffffu;
var cursor = 0u;
loop {
if (cursor >= gram_count) {
break;
}
let after_previous = previous == 0xffffffffu || compare_grams(cursor, previous) > 0i;
let before_best = best == 0xffffffffu || compare_grams(cursor, best) < 0i;
if (after_previous && before_best) {
best = cursor;
}
cursor = cursor + 1u;
}
if (best == 0xffffffffu) {
break;
}
store_gram(emitted, best);
output_counts[emitted] = count_equal(best, gram_count);
previous = best;
emitted = emitted + 1u;
}
record_count[0] = emitted;
}