#version 450
// conv1d. input [b,c_in,l_in] contiguous, weight [c_out,c_in,k] contiguous, out [b,c_out,l_out].
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) readonly buffer W { float w[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc {
uint b_size; uint c_in; uint c_out; uint l_in; uint l_out;
uint k_size; uint padding; uint stride; uint dilation;
};
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = b_size * c_out * l_out;
if (gid >= total) { return; }
uint ox = gid % l_out;
uint oc = (gid / l_out) % c_out;
uint bb = gid / (l_out * c_out);
float acc = 0.0;
for (uint ic = 0u; ic < c_in; ic++) {
uint in_base = (bb * c_in + ic) * l_in;
uint w_base = (oc * c_in + ic) * k_size;
for (uint k = 0u; k < k_size; k++) {
int ix = int(ox * stride + k * dilation) - int(padding);
if (ix >= 0 && ix < int(l_in)) {
acc += inp[in_base + uint(ix)] * w[w_base + k];
}
}
}
o[(bb * c_out + oc) * l_out + ox] = acc;
}