1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Top-p (nucleus) logits filter for the sampling pipeline.
//!
//! Top-p sampling keeps the smallest set of most-likely tokens whose
//! cumulative probability reaches `top_p`, and masks the rest. The
//! reference definition sorts the probabilities descending and walks
//! the prefix until the running sum clears `top_p`. Equivalently — and
//! without a sort — there is a probability cutoff `c` such that the
//! kept set is exactly `{ i : P(i) ≥ c }`, and that set's mass is the
//! smallest that reaches `top_p`. This kernel finds `c` directly.
//!
//! Working in logit space avoids a full softmax: for any shift, the
//! unnormalised weight of token `i` is `w_i = exp(logit_i − logit_max)`
//! and `Z = Σ w_i`. The keep test `P(i) ≥ c` becomes `w_i ≥ c·Z`, so
//! the cutoff search runs entirely on `w ∈ (0, 1]`.
//!
//! `w` is not sorted, so `c` is found by **bisection**: the kept mass
//! `S(t) = Σ_{w_i ≥ t} w_i` is monotonically non-increasing in `t`, so
//! a binary search on `t ∈ [0, 1]` converges on the threshold where
//! `S(t)` just reaches `top_p·Z`. 24 halvings pin `t` to a `2⁻²⁴`
//! (≈ 6e-8) interval — far finer than the gap between adjacent token
//! weights near any realistic cutoff. A final pass masks every logit
//! whose weight is below the converged floor to `-INFINITY`, so the
//! downstream `softmax_categorical_sample` sees `exp(-inf) = 0`.
//!
//! This is the iterative-search sibling of `logits_min_p_mask`: min-p's
//! cutoff is a closed form of the row max (one reduction), but top-p's
//! cutoff depends on the whole mass profile, so it costs one reduction
//! per bisection step. The whole filter is still a single self-contained
//! GPU kernel — one threadgroup per row, no host round-trip, no sort.
//!
//! Reduction-mode, generic over T; the max, the partition function and
//! every kept-mass sum are computed in f32 so f16/bf16 logits don't
//! drift. One threadgroup per row; `n` is the vocab length, looped so
//! any `n` works at any (multiple-of-32) threadgroup size.
//!
//! Caller contract: `0 < top_p < 1`. As `top_p → 0` only the argmax
//! survives; as `top_p → 1` nothing is masked. A typical serving value
//! is 0.9–0.95.
use ;