1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! Multi-query SDPA — attends `n_query` query rows against a shared
//! K/V cache in a single dispatch. Used by Nemotron-Labs-Diffusion's
//! block-diffusion / self-speculation `forwardTokens`, where a whole
//! block of tokens is forwarded at once instead of one decode step at
//! a time.
//!
//! This is `ffai_sdpa_decode` generalised with a query dimension: one
//! threadgroup per (query, q_head), the same TPG=1024 online-softmax
//! cross-simdgroup reduction. Two attention modes select via the
//! `causal` uniform:
//!
//! - `causal == 0` — every query attends `[0, base_kv + n_query)`:
//! full / bidirectional over the cached prefix plus the whole
//! block (the diffusion-denoise pattern).
//! - `causal == 1` — query `r` attends `[0, base_kv + r + 1)`:
//! causal within the block, the prefix always fully visible
//! (the AR-verify / causal-commit pattern).
//!
//! ## DISPATCH INVARIANTS
//!
//! Reduction-mode kernel — STRICT threadgroup geometry, the same
//! machine-freeze hazard as `ffai_sdpa_decode`. Consumers MUST encode
//! these as preconditions in their wrappers.
//!
//! - **TPG = 1024 threads** (32 simdgroups × 32 lanes). Hard. A TPG
//! below 32 makes `n_simd = TPG / 32 = 0`, turning the K walk
//! `range(sg, n_kv, 0)` into an infinite GPU loop — the freeze.
//! - **`head_dim == 128`.** Each lane owns 4 consecutive Q/K/V
//! elements at `lane*4 + {0..3}`, indexed unconditionally.
//! - **Grid: 1 threadgroup per (query, q_head).** `tgid_x` ranges
//! `[0, n_q_heads * n_query)`; decoded `query = tgid / n_q_heads`,
//! `q_head = tgid % n_q_heads`. Wrapper dispatches
//! `grid = (n_q_heads * n_query * 1024, 1, 1)`, `tg = (1024, 1, 1)`.
//! - **`n_q_heads % heads_per_group == 0`** for integer GQA fan-out.
//! - **`base_kv + n_query <= kv_stride`** — the kernel never walks
//! past the cache's allocated depth.
//!
//! K/V cache layout `[n_kv_heads, kv_stride, head_dim]`; Q and `out`
//! layout `[n_query, n_q_heads, head_dim]`. Online softmax runs in
//! fp32 throughout (storage stays in T).
use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="sdpa",
subop="sdpa_multi",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn ffai_sdpa_multi<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_q_heads: u32,
#[constexpr] base_kv: u32,
#[constexpr] n_query: u32,
#[constexpr] kv_stride: u32,
#[constexpr] heads_per_group: u32,
#[constexpr] causal: u32,
#[constexpr] scale: f32,
) {
let tg = tgid_x;
let query_idx = tg / n_q_heads;
let q_head = tg % n_q_heads;
let kv_head = q_head / heads_per_group;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
// KV positions this query attends. causal: prefix + the block up to
// and including this query. full: prefix + the entire block.
let n_kv = select(causal == 1u32, base_kv + query_idx + 1u32, base_kv + n_query);
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out0", 1056);
threadgroup_alloc("tg_out1", 1056);
threadgroup_alloc("tg_out2", 1056);
threadgroup_alloc("tg_out3", 1056);
let q_off = (query_idx * n_q_heads + q_head) * head_dim;
let kv_head_base = kv_head * kv_stride * head_dim;
let d0 = lane * 4u32;
// Pre-scale this lane's 4-element Q quartile once; K/V are streamed.
let q0 = load(q[q_off + d0]).cast::<f32>() * scale;
let q1 = load(q[q_off + d0 + 1u32]).cast::<f32>() * scale;
let q2 = load(q[q_off + d0 + 2u32]).cast::<f32>() * scale;
let q3 = load(q[q_off + d0 + 3u32]).cast::<f32>() * scale;
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
let mut o1 = 0.0f32;
let mut o2 = 0.0f32;
let mut o3 = 0.0f32;
// Each simdgroup walks every ns-th KV position. simd_sum reduces the
// per-lane quartile dot product into the full score; online softmax
// updates the running (max, sum); V accumulates into fp32 registers.
// Pre-compute the kv VIDs before the loads so vectorize sees 4
// consecutive Op::Load (same constraint as ffai_sdpa_decode).
for _t in range(sg, n_kv, ns) {
let base = kv_head_base + _t * head_dim;
let kv_idx = base + d0;
let kv0 = kv_idx;
let kv1 = kv_idx + 1u32;
let kv2 = kv_idx + 2u32;
let kv3 = kv_idx + 3u32;
let k0_raw = load(k[kv0]);
let k1_raw = load(k[kv1]);
let k2_raw = load(k[kv2]);
let k3_raw = load(k[kv3]);
let k0 = k0_raw.cast::<f32>();
let k1 = k1_raw.cast::<f32>();
let k2 = k2_raw.cast::<f32>();
let k3 = k3_raw.cast::<f32>();
let partial = q0 * k0 + q1 * k1 + q2 * k2 + q3 * k3;
let score = simd_sum(partial);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0_raw = load(v[kv0]);
let v1_raw = load(v[kv1]);
let v2_raw = load(v[kv2]);
let v3_raw = load(v[kv3]);
let v0 = v0_raw.cast::<f32>();
let v1 = v1_raw.cast::<f32>();
let v2 = v2_raw.cast::<f32>();
let v3 = v3_raw.cast::<f32>();
o0 = o0 * factor + weight * v0;
o1 = o1 * factor + weight * v1;
o2 = o2 * factor + weight * v2;
o3 = o3 * factor + weight * v3;
}
// ── Cross-simdgroup reduction: max + sum_exp ────────────────────
if lane == 0 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
// ── Cross-simdgroup reduction: outputs ──────────────────────────
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let rescale = select(g_sum > 0.0f32, exp(run_max - g_max) / g_sum, 0.0f32);
// Transpose-then-reduce with a +1 padded stride so adjacent lanes
// hit distinct threadgroup-memory banks (see ffai_sdpa_decode).
let stride = ns + 1u32;
let idx = lane * stride + sg;
threadgroup_store("tg_out0", idx, o0 * rescale);
threadgroup_store("tg_out1", idx, o1 * rescale);
threadgroup_store("tg_out2", idx, o2 * rescale);
threadgroup_store("tg_out3", idx, o3 * rescale);
threadgroup_barrier();
if sg == 0 {
let mut so0 = 0.0f32;
let mut so1 = 0.0f32;
let mut so2 = 0.0f32;
let mut so3 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so0 = so0 + threadgroup_load("tg_out0", ri);
so1 = so1 + threadgroup_load("tg_out1", ri);
so2 = so2 + threadgroup_load("tg_out2", ri);
so3 = so3 + threadgroup_load("tg_out3", ri);
}
let out_off = q_off + d0;
store(out[out_off], so0.cast::<T>());
store(out[out_off + 1u32], so1.cast::<T>());
store(out[out_off + 2u32], so2.cast::<T>());
store(out[out_off + 3u32], so3.cast::<T>());
}
}