1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! GPU sampling kernels — softmax + categorical inverse-CDF walk used
//! by FFAI's `gpu-categorical` decode path (T > 0, no filters). The
//! greedy fast path uses `argmax` instead.
//!
//! Codegen-only. End-to-end sampling correctness lives in FFAI's
//! harness.
use metaltile::{bench_kernel, kernel};
// Tree reductions for the max-pass and sum-pass each fold 256 threadgroup
// slots → 1 value across 8 power-of-two halving stages. Originally
// hand-unrolled via `tg_max_step!` / `tg_sum_step!` declarative macros;
// the proc-macro does not expand inner `macro_rules!` so the unrolled
// expansions silently produced no IR. Replaced with DSL `for` loops
// that yield the same Metal output and survive the proc-macro intact.
// Softmax + categorical sample over a 1D logits tensor. Cooperative
// reduction (256 threads) for max-pass; combined chunked sum-exp +
// inclusive scan + parallel-prefix CDF walk for the categorical pick.
//
// Inputs:
// inp — logits [n]
// out — token id [1] (u32)
// temperature_in — temperature [1] (f32, must be > 0)
// uniform_in — uniform draw in [0, 1) [1] (f32)
//
// Output is the smallest index `i` such that the cumulative softmax
// (in fp32) up to and including `i` is ≥ `uniform_in * sum_exp`.
//
// Cost: vocab=152K on M5 Max ~563µs median (down from ~8370µs in the
// single-thread CDF walk version, measured via the 1000-iter dispatch
// loop in `tests/softmax_categorical_sample_perf.rs`). ~15× speedup
// dominated by collapsing pass 3's O(n) walk. Lane lid owns a contiguous
// chunk = ceil(n/lsize) ≈ 594 positions; Hillis-Steele inclusive scan
// turns per-lane chunk-partials into per-lane cumulative bounds; the
// lane whose chunk contains `u * total` walks its own chunk serially
// to find the exact index. The full-vocab serial walk (152K ops) is
// replaced by 1 × n/lsize chunk-traverse per lane + an 8-stage scan +
// 1 × n/lsize finalizing walk on the winning lane.
#[bench_kernel(
op="sampling",
subop="softmax_categorical_sample",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn softmax_categorical_sample<T>(
inp: Tensor<T>,
out: Tensor<u32>,
temperature_in: Tensor<f32>,
uniform_in: Tensor<f32>,
#[constexpr] n: u32,
) {
let lid = tid;
let inv_t = 1.0f32 / load(temperature_in[0]);
// ─── Pass 1: cooperative max reduce (strided) ───────────────────
let mut local_max = neg_infinity();
threadgroup_alloc("tg_max", 256);
let n_iters = (n + lsize - 1u32) / lsize;
for _r in range(0u32, n_iters, 1u32) {
let pos = _r * lsize + lid;
if pos < n {
let v = load(inp[pos]).cast::<f32>() * inv_t;
local_max = select(v > local_max, v, local_max);
}
}
threadgroup_store("tg_max", lid, local_max);
threadgroup_barrier();
// 8-stage power-of-two halving max-reduction (stride 128 → 1).
for _stage in range(0u32, 8u32, 1u32) {
let stride = 128u32 >> _stage;
if lid < stride {
let ov = threadgroup_load("tg_max", lid + stride);
let tv = threadgroup_load("tg_max", lid);
threadgroup_store("tg_max", lid, select(ov > tv, ov, tv));
}
threadgroup_barrier();
}
let max_val = threadgroup_load("tg_max", 0u32);
// ─── Combined pass 2+3: chunk-partial sum-exp → inclusive scan
// → parallel-prefix CDF walk ─────────────
//
// Lane lid covers contiguous chunk [lo, hi); `total = tg_cdf[lsize-1]`
// after the scan replaces the previous standalone sum-exp reduce.
let chunk = (n + lsize - 1u32) / lsize;
let lo = lid * chunk;
let hi_raw = lo + chunk;
let hi = select(hi_raw > n, n, hi_raw);
let mut local_partial = 0.0f32;
for j in range(lo, hi, 1u32) {
if j < n {
let v = load(inp[j]).cast::<f32>() * inv_t;
local_partial = local_partial + exp(v - max_val);
}
}
threadgroup_alloc("tg_cdf", 256);
threadgroup_store("tg_cdf", lid, local_partial);
threadgroup_barrier();
// Hillis-Steele inclusive scan: 8 stages (stride 1 → 128).
// Underflow-safe: lanes with lid < stride contribute 0 instead of
// reading from negative indices.
for _stage in range(0u32, 8u32, 1u32) {
let stride = 1u32 << _stage;
let safe_neighbor = select(lid >= stride, lid - stride, lid);
let raw = threadgroup_load("tg_cdf", safe_neighbor);
let neighbor_val = select(lid >= stride, raw, 0.0f32);
threadgroup_barrier();
let cur = threadgroup_load("tg_cdf", lid);
threadgroup_store("tg_cdf", lid, cur + neighbor_val);
threadgroup_barrier();
}
let total = threadgroup_load("tg_cdf", lsize - 1u32);
let target = load(uniform_in[0]) * total;
let my_cum_end = threadgroup_load("tg_cdf", lid);
let prev_cum = select(
lid == 0u32,
0.0f32,
threadgroup_load("tg_cdf", select(lid > 0u32, lid - 1u32, lid)),
);
// Hit lane: target sits in (prev_cum, my_cum_end]. The strict
// lower bound means exactly one lane fires at a boundary value.
let is_hit = (prev_cum < target) & (target <= my_cum_end) & (lo < n);
if is_hit {
let mut cum = prev_cum;
let mut found_idx = hi - 1u32; // fallback: last position in chunk
let mut done = 0u32;
for i in range(lo, hi, 1u32) {
if i < n {
let v = load(inp[i]).cast::<f32>() * inv_t;
cum = cum + exp(v - max_val);
let hit_i = (cum >= target) & (done == 0u32);
found_idx = select(hit_i, i, found_idx);
done = select(hit_i, 1u32, done);
}
}
store(out[0], found_idx);
}
}