#version 450
// upsample_bilinear2d. input [b,c,ih,iw], out [b,c,oh,ow]. scale_h/scale_w + align_corners
// are computed host-side (PyTorch area_pixel logic) and passed in.
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc {
uint b_size; uint c; uint ih; uint iw; uint oh; uint ow;
uint align_corners; float scale_h; float scale_w;
};
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = b_size * c * oh * ow;
if (gid >= total) { return; }
uint oxw = gid % ow;
uint oxh = (gid / ow) % oh;
uint cc = (gid / (ow * oh)) % c;
uint bb = gid / (ow * oh * c);
float src_h = (align_corners != 0u) ? (scale_h * float(oxh))
: (scale_h * (float(oxh) + 0.5) - 0.5);
float src_w = (align_corners != 0u) ? (scale_w * float(oxw))
: (scale_w * (float(oxw) + 0.5) - 0.5);
src_h = max(src_h, 0.0);
src_w = max(src_w, 0.0);
uint h0 = uint(floor(src_h));
uint w0 = uint(floor(src_w));
uint h1 = min(h0 + 1u, ih - 1u);
uint w1 = min(w0 + 1u, iw - 1u);
float wh = clamp(src_h - float(h0), 0.0, 1.0);
float ww = clamp(src_w - float(w0), 0.0, 1.0);
uint base = (bb * c + cc) * ih * iw;
float v00 = inp[base + h0 * iw + w0];
float v01 = inp[base + h0 * iw + w1];
float v10 = inp[base + h1 * iw + w0];
float v11 = inp[base + h1 * iw + w1];
float top = v00 * (1.0 - ww) + v01 * ww;
float bot = v10 * (1.0 - ww) + v11 * ww;
o[((bb * c + cc) * oh + oxh) * ow + oxw] = top * (1.0 - wh) + bot * wh;
}