#version 450
// conv_transpose2d. input [b,c_in,ih,iw], weight [c_in,c_out,kh,kw], out [b,c_out,oh,ow].
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) readonly buffer W { float w[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc {
uint b_size; uint c_in; uint c_out; uint ih; uint iw; uint oh; uint ow;
uint kh; uint kw; uint padding; uint stride; uint dilation;
};
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = b_size * c_out * oh * ow;
if (gid >= total) { return; }
uint oxw = gid % ow;
uint oxh = (gid / ow) % oh;
uint oc = (gid / (ow * oh)) % c_out;
uint bb = gid / (ow * oh * c_out);
float acc = 0.0;
for (uint r = 0u; r < kh; r++) {
int numy = int(oxh + padding) - int(r * dilation);
if (numy < 0 || uint(numy) % stride != 0u) { continue; }
uint iy = uint(numy) / stride;
if (iy >= ih) { continue; }
for (uint c = 0u; c < kw; c++) {
int numx = int(oxw + padding) - int(c * dilation);
if (numx < 0 || uint(numx) % stride != 0u) { continue; }
uint ix = uint(numx) / stride;
if (ix >= iw) { continue; }
for (uint ic = 0u; ic < c_in; ic++) {
uint in_i = ((bb * c_in + ic) * ih + iy) * iw + ix;
uint w_i = (((ic * c_out + oc) * kh) + r) * kw + c;
acc += inp[in_i] * w[w_i];
}
}
}
o[((bb * c_out + oc) * oh + oxh) * ow + oxw] = acc;
}