1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! AURA Flash Pass 2 — cross-block online-softmax merge.
//!
//! Reduces the `(o_partials, m_partials, l_partials)` tuples emitted
//! by `aura_flash_p1` (one tuple per (q_idx, block_idx) pair) into a
//! single `(o, m, l)` per q_idx, then writes the final attention
//! output `o / l` cast to bf16.
//!
//! Port of `turbo_flash_pass2` from
//! `ekryski/mlx@alpha:mlx/backend/metal/kernels/turbo_quant.metal`.
//!
//! ## Layout
//!
//! Inputs:
//! - `o_partials [q_heads, num_blocks, dim]` f32
//! - `m_partials [q_heads, num_blocks]` f32 — per-block max.
//! - `l_partials [q_heads, num_blocks]` f32 — per-block sum_exp.
//!
//! Output:
//! - `output [q_heads, dim]` bf16
//!
//! ## Dispatch
//!
//! Reduction mode; threadgroup = (32, 1, 1) per q_idx. Each lane owns
//! `DIMS_PER_LANE = ceil(dim / 32)` output slots (the lane's stride-32
//! slice of `dim`), kept in a per-thread stack array. Cross-block
//! merge: replay `b_idx ∈ [0, num_blocks)`, rescaling `o[]` and `l`
//! by the standard online-softmax max-shift on each step.
//!
//! ## Output dtype
//!
//! Bf16 directly — matches the MLX upstream's choice. Accumulators
//! stay fp32; only the final write narrows. See the note in the
//! upstream file about Qwen3.5-9B `!!!!!` decoding regressions when
//! this was fp32 + caller-side cast.
use metaltile::{bench_kernel, kernel};
use crate::bench_types::DType;
// Keep `DType` referenced — the bf16-only shortlist was a legacy default;
// now that the partials and output are all `Tensor<T>` the kernel handles
// fp32/fp16/bf16.
const _: DType = DType::F32;
macro_rules! aura_flash_pass2_kernel {
($name:ident, $dims_per_lane:literal, $subop:literal) => {
#[bench_kernel(op="aura", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Reduction,)]
#[kernel]
pub fn $name<T>(
o_partials: Tensor<T>,
m_partials: Tensor<T>,
l_partials: Tensor<T>,
mut output: Tensor<T>,
#[constexpr] dim: u32,
#[constexpr] num_blocks: u32,
) {
let lane = tid;
let q_idx = tgid_x;
// Per-lane accumulators. `o` is the running output slice;
// `m` and `l` are scalars updated each block. Initialised
// to (-INF, 0, 0).
stack_alloc("o", $dims_per_lane, "f32");
for i in range(0u32, $dims_per_lane, 1u32) {
stack_store("o", i, 0.0f32);
}
let mut m_acc = neg_infinity();
let mut l_acc = 0.0f32;
// Replay every block; rescale on each step using the
// standard online-softmax max-shift identity. Partials are
// promoted to f32 for the merge — keeps numerical stability
// independent of the storage dtype.
for b in range(0u32, num_blocks, 1u32) {
let ml_idx = q_idx * num_blocks + b;
let block_m = load(m_partials[ml_idx]).cast::<f32>();
let block_l = load(l_partials[ml_idx]).cast::<f32>();
// Skip empty blocks (causal masking can leave some
// blocks with l=0).
if block_l != 0.0f32 {
let new_m = select(m_acc > block_m, m_acc, block_m);
let exp_old = exp(m_acc - new_m);
let exp_block = exp(block_m - new_m);
let partial_base = (q_idx * num_blocks + b) * dim;
for i in range(0u32, $dims_per_lane, 1u32) {
let d = lane + i * 32u32;
if d < dim {
let prev = stack_load("o", i);
let part = load(o_partials[partial_base + d]).cast::<f32>();
let scaled = prev * exp_old + part * exp_block;
stack_store("o", i, scaled);
}
}
l_acc = l_acc * exp_old + block_l * exp_block;
m_acc = new_m;
}
}
// Final normalise + narrow-cast write.
let inv_l = select(l_acc > 0.0f32, 1.0f32 / l_acc, 0.0f32);
for i in range(0u32, $dims_per_lane, 1u32) {
let d = lane + i * 32u32;
if d < dim {
let v = stack_load("o", i) * inv_l;
store(output[q_idx * dim + d], v.cast::<T>());
}
}
}
};
}
// One instantiation per (dim). `dims_per_lane = ceil(dim / 32)`.
//
// dim 64 → 2 dims/lane
// dim 80 → 3 (3·32 = 96 ≥ 80)
// dim 96 → 3
// dim 128 → 4
// dim 256 → 8
// dim 512 → 16
aura_flash_pass2_kernel!(aura_flash_pass2_d64, 2u32, "flash_pass2_d64");
aura_flash_pass2_kernel!(aura_flash_pass2_d80, 3u32, "flash_pass2_d80");
aura_flash_pass2_kernel!(aura_flash_pass2_d96, 3u32, "flash_pass2_d96");
aura_flash_pass2_kernel!(aura_flash_pass2_d128, 4u32, "flash_pass2_d128");
aura_flash_pass2_kernel!(aura_flash_pass2_d256, 8u32, "flash_pass2_d256");
aura_flash_pass2_kernel!(aura_flash_pass2_d512, 16u32, "flash_pass2_d512");