#version 450
// Top-K indices along the last axis (f32-encoded), sorted descending, ties
// broken by smaller index (torch.topk largest=True, sorted=True). One
// invocation per row; array-free selection that carries the previous pick so
// duplicate values still advance. Input [..., n] → output [..., k].
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint rows;
uint n;
uint k;
uint in_off;
uint out_off;
} pc;
void main() {
uint row = gl_GlobalInvocationID.x;
if (row >= pc.rows) { return; }
uint base = pc.in_off + row * pc.n;
float prev_val = 3.402823466e38; // sentinel: everything is "below" this
uint prev_idx = 0u;
bool have_prev = false;
for (uint s = 0u; s < pc.k; s++) {
float best = -3.402823466e38;
uint best_i = 0u;
bool found = false;
for (uint i = 0u; i < pc.n; i++) {
float v = data[base + i];
// Strictly after the previous pick in (value desc, index asc) order.
bool below = !have_prev || (v < prev_val) || (v == prev_val && i > prev_idx);
bool better = (v > best) || (v == best && i < best_i);
if (below && (!found || better)) {
best = v;
best_i = i;
found = true;
}
}
data[pc.out_off + row * pc.k + s] = float(best_i);
prev_val = best;
prev_idx = best_i;
have_prev = true;
}
}