Skip to main content

mlx_native/ops/
flash_attn_vec_tq.rs

1//! Flash attention vector kernel dispatch for TurboQuant-compressed KV cache.
2//!
3//! Fork of `flash_attn_vec` that reads K and V from nibble-packed indices
4//! + per-position norms, with inline scalar dequant from a register-resident
5//! 16-element codebook. No centroid table buffer needed.
6//!
7//! Key differences from `flash_attn_vec`:
8//! - Adaptive NWG (1-32) based on kv_seq_len. At short context NWG=1
9//!   avoids the reduce kernel. At long context NWG scales up for parallelism.
10//! - Caller handles FWHT: pre-rotates Q, post-rotates output (1× per head).
11//! - Dequant is inline: codebook[nibble] * inv_sqrt(head_dim) * norm
12//! - Zero scattered memory access — codebook fits in registers
13
14use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::device::MlxDevice;
18use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
19use crate::error::{MlxError, Result};
20use crate::kernel_registry::KernelRegistry;
21
22/// MSL source for the TQ flash attention vector kernel (embedded at compile time).
23pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
24    include_str!("../shaders/flash_attn_vec_tq.metal");
25
26/// Register TQ flash attention vector shader source with the given kernel registry.
27pub fn register(registry: &mut KernelRegistry) {
28    registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
29    registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
30}
31
32/// Parameters for the TQ flash attention vector kernel.
33#[derive(Debug, Clone, Copy)]
34pub struct FlashAttnVecTqParams {
35    /// Number of query attention heads.
36    pub num_heads: u32,
37    /// Number of key/value attention heads (GQA: may be < num_heads).
38    pub num_kv_heads: u32,
39    /// Dimension of each attention head (256 or 512).
40    pub head_dim: u32,
41    /// Current KV sequence length (number of valid positions).
42    pub kv_seq_len: u32,
43    /// KV cache capacity (stride between KV heads in positions).
44    pub kv_capacity: u32,
45    /// Attention score scaling factor (e.g. 1/sqrt(head_dim) or 1.0).
46    pub scale: f32,
47    /// Mask type: 0=none, 1=causal, 2=sliding_window.
48    pub mask_type: u32,
49    /// Sliding window size (only used when mask_type == 2).
50    pub sliding_window: u32,
51    /// Logit softcapping (0 = disabled).
52    pub softcap: f32,
53    /// Ring buffer start slot for sliding-window after wrap (ADR-009 Track 2).
54    ///
55    /// Physical slot index of the chronologically OLDEST entry in the ring
56    /// buffer. Before wrap (`kv_seq_len < kv_capacity`): set to 0.
57    /// After wrap: set to `write_pos % capacity`.
58    ///
59    /// The shader uses this to map physical slots to logical positions
60    /// for correct causal/sliding-window masking after wrap.
61    pub ring_start: u32,
62    /// iter-18 S2B: reciprocal of the D=512 per-block encoder scale factor.
63    /// Decoder applies: actual_blk_norm = stored_blk_norm / scale_factor_d512.
64    /// bare=1.0 (iter-16 control), sqrt256=16.0, sqrt512≈22.627.
65    /// D=256 path ignores this field.
66    /// Use `None`/0.0 to default to 1.0 (bare behavior).
67    pub scale_factor_d512: f32,
68}
69
70/// GPU-side reduce params. Must match `FlashAttnVecReduceParams` in the MSL.
71#[repr(C)]
72#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
73struct FlashAttnVecReduceParamsGpu {
74    nrows: u32,
75}
76
77/// GPU-side parameter struct. Must match the MSL `FlashAttnVecTqParams` exactly.
78#[repr(C)]
79#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
80struct FlashAttnVecTqParamsGpu {
81    n_heads: u32,
82    n_kv_heads: u32,
83    head_dim: u32,
84    kv_seq_len: u32,
85    kv_capacity: u32,
86    scale: f32,
87    mask_type: u32,
88    sliding_window: u32,
89    softcap: f32,
90    nwg: u32,
91    ring_start: u32,
92    /// iter-18 S2B: reciprocal scale factor for D=512 dequant. See FlashAttnVecTqParams.
93    scale_factor_d512: f32,
94}
95
96/// Validate TQ flash attention parameters.
97fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
98    if params.head_dim != 256 && params.head_dim != 512 {
99        return Err(MlxError::InvalidArgument(format!(
100            "flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
101            params.head_dim
102        )));
103    }
104    if params.num_heads == 0 || params.num_kv_heads == 0 {
105        return Err(MlxError::InvalidArgument(
106            "flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
107        ));
108    }
109    if params.num_heads % params.num_kv_heads != 0 {
110        return Err(MlxError::InvalidArgument(format!(
111            "flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
112            params.num_heads, params.num_kv_heads
113        )));
114    }
115    if params.kv_seq_len == 0 {
116        return Err(MlxError::InvalidArgument(
117            "flash_attn_vec_tq: kv_seq_len must be > 0".into(),
118        ));
119    }
120    if params.kv_capacity < params.kv_seq_len {
121        return Err(MlxError::InvalidArgument(format!(
122            "flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
123            params.kv_capacity, params.kv_seq_len
124        )));
125    }
126    Ok(())
127}
128
129/// Compute NWG for TQ SDPA.
130///
131/// NWG=16 is optimal across both short and long context on M5 Max
132/// (measured: outperforms both NWG=1 and NWG=32 at all tested lengths).
133/// NWG=32 adds reduce kernel overhead that outweighs its parallelism gain.
134/// NWG<16 starves the GPU at long context.
135///
136/// Override: set HF2Q_TQ_NWG=N to force a specific value (for benchmarking).
137fn compute_nwg(_kv_seq_len: u32) -> u32 {
138    // ADR-029 iter-175 Step 1az: cached parsed override (same pattern as
139    // flash_attn_vec_tq_hb::compute_nwg from Step 1at).  This 4-bit TQ path
140    // is NOT in production gemma4 (default codebook is 8-bit), but caching
141    // here is consistent + benign + protects against future codebook flip.
142    use std::sync::atomic::{AtomicI32, Ordering};
143    static CACHED_TQ_NWG_4BIT: AtomicI32 = AtomicI32::new(-1);
144    let mut v = CACHED_TQ_NWG_4BIT.load(Ordering::Relaxed);
145    if v < 0 {
146        let parsed = std::env::var("HF2Q_TQ_NWG")
147            .ok()
148            .and_then(|s| s.parse::<u32>().ok())
149            .filter(|&n| n >= 1 && n <= 32)
150            .unwrap_or(0);
151        CACHED_TQ_NWG_4BIT.store(parsed as i32, Ordering::Relaxed);
152        v = parsed as i32;
153    }
154    if v > 0 { v as u32 } else { 16 }
155}
156
157/// Dispatch TQ flash attention vector kernel on the GPU.
158///
159/// Dispatches NWG=32 workgroups per head, then a reduce kernel.
160///
161/// **FWHT is NOT done inside this kernel.** The caller must:
162/// 1. Pre-rotate Q via `dispatch_fwht_f32` before calling this function
163/// 2. Apply inverse FWHT to the output after this function returns
164///
165/// With NWG=32, doing FWHT per-workgroup would repeat it 32× per head.
166/// Keeping FWHT outside means it's done once per head regardless of NWG.
167///
168/// # Arguments
169///
170/// * `q`            — Query buffer `[num_heads, 1, head_dim]`, F32, **pre-rotated via FWHT**.
171/// * `output`       — Output buffer `[num_heads, 1, head_dim]`, F32, **in rotated domain**.
172/// * `tmp`          — Temporary buffer for NWG partial results.
173#[allow(clippy::too_many_arguments)]
174pub fn flash_attn_vec_tq(
175    encoder: &mut CommandEncoder,
176    registry: &mut KernelRegistry,
177    device: &MlxDevice,
178    q: &MlxBuffer,
179    k_packed: &MlxBuffer,
180    k_norms: &MlxBuffer,
181    v_packed: &MlxBuffer,
182    v_norms: &MlxBuffer,
183    output: &MlxBuffer,
184    tmp: &MlxBuffer,
185    params: &FlashAttnVecTqParams,
186) -> Result<()> {
187    validate_params(params)?;
188
189    let head_dim = params.head_dim;
190    let nwg = compute_nwg(params.kv_seq_len);
191
192    // Ensure scale_factor_d512 is always >= 1.0 (0.0 treated as 1.0 for safety).
193    let effective_scale_d512 = if params.scale_factor_d512 < 1e-6 { 1.0_f32 } else { params.scale_factor_d512 };
194    let gpu_params = FlashAttnVecTqParamsGpu {
195        n_heads: params.num_heads,
196        n_kv_heads: params.num_kv_heads,
197        head_dim: params.head_dim,
198        kv_seq_len: params.kv_seq_len,
199        kv_capacity: params.kv_capacity,
200        scale: params.scale,
201        mask_type: params.mask_type,
202        sliding_window: params.sliding_window,
203        softcap: params.softcap,
204        nwg,
205        ring_start: params.ring_start,
206        scale_factor_d512: effective_scale_d512,
207    };
208
209    let kernel_name = match head_dim {
210        256 => "flash_attn_vec_tq_dk256",
211        512 => "flash_attn_vec_tq_dk512",
212        _ => return Err(MlxError::InvalidArgument(format!(
213            "flash_attn_vec_tq: unsupported head_dim {head_dim}"
214        ))),
215    };
216    let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
217
218    // Shared memory size — same layout as flash_attn_vec.
219    // PK halfs (Q half4) + SH halfs (scratch) + 2*PV halfs (output float4)
220    let pk = pad2(head_dim as usize, 128);
221    let pv = pad2(head_dim as usize, 128);
222    let sh = 4 * 32; // 4 * C = 128 halfs
223    let shmem_halfs = pk + sh + 2 * pv;
224    let shmem_bytes = shmem_halfs * 2; // 2 bytes per half
225
226    // Tag for the reorder pass: SDPA is NOT reorderable.
227    encoder.set_op_kind(CapturedOpKind::Sdpa);
228
229    // Dispatch main kernel: (1 query, num_heads, NWG workgroups).
230    let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
231    let threadgroup_size = MTLSize::new(32, 1, 1); // 1 simdgroup of 32 threads
232
233    // NWG=1: write directly to output (no reduce needed).
234    // NWG>1: write to tmp, then reduce into output.
235    let dst_buf = if nwg == 1 { output } else { tmp };
236
237    encoder.encode_threadgroups_with_args_and_shared(
238        pipeline,
239        &[
240            (0, KernelArg::Bytes(as_bytes(&gpu_params))),
241            (1, KernelArg::Buffer(q)),
242            (2, KernelArg::Buffer(k_packed)),
243            (3, KernelArg::Buffer(k_norms)),
244            (4, KernelArg::Buffer(v_packed)),
245            (5, KernelArg::Buffer(v_norms)),
246            (6, KernelArg::Buffer(dst_buf)),
247        ],
248        &[(0, shmem_bytes as u64)],
249        threadgroups,
250        threadgroup_size,
251    );
252
253    // --- Reduce kernel (NWG > 1 only) ---
254    if nwg > 1 {
255        encoder.memory_barrier();
256
257        let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
258
259        let reduce_kernel = match head_dim {
260            256 => "flash_attn_vec_reduce_dk256",
261            512 => "flash_attn_vec_reduce_dk512",
262            _ => unreachable!(),
263        };
264        let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
265
266        let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
267        let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
268
269        encoder.encode_threadgroups_with_args(
270            reduce_pipeline,
271            &[
272                (0, KernelArg::Bytes(as_bytes(&reduce_params))),
273                (1, KernelArg::Buffer(tmp)),
274                (2, KernelArg::Buffer(output)),
275                (3, KernelArg::Bytes(as_bytes(&nwg))),
276            ],
277            reduce_tg,
278            reduce_tg_size,
279        );
280    }
281
282    Ok(())
283}
284
285/// Compute the size in bytes of the temporary buffer needed for TQ SDPA.
286///
287/// Sized for max NWG=32 regardless of actual adaptive NWG — the buffer is
288/// allocated once at model load time and reused for all context lengths.
289pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
290    let nrows = num_heads as usize;
291    let max_nwg = 32usize;
292    let dv = head_dim as usize;
293    (nrows * max_nwg * (dv + 2)) * std::mem::size_of::<f32>()
294}
295
296/// Pad x up to next multiple of n (n must be power of 2).
297fn pad2(x: usize, n: usize) -> usize {
298    (x + n - 1) & !(n - 1)
299}
300
301#[cfg(test)]
302#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_validate_params_ok() {
308        let p = FlashAttnVecTqParams {
309            num_heads: 8,
310            num_kv_heads: 4,
311            head_dim: 256,
312            kv_seq_len: 64,
313            kv_capacity: 1024,
314            scale: 1.0,
315            mask_type: 1,
316            sliding_window: 0,
317            softcap: 0.0,
318            ring_start: 0,
319            scale_factor_d512: 1.0,
320        };
321        assert!(validate_params(&p).is_ok());
322    }
323
324    #[test]
325    fn test_validate_params_bad_head_dim() {
326        let p = FlashAttnVecTqParams {
327            num_heads: 8,
328            num_kv_heads: 4,
329            head_dim: 128,
330            kv_seq_len: 64,
331            kv_capacity: 1024,
332            scale: 1.0,
333            mask_type: 0,
334            sliding_window: 0,
335            softcap: 0.0,
336            ring_start: 0,
337            scale_factor_d512: 1.0,
338        };
339        assert!(validate_params(&p).is_err());
340    }
341
342    #[test]
343    fn test_gpu_params_layout() {
344        assert_eq!(
345            std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
346            48, // 12 x u32/f32 = 48 bytes (iter-18: +scale_factor_d512)
347        );
348    }
349}