#version 450
// avg_pool2d. input [b,c,ih,iw], out [b,c,oh,ow]. window (kh,kw), stride (sh,sw), no padding.
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc {
uint b_size; uint c; uint ih; uint iw; uint oh; uint ow;
uint kh; uint kw; uint sh; uint sw;
};
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = b_size * c * oh * ow;
if (gid >= total) { return; }
uint oxw = gid % ow;
uint oxh = (gid / ow) % oh;
uint cc = (gid / (ow * oh)) % c;
uint bb = gid / (ow * oh * c);
uint base = (bb * c + cc) * ih * iw;
float acc = 0.0;
for (uint r = 0u; r < kh; r++) {
uint iy = oxh * sh + r;
for (uint cw = 0u; cw < kw; cw++) {
uint ix = oxw * sw + cw;
acc += inp[base + iy * iw + ix];
}
}
o[((bb * c + cc) * oh + oxh) * ow + oxw] = acc / float(kh * kw);
}