#version 450
// im2col (rows layout): x [N,Cin,H,W] → [N*Ho*Wo, Cin*kH*kW]. out[row, col] =
// x[n, c, ho*sH-pH+kh*dH, wo*sW-pW+kw*dW] (0 when out of bounds), with
// row = (n*Ho+ho)*Wo+wo and col = (c*kH+kh)*kW+kw.
layout(local_size_x = 256) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint nn; uint cin; uint hh; uint ww;
uint ho; uint wo;
uint kh; uint kw; uint sh; uint sw; uint ph; uint pw; uint dh; uint dw;
uint x_off; uint out_off;
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint cols = pc.cin * pc.kh * pc.kw;
uint rows = pc.nn * pc.ho * pc.wo;
if (gid >= rows * cols) { return; }
uint row = gid / cols;
uint col = gid % cols;
uint owi = row % pc.wo;
uint ohi = (row / pc.wo) % pc.ho;
uint n = row / (pc.wo * pc.ho);
uint kwi = col % pc.kw;
uint khi = (col / pc.kw) % pc.kh;
uint c = col / (pc.kw * pc.kh);
int h = int(ohi * pc.sh) - int(pc.ph) + int(khi * pc.dh);
int w = int(owi * pc.sw) - int(pc.pw) + int(kwi * pc.dw);
float v = 0.0;
if (h >= 0 && h < int(pc.hh) && w >= 0 && w < int(pc.ww)) {
v = data[pc.x_off + ((n * pc.cin + c) * pc.hh + uint(h)) * pc.ww + uint(w)];
}
data[pc.out_off + gid] = v;
}