#version 450
// conv2d. input [b,c_in,ih,iw], weight [c_out,c_in,kh,kw], out [b,c_out,oh,ow]. All contiguous.
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) readonly buffer W { float w[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc {
uint b_size; uint c_in; uint c_out; uint ih; uint iw; uint oh; uint ow;
uint kh; uint kw; uint padding; uint stride; uint dilation;
};
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = b_size * c_out * oh * ow;
if (gid >= total) { return; }
uint oxw = gid % ow;
uint oxh = (gid / ow) % oh;
uint oc = (gid / (ow * oh)) % c_out;
uint bb = gid / (ow * oh * c_out);
float acc = 0.0;
for (uint ic = 0u; ic < c_in; ic++) {
uint in_c = (bb * c_in + ic) * ih * iw;
uint w_c = ((oc * c_in + ic) * kh) * kw;
for (uint r = 0u; r < kh; r++) {
int iy = int(oxh * stride + r * dilation) - int(padding);
if (iy < 0 || iy >= int(ih)) { continue; }
for (uint c = 0u; c < kw; c++) {
int ix = int(oxw * stride + c * dilation) - int(padding);
if (ix < 0 || ix >= int(iw)) { continue; }
acc += inp[in_c + uint(iy) * iw + uint(ix)] * w[w_c + r * kw + c];
}
}
}
o[((bb * c_out + oc) * oh + oxh) * ow + oxw] = acc;
}