llama-cpp-sys-4 0.2.45

Low Level Bindings to llama.cpp
Documentation
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;

@group(0) @binding(1)
var<storage, read_write> idx_in: array<i32>;

@group(0) @binding(2)
var<storage, read_write> idx_out: array<i32>;

struct Params {
    offset_src: u32, // in elements
    offset_in: u32,  // in elements
    offset_out: u32, // in elements

    stride_src1: u32,
    stride_src2: u32,
    stride_src3: u32,

    stride_idx1: u32,
    stride_idx2: u32,
    stride_idx3: u32,

    stride_out1: u32,
    stride_out2: u32,
    stride_out3: u32,

    ne0: u32,
    ne1: u32,
    ne2: u32,

    top_k: u32,

    len: u32,
    nm: u32,
    nrows: u32
};

@group(0) @binding(3)
var<uniform> params: Params;

fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
    let a_val = src[row_base + u32(a_idx)];
    let b_val = src[row_base + u32(b_idx)];
#if ORDER == 0
    return a_val <= b_val;
#else
    return a_val >= b_val;
#endif
}

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
        @builtin(num_workgroups) num_wg: vec3<u32>,
        @builtin(local_invocation_id) lid: vec3<u32>) {
    let linear = wid.x + wid.y * num_wg.x;
    // guard against overprovisioned workgroups
    if (linear >= params.nm * params.nrows) {
        return;
    }

    let start = (linear % params.nm) * params.len * 2;
    let len0 = min(params.len, params.ne0 - start);
    let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
    let len1 = min(params.len, rem1);
    let total = len0 + len1;
    let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
    let k0 = lid.x * chunk;
    let k1 = min(min(k0 + chunk, total), params.top_k);
    // guard against overprovisioned threads
    if (k0 >= params.top_k || k0 >= total) {
        return;
    }

    var row = linear / params.nm;
    let i3 = row / (params.ne2 * params.ne1);
    row = row % (params.ne2 * params.ne1);
    let i2 = row / params.ne1;
    let i1 = row % params.ne1;

    let row_src = params.offset_src +
        i1 * params.stride_src1 +
        i2 * params.stride_src2 +
        i3 * params.stride_src3;

    let row_in = params.offset_in +
        i1 * params.stride_idx1 +
        i2 * params.stride_idx2 +
        i3 * params.stride_idx3;

    let row_out = params.offset_out +
        i1 * params.stride_out1 +
        i2 * params.stride_out2 +
        i3 * params.stride_out3;


    var low: u32 = select(0, k0 - len1, k0 > len1);
    var high: u32 = min(k0, len0);

    while (low < high) {
        let mid = (low + high) >> 1;
        let idx0 = idx_in[row_in + start + mid];
        let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
        if (take_left(idx0, idx1, row_src)) {
            low = mid + 1;
        } else {
            high = mid;
        }
    }

    var i = low;
    var j = k0 - i;
    var k = k0;
    while (k < k1) {
        var take_l = false;
        if (i >= len0) {
            take_l = false;
        } else if (j >= len1) {
            take_l = true;
        } else {
            let idx0 = idx_in[row_in + start + i];
            let idx1 = idx_in[row_in + start + params.len + j];
            take_l = take_left(idx0, idx1, row_src);
        }

        let out_idx = select(
            idx_in[row_in + start + params.len + j],
            idx_in[row_in + start + i],
            take_l);
        idx_out[row_out + start + k] = out_idx;
        i = select(i, i + 1, take_l);
        j = select(j + 1, j, take_l);
        k += 1;
    }
}