llama-cpp-sys-4 0.2.45

Low Level Bindings to llama.cpp
Documentation
#version 450

#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_CLUSTERED
#extension GL_KHR_shader_subgroup_clustered : enable
#endif
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif

// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0,
// so no bounds checking is needed.
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint KDA = 0;
layout(constant_id = 2) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 3) const uint LANES_PER_COLUMN = 32;

const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN;
const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN;

layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;

layout(push_constant) uniform Parameters {
    uint H;
    uint n_tokens;
    uint n_seqs;
    uint s_off;
    uint sq1, sq2, sq3;
    uint sv1, sv2, sv3;
    uint sb1, sb2, sb3;
    uint neq1, rq3;
    float scale;
};

layout(binding = 0) readonly  buffer QBuf     { FLOAT_TYPE data_q[];     };
layout(binding = 1) readonly  buffer KBuf     { FLOAT_TYPE data_k[];     };
layout(binding = 2) readonly  buffer VBuf     { FLOAT_TYPE data_v[];     };
layout(binding = 3) readonly  buffer GBuf     { FLOAT_TYPE data_g[];     };
layout(binding = 4) readonly  buffer BetaBuf  { FLOAT_TYPE data_beta[];  };
layout(binding = 5) readonly  buffer StateBuf { FLOAT_TYPE data_state[]; };
layout(binding = 6)           buffer DstBuf   { FLOAT_TYPE data_dst[];   };

#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED
shared FLOAT_TYPE temp[SUBGROUP_SIZE];

// This does a reduction across groups of LANES_PER_COLUMN
FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) {
    const uint lane = gl_SubgroupInvocationID;
    temp[lane] = partial;
    barrier();
    [[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) {
        FLOAT_TYPE other = temp[lane ^ s];
        barrier();
        temp[lane] += other;
        barrier();
    }
    const FLOAT_TYPE result = temp[lane];
    barrier();
    return result;
}
#endif

// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant
FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) {
    switch (LANES_PER_COLUMN) {
        case 1u:
            return partial;
#if USE_SUBGROUP_CLUSTERED
        // Workaround for GLSL requiring a literal constant for the cluster size.
        // The branches should all fold away.
        case 2u:
            return subgroupClusteredAdd(partial, 2u);
        case 4u:
            return subgroupClusteredAdd(partial, 4u);
        case 8u:
            return subgroupClusteredAdd(partial, 8u);
        case 16u:
            return subgroupClusteredAdd(partial, 16u);
        case 32u:
            return subgroupClusteredAdd(partial, 32u);
        case 64u:
            return subgroupClusteredAdd(partial, 64u);
#endif
        default:
#if USE_SUBGROUP_ADD
            return subgroupAdd(partial);
#else
            return reduce_add_shmem(partial);
#endif
    }
}

void main() {
    const uint head_id = gl_WorkGroupID.x;
    const uint seq_id = gl_WorkGroupID.y;
    const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN;
    const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN);

    const uint iq1 = head_id % neq1;
    const uint iq3 = seq_id / rq3;

    const uint state_size = S_V * S_V;
    const uint state_base = (seq_id * H + head_id) * state_size;

    FLOAT_TYPE s_shard[ROWS_PER_LANE];
    [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
        s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
    }

    uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;

    for (uint t = 0; t < n_tokens; t++) {
        const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
        const uint k_off = q_off;
        const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
        const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
        const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);

        FLOAT_TYPE k_reg[ROWS_PER_LANE];
        FLOAT_TYPE q_reg[ROWS_PER_LANE];
        [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
            const uint i = r * LANES_PER_COLUMN + lane;
            k_reg[r] = FLOAT_TYPE(data_k[k_off + i]);
            q_reg[r] = FLOAT_TYPE(data_q[q_off + i]);
        }

        FLOAT_TYPE g_exp[ROWS_PER_LANE];
        if (KDA == 0) {
            const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
            [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
                g_exp[r] = g_val;
            }
        } else {
            const uint g_base = gb_off * S_V;
            [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
                const uint i = r * LANES_PER_COLUMN + lane;
                g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i]));
            }
        }

        const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);

        FLOAT_TYPE kv_shard = 0.0;
        [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
            kv_shard += g_exp[r] * s_shard[r] * k_reg[r];
        }
        FLOAT_TYPE kv_col = reduce_partial(kv_shard);

        FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;

        FLOAT_TYPE attn_partial = 0.0;
        [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
            s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col;
            attn_partial += s_shard[r] * q_reg[r];
        }
        FLOAT_TYPE attn_col = reduce_partial(attn_partial);

        if (lane == 0) {
            data_dst[attn_off + col] = attn_col * scale;
        }

        attn_off += S_V * H;
    }

    [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
        data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
    }
}