Skip to main content

mlx_native/ops/
flash_attn_vec_tq.rs

1//! Flash attention vector kernel dispatch for TurboQuant-compressed KV cache.
2//!
3//! Fork of `flash_attn_vec` that reads K and V from nibble-packed indices
4//! + per-position norms, with inline scalar dequant from a register-resident
5//! 16-element codebook. No centroid table buffer needed.
6//!
7//! Key differences from `flash_attn_vec`:
8//! - Adaptive NWG (1-32) based on kv_seq_len. At short context NWG=1
9//!   avoids the reduce kernel. At long context NWG scales up for parallelism.
10//! - Caller handles FWHT: pre-rotates Q, post-rotates output (1× per head).
11//! - Dequant is inline: codebook[nibble] * inv_sqrt(head_dim) * norm
12//! - Zero scattered memory access — codebook fits in registers
13
14use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::device::MlxDevice;
18use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
19use crate::error::{MlxError, Result};
20use crate::kernel_registry::KernelRegistry;
21
22/// MSL source for the TQ flash attention vector kernel (embedded at compile time).
23pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
24    include_str!("../shaders/flash_attn_vec_tq.metal");
25
26/// Register TQ flash attention vector shader source with the given kernel registry.
27pub fn register(registry: &mut KernelRegistry) {
28    registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
29    registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
30}
31
32/// Parameters for the TQ flash attention vector kernel.
33#[derive(Debug, Clone, Copy)]
34pub struct FlashAttnVecTqParams {
35    /// Number of query attention heads.
36    pub num_heads: u32,
37    /// Number of key/value attention heads (GQA: may be < num_heads).
38    pub num_kv_heads: u32,
39    /// Dimension of each attention head (256 or 512).
40    pub head_dim: u32,
41    /// Current KV sequence length (number of valid positions).
42    pub kv_seq_len: u32,
43    /// KV cache capacity (stride between KV heads in positions).
44    pub kv_capacity: u32,
45    /// Attention score scaling factor (e.g. 1/sqrt(head_dim) or 1.0).
46    pub scale: f32,
47    /// Mask type: 0=none, 1=causal, 2=sliding_window.
48    pub mask_type: u32,
49    /// Sliding window size (only used when mask_type == 2).
50    pub sliding_window: u32,
51    /// Logit softcapping (0 = disabled).
52    pub softcap: f32,
53    /// Ring buffer start slot for sliding-window after wrap (ADR-009 Track 2).
54    ///
55    /// Physical slot index of the chronologically OLDEST entry in the ring
56    /// buffer. Before wrap (`kv_seq_len < kv_capacity`): set to 0.
57    /// After wrap: set to `write_pos % capacity`.
58    ///
59    /// The shader uses this to map physical slots to logical positions
60    /// for correct causal/sliding-window masking after wrap.
61    pub ring_start: u32,
62    /// iter-18 S2B: reciprocal of the D=512 per-block encoder scale factor.
63    /// Decoder applies: actual_blk_norm = stored_blk_norm / scale_factor_d512.
64    /// bare=1.0 (iter-16 control), sqrt256=16.0, sqrt512≈22.627.
65    /// D=256 path ignores this field.
66    /// Use `None`/0.0 to default to 1.0 (bare behavior).
67    pub scale_factor_d512: f32,
68}
69
70/// GPU-side reduce params. Must match `FlashAttnVecReduceParams` in the MSL.
71#[repr(C)]
72#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
73struct FlashAttnVecReduceParamsGpu {
74    nrows: u32,
75}
76
77/// GPU-side parameter struct. Must match the MSL `FlashAttnVecTqParams` exactly.
78#[repr(C)]
79#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
80struct FlashAttnVecTqParamsGpu {
81    n_heads: u32,
82    n_kv_heads: u32,
83    head_dim: u32,
84    kv_seq_len: u32,
85    kv_capacity: u32,
86    scale: f32,
87    mask_type: u32,
88    sliding_window: u32,
89    softcap: f32,
90    nwg: u32,
91    ring_start: u32,
92    /// iter-18 S2B: reciprocal scale factor for D=512 dequant. See FlashAttnVecTqParams.
93    scale_factor_d512: f32,
94}
95
96/// Validate TQ flash attention parameters.
97fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
98    if params.head_dim != 256 && params.head_dim != 512 {
99        return Err(MlxError::InvalidArgument(format!(
100            "flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
101            params.head_dim
102        )));
103    }
104    if params.num_heads == 0 || params.num_kv_heads == 0 {
105        return Err(MlxError::InvalidArgument(
106            "flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
107        ));
108    }
109    if params.num_heads % params.num_kv_heads != 0 {
110        return Err(MlxError::InvalidArgument(format!(
111            "flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
112            params.num_heads, params.num_kv_heads
113        )));
114    }
115    if params.kv_seq_len == 0 {
116        return Err(MlxError::InvalidArgument(
117            "flash_attn_vec_tq: kv_seq_len must be > 0".into(),
118        ));
119    }
120    if params.kv_capacity < params.kv_seq_len {
121        return Err(MlxError::InvalidArgument(format!(
122            "flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
123            params.kv_capacity, params.kv_seq_len
124        )));
125    }
126    Ok(())
127}
128
129/// Compute NWG for TQ SDPA.
130///
131/// NWG=16 is optimal across both short and long context on M5 Max
132/// (measured: outperforms both NWG=1 and NWG=32 at all tested lengths).
133/// NWG=32 adds reduce kernel overhead that outweighs its parallelism gain.
134/// NWG<16 starves the GPU at long context.
135///
136/// Override: set HF2Q_TQ_NWG=N to force a specific value (for benchmarking).
137fn compute_nwg(_kv_seq_len: u32) -> u32 {
138    if let Ok(v) = std::env::var("HF2Q_TQ_NWG") {
139        if let Ok(n) = v.parse::<u32>() {
140            if n >= 1 && n <= 32 {
141                return n;
142            }
143        }
144    }
145    16
146}
147
148/// Dispatch TQ flash attention vector kernel on the GPU.
149///
150/// Dispatches NWG=32 workgroups per head, then a reduce kernel.
151///
152/// **FWHT is NOT done inside this kernel.** The caller must:
153/// 1. Pre-rotate Q via `dispatch_fwht_f32` before calling this function
154/// 2. Apply inverse FWHT to the output after this function returns
155///
156/// With NWG=32, doing FWHT per-workgroup would repeat it 32× per head.
157/// Keeping FWHT outside means it's done once per head regardless of NWG.
158///
159/// # Arguments
160///
161/// * `q`            — Query buffer `[num_heads, 1, head_dim]`, F32, **pre-rotated via FWHT**.
162/// * `output`       — Output buffer `[num_heads, 1, head_dim]`, F32, **in rotated domain**.
163/// * `tmp`          — Temporary buffer for NWG partial results.
164#[allow(clippy::too_many_arguments)]
165pub fn flash_attn_vec_tq(
166    encoder: &mut CommandEncoder,
167    registry: &mut KernelRegistry,
168    device: &MlxDevice,
169    q: &MlxBuffer,
170    k_packed: &MlxBuffer,
171    k_norms: &MlxBuffer,
172    v_packed: &MlxBuffer,
173    v_norms: &MlxBuffer,
174    output: &MlxBuffer,
175    tmp: &MlxBuffer,
176    params: &FlashAttnVecTqParams,
177) -> Result<()> {
178    validate_params(params)?;
179
180    let head_dim = params.head_dim;
181    let nwg = compute_nwg(params.kv_seq_len);
182
183    // Ensure scale_factor_d512 is always >= 1.0 (0.0 treated as 1.0 for safety).
184    let effective_scale_d512 = if params.scale_factor_d512 < 1e-6 { 1.0_f32 } else { params.scale_factor_d512 };
185    let gpu_params = FlashAttnVecTqParamsGpu {
186        n_heads: params.num_heads,
187        n_kv_heads: params.num_kv_heads,
188        head_dim: params.head_dim,
189        kv_seq_len: params.kv_seq_len,
190        kv_capacity: params.kv_capacity,
191        scale: params.scale,
192        mask_type: params.mask_type,
193        sliding_window: params.sliding_window,
194        softcap: params.softcap,
195        nwg,
196        ring_start: params.ring_start,
197        scale_factor_d512: effective_scale_d512,
198    };
199
200    let kernel_name = match head_dim {
201        256 => "flash_attn_vec_tq_dk256",
202        512 => "flash_attn_vec_tq_dk512",
203        _ => return Err(MlxError::InvalidArgument(format!(
204            "flash_attn_vec_tq: unsupported head_dim {head_dim}"
205        ))),
206    };
207    let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
208
209    // Shared memory size — same layout as flash_attn_vec.
210    // PK halfs (Q half4) + SH halfs (scratch) + 2*PV halfs (output float4)
211    let pk = pad2(head_dim as usize, 128);
212    let pv = pad2(head_dim as usize, 128);
213    let sh = 4 * 32; // 4 * C = 128 halfs
214    let shmem_halfs = pk + sh + 2 * pv;
215    let shmem_bytes = shmem_halfs * 2; // 2 bytes per half
216
217    // Tag for the reorder pass: SDPA is NOT reorderable.
218    encoder.set_op_kind(CapturedOpKind::Sdpa);
219
220    // Dispatch main kernel: (1 query, num_heads, NWG workgroups).
221    let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
222    let threadgroup_size = MTLSize::new(32, 1, 1); // 1 simdgroup of 32 threads
223
224    // NWG=1: write directly to output (no reduce needed).
225    // NWG>1: write to tmp, then reduce into output.
226    let dst_buf = if nwg == 1 { output } else { tmp };
227
228    encoder.encode_threadgroups_with_args_and_shared(
229        pipeline,
230        &[
231            (0, KernelArg::Bytes(as_bytes(&gpu_params))),
232            (1, KernelArg::Buffer(q)),
233            (2, KernelArg::Buffer(k_packed)),
234            (3, KernelArg::Buffer(k_norms)),
235            (4, KernelArg::Buffer(v_packed)),
236            (5, KernelArg::Buffer(v_norms)),
237            (6, KernelArg::Buffer(dst_buf)),
238        ],
239        &[(0, shmem_bytes as u64)],
240        threadgroups,
241        threadgroup_size,
242    );
243
244    // --- Reduce kernel (NWG > 1 only) ---
245    if nwg > 1 {
246        encoder.memory_barrier();
247
248        let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
249
250        let reduce_kernel = match head_dim {
251            256 => "flash_attn_vec_reduce_dk256",
252            512 => "flash_attn_vec_reduce_dk512",
253            _ => unreachable!(),
254        };
255        let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
256
257        let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
258        let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
259
260        encoder.encode_threadgroups_with_args(
261            reduce_pipeline,
262            &[
263                (0, KernelArg::Bytes(as_bytes(&reduce_params))),
264                (1, KernelArg::Buffer(tmp)),
265                (2, KernelArg::Buffer(output)),
266                (3, KernelArg::Bytes(as_bytes(&nwg))),
267            ],
268            reduce_tg,
269            reduce_tg_size,
270        );
271    }
272
273    Ok(())
274}
275
276/// Compute the size in bytes of the temporary buffer needed for TQ SDPA.
277///
278/// Sized for max NWG=32 regardless of actual adaptive NWG — the buffer is
279/// allocated once at model load time and reused for all context lengths.
280pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
281    let nrows = num_heads as usize;
282    let max_nwg = 32usize;
283    let dv = head_dim as usize;
284    (nrows * max_nwg * (dv + 2)) * std::mem::size_of::<f32>()
285}
286
287/// Pad x up to next multiple of n (n must be power of 2).
288fn pad2(x: usize, n: usize) -> usize {
289    (x + n - 1) & !(n - 1)
290}
291
292#[cfg(test)]
293#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_validate_params_ok() {
299        let p = FlashAttnVecTqParams {
300            num_heads: 8,
301            num_kv_heads: 4,
302            head_dim: 256,
303            kv_seq_len: 64,
304            kv_capacity: 1024,
305            scale: 1.0,
306            mask_type: 1,
307            sliding_window: 0,
308            softcap: 0.0,
309            ring_start: 0,
310            scale_factor_d512: 1.0,
311        };
312        assert!(validate_params(&p).is_ok());
313    }
314
315    #[test]
316    fn test_validate_params_bad_head_dim() {
317        let p = FlashAttnVecTqParams {
318            num_heads: 8,
319            num_kv_heads: 4,
320            head_dim: 128,
321            kv_seq_len: 64,
322            kv_capacity: 1024,
323            scale: 1.0,
324            mask_type: 0,
325            sliding_window: 0,
326            softcap: 0.0,
327            ring_start: 0,
328            scale_factor_d512: 1.0,
329        };
330        assert!(validate_params(&p).is_err());
331    }
332
333    #[test]
334    fn test_gpu_params_layout() {
335        assert_eq!(
336            std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
337            48, // 12 x u32/f32 = 48 bytes (iter-18: +scale_factor_d512)
338        );
339    }
340}