llama-cpp-sys-4 0.2.46

Low Level Bindings to llama.cpp
Documentation
#version 450

#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable

layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
layout (constant_id = 2) const uint Br = 32;
layout (constant_id = 3) const uint Bc = 32;

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {float16_t data_a[];};
layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
layout (binding = 1) writeonly buffer D {uint data_d[];};

layout (push_constant) uniform parameter {
    uint nem0;
    uint nem1;
    uint nem2;
    uint nbm1;
    uint nbm2;
    uint nbm3;
    uint nbd1;
    uint nbd2;
    uint nbd3;
};

#define MASK_OPT_ALL_NEG_INF 1
#define MASK_OPT_ALL_ZERO 2

shared float minsh[NUM_SUBGROUPS];
shared float maxsh[NUM_SUBGROUPS];

float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);

void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {
    const uint tid = gl_LocalInvocationIndex;

    [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
        float min_v = FLT_MAX_OVER_2;
        float max_v = -FLT_MAX_OVER_2;
        [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
            uint j0 = (i + tid) % (Bc / 4);
            uint j1 = (i + tid) / (Bc / 4);

            j0 *= 4;
            j0 += (i0 * 16 + block_x) * Bc;
            j1 += i1 * Br;

            if (!need_bounds_check || j0 + 3 < nem0) {
                vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
                [[unroll]] for (int c = 0; c < 4; ++c) {
                    min_v = min(min_v, f[c]);
                    max_v = max(max_v, f[c]);
                }
            } else {
                [[unroll]] for (int c = 0; c < 4; ++c) {
                    if (j0 + c < nem0) {
                        float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
                        min_v = min(min_v, f);
                        max_v = max(max_v, f);
                    }
                }
            }
        }
        min_v = subgroupMin(min_v);
        max_v = subgroupMax(max_v);
        if (gl_SubgroupInvocationID == 0) {
            minsh[gl_SubgroupID] = min_v;
            maxsh[gl_SubgroupID] = max_v;
        }
        barrier();
        if (tid == 0) {
            [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
                min_v = min(min_v, minsh[i]);
                max_v = max(max_v, maxsh[i]);
            }
            if (max_v <= -FLT_MAX_OVER_2) {
                result |= 1 << (2*block_x);
            }
            if (min_v == 0.0f && max_v == 0.0f) {
                result |= 2 << (2*block_x);
            }
        }
        barrier();
    }
}

// For each Br x Bc block of the mask (input) buffer, read all values and check
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
// 32-bit result mask.
//
// TODO: This is a lot of work per workgroup, might make sense to split this into
// more workgroups in the future.
void main() {
    // Each workgroup handles a row
    const uint tid = gl_LocalInvocationIndex;
    const uint i0 = gl_WorkGroupID.x;
    const uint i1 = gl_WorkGroupID.y;
    const uint i2 = gl_WorkGroupID.z % nem2;
    const uint i3 = gl_WorkGroupID.z / nem2;

    uint result = 0;

    // Fast path for fully in-bounds blocks where we can do f16vec4 loads
    if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
        ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
        if ((i0 + 1) * 16 * Bc <= nem0) {
            loadvec4(result, i0, i1, i2, i3, false);
        } else {
            loadvec4(result, i0, i1, i2, i3, true);
        }
    } else {
        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
            float min_v = FLT_MAX_OVER_2;
            float max_v = -FLT_MAX_OVER_2;
            [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
                if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
                    continue;
                }
                uint j0 = (i + tid) % Bc;
                uint j1 = (i + tid) / Bc;

                j0 += (i0 * 16 + block_x) * Bc;
                j1 += i1 * Br;

                if (j0 < nem0 && j1 < nem1) {
                    float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
                    min_v = min(min_v, f);
                    max_v = max(max_v, f);
                }
            }
            min_v = subgroupMin(min_v);
            max_v = subgroupMax(max_v);
            if (gl_SubgroupInvocationID == 0) {
                minsh[gl_SubgroupID] = min_v;
                maxsh[gl_SubgroupID] = max_v;
            }
            barrier();
            if (tid == 0) {
                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
                    min_v = min(min_v, minsh[i]);
                    max_v = max(max_v, maxsh[i]);
                }
                if (max_v <= -FLT_MAX_OVER_2) {
                    result |= 1 << (2*block_x);
                }
                if (min_v == 0.0f && max_v == 0.0f) {
                    result |= 2 << (2*block_x);
                }
            }
            barrier();
        }
    }

    if (tid == 0) {
        data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
    }
}