#version 450
// upsample_nearest1d. input [b,c,l_in], out [b,c,l_out]. src idx = min(l_in-1, (ox*l_in)/l_out).
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc { uint b_size; uint c; uint l_in; uint l_out; };
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = b_size * c * l_out;
if (gid >= total) { return; }
uint ox = gid % l_out;
uint cc = (gid / l_out) % c;
uint bb = gid / (l_out * c);
uint sx = min(l_in - 1u, (ox * l_in) / l_out);
o[(bb * c + cc) * l_out + ox] = inp[(bb * c + cc) * l_in + sx];
}