// Portable GPU kernel for string_similarity.simhash64.
//
// Bindings:
// 0 params: len_a, len_b, param_c, param_d
// 1 input_words: little-endian packed bytes
// 2 output_words: two u32 words, low then high
struct Params {
input_len: u32,
reserved0: u32,
reserved1: u32,
reserved2: u32,
}
struct U64Parts {
lo: u32,
hi: u32,
}
const FNV_OFFSET_LO: u32 = 0x84222325u;
const FNV_OFFSET_HI: u32 = 0xcbf29ce4u;
const FNV_PRIME_LO: u32 = 0x000001b3u;
const FNV_PRIME_HI: u32 = 0x00000100u;
const DEFAULT_GRAM: u32 = 4u;
@group(0) @binding(0) var<uniform> params: Params;
@group(0) @binding(1) var<storage, read> input_words: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_words: array<u32>;
fn byte_at(index: u32) -> u32 {
let word = input_words[index >> 2u];
let shift = (index & 3u) << 3u;
return (word >> shift) & 0xffu;
}
fn mul_hi_u32(a: u32, b: u32) -> u32 {
let a0 = a & 0xffffu;
let a1 = a >> 16u;
let b0 = b & 0xffffu;
let b1 = b >> 16u;
let p0 = a0 * b0;
let p1 = a1 * b0;
let p2 = a0 * b1;
let p3 = a1 * b1;
let middle = (p0 >> 16u) + (p1 & 0xffffu) + (p2 & 0xffffu);
return p3 + (p1 >> 16u) + (p2 >> 16u) + (middle >> 16u);
}
fn fnv_mul(hash: U64Parts) -> U64Parts {
let low = hash.lo * FNV_PRIME_LO;
let high = mul_hi_u32(hash.lo, FNV_PRIME_LO)
+ (hash.lo * FNV_PRIME_HI)
+ (hash.hi * FNV_PRIME_LO);
return U64Parts(low, high);
}
fn fnv1a64(start: u32, len: u32) -> U64Parts {
var hash = U64Parts(FNV_OFFSET_LO, FNV_OFFSET_HI);
var offset = 0u;
loop {
if (offset >= len) {
break;
}
hash.lo = hash.lo ^ byte_at(start + offset);
hash = fnv_mul(hash);
offset = offset + 1u;
}
return hash;
}
fn hash_bit(hash: U64Parts, bit: u32) -> u32 {
if (bit < 32u) {
return (hash.lo >> bit) & 1u;
}
return (hash.hi >> (bit - 32u)) & 1u;
}
@compute @workgroup_size(1, 1, 1)
fn string_similarity_simhash64(@builtin(global_invocation_id) id: vec3<u32>) {
if (id.x != 0u) {
return;
}
if (params.len_a == 0u) {
output_words[0] = 0u;
output_words[1] = 0u;
return;
}
let gram_len = select(params.len_a, DEFAULT_GRAM, params.len_a >= DEFAULT_GRAM);
let gram_count = select(1u, params.len_a - DEFAULT_GRAM + 1u, params.len_a >= DEFAULT_GRAM);
var out = U64Parts(0u, 0u);
var bit = 0u;
loop {
if (bit >= 64u) {
break;
}
var weight = 0i;
var gram = 0u;
loop {
if (gram >= gram_count) {
break;
}
let h = fnv1a64(gram, gram_len);
weight = weight + select(-1i, 1i, hash_bit(h, bit) == 1u);
gram = gram + 1u;
}
if (weight >= 0i) {
if (bit < 32u) {
out.lo = out.lo | (1u << bit);
} else {
out.hi = out.hi | (1u << (bit - 32u));
}
}
bit = bit + 1u;
}
output_words[0] = out.lo;
output_words[1] = out.hi;
}