#version 450
// Nearest-neighbor 2× upsample on NCHW: out[n,c,2h+i,2w+j] = in[n,c,h,w].
// (Implemented natively rather than via the generic lowering, which decomposes
// to a tiling concat — `[x;x]` — instead of element-repeat.)
layout(local_size_x = 256) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint nn; uint cc; uint hh; uint ww;
uint x_off; uint out_off;
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint ow = pc.ww * 2u;
uint oh = pc.hh * 2u;
uint total = pc.nn * pc.cc * oh * ow;
if (gid >= total) { return; }
uint owi = gid % ow;
uint ohi = (gid / ow) % oh;
uint c = (gid / (ow * oh)) % pc.cc;
uint n = gid / (ow * oh * pc.cc);
uint h = ohi / 2u;
uint w = owi / 2u;
uint in_idx = ((n * pc.cc + c) * pc.hh + h) * pc.ww + w;
data[pc.out_off + gid] = data[pc.x_off + in_idx];
}