#version 450
// Row-wise argmin over the last dim: out[row] = index of the minimum in row. Output u32, length rows.
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer In { float inp[]; };
layout(set = 0, binding = 1) writeonly buffer Out { uint o[]; };
layout(push_constant) uniform Pc { uint rows; uint cols; };
void main() {
uint row = gl_GlobalInvocationID.x;
if (row >= rows) { return; }
uint base = row * cols;
float best = inp[base];
uint bi = 0u;
for (uint c = 1u; c < cols; c++) {
float v = inp[base + c];
if (v < best) { best = v; bi = c; }
}
o[row] = bi;
}