#version 450
// upsample_nearest2d. input [b,c,ih,iw], out [b,c,oh,ow].
// src idx = min(in-1, (out_idx*in)/out) per axis (matches hanzo-ml's floor(out*scale)).
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc { uint b_size; uint c; uint ih; uint iw; uint oh; uint ow; };
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = b_size * c * oh * ow;
if (gid >= total) { return; }
uint oxw = gid % ow;
uint oxh = (gid / ow) % oh;
uint cc = (gid / (ow * oh)) % c;
uint bb = gid / (ow * oh * c);
uint sy = min(ih - 1u, (oxh * ih) / oh);
uint sx = min(iw - 1u, (oxw * iw) / ow);
o[((bb * c + cc) * oh + oxh) * ow + oxw] = inp[((bb * c + cc) * ih + sy) * iw + sx];
}