llama-cpp-sys-4 0.2.51

Low Level Bindings to llama.cpp
Documentation
#include "common_decls.tmpl"
enable f16;

@group(0) @binding(0)
#if defined(WEIGHT_F32)
var<storage, read_write> weights: array<f32>;
#elif defined(WEIGHT_F16)
var<storage, read_write> weights: array<f16>;
#endif

@group(0) @binding(1)
#if defined(INPUT_F32)
var<storage, read_write> input: array<f32>;
#elif defined(INPUT_F16)
var<storage, read_write> input: array<f16>;
#endif

@group(0) @binding(2)
#if defined(OUTPUT_F32)
var<storage, read_write> output: array<f32>;
#elif defined(OUTPUT_F16)
var<storage, read_write> output: array<f16>;
#endif

struct Params {
    offset_w: u32,
    offset_i: u32,
    offset_o: u32,

    // element strides
    sw0: u32, sw1: u32, sw2: u32, sw3: u32,
    si0: u32, si1: u32, si2: u32, si3: u32,
    so0: u32, so1: u32, so2: u32, so3: u32,

    // kernel dimensions
    KW: u32, KH: u32, IC: u32,
    // input dimensions
    IW: u32, IH: u32,
    // output dimensions
    OW: u32, OH: u32, OC_out: u32, N_out: u32,

    // stride
    s0: u32, s1: u32,
    // padding
    p0: u32, p1: u32,
    // dilation
    d0: u32, d1: u32,
};

@group(0) @binding(3)
var<uniform> params: Params;

fn load_weight(idx: u32) -> f32 {
    #if defined(WEIGHT_F32)
        return weights[idx];
    #elif defined(WEIGHT_F16)
        return f32(weights[idx]);
    #endif
}

fn load_input(idx: u32) -> f32 {
    #if defined(INPUT_F32)
        return input[idx];
    #elif defined(INPUT_F16)
        return f32(input[idx]);
    #endif
}

fn store_output(idx: u32, val: f32) {
    #if defined(OUTPUT_F32)
        output[idx] = val;
    #elif defined(OUTPUT_F16)
        output[idx] = f16(val);
    #endif
}

fn ceil_div_u32(x: u32, y: u32) -> u32 {
    return (x + y - 1) / y;
}

// returns the first valid kernel index k such that base + k * step >= 0
fn first_valid_k(base: i32, step: u32) -> u32 {
    if (base >= 0) {
        return 0;
    }

    return ceil_div_u32(u32(-base), step);
}

// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k)
fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 {
    let remaining = i32(limit) - base;
    if (remaining <= 0) {
        return 0;
    }

    return min(k_max, ceil_div_u32(u32(remaining), step));
}

@compute @workgroup_size(WG_SIZE)
fn main(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(num_workgroups) num_wg: vec3<u32>
) {

    let threads_per_group = u32(WG_SIZE);
    let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y;
    let n_out = params.OW * params.OH * params.OC_out * params.N_out;

    var sum: f32 = 0.0;
    if (i_out >= n_out) {
        return;
    }

    // Kernel layout: [KW, KH, IC, ..]
    // Input layout:  [IW, IH, .., ..]
    // Output layout: [OW, OH, OC, N]

    var i = i_out;
    let n = i / (params.OC_out * params.OH * params.OW);
    i = i % (params.OC_out * params.OH * params.OW);
    let oc = i / (params.OH * params.OW);
    i = i % (params.OH * params.OW);
    let oh = i / params.OW;
    let ow = i % params.OW;

    let ow_base = i32(ow * params.s0) - i32(params.p0);
    let oh_base = i32(oh * params.s1) - i32(params.p1);

    // clip the valid kernel window once
    let kw_begin = first_valid_k(ow_base, params.d0);
    let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW);
    let kh_begin = first_valid_k(oh_base, params.d1);
    let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH);

    // entire receptive field is out of bounds
    if (kw_begin >= kw_end || kh_begin >= kh_end) {
        let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3;
        store_output(out_idx, 0.0);
        return;
    }

    let weight_oc_base = params.offset_w + oc * params.sw3;
    let input_n_base = params.offset_i + n * params.si3;

    for (var ic: u32 = 0; ic < params.IC; ic += 1) {
        let w_base_ic = ic * params.sw2 + weight_oc_base;
        let in_base = ic * params.si2 + input_n_base;

        for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) {
            let ih = u32(oh_base + i32(kh * params.d1));
            let w_row_base = w_base_ic + kh * params.sw1;
            let in_row_base = in_base + ih * params.si1;
            for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) {
                let iw = u32(ow_base + i32(kw * params.d0));
                let w_idx = w_row_base + kw * params.sw0;
                let in_idx = in_row_base + iw * params.si0;
                sum += load_weight(w_idx) * load_input(in_idx);
            }
        }
    }

    let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3;
    store_output(out_idx, sum);
}