vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// Portable GPU kernel for string_similarity.ngram_histogram.
//
// Bindings:
// 0 params: len_a, len_b, param_c, param_d
// 1 input_words: little-endian packed bytes
// 2 output_grams: each unique n-gram copied into output_stride_words packed u32 words
// 3 output_counts: u32 count per emitted n-gram
// 4 record_count: u32 number of records emitted

@group(0) @binding(1) var<storage, read> input_words: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_grams: array<u32>;
@group(0) @binding(3) var<storage, read_write> output_counts: array<u32>;
@group(0) @binding(4) var<storage, read_write> record_count: array<u32>;

fn byte_at(index: u32) -> u32 {
    let word = input_words[index >> 2u];
    let shift = (index & 3u) << 3u;
    return (word >> shift) & 0xffu;
}

fn compare_grams(left: u32, right: u32) -> i32 {
    var offset = 0u;
    loop {
        if (offset >= params.len_b) {
            break;
        }
        let a = byte_at(left + offset);
        let b = byte_at(right + offset);
        if (a < b) {
            return -1i;
        }
        if (a > b) {
            return 1i;
        }
        offset = offset + 1u;
    }
    return 0i;
}

fn store_byte(base_word: u32, offset: u32, byte: u32) {
    let word_index = base_word + (offset >> 2u);
    let shift = (offset & 3u) << 3u;
    let mask = 0xffu << shift;
    output_grams[word_index] = (output_grams[word_index] & ~mask) | ((byte & 0xffu) << shift);
}

fn store_gram(record: u32, pos: u32) {
    let base = record * params.param_d;
    var offset = 0u;
    loop {
        if (offset >= params.len_b) {
            break;
        }
        store_byte(base, offset, byte_at(pos + offset));
        offset = offset + 1u;
    }
}

fn count_equal(pos: u32, gram_count: u32) -> u32 {
    var count = 0u;
    var cursor = 0u;
    loop {
        if (cursor >= gram_count) {
            break;
        }
        if (compare_grams(pos, cursor) == 0i) {
            count = count + 1u;
        }
        cursor = cursor + 1u;
    }
    return count;
}

@compute @workgroup_size(1, 1, 1)
fn string_similarity_ngram_histogram(@builtin(global_invocation_id) id: vec3<u32>) {
    if (id.x != 0u) {
        return;
    }
    record_count[0] = 0u;
    if (params.len_b == 0u || params.len_b > params.len_a || params.param_d == 0u) {
        return;
    }

    let gram_count = params.len_a - params.len_b + 1u;
    var previous = 0xffffffffu;
    var emitted = 0u;
    loop {
        if (emitted >= params.param_c) {
            break;
        }
        var best = 0xffffffffu;
        var cursor = 0u;
        loop {
            if (cursor >= gram_count) {
                break;
            }
            let after_previous = previous == 0xffffffffu || compare_grams(cursor, previous) > 0i;
            let before_best = best == 0xffffffffu || compare_grams(cursor, best) < 0i;
            if (after_previous && before_best) {
                best = cursor;
            }
            cursor = cursor + 1u;
        }
        if (best == 0xffffffffu) {
            break;
        }
        store_gram(emitted, best);
        output_counts[emitted] = count_equal(best, gram_count);
        previous = best;
        emitted = emitted + 1u;
    }
    record_count[0] = emitted;
}