1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Top-K filter — masking variant.
//!
//! The full top-K filter pipeline is:
//!
//! 1. Find the K-th largest logit value: `threshold = argpartition(logits, -K)`
//! 2. For every logit, if `logit >= threshold` keep it; else set to `-inf`
//!
//! Step 1 is a selection / partial-sort. On GPU at typical serving K (50, 100)
//! and Qwen-scale vocab (152K) the per-call cost is dominated by Metal command-
//! buffer overhead, not arithmetic — a CPU argpartition + threshold-pass is
//! roughly the same wall-clock as a GPU select kernel and one less dispatch.
//! This file ships the GPU mask kernel and leaves threshold computation to
//! the caller. A future PR can add a GPU-side selection kernel when serving
//! batch sizes make a single fused dispatch pull ahead.
//!
//! Caller contract:
//! - Compute `threshold` = the K-th largest value (descending) on the host.
//! - Pass it as the constexpr `threshold` parameter.
//! - Logits below `threshold` are replaced with `-INFINITY` (this is the
//! standard sentinel — downstream softmax sees `exp(-inf) = 0` and the
//! filtered tokens contribute zero probability).
//!
//! Generic over T. Grid3D one-thread-per-vocab-position.
//!
//! ## DISPATCH INVARIANTS
//!
//! - **Mode: Grid3D.** One thread per vocab position.
//! - **Grid: `[ceil(n / TPG), 1, 1]`, TG: `[TPG, 1, 1]`** (TPG = 256 is the
//! tested geometry; the kernel is pure elementwise so any TPG works).
//! - **`n = grid.x * tg.x`** — caller sizes the dispatch so the total
//! thread count exactly matches the vocab length. Threads past `n`
//! would read/write out of bounds; the runtime should not overshoot.
//! - **No `threadgroup_*` / `simd_*` cooperation** — every thread is
//! independent. The only invariant is the threshold semantic above.
use ;