1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Logits-processor kernels for the sampling pipeline.
//!
//! Decode-form samplers (other than the bare-softmax fused
//! `softmax_categorical_sample`) compose a small set of in-place
//! transforms on the logits vector before the final categorical
//! draw. Pipeline shape: temperature → repetition penalty → top-k →
//! top-p (nucleus) → categorical sample. This file ships kernels 1
//! and 2 (temperature, repetition penalty); top-k / top-p require a
//! sort or quickselect pass and live in a follow-up.
//!
//! Semantic contracts:
//!
//! - **temperature**: `logits[i] /= temperature` (no-op at 1.0;
//! small T sharpens toward argmax). Caller clamps to a positive
//! floor before dispatch.
//! - **repetition penalty**: for each token id in `token_ids`,
//! `v > 0 → v /= penalty`, `v ≤ 0 → v *= penalty`. Matches the
//! HuggingFace `transformers.LogitsProcessorList` and vLLM
//! conventions. `penalty == 1.0` is a no-op.
//!
//! Top-k and top-p require a sort or quickselect pass — they live
//! in a follow-up kernel since the sort dispatch geometry doesn't
//! fit the simple one-thread-per-element shape these two use.
//!
//! Generic over T; all values are upcast to f32 internally so f16/bf16
//! logits accumulate cleanly across the scale and don't drift on the
//! repeated-token gather. Output dtype matches input dtype.
use ;
// ── Temperature scaling ───────────────────────────────────────────────────
//
// Pure elementwise `logits[i] /= temperature`. Generic-T, one thread per
// vocab position. At `temperature == 1.0` this is a copy; at very small
// temperature it sharpens the distribution toward greedy argmax (the
// downstream `softmax_categorical_sample` handles the softmax itself).
//
// Caller contract: `temperature > 0`. A zero or negative temperature
// produces inf / sign-flipped logits — callers should clamp before
// dispatch (`max(temperature, 1e-5)` is the standard guard).
//
// ## DISPATCH INVARIANTS
//
// - **Mode: Grid3D.** One thread per vocab position.
// - **Grid: `[ceil(n / TPG), 1, 1]`, TG: `[TPG, 1, 1]`** (TPG = 256 is the
// tested geometry; any value works since the kernel is pure elementwise
// and uses no `threadgroup_*` / `simd_*` cooperation).
// - **`n = grid.x * tg.x`** — the caller is responsible for `n` covering
// the full logits length. Threads with `program_id::<0>() >= n` would
// read/write out of bounds; the runtime should size the dispatch so the
// total thread count exactly matches the logits length.
// ── Repetition penalty ────────────────────────────────────────────────────
//
// In-place mutate the logits at every position appearing in `token_ids`,
// scaling toward 0 to discourage repeats. Convention matches HuggingFace
// `transformers.LogitsProcessorList`:
//
// for tok in token_ids:
// if logits[tok] > 0: logits[tok] /= penalty
// else: logits[tok] *= penalty
//
// `penalty == 1.0` is a no-op; `penalty > 1.0` discourages repeats;
// `penalty < 1.0` encourages repeats (rare).
//
// Dispatch: one thread per `token_ids` entry. The kernel reads
// `logits[token_ids[i]]`, updates, and writes back. With duplicate
// token ids the operation is **idempotent in expectation but
// non-deterministic in order** — multiple threads racing on the same
// vocab slot pick a write order. Callers MUST dedupe `token_ids` before
// dispatch (or accept the last-writer-wins semantics, which matches
// what a sequential CPU pass produces *only* on a deduped input).
//
// ## DISPATCH INVARIANTS
//
// - **Mode: Grid3D.** One thread per `token_ids` entry.
// - **Grid / TG: `grid.x * tg.x == token_ids.len()`** — caller must size
// the dispatch to exactly the token-id count. TPG = 256 (or smaller for
// small contexts) is the tested geometry.
// - **No `threadgroup_*` / `simd_*` cooperation** — every thread is
// independent. The only invariant is the dedupe contract above.