const UNION_FIND_CAPACITY: u32 = 1024u;
const UNION_FIND_OP_FIND: u32 = 1u;
const UNION_FIND_OP_UNION: u32 = 2u;
const UNION_FIND_OP_CONNECTED: u32 = 3u;
const UNION_FIND_OK: u32 = 0u;
const UNION_FIND_OUT_OF_RANGE: u32 = 1u;
const UNION_FIND_ALREADY_UNIFIED: u32 = 2u;
const UNION_FIND_FALSE: u32 = 0u;
const UNION_FIND_TRUE: u32 = 1u;
const UNION_FIND_NO_ROOT: u32 = 0xffffffffu;
struct UnionFindCommand {
op: u32,
lane: u32,
a: u32,
b: u32,
}
struct UnionFindResult {
status: u32,
root_a: u32,
root_b: u32,
value: u32,
}
@group(0) @binding(0) var<storage, read> union_find_commands: array<UnionFindCommand>;
@group(0) @binding(1) var<storage, read_write> union_find_results: array<UnionFindResult>;
var<workgroup> union_find_parent: array<u32, 1024>;
var<workgroup> union_find_rank: array<u32, 1024>;
fn union_find_find(x: u32) -> u32 {
if (x >= UNION_FIND_CAPACITY) {
return UNION_FIND_NO_ROOT;
}
var root = x;
loop {
let parent = union_find_parent[root];
if (parent == root) {
break;
}
root = parent;
}
var cursor = x;
loop {
let parent = union_find_parent[cursor];
if (parent == root) {
break;
}
union_find_parent[cursor] = root;
cursor = parent;
}
return root;
}
fn union_find_union(a: u32, b: u32) -> UnionFindResult {
let ra = union_find_find(a);
let rb = union_find_find(b);
if (ra == UNION_FIND_NO_ROOT || rb == UNION_FIND_NO_ROOT) {
return UnionFindResult(UNION_FIND_OUT_OF_RANGE, ra, rb, UNION_FIND_FALSE);
}
if (ra == rb) {
return UnionFindResult(UNION_FIND_ALREADY_UNIFIED, ra, rb, UNION_FIND_TRUE);
}
let rank_a = union_find_rank[ra];
let rank_b = union_find_rank[rb];
var child = ra;
var parent = rb;
if (rank_a < rank_b) {
child = ra;
parent = rb;
} else if (rank_a > rank_b) {
child = rb;
parent = ra;
} else if (ra < rb) {
child = rb;
parent = ra;
union_find_rank[parent] = union_find_rank[parent] + 1u;
} else {
child = ra;
parent = rb;
union_find_rank[parent] = union_find_rank[parent] + 1u;
}
union_find_parent[child] = parent;
return UnionFindResult(UNION_FIND_OK, parent, parent, UNION_FIND_TRUE);
}
@compute @workgroup_size(64, 1, 1)
fn workgroup_union_find_kernel(@builtin(local_invocation_id) local_id: vec3<u32>) {
let lane = local_id.x;
var init_index = lane;
loop {
if (init_index >= UNION_FIND_CAPACITY) {
break;
}
union_find_parent[init_index] = init_index;
union_find_rank[init_index] = 0u;
init_index = init_index + 64u;
}
workgroupBarrier();
var command_index = 0u;
let command_count = arrayLength(&union_find_commands);
loop {
if (command_index >= command_count) {
break;
}
let command = union_find_commands[command_index];
if (command.lane == lane) {
var result = UnionFindResult(UNION_FIND_OK, UNION_FIND_NO_ROOT, UNION_FIND_NO_ROOT, 0u);
if (command.op == UNION_FIND_OP_FIND) {
let root = union_find_find(command.a);
let status = select(UNION_FIND_OK, UNION_FIND_OUT_OF_RANGE, root == UNION_FIND_NO_ROOT);
result = UnionFindResult(status, root, UNION_FIND_NO_ROOT, root);
} else if (command.op == UNION_FIND_OP_UNION) {
result = union_find_union(command.a, command.b);
} else if (command.op == UNION_FIND_OP_CONNECTED) {
let ra = union_find_find(command.a);
let rb = union_find_find(command.b);
let status = select(
UNION_FIND_OK,
UNION_FIND_OUT_OF_RANGE,
ra == UNION_FIND_NO_ROOT || rb == UNION_FIND_NO_ROOT,
);
let connected = select(UNION_FIND_FALSE, UNION_FIND_TRUE, status == UNION_FIND_OK && ra == rb);
result = UnionFindResult(status, ra, rb, connected);
}
union_find_results[command_index] = result;
}
workgroupBarrier();
command_index = command_index + 1u;
}
}