Skip to main content

mlx_native/ops/
flash_attn_vec_tq.rs

1//! Flash attention vector kernel dispatch for TurboQuant-compressed KV cache.
2//!
3//! Fork of `flash_attn_vec` that reads K and V from nibble-packed indices
4//! + per-position norms, with inline scalar dequant from a register-resident
5//! 16-element codebook. No centroid table buffer needed.
6//!
7//! Key differences from `flash_attn_vec`:
8//! - Adaptive NWG (1-32) based on kv_seq_len. At short context NWG=1
9//!   avoids the reduce kernel. At long context NWG scales up for parallelism.
10//! - Caller handles FWHT: pre-rotates Q, post-rotates output (1× per head).
11//! - Dequant is inline: codebook[nibble] * inv_sqrt(head_dim) * norm
12//! - Zero scattered memory access — codebook fits in registers
13
14use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::device::MlxDevice;
18use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
19use crate::error::{MlxError, Result};
20use crate::kernel_registry::KernelRegistry;
21
22/// MSL source for the TQ flash attention vector kernel (embedded at compile time).
23pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
24    include_str!("../shaders/flash_attn_vec_tq.metal");
25
26/// Register TQ flash attention vector shader source with the given kernel registry.
27pub fn register(registry: &mut KernelRegistry) {
28    registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
29    registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
30}
31
32/// Parameters for the TQ flash attention vector kernel.
33#[derive(Debug, Clone, Copy)]
34pub struct FlashAttnVecTqParams {
35    /// Number of query attention heads.
36    pub num_heads: u32,
37    /// Number of key/value attention heads (GQA: may be < num_heads).
38    pub num_kv_heads: u32,
39    /// Dimension of each attention head (256 or 512).
40    pub head_dim: u32,
41    /// Current KV sequence length (number of valid positions).
42    pub kv_seq_len: u32,
43    /// KV cache capacity (stride between KV heads in positions).
44    pub kv_capacity: u32,
45    /// Attention score scaling factor (e.g. 1/sqrt(head_dim) or 1.0).
46    pub scale: f32,
47    /// Mask type: 0=none, 1=causal, 2=sliding_window.
48    pub mask_type: u32,
49    /// Sliding window size (only used when mask_type == 2).
50    pub sliding_window: u32,
51    /// Logit softcapping (0 = disabled).
52    pub softcap: f32,
53}
54
55/// GPU-side reduce params. Must match `FlashAttnVecReduceParams` in the MSL.
56#[repr(C)]
57#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
58struct FlashAttnVecReduceParamsGpu {
59    nrows: u32,
60}
61
62/// GPU-side parameter struct. Must match the MSL `FlashAttnVecTqParams` exactly.
63#[repr(C)]
64#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
65struct FlashAttnVecTqParamsGpu {
66    n_heads: u32,
67    n_kv_heads: u32,
68    head_dim: u32,
69    kv_seq_len: u32,
70    kv_capacity: u32,
71    scale: f32,
72    mask_type: u32,
73    sliding_window: u32,
74    softcap: f32,
75    nwg: u32,
76}
77
78/// Validate TQ flash attention parameters.
79fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
80    if params.head_dim != 256 && params.head_dim != 512 {
81        return Err(MlxError::InvalidArgument(format!(
82            "flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
83            params.head_dim
84        )));
85    }
86    if params.num_heads == 0 || params.num_kv_heads == 0 {
87        return Err(MlxError::InvalidArgument(
88            "flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
89        ));
90    }
91    if params.num_heads % params.num_kv_heads != 0 {
92        return Err(MlxError::InvalidArgument(format!(
93            "flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
94            params.num_heads, params.num_kv_heads
95        )));
96    }
97    if params.kv_seq_len == 0 {
98        return Err(MlxError::InvalidArgument(
99            "flash_attn_vec_tq: kv_seq_len must be > 0".into(),
100        ));
101    }
102    if params.kv_capacity < params.kv_seq_len {
103        return Err(MlxError::InvalidArgument(format!(
104            "flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
105            params.kv_capacity, params.kv_seq_len
106        )));
107    }
108    Ok(())
109}
110
111/// Compute NWG for TQ SDPA.
112///
113/// NWG=16 is optimal across both short and long context on M5 Max
114/// (measured: outperforms both NWG=1 and NWG=32 at all tested lengths).
115/// NWG=32 adds reduce kernel overhead that outweighs its parallelism gain.
116/// NWG<16 starves the GPU at long context.
117///
118/// Override: set HF2Q_TQ_NWG=N to force a specific value (for benchmarking).
119fn compute_nwg(_kv_seq_len: u32) -> u32 {
120    if let Ok(v) = std::env::var("HF2Q_TQ_NWG") {
121        if let Ok(n) = v.parse::<u32>() {
122            if n >= 1 && n <= 32 {
123                return n;
124            }
125        }
126    }
127    16
128}
129
130/// Dispatch TQ flash attention vector kernel on the GPU.
131///
132/// Dispatches NWG=32 workgroups per head, then a reduce kernel.
133///
134/// **FWHT is NOT done inside this kernel.** The caller must:
135/// 1. Pre-rotate Q via `dispatch_fwht_f32` before calling this function
136/// 2. Apply inverse FWHT to the output after this function returns
137///
138/// With NWG=32, doing FWHT per-workgroup would repeat it 32× per head.
139/// Keeping FWHT outside means it's done once per head regardless of NWG.
140///
141/// # Arguments
142///
143/// * `q`            — Query buffer `[num_heads, 1, head_dim]`, F32, **pre-rotated via FWHT**.
144/// * `output`       — Output buffer `[num_heads, 1, head_dim]`, F32, **in rotated domain**.
145/// * `tmp`          — Temporary buffer for NWG partial results.
146#[allow(clippy::too_many_arguments)]
147pub fn flash_attn_vec_tq(
148    encoder: &mut CommandEncoder,
149    registry: &mut KernelRegistry,
150    device: &MlxDevice,
151    q: &MlxBuffer,
152    k_packed: &MlxBuffer,
153    k_norms: &MlxBuffer,
154    v_packed: &MlxBuffer,
155    v_norms: &MlxBuffer,
156    output: &MlxBuffer,
157    tmp: &MlxBuffer,
158    params: &FlashAttnVecTqParams,
159) -> Result<()> {
160    validate_params(params)?;
161
162    let head_dim = params.head_dim;
163    let nwg = compute_nwg(params.kv_seq_len);
164
165    let gpu_params = FlashAttnVecTqParamsGpu {
166        n_heads: params.num_heads,
167        n_kv_heads: params.num_kv_heads,
168        head_dim: params.head_dim,
169        kv_seq_len: params.kv_seq_len,
170        kv_capacity: params.kv_capacity,
171        scale: params.scale,
172        mask_type: params.mask_type,
173        sliding_window: params.sliding_window,
174        softcap: params.softcap,
175        nwg,
176    };
177
178    let kernel_name = match head_dim {
179        256 => "flash_attn_vec_tq_dk256",
180        512 => "flash_attn_vec_tq_dk512",
181        _ => return Err(MlxError::InvalidArgument(format!(
182            "flash_attn_vec_tq: unsupported head_dim {head_dim}"
183        ))),
184    };
185    let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
186
187    // Shared memory size — same layout as flash_attn_vec.
188    // PK halfs (Q half4) + SH halfs (scratch) + 2*PV halfs (output float4)
189    let pk = pad2(head_dim as usize, 128);
190    let pv = pad2(head_dim as usize, 128);
191    let sh = 4 * 32; // 4 * C = 128 halfs
192    let shmem_halfs = pk + sh + 2 * pv;
193    let shmem_bytes = shmem_halfs * 2; // 2 bytes per half
194
195    // Tag for the reorder pass: SDPA is NOT reorderable.
196    encoder.set_op_kind(CapturedOpKind::Sdpa);
197
198    // Dispatch main kernel: (1 query, num_heads, NWG workgroups).
199    let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
200    let threadgroup_size = MTLSize::new(32, 1, 1); // 1 simdgroup of 32 threads
201
202    // NWG=1: write directly to output (no reduce needed).
203    // NWG>1: write to tmp, then reduce into output.
204    let dst_buf = if nwg == 1 { output } else { tmp };
205
206    encoder.encode_threadgroups_with_args_and_shared(
207        pipeline,
208        &[
209            (0, KernelArg::Bytes(as_bytes(&gpu_params))),
210            (1, KernelArg::Buffer(q)),
211            (2, KernelArg::Buffer(k_packed)),
212            (3, KernelArg::Buffer(k_norms)),
213            (4, KernelArg::Buffer(v_packed)),
214            (5, KernelArg::Buffer(v_norms)),
215            (6, KernelArg::Buffer(dst_buf)),
216        ],
217        &[(0, shmem_bytes as u64)],
218        threadgroups,
219        threadgroup_size,
220    );
221
222    // --- Reduce kernel (NWG > 1 only) ---
223    if nwg > 1 {
224        encoder.memory_barrier();
225
226        let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
227
228        let reduce_kernel = match head_dim {
229            256 => "flash_attn_vec_reduce_dk256",
230            512 => "flash_attn_vec_reduce_dk512",
231            _ => unreachable!(),
232        };
233        let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
234
235        let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
236        let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
237
238        encoder.encode_threadgroups_with_args(
239            reduce_pipeline,
240            &[
241                (0, KernelArg::Bytes(as_bytes(&reduce_params))),
242                (1, KernelArg::Buffer(tmp)),
243                (2, KernelArg::Buffer(output)),
244                (3, KernelArg::Bytes(as_bytes(&nwg))),
245            ],
246            reduce_tg,
247            reduce_tg_size,
248        );
249    }
250
251    Ok(())
252}
253
254/// Compute the size in bytes of the temporary buffer needed for TQ SDPA.
255///
256/// Sized for max NWG=32 regardless of actual adaptive NWG — the buffer is
257/// allocated once at model load time and reused for all context lengths.
258pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
259    let nrows = num_heads as usize;
260    let max_nwg = 32usize;
261    let dv = head_dim as usize;
262    (nrows * max_nwg * (dv + 2)) * std::mem::size_of::<f32>()
263}
264
265/// Pad x up to next multiple of n (n must be power of 2).
266fn pad2(x: usize, n: usize) -> usize {
267    (x + n - 1) & !(n - 1)
268}
269
270#[cfg(test)]
271#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_validate_params_ok() {
277        let p = FlashAttnVecTqParams {
278            num_heads: 8,
279            num_kv_heads: 4,
280            head_dim: 256,
281            kv_seq_len: 64,
282            kv_capacity: 1024,
283            scale: 1.0,
284            mask_type: 1,
285            sliding_window: 0,
286            softcap: 0.0,
287        };
288        assert!(validate_params(&p).is_ok());
289    }
290
291    #[test]
292    fn test_validate_params_bad_head_dim() {
293        let p = FlashAttnVecTqParams {
294            num_heads: 8,
295            num_kv_heads: 4,
296            head_dim: 128,
297            kv_seq_len: 64,
298            kv_capacity: 1024,
299            scale: 1.0,
300            mask_type: 0,
301            sliding_window: 0,
302            softcap: 0.0,
303        };
304        assert!(validate_params(&p).is_err());
305    }
306
307    #[test]
308    fn test_gpu_params_layout() {
309        assert_eq!(
310            std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
311            40, // 10 x u32/f32 = 40 bytes
312        );
313    }
314}