#version 450
// Last-axis reduction: out[o] = reduce_{r<R} in[o*R + r]. Non-last / multi-
// axis reduces are lowered to this form upstream (LowerNonLastAxisReduce).
// op: 0 sum 1 mean 2 max 3 min 4 prod.
layout(local_size_x = 256) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint outer; // number of output rows
uint r; // reduced extent (last axis)
uint in_off;
uint out_off;
uint op;
} pc;
void main() {
uint o = gl_GlobalInvocationID.x;
if (o >= pc.outer) { return; }
uint base = pc.in_off + o * pc.r;
float acc;
if (pc.op == 2u) { acc = -3.402823466e38; } // max
else if (pc.op == 3u) { acc = 3.402823466e38; } // min
else if (pc.op == 4u) { acc = 1.0; } // prod
else { acc = 0.0; } // sum/mean
for (uint i = 0u; i < pc.r; i++) {
float x = data[base + i];
switch (pc.op) {
case 0u: acc += x; break;
case 1u: acc += x; break;
case 2u: acc = max(acc, x); break;
case 3u: acc = min(acc, x); break;
case 4u: acc *= x; break;
}
}
if (pc.op == 1u && pc.r > 0u) { acc /= float(pc.r); }
data[pc.out_off + o] = acc;
}