1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Min-p (minimum-probability) logits filter for the sampling pipeline.
//!
//! Min-p sampling keeps every token whose probability is at least
//! `min_p` times the probability of the most-likely token, and masks
//! the rest:
//!
//! keep token i ⇔ P(i) ≥ min_p · P_max
//!
//! Working in logit space avoids a full softmax. For any shift `C`,
//! `P(i) / P_max = exp(logit_i − logit_max)`, so the keep test is
//! simply `exp(logit_i − logit_max) ≥ min_p`. The kernel finds the
//! row max with one threadgroup reduction, then masks every logit
//! below the cutoff to `-INFINITY` in a second pass. Downstream
//! `softmax_categorical_sample` sees `exp(-inf) = 0`, so masked tokens
//! contribute zero probability.
//!
//! This is the reduction-mode sibling of `logits_topk_mask`: top-K
//! needs a host-computed K-th-largest threshold, but min-p's cutoff is
//! defined purely by the row max, so the whole filter fits in one
//! self-contained GPU kernel — no host round-trip, no sort.
//!
//! Reduction-mode, generic over T; the max and the ratio are computed
//! in f32 so f16/bf16 logits don't drift. One threadgroup per row;
//! `n` is the vocab length, looped so any `n` works at any
//! (multiple-of-32) threadgroup size.
//!
//! Caller contract: `0 < min_p < 1`. As `min_p → 0` nothing is masked;
//! as `min_p → 1` only the argmax (and exact ties) survive. A typical
//! serving value is 0.05–0.1.
use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="logits_processors",
subop="min_p_mask",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn logits_min_p_mask<T>(
inp: Tensor<T>,
out: Tensor<T>,
#[constexpr] n: u32,
#[constexpr] min_p: f32,
) {
let row = program_id::<0>();
let rs = row * n;
let re = rs + n;
// Pass 1: threadgroup-wide max of the row's logits.
let mut lm = neg_infinity();
for _i in range(rs + tid, re, lsize) {
lm = max(lm, load(inp[_i]).cast::<f32>());
}
let row_max = reduce_max(lm);
// Pass 2: keep a logit iff exp(logit - row_max) >= min_p, else -inf.
let neg_inf = neg_infinity();
for _i in range(rs + tid, re, lsize) {
let v = load(inp[_i]).cast::<f32>();
store(out[_i], select(exp(v - row_max) >= min_p, v, neg_inf).cast::<T>());
}
}