#version 450
// 2-D pooling on NCHW. kind: 2 = max, 1 = avg (ReduceOp::Max / ::Mean). One
// invocation per output element. Avg divides by the valid (in-bounds) count.
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint nn; uint cc; uint hh; uint ww; // input N,C,H,W
uint ho; uint wo; // output H,W
uint kh; uint kw; uint sh; uint sw; uint ph; uint pw;
uint x_off; uint out_off;
uint kind; // 2 = max, 1 = avg
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = pc.nn * pc.cc * pc.ho * pc.wo;
if (gid >= total) { return; }
uint owi = gid % pc.wo;
uint ohi = (gid / pc.wo) % pc.ho;
uint c = (gid / (pc.wo * pc.ho)) % pc.cc;
uint n = gid / (pc.wo * pc.ho * pc.cc);
int h0 = int(ohi * pc.sh) - int(pc.ph);
int w0 = int(owi * pc.sw) - int(pc.pw);
float acc = (pc.kind == 2u) ? -3.402823466e38 : 0.0;
uint cnt = 0u;
for (uint i = 0u; i < pc.kh; i++) {
int h = h0 + int(i);
if (h < 0 || h >= int(pc.hh)) { continue; }
for (uint j = 0u; j < pc.kw; j++) {
int w = w0 + int(j);
if (w < 0 || w >= int(pc.ww)) { continue; }
uint idx = ((n * pc.cc + c) * pc.hh + uint(h)) * pc.ww + uint(w);
float v = data[pc.x_off + idx];
if (pc.kind == 2u) { acc = max(acc, v); } else { acc += v; }
cnt++;
}
}
if (pc.kind == 1u) { acc = (cnt > 0u) ? acc / float(cnt) : 0.0; }
data[pc.out_off + gid] = acc;
}