vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// Portable GPU kernel for string_similarity.simhash64.
//
// Bindings:
// 0 params: len_a, len_b, param_c, param_d
// 1 input_words: little-endian packed bytes
// 2 output_words: two u32 words, low then high

struct Params {
    input_len: u32,
    reserved0: u32,
    reserved1: u32,
    reserved2: u32,
}

struct U64Parts {
    lo: u32,
    hi: u32,
}

const FNV_OFFSET_LO: u32 = 0x84222325u;
const FNV_OFFSET_HI: u32 = 0xcbf29ce4u;
const FNV_PRIME_LO: u32 = 0x000001b3u;
const FNV_PRIME_HI: u32 = 0x00000100u;
const DEFAULT_GRAM: u32 = 4u;

@group(0) @binding(0) var<uniform> params: Params;
@group(0) @binding(1) var<storage, read> input_words: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_words: array<u32>;

fn byte_at(index: u32) -> u32 {
    let word = input_words[index >> 2u];
    let shift = (index & 3u) << 3u;
    return (word >> shift) & 0xffu;
}

fn mul_hi_u32(a: u32, b: u32) -> u32 {
    let a0 = a & 0xffffu;
    let a1 = a >> 16u;
    let b0 = b & 0xffffu;
    let b1 = b >> 16u;
    let p0 = a0 * b0;
    let p1 = a1 * b0;
    let p2 = a0 * b1;
    let p3 = a1 * b1;
    let middle = (p0 >> 16u) + (p1 & 0xffffu) + (p2 & 0xffffu);
    return p3 + (p1 >> 16u) + (p2 >> 16u) + (middle >> 16u);
}

fn fnv_mul(hash: U64Parts) -> U64Parts {
    let low = hash.lo * FNV_PRIME_LO;
    let high = mul_hi_u32(hash.lo, FNV_PRIME_LO)
        + (hash.lo * FNV_PRIME_HI)
        + (hash.hi * FNV_PRIME_LO);
    return U64Parts(low, high);
}

fn fnv1a64(start: u32, len: u32) -> U64Parts {
    var hash = U64Parts(FNV_OFFSET_LO, FNV_OFFSET_HI);
    var offset = 0u;
    loop {
        if (offset >= len) {
            break;
        }
        hash.lo = hash.lo ^ byte_at(start + offset);
        hash = fnv_mul(hash);
        offset = offset + 1u;
    }
    return hash;
}

fn hash_bit(hash: U64Parts, bit: u32) -> u32 {
    if (bit < 32u) {
        return (hash.lo >> bit) & 1u;
    }
    return (hash.hi >> (bit - 32u)) & 1u;
}

@compute @workgroup_size(1, 1, 1)
fn string_similarity_simhash64(@builtin(global_invocation_id) id: vec3<u32>) {
    if (id.x != 0u) {
        return;
    }
    if (params.len_a == 0u) {
        output_words[0] = 0u;
        output_words[1] = 0u;
        return;
    }

    let gram_len = select(params.len_a, DEFAULT_GRAM, params.len_a >= DEFAULT_GRAM);
    let gram_count = select(1u, params.len_a - DEFAULT_GRAM + 1u, params.len_a >= DEFAULT_GRAM);
    var out = U64Parts(0u, 0u);
    var bit = 0u;
    loop {
        if (bit >= 64u) {
            break;
        }
        var weight = 0i;
        var gram = 0u;
        loop {
            if (gram >= gram_count) {
                break;
            }
            let h = fnv1a64(gram, gram_len);
            weight = weight + select(-1i, 1i, hash_bit(h, bit) == 1u);
            gram = gram + 1u;
        }
        if (weight >= 0i) {
            if (bit < 32u) {
                out.lo = out.lo | (1u << bit);
            } else {
                out.hi = out.hi | (1u << (bit - 32u));
            }
        }
        bit = bit + 1u;
    }
    output_words[0] = out.lo;
    output_words[1] = out.hi;
}