llama-cpp-sys-4 0.3.1

Low Level Bindings to llama.cpp
Documentation
enable f16;

struct MulMatIdGatherParams {
    offset_ids: u32,

    n_expert: u32,
    n_expert_used: u32,
    n_tokens: u32,

    stride_ids_1: u32,
};

@group(0) @binding(0) var<storage, read_write> ids: array<i32>;        // [n_expert_used, n_tokens]
@group(0) @binding(1) var<storage, read_write> global_gathered_expert_used: array<u32>; // [n_expert][n_tokens]
@group(0) @binding(2) var<storage, read_write> global_gathered_tokens: array<u32>; // [n_expert][n_tokens]
@group(0) @binding(3) var<storage, read_write> gathered_count_ids: array<u32>; // [n_expert]

@group(0) @binding(4) var<uniform> params: MulMatIdGatherParams;

var<workgroup> count:atomic<u32>;

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
        @builtin(local_invocation_id) local_id: vec3<u32>) {

    let thread_id = local_id.x;
    let own_expert = wg_id.x; // the expert assigned to this workgroup

    if (thread_id == 0u) {
        atomicStore(&count, 0);
    }

    workgroupBarrier();

    for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
        let row = i / params.n_expert_used;
        let col = i % params.n_expert_used;
        let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
        if (own_expert == expert) {
            let pos = atomicAdd(&count, 1u);
            let gathered_id = own_expert * params.n_tokens + pos;
            global_gathered_expert_used[gathered_id] = col;
            global_gathered_tokens[gathered_id] = row;
        }
    }

    workgroupBarrier();

    if (thread_id == 0u) {
        gathered_count_ids[own_expert] = atomicLoad(&count);
    }
}