hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
#version 450
// Register-blocked batched GEMM, NT layout: C[bt](m x n) = A[bt](m x k) * W[bt](n x k)^T.
// W is a Linear weight in its natural row-major [n,k] layout, so no transpose copy is needed
// (the staged B tile reads Bs[k,n] = W[n,k]). Each thread holds a TM x TN register micro-tile.
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

layout(set = 0, binding = 0) readonly  buffer A { float a[]; };
layout(set = 0, binding = 1) readonly  buffer B { float b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };

const uint BM = 64u;
const uint BN = 64u;
const uint BK = 8u;
const uint TM = 4u;
const uint TN = 4u;
// threads = (BM/TM) * (BN/TN) = 16 * 16 = 256 (matches local_size_x)
shared float As[BM * BK];
shared float Bs[BK * BN];

void main() {
    uint bt = gl_WorkGroupID.z;
    uint rowBase = gl_WorkGroupID.y * BM;
    uint colBase = gl_WorkGroupID.x * BN;
    uint tid = gl_LocalInvocationID.x;
    uint threadCol = tid % (BN / TN); // 0..15
    uint threadRow = tid / (BN / TN); // 0..15

    uint ao = bt * m * k;
    uint bo = bt * k * n;
    uint co = bt * m * n;

    float acc[TM][TN];
    [[unroll]] for (uint i = 0u; i < TM; i++)
        [[unroll]] for (uint j = 0u; j < TN; j++)
            acc[i][j] = 0.0;

    uint ntiles = (k + BK - 1u) / BK;
    for (uint t = 0u; t < ntiles; t++) {
        uint kBase = t * BK;
        for (uint i = tid; i < BM * BK; i += 256u) {
            uint r = i / BK;
            uint cc = i % BK;
            uint gr = rowBase + r;
            uint gc = kBase + cc;
            As[i] = (gr < m && gc < k) ? a[ao + gr * k + gc] : 0.0;
        }
        for (uint i = tid; i < BK * BN; i += 256u) {
            uint r = i / BN;
            uint cc = i % BN;
            uint gr = kBase + r;
            uint gc = colBase + cc;
            Bs[i] = (gr < k && gc < n) ? b[bo + gc * k + gr] : 0.0;
        }
        barrier();
        [[unroll]] for (uint kk = 0u; kk < BK; kk++) {
            float aReg[TM];
            float bReg[TN];
            [[unroll]] for (uint i = 0u; i < TM; i++)
                aReg[i] = As[(threadRow * TM + i) * BK + kk];
            [[unroll]] for (uint j = 0u; j < TN; j++)
                bReg[j] = Bs[kk * BN + threadCol * TN + j];
            [[unroll]] for (uint i = 0u; i < TM; i++)
                [[unroll]] for (uint j = 0u; j < TN; j++)
                    acc[i][j] += aReg[i] * bReg[j];
        }
        barrier();
    }

    if (bt >= batch) return;
    [[unroll]] for (uint i = 0u; i < TM; i++) {
        uint gr = rowBase + threadRow * TM + i;
        if (gr >= m) continue;
        [[unroll]] for (uint j = 0u; j < TN; j++) {
            uint gc = colBase + threadCol * TN + j;
            if (gc < n)
                c[co + gr * n + gc] = acc[i][j];
        }
    }
}