hanzo-ml 0.10.2

Minimalist ML framework.
Documentation
#version 450
// conv2d. input [b,c_in,ih,iw], weight [c_out,c_in,kh,kw], out [b,c_out,oh,ow]. All contiguous.
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly  buffer In  { float inp[]; };
layout(set = 0, binding = 1) readonly  buffer W   { float w[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc {
    uint b_size; uint c_in; uint c_out; uint ih; uint iw; uint oh; uint ow;
    uint kh; uint kw; uint padding; uint stride; uint dilation;
};
void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = b_size * c_out * oh * ow;
    if (gid >= total) { return; }
    uint oxw = gid % ow;
    uint oxh = (gid / ow) % oh;
    uint oc  = (gid / (ow * oh)) % c_out;
    uint bb  = gid / (ow * oh * c_out);
    float acc = 0.0;
    for (uint ic = 0u; ic < c_in; ic++) {
        uint in_c = (bb * c_in + ic) * ih * iw;
        uint w_c = ((oc * c_in + ic) * kh) * kw;
        for (uint r = 0u; r < kh; r++) {
            int iy = int(oxh * stride + r * dilation) - int(padding);
            if (iy < 0 || iy >= int(ih)) { continue; }
            for (uint c = 0u; c < kw; c++) {
                int ix = int(oxw * stride + c * dilation) - int(padding);
                if (ix < 0 || ix >= int(iw)) { continue; }
                acc += inp[in_c + uint(iy) * iw + uint(ix)] * w[w_c + r * kw + c];
            }
        }
    }
    o[((bb * c_out + oc) * oh + oxh) * ow + oxw] = acc;
}