hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
#version 450
// Batched tiled matmul: C[bt] = A[bt](m x k) * B[bt](k x n), row-major, contiguous.
// 16x16 shared-memory tiles so each global element is read once per tile instead of
// once per output cell (cuts global memory traffic ~TILE x vs the naive kernel).
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

layout(set = 0, binding = 0) readonly  buffer A { float a[]; };
layout(set = 0, binding = 1) readonly  buffer B { float b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };

const uint TILE = 16u;
shared float As[16][16];
shared float Bs[16][16];

void main() {
    uint bt  = gl_GlobalInvocationID.z;
    uint row = gl_GlobalInvocationID.y;
    uint col = gl_GlobalInvocationID.x;
    uint lr  = gl_LocalInvocationID.y;
    uint lc  = gl_LocalInvocationID.x;

    uint ao = bt * m * k;
    uint bo = bt * k * n;
    uint co = bt * m * n;

    float acc = 0.0;
    uint ntiles = (k + TILE - 1u) / TILE;
    for (uint t = 0u; t < ntiles; t++) {
        uint acol = t * TILE + lc;
        uint brow = t * TILE + lr;
        As[lr][lc] = (row < m && acol < k) ? a[ao + row * k + acol] : 0.0;
        Bs[lr][lc] = (brow < k && col < n) ? b[bo + brow * n + col] : 0.0;
        barrier();
        for (uint kk = 0u; kk < TILE; kk++) {
            acc += As[lr][kk] * Bs[kk][lc];
        }
        barrier();
    }
    if (bt < batch && row < m && col < n) {
        c[co + row * n + col] = acc;
    }
}