1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
//! Flash attention vector kernel dispatch for TurboQuant-compressed KV cache.
//!
//! Fork of `flash_attn_vec` that reads K and V from nibble-packed indices
//! + per-position norms, with inline scalar dequant from a register-resident
//! 16-element codebook. No centroid table buffer needed.
//!
//! Key differences from `flash_attn_vec`:
//! - Adaptive NWG (1-32) based on kv_seq_len. At short context NWG=1
//! avoids the reduce kernel. At long context NWG scales up for parallelism.
//! - Caller handles FWHT: pre-rotates Q, post-rotates output (1× per head).
//! - Dequant is inline: codebook[nibble] * inv_sqrt(head_dim) * norm
//! - Zero scattered memory access — codebook fits in registers
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
/// MSL source for the TQ flash attention vector kernel (embedded at compile time).
pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_vec_tq.metal");
/// Register TQ flash attention vector shader source with the given kernel registry.
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
}
/// Parameters for the TQ flash attention vector kernel.
#[derive(Debug, Clone, Copy)]
pub struct FlashAttnVecTqParams {
/// Number of query attention heads.
pub num_heads: u32,
/// Number of key/value attention heads (GQA: may be < num_heads).
pub num_kv_heads: u32,
/// Dimension of each attention head (256 or 512).
pub head_dim: u32,
/// Current KV sequence length (number of valid positions).
pub kv_seq_len: u32,
/// KV cache capacity (stride between KV heads in positions).
pub kv_capacity: u32,
/// Attention score scaling factor (e.g. 1/sqrt(head_dim) or 1.0).
pub scale: f32,
/// Mask type: 0=none, 1=causal, 2=sliding_window.
pub mask_type: u32,
/// Sliding window size (only used when mask_type == 2).
pub sliding_window: u32,
/// Logit softcapping (0 = disabled).
pub softcap: f32,
/// Ring buffer start slot for sliding-window after wrap (ADR-009 Track 2).
///
/// Physical slot index of the chronologically OLDEST entry in the ring
/// buffer. Before wrap (`kv_seq_len < kv_capacity`): set to 0.
/// After wrap: set to `write_pos % capacity`.
///
/// The shader uses this to map physical slots to logical positions
/// for correct causal/sliding-window masking after wrap.
pub ring_start: u32,
/// iter-18 S2B: reciprocal of the D=512 per-block encoder scale factor.
/// Decoder applies: actual_blk_norm = stored_blk_norm / scale_factor_d512.
/// bare=1.0 (iter-16 control), sqrt256=16.0, sqrt512≈22.627.
/// D=256 path ignores this field.
/// Use `None`/0.0 to default to 1.0 (bare behavior).
pub scale_factor_d512: f32,
}
/// GPU-side reduce params. Must match `FlashAttnVecReduceParams` in the MSL.
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecReduceParamsGpu {
nrows: u32,
}
/// GPU-side parameter struct. Must match the MSL `FlashAttnVecTqParams` exactly.
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FlashAttnVecTqParamsGpu {
n_heads: u32,
n_kv_heads: u32,
head_dim: u32,
kv_seq_len: u32,
kv_capacity: u32,
scale: f32,
mask_type: u32,
sliding_window: u32,
softcap: f32,
nwg: u32,
ring_start: u32,
/// iter-18 S2B: reciprocal scale factor for D=512 dequant. See FlashAttnVecTqParams.
scale_factor_d512: f32,
}
/// Validate TQ flash attention parameters.
fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
if params.head_dim != 256 && params.head_dim != 512 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
params.head_dim
)));
}
if params.num_heads == 0 || params.num_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
));
}
if params.num_heads % params.num_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
params.num_heads, params.num_kv_heads
)));
}
if params.kv_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_vec_tq: kv_seq_len must be > 0".into(),
));
}
if params.kv_capacity < params.kv_seq_len {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
params.kv_capacity, params.kv_seq_len
)));
}
Ok(())
}
/// Compute NWG for TQ SDPA.
///
/// NWG=16 is optimal across both short and long context on M5 Max
/// (measured: outperforms both NWG=1 and NWG=32 at all tested lengths).
/// NWG=32 adds reduce kernel overhead that outweighs its parallelism gain.
/// NWG<16 starves the GPU at long context.
///
/// Override: set HF2Q_TQ_NWG=N to force a specific value (for benchmarking).
fn compute_nwg(_kv_seq_len: u32) -> u32 {
if let Ok(v) = std::env::var("HF2Q_TQ_NWG") {
if let Ok(n) = v.parse::<u32>() {
if n >= 1 && n <= 32 {
return n;
}
}
}
16
}
/// Dispatch TQ flash attention vector kernel on the GPU.
///
/// Dispatches NWG=32 workgroups per head, then a reduce kernel.
///
/// **FWHT is NOT done inside this kernel.** The caller must:
/// 1. Pre-rotate Q via `dispatch_fwht_f32` before calling this function
/// 2. Apply inverse FWHT to the output after this function returns
///
/// With NWG=32, doing FWHT per-workgroup would repeat it 32× per head.
/// Keeping FWHT outside means it's done once per head regardless of NWG.
///
/// # Arguments
///
/// * `q` — Query buffer `[num_heads, 1, head_dim]`, F32, **pre-rotated via FWHT**.
/// * `output` — Output buffer `[num_heads, 1, head_dim]`, F32, **in rotated domain**.
/// * `tmp` — Temporary buffer for NWG partial results.
#[allow(clippy::too_many_arguments)]
pub fn flash_attn_vec_tq(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k_packed: &MlxBuffer,
k_norms: &MlxBuffer,
v_packed: &MlxBuffer,
v_norms: &MlxBuffer,
output: &MlxBuffer,
tmp: &MlxBuffer,
params: &FlashAttnVecTqParams,
) -> Result<()> {
validate_params(params)?;
let head_dim = params.head_dim;
let nwg = compute_nwg(params.kv_seq_len);
// Ensure scale_factor_d512 is always >= 1.0 (0.0 treated as 1.0 for safety).
let effective_scale_d512 = if params.scale_factor_d512 < 1e-6 { 1.0_f32 } else { params.scale_factor_d512 };
let gpu_params = FlashAttnVecTqParamsGpu {
n_heads: params.num_heads,
n_kv_heads: params.num_kv_heads,
head_dim: params.head_dim,
kv_seq_len: params.kv_seq_len,
kv_capacity: params.kv_capacity,
scale: params.scale,
mask_type: params.mask_type,
sliding_window: params.sliding_window,
softcap: params.softcap,
nwg,
ring_start: params.ring_start,
scale_factor_d512: effective_scale_d512,
};
let kernel_name = match head_dim {
256 => "flash_attn_vec_tq_dk256",
512 => "flash_attn_vec_tq_dk512",
_ => return Err(MlxError::InvalidArgument(format!(
"flash_attn_vec_tq: unsupported head_dim {head_dim}"
))),
};
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
// Shared memory size — same layout as flash_attn_vec.
// PK halfs (Q half4) + SH halfs (scratch) + 2*PV halfs (output float4)
let pk = pad2(head_dim as usize, 128);
let pv = pad2(head_dim as usize, 128);
let sh = 4 * 32; // 4 * C = 128 halfs
let shmem_halfs = pk + sh + 2 * pv;
let shmem_bytes = shmem_halfs * 2; // 2 bytes per half
// Tag for the reorder pass: SDPA is NOT reorderable.
encoder.set_op_kind(CapturedOpKind::Sdpa);
// Dispatch main kernel: (1 query, num_heads, NWG workgroups).
let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
let threadgroup_size = MTLSize::new(32, 1, 1); // 1 simdgroup of 32 threads
// NWG=1: write directly to output (no reduce needed).
// NWG>1: write to tmp, then reduce into output.
let dst_buf = if nwg == 1 { output } else { tmp };
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(q)),
(2, KernelArg::Buffer(k_packed)),
(3, KernelArg::Buffer(k_norms)),
(4, KernelArg::Buffer(v_packed)),
(5, KernelArg::Buffer(v_norms)),
(6, KernelArg::Buffer(dst_buf)),
],
&[(0, shmem_bytes as u64)],
threadgroups,
threadgroup_size,
);
// --- Reduce kernel (NWG > 1 only) ---
if nwg > 1 {
encoder.memory_barrier();
let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
let reduce_kernel = match head_dim {
256 => "flash_attn_vec_reduce_dk256",
512 => "flash_attn_vec_reduce_dk512",
_ => unreachable!(),
};
let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
encoder.encode_threadgroups_with_args(
reduce_pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&reduce_params))),
(1, KernelArg::Buffer(tmp)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&nwg))),
],
reduce_tg,
reduce_tg_size,
);
}
Ok(())
}
/// Compute the size in bytes of the temporary buffer needed for TQ SDPA.
///
/// Sized for max NWG=32 regardless of actual adaptive NWG — the buffer is
/// allocated once at model load time and reused for all context lengths.
pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
let nrows = num_heads as usize;
let max_nwg = 32usize;
let dv = head_dim as usize;
(nrows * max_nwg * (dv + 2)) * std::mem::size_of::<f32>()
}
/// Pad x up to next multiple of n (n must be power of 2).
fn pad2(x: usize, n: usize) -> usize {
(x + n - 1) & !(n - 1)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_validate_params_ok() {
let p = FlashAttnVecTqParams {
num_heads: 8,
num_kv_heads: 4,
head_dim: 256,
kv_seq_len: 64,
kv_capacity: 1024,
scale: 1.0,
mask_type: 1,
sliding_window: 0,
softcap: 0.0,
ring_start: 0,
scale_factor_d512: 1.0,
};
assert!(validate_params(&p).is_ok());
}
#[test]
fn test_validate_params_bad_head_dim() {
let p = FlashAttnVecTqParams {
num_heads: 8,
num_kv_heads: 4,
head_dim: 128,
kv_seq_len: 64,
kv_capacity: 1024,
scale: 1.0,
mask_type: 0,
sliding_window: 0,
softcap: 0.0,
ring_start: 0,
scale_factor_d512: 1.0,
};
assert!(validate_params(&p).is_err());
}
#[test]
fn test_gpu_params_layout() {
assert_eq!(
std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
48, // 12 x u32/f32 = 48 bytes (iter-18: +scale_factor_d512)
);
}
}