#version 450
// ArgMax / ArgMin along an axis (viewed as [outer, axis_len, inner]). Output is
// the f32-encoded index of the extreme element. op: 0 = argmax, 1 = argmin.
// Ties break to the smaller index (matches NumPy / the CPU reference).
layout(local_size_x = 256) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint outer;
uint axis_len;
uint inner;
uint in_off;
uint out_off;
uint op;
} pc;
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = pc.outer * pc.inner;
if (gid >= total) { return; }
uint o = gid / pc.inner;
uint j = gid % pc.inner;
uint base = o * pc.axis_len * pc.inner + j;
float best = data[pc.in_off + base];
uint best_i = 0u;
for (uint k = 1u; k < pc.axis_len; k++) {
float v = data[pc.in_off + base + k * pc.inner];
bool better = (pc.op == 0u) ? (v > best) : (v < best);
if (better) {
best = v;
best_i = k;
}
}
data[pc.out_off + gid] = float(best_i);
}