hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
#version 450
// Batched matmul on the matrix cores via VK_KHR_cooperative_matrix.
// C[bt] (m x n, f32) = A[bt] (m x k, f16) * B[bt] (k x n, f16), row-major contiguous.
// One subgroup computes one 16x16 output tile, looping over K in 16-wide steps. Requires
// m, n, k to be multiples of 16 (the caller falls back to the tiled kernel otherwise).
#extension GL_KHR_cooperative_matrix : require
#extension GL_KHR_memory_scope_semantics : require
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_16bit_storage : require

// 64 lanes covers an AMD wave (wave32 -> 2 subgroups, wave64 -> 1); only subgroup 0 works the
// tile, so this is correct for either subgroup size without a specialization constant.
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

layout(set = 0, binding = 0) readonly  buffer A { float16_t a[]; };
layout(set = 0, binding = 1) readonly  buffer B { float16_t b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float     c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };

const uint T = 16u;

void main() {
    if (gl_SubgroupID != 0u) {
        return; // uniform per subgroup: only subgroup 0 computes/stores this workgroup's tile
    }
    uint bt   = gl_WorkGroupID.z;
    uint trow = gl_WorkGroupID.y; // output tile row (units of 16)
    uint tcol = gl_WorkGroupID.x; // output tile col (units of 16)
    uint ao = bt * m * k;
    uint bo = bt * k * n;
    uint co = bt * m * n;

    coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> acc =
        coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator>(0.0);

    uint ktiles = k / T;
    for (uint kt = 0u; kt < ktiles; kt++) {
        coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> ma;
        coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> mb;
        coopMatLoad(ma, a, ao + (trow * T) * k + kt * T, k, gl_CooperativeMatrixLayoutRowMajor);
        coopMatLoad(mb, b, bo + (kt * T) * n + tcol * T, n, gl_CooperativeMatrixLayoutRowMajor);
        acc = coopMatMulAdd(ma, mb, acc);
    }
    coopMatStore(acc, c, co + (trow * T) * n + tcol * T, n, gl_CooperativeMatrixLayoutRowMajor);
}