llama-cpp-sys-4 0.3.2

Low Level Bindings to llama.cpp
Documentation
#version 450

#include "types.glsl"

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};   // columns: [K_OC, T_in]
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};  // output:  [T_out, OC]

layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

layout (push_constant) uniform parameter {
    uint32_t T_out;
    uint32_t OC;
    uint32_t K_OC;
    uint32_t T_in;
    uint32_t K;
    int32_t  stride;
    int32_t  p0;
} p;

// Load A_TYPE to float
float load_col(uint32_t idx) {
#if defined(DATA_A_BF16)
    return bf16_to_fp32(uint32_t(data_a[idx]));
#else
    return float(data_a[idx]);
#endif
}

// Store float as D_TYPE
void store_dst(uint32_t idx, float v) {
#if defined(DATA_A_BF16)
    data_d[idx] = D_TYPE(fp32_to_bf16(v));
#else
    data_d[idx] = D_TYPE(v);
#endif
}

void main() {
    const uint32_t t_out = gl_GlobalInvocationID.x;
    const uint32_t oc    = gl_GlobalInvocationID.y;
    if (t_out >= p.T_out || oc >= p.OC) return;

    const int32_t t_abs = int32_t(t_out) + p.p0; // absolute position in uncropped signal

    // Gather: only the ceil(K/stride) columns that scatter into t_abs, no modulo
    int32_t t_in_min = (t_abs - int32_t(p.K) + p.stride) / p.stride;
    if (t_in_min < 0) t_in_min = 0;
    int32_t t_in_max = t_abs / p.stride;
    if (t_in_max >= int32_t(p.T_in)) t_in_max = int32_t(p.T_in) - 1;

    float val = 0.0;
    for (int32_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
        int32_t k = t_abs - t_in * p.stride;
        // col layout: [K_OC, T_in], column index = oc * K + k
        uint32_t col_idx = (oc * p.K + uint32_t(k)) + uint32_t(t_in) * p.K_OC;
        val += load_col(col_idx);
    }

    // dst layout: [T_out, OC], element (t_out, oc) = t_out + oc * T_out
    store_dst(t_out + oc * p.T_out, val);
}