hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
#version 450
// Register-blocked coopmat matmul, NT layout: C(m x n,f32) = A(m x k,f16) * W(n x k,f16)^T.
// W is a Linear weight in natural row-major [n,k]; the B fragment loads it column-major (stride k)
// so no transpose copy is needed. A load and the rest match bmm_coopmat_rb.
// One subgroup computes an RM x RN grid of 16x16 output tiles, holding RM*RN accumulator
// fragments in registers. Per K-step it loads RM A-fragments + RN B-fragments and issues
// RM*RN MulAdds, so each loaded fragment is reused RN (resp. RM) times -- far higher
// arithmetic intensity than the naive 1-tile kernel (which reloads A/B for every tile).
// m, n, k must be multiples of 16 (caller guarantees); partial RM x RN grids at the matrix
// edge are handled by per-fragment bounds checks (the conditions are subgroup-uniform).
#extension GL_KHR_cooperative_matrix : require
#extension GL_KHR_memory_scope_semantics : require
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : enable

layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

layout(set = 0, binding = 0) readonly  buffer A { float16_t a[]; };
layout(set = 0, binding = 1) readonly  buffer B { float16_t b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float     c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };

const uint T = 16u;
const uint RM = 4u; // 16x16 output tiles per workgroup, row direction
const uint RN = 4u; // 16x16 output tiles per workgroup, col direction

void main() {
    if (gl_SubgroupID != 0u) {
        return; // only subgroup 0 works this workgroup's RM x RN tile grid
    }
    uint bt = gl_WorkGroupID.z;
    uint trow0 = gl_WorkGroupID.y * RM; // first output tile row (units of 16)
    uint tcol0 = gl_WorkGroupID.x * RN; // first output tile col (units of 16)
    uint ao = bt * m * k;
    uint bo = bt * k * n;
    uint co = bt * m * n;
    uint mt = m / T;
    uint nt = n / T;

    coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> acc[RM][RN];
    [[unroll]] for (uint i = 0u; i < RM; i++)
        [[unroll]] for (uint j = 0u; j < RN; j++)
            acc[i][j] = coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator>(0.0);

    uint ktiles = k / T;
    for (uint kt = 0u; kt < ktiles; kt++) {
        coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> ma[RM];
        coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> mb[RN];
        [[unroll]] for (uint i = 0u; i < RM; i++)
            if (trow0 + i < mt)
                coopMatLoad(ma[i], a, ao + (trow0 + i) * T * k + kt * T, k, gl_CooperativeMatrixLayoutRowMajor);
        [[unroll]] for (uint j = 0u; j < RN; j++)
            if (tcol0 + j < nt)
                coopMatLoad(mb[j], b, bo + (tcol0 + j) * T * k + kt * T, k, gl_CooperativeMatrixLayoutColumnMajor);
        [[unroll]] for (uint i = 0u; i < RM; i++)
            [[unroll]] for (uint j = 0u; j < RN; j++)
                if (trow0 + i < mt && tcol0 + j < nt)
                    acc[i][j] = coopMatMulAdd(ma[i], mb[j], acc[i][j]);
    }
    [[unroll]] for (uint i = 0u; i < RM; i++)
        [[unroll]] for (uint j = 0u; j < RN; j++)
            if (trow0 + i < mt && tcol0 + j < nt)
                coopMatStore(acc[i][j], c, co + (trow0 + i) * T * n + (tcol0 + j) * T, n, gl_CooperativeMatrixLayoutRowMajor);
}