#version 450
// conv_transpose1d. input [b,c_in,l_in], weight [c_in,c_out,k], out [b,c_out,l_out].
// out[b,oc,ox] = sum over ic,k where (ox+padding-k*dilation) == ix*stride, ix in [0,l_in).
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) readonly buffer W { float w[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc {
uint b_size; uint c_in; uint c_out; uint l_in; uint l_out;
uint k_size; uint padding; uint stride; uint dilation;
};
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = b_size * c_out * l_out;
if (gid >= total) { return; }
uint ox = gid % l_out;
uint oc = (gid / l_out) % c_out;
uint bb = gid / (l_out * c_out);
float acc = 0.0;
for (uint k = 0u; k < k_size; k++) {
int num = int(ox + padding) - int(k * dilation);
if (num < 0) { continue; }
if (uint(num) % stride != 0u) { continue; }
uint ix = uint(num) / stride;
if (ix >= l_in) { continue; }
for (uint ic = 0u; ic < c_in; ic++) {
acc += inp[(bb * c_in + ic) * l_in + ix] * w[(ic * c_out + oc) * k_size + k];
}
}
o[(bb * c_out + oc) * l_out + ox] = acc;
}