vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
const UNION_FIND_CAPACITY: u32 = 1024u;
const UNION_FIND_OP_FIND: u32 = 1u;
const UNION_FIND_OP_UNION: u32 = 2u;
const UNION_FIND_OP_CONNECTED: u32 = 3u;
const UNION_FIND_OK: u32 = 0u;
const UNION_FIND_OUT_OF_RANGE: u32 = 1u;
const UNION_FIND_ALREADY_UNIFIED: u32 = 2u;
const UNION_FIND_FALSE: u32 = 0u;
const UNION_FIND_TRUE: u32 = 1u;
const UNION_FIND_NO_ROOT: u32 = 0xffffffffu;

struct UnionFindCommand {
    op: u32,
    lane: u32,
    a: u32,
    b: u32,
}

struct UnionFindResult {
    status: u32,
    root_a: u32,
    root_b: u32,
    value: u32,
}

@group(0) @binding(0) var<storage, read> union_find_commands: array<UnionFindCommand>;
@group(0) @binding(1) var<storage, read_write> union_find_results: array<UnionFindResult>;

var<workgroup> union_find_parent: array<u32, 1024>;
var<workgroup> union_find_rank: array<u32, 1024>;

fn union_find_find(x: u32) -> u32 {
    if (x >= UNION_FIND_CAPACITY) {
        return UNION_FIND_NO_ROOT;
    }
    var root = x;
    loop {
        let parent = union_find_parent[root];
        if (parent == root) {
            break;
        }
        root = parent;
    }
    var cursor = x;
    loop {
        let parent = union_find_parent[cursor];
        if (parent == root) {
            break;
        }
        union_find_parent[cursor] = root;
        cursor = parent;
    }
    return root;
}

fn union_find_union(a: u32, b: u32) -> UnionFindResult {
    let ra = union_find_find(a);
    let rb = union_find_find(b);
    if (ra == UNION_FIND_NO_ROOT || rb == UNION_FIND_NO_ROOT) {
        return UnionFindResult(UNION_FIND_OUT_OF_RANGE, ra, rb, UNION_FIND_FALSE);
    }
    if (ra == rb) {
        return UnionFindResult(UNION_FIND_ALREADY_UNIFIED, ra, rb, UNION_FIND_TRUE);
    }

    let rank_a = union_find_rank[ra];
    let rank_b = union_find_rank[rb];
    var child = ra;
    var parent = rb;
    if (rank_a < rank_b) {
        child = ra;
        parent = rb;
    } else if (rank_a > rank_b) {
        child = rb;
        parent = ra;
    } else if (ra < rb) {
        child = rb;
        parent = ra;
        union_find_rank[parent] = union_find_rank[parent] + 1u;
    } else {
        child = ra;
        parent = rb;
        union_find_rank[parent] = union_find_rank[parent] + 1u;
    }
    union_find_parent[child] = parent;
    return UnionFindResult(UNION_FIND_OK, parent, parent, UNION_FIND_TRUE);
}

@compute @workgroup_size(64, 1, 1)
fn workgroup_union_find_kernel(@builtin(local_invocation_id) local_id: vec3<u32>) {
    let lane = local_id.x;
    var init_index = lane;
    loop {
        if (init_index >= UNION_FIND_CAPACITY) {
            break;
        }
        union_find_parent[init_index] = init_index;
        union_find_rank[init_index] = 0u;
        init_index = init_index + 64u;
    }
    workgroupBarrier();

    var command_index = 0u;
    let command_count = arrayLength(&union_find_commands);
    loop {
        if (command_index >= command_count) {
            break;
        }
        let command = union_find_commands[command_index];
        if (command.lane == lane) {
            var result = UnionFindResult(UNION_FIND_OK, UNION_FIND_NO_ROOT, UNION_FIND_NO_ROOT, 0u);
            if (command.op == UNION_FIND_OP_FIND) {
                let root = union_find_find(command.a);
                let status = select(UNION_FIND_OK, UNION_FIND_OUT_OF_RANGE, root == UNION_FIND_NO_ROOT);
                result = UnionFindResult(status, root, UNION_FIND_NO_ROOT, root);
            } else if (command.op == UNION_FIND_OP_UNION) {
                result = union_find_union(command.a, command.b);
            } else if (command.op == UNION_FIND_OP_CONNECTED) {
                let ra = union_find_find(command.a);
                let rb = union_find_find(command.b);
                let status = select(
                    UNION_FIND_OK,
                    UNION_FIND_OUT_OF_RANGE,
                    ra == UNION_FIND_NO_ROOT || rb == UNION_FIND_NO_ROOT,
                );
                let connected = select(UNION_FIND_FALSE, UNION_FIND_TRUE, status == UNION_FIND_OK && ra == rb);
                result = UnionFindResult(status, ra, rb, connected);
            }
            union_find_results[command_index] = result;
        }
        workgroupBarrier();
        command_index = command_index + 1u;
    }
}