llama-cpp-sys-4 0.2.46

Low Level Bindings to llama.cpp
Documentation
enable f16;

struct MulMatIdGatherParams {
    offset_ids: u32,

    n_expert: u32,
    n_expert_used: u32,
    n_tokens: u32,

    stride_ids_1: u32,
};

@group(0) @binding(0) var<storage, read_write> ids: array<i32>;        // [n_expert_used, n_tokens]
@group(0) @binding(1) var<storage, read_write> global_gathered_expert_used: array<u32>; // [n_expert][n_tokens]
@group(0) @binding(2) var<storage, read_write> global_gathered_tokens: array<u32>; // [n_expert][n_tokens]
@group(0) @binding(3) var<storage, read_write> gathered_count_ids: array<u32>; // [n_expert]

@group(0) @binding(4) var<uniform> params: MulMatIdGatherParams;

var<workgroup> count:atomic<u32>;

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
        @builtin(local_invocation_id) local_id: vec3<u32>,
        @builtin(num_workgroups) num_wg: vec3<u32>) {

    let thread_id = local_id.x;
    let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup

    if (own_expert < params.n_expert) {
        if (thread_id == 0u) {
            atomicStore(&count, 0);
        }

        workgroupBarrier();

        for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
            let row = i / params.n_expert_used;
            let col = i % params.n_expert_used;
            let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
            if (own_expert == expert) {
                let pos = atomicAdd(&count, 1u);
                let gathered_id = own_expert * params.n_tokens + pos;
                global_gathered_expert_used[gathered_id] = col;
                global_gathered_tokens[gathered_id] = row;
            }
        }

        workgroupBarrier();

        if (thread_id == 0u) {
            gathered_count_ids[own_expert] = atomicLoad(&count);
        }
    }
}