llama-cpp-sys-4 0.3.1

Low Level Bindings to llama.cpp
Documentation
#version 450

#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require

#include "types.glsl"

layout (push_constant) uniform parameter
{
    BDA_STORAGE_T dst_addr;
    uint batch_offset; uint offset_delta;
    uint IC;
    uint IW; uint IH;
    uint OW; uint OH;
    uint KW; uint KH;
    uint OH_batch;
    uint CHW;
    int s0; int s1;
    int p0; int p1;
    int d0; int d1;
    uint batch_IC;
} p;

layout(constant_id = 0) const uint BLOCK_SIZE = 32;

const uint NUM_ITER = 512 / BLOCK_SIZE;

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif

void im2col(const uint ow, const uint z_idx) {
    const uint oh = z_idx % p.OH;
    const uint batch_idx = z_idx / p.OH;

    const uint gidx = gl_LocalInvocationID.x;
    const uint src_batch = batch_idx * p.batch_offset;
    const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW;

    const uint KHKW = p.KH * p.KW;

    // Precompute base input coordinates
    const int base_iw = int(ow * p.s0) - p.p0;
    const int base_ih = int(oh * p.s1) - p.p1;

    // Precompute step deltas
    const uint delta_ic  = BLOCK_SIZE / KHKW;
    const uint delta_rem = BLOCK_SIZE % KHKW;

    const uint delta_ky  = delta_rem / p.KW;
    const uint delta_kx  = delta_rem % p.KW;

    const uint delta_ic_offset = delta_ic * p.offset_delta;

    // If using BDA mode, precompute the base pointer  and step size
#if BDA
    const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row;
    const uint bda_step = D_SIZE * BLOCK_SIZE;
#endif

    uint wg_x = gl_WorkGroupID.x;
    do {
        const uint wg_offset = wg_x * 512;

        uint chw_idx = wg_offset + gidx;

        uint ic  = chw_idx / KHKW;
        uint rem = chw_idx % KHKW;

        uint ky  = rem / p.KW;
        uint kx  = rem % p.KW;

        uint ic_offset = src_batch + ic * p.offset_delta;

        // Initialize running pointer/index for the destination buffer
#if BDA
        BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx;
#else
        uint current_dst_idx = dst_row + chw_idx;
#endif

        [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
            if (chw_idx >= p.CHW) {
                return;
            }

            const int iiw = base_iw + int(kx * p.d0);
            const int iih = base_ih + int(ky * p.d1);

            A_TYPE val = A_TYPE(0);
            if (uint(iih) < p.IH && uint(iiw) < p.IW) {
                val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)];
            }

#if BDA
            D_ptr(current_dst_addr).d = D_TYPE(val);
            current_dst_addr += bda_step;
#else
            data_d[current_dst_idx] = D_TYPE(val);
            current_dst_idx += BLOCK_SIZE;
#endif

            chw_idx   += BLOCK_SIZE;
            ic_offset += delta_ic_offset;
            kx        += delta_kx;
            ky        += delta_ky;

            // Handle X axis wrap
            uint kx_wrap = uint(kx >= p.KW);
            kx          -= kx_wrap * p.KW;
            ky          += kx_wrap;

            // Handle Y axis wrap
            uint ky_wrap = uint(ky >= p.KH);
            ky          -= ky_wrap * p.KH;
            ic_offset   += ky_wrap * p.offset_delta;
        }

        wg_x += gl_NumWorkGroups.x;
    } while (wg_x * 512 < p.CHW);
}

void main() {
    uint ow = gl_GlobalInvocationID.y;
    while (ow < p.OW) {
        uint z = gl_GlobalInvocationID.z;
        while (z < p.OH_batch) {
            im2col(ow, z);
            z += gl_NumWorkGroups.z;
        }
        ow += gl_NumWorkGroups.y;
    }
}