Skip to main content

mlx_native/ops/
flash_attn_vec_tq.rs

1//! Flash attention vector kernel dispatch for TurboQuant-compressed KV cache.
2//!
3//! Fork of `flash_attn_vec` that reads K and V from nibble-packed indices
4//! + per-position norms, with inline scalar dequant from a register-resident
5//! 16-element codebook. No centroid table buffer needed.
6//!
7//! Key differences from `flash_attn_vec`:
8//! - NWG=1: no reduce kernel needed (TQ's 4× smaller KV reads mean one
9//!   workgroup per head is sufficient)
10//! - FWHT rotation of Q and inverse-FWHT of output are FUSED into the kernel
11//!   (no standalone FWHT dispatches or barriers needed by the caller)
12//! - Dequant is inline: codebook[nibble] * inv_sqrt(head_dim) * norm
13//! - Zero scattered memory access — codebook fits in registers
14
15use metal::MTLSize;
16
17use crate::buffer::MlxBuffer;
18use crate::device::MlxDevice;
19use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
20use crate::error::{MlxError, Result};
21use crate::kernel_registry::KernelRegistry;
22
23/// MSL source for the TQ flash attention vector kernel (embedded at compile time).
24pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
25    include_str!("../shaders/flash_attn_vec_tq.metal");
26
27/// Register TQ flash attention vector shader source with the given kernel registry.
28pub fn register(registry: &mut KernelRegistry) {
29    registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
30    registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
31}
32
33/// Parameters for the TQ flash attention vector kernel.
34#[derive(Debug, Clone, Copy)]
35pub struct FlashAttnVecTqParams {
36    /// Number of query attention heads.
37    pub num_heads: u32,
38    /// Number of key/value attention heads (GQA: may be < num_heads).
39    pub num_kv_heads: u32,
40    /// Dimension of each attention head (256 or 512).
41    pub head_dim: u32,
42    /// Current KV sequence length (number of valid positions).
43    pub kv_seq_len: u32,
44    /// KV cache capacity (stride between KV heads in positions).
45    pub kv_capacity: u32,
46    /// Attention score scaling factor (e.g. 1/sqrt(head_dim) or 1.0).
47    pub scale: f32,
48    /// Mask type: 0=none, 1=causal, 2=sliding_window.
49    pub mask_type: u32,
50    /// Sliding window size (only used when mask_type == 2).
51    pub sliding_window: u32,
52    /// Logit softcapping (0 = disabled).
53    pub softcap: f32,
54}
55
56/// GPU-side reduce params. Must match `FlashAttnVecReduceParams` in the MSL.
57#[repr(C)]
58#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
59struct FlashAttnVecReduceParamsGpu {
60    nrows: u32,
61}
62
63/// GPU-side parameter struct. Must match the MSL `FlashAttnVecTqParams` exactly.
64#[repr(C)]
65#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
66struct FlashAttnVecTqParamsGpu {
67    n_heads: u32,
68    n_kv_heads: u32,
69    head_dim: u32,
70    kv_seq_len: u32,
71    kv_capacity: u32,
72    scale: f32,
73    mask_type: u32,
74    sliding_window: u32,
75    softcap: f32,
76    nwg: u32,
77}
78
79/// Validate TQ flash attention parameters.
80fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
81    if params.head_dim != 256 && params.head_dim != 512 {
82        return Err(MlxError::InvalidArgument(format!(
83            "flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
84            params.head_dim
85        )));
86    }
87    if params.num_heads == 0 || params.num_kv_heads == 0 {
88        return Err(MlxError::InvalidArgument(
89            "flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
90        ));
91    }
92    if params.num_heads % params.num_kv_heads != 0 {
93        return Err(MlxError::InvalidArgument(format!(
94            "flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
95            params.num_heads, params.num_kv_heads
96        )));
97    }
98    if params.kv_seq_len == 0 {
99        return Err(MlxError::InvalidArgument(
100            "flash_attn_vec_tq: kv_seq_len must be > 0".into(),
101        ));
102    }
103    if params.kv_capacity < params.kv_seq_len {
104        return Err(MlxError::InvalidArgument(format!(
105            "flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
106            params.kv_capacity, params.kv_seq_len
107        )));
108    }
109    Ok(())
110}
111
112/// Number of workgroups per head for TQ SDPA. Must match the F16 SDPA's NWG
113/// for similar parallelism. Each workgroup processes kv_seq_len/NWG chunks.
114const NWG: u32 = 32;
115
116/// Dispatch TQ flash attention vector kernel on the GPU.
117///
118/// Dispatches NWG=32 workgroups per head (same as F16 SDPA), then a reduce
119/// kernel to combine partial results. FWHT is applied per-workgroup before
120/// writing partials; since FWHT is linear, the reduce output is already in
121/// the original (un-rotated) domain.
122///
123/// # Arguments
124///
125/// * `encoder`      — Command encoder to record dispatches into.
126/// * `registry`     — Kernel registry for pipeline lookup/compilation.
127/// * `device`       — Metal device.
128/// * `q`            — Query buffer `[num_heads, 1, head_dim]`, F32.
129/// * `k_packed`     — Nibble-packed K indices `[num_kv_heads, kv_capacity, head_dim/2]`, U8.
130/// * `k_norms`      — Per-position K norms `[num_kv_heads, kv_capacity]`, F32.
131/// * `v_packed`     — Nibble-packed V indices `[num_kv_heads, kv_capacity, head_dim/2]`, U8.
132/// * `v_norms`      — Per-position V norms `[num_kv_heads, kv_capacity]`, F32.
133/// * `output`       — Output buffer `[num_heads, 1, head_dim]`, F32, pre-allocated.
134/// * `tmp`          — Temporary buffer for NWG partial results. Size from `tmp_buffer_bytes()`.
135/// * `params`       — TQ flash attention parameters.
136#[allow(clippy::too_many_arguments)]
137pub fn flash_attn_vec_tq(
138    encoder: &mut CommandEncoder,
139    registry: &mut KernelRegistry,
140    device: &MlxDevice,
141    q: &MlxBuffer,
142    k_packed: &MlxBuffer,
143    k_norms: &MlxBuffer,
144    v_packed: &MlxBuffer,
145    v_norms: &MlxBuffer,
146    output: &MlxBuffer,
147    tmp: &MlxBuffer,
148    params: &FlashAttnVecTqParams,
149) -> Result<()> {
150    validate_params(params)?;
151
152    let head_dim = params.head_dim;
153    let nwg = NWG;
154
155    let gpu_params = FlashAttnVecTqParamsGpu {
156        n_heads: params.num_heads,
157        n_kv_heads: params.num_kv_heads,
158        head_dim: params.head_dim,
159        kv_seq_len: params.kv_seq_len,
160        kv_capacity: params.kv_capacity,
161        scale: params.scale,
162        mask_type: params.mask_type,
163        sliding_window: params.sliding_window,
164        softcap: params.softcap,
165        nwg,
166    };
167
168    let kernel_name = match head_dim {
169        256 => "flash_attn_vec_tq_dk256",
170        512 => "flash_attn_vec_tq_dk512",
171        _ => return Err(MlxError::InvalidArgument(format!(
172            "flash_attn_vec_tq: unsupported head_dim {head_dim}"
173        ))),
174    };
175    let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
176
177    // Shared memory size — same layout as flash_attn_vec.
178    // PK halfs (Q half4) + SH halfs (scratch) + 2*PV halfs (output float4)
179    let pk = pad2(head_dim as usize, 128);
180    let pv = pad2(head_dim as usize, 128);
181    let sh = 4 * 32; // 4 * C = 128 halfs
182    let shmem_halfs = pk + sh + 2 * pv;
183    let shmem_bytes = shmem_halfs * 2; // 2 bytes per half
184
185    // Tag for the reorder pass: SDPA is NOT reorderable.
186    encoder.set_op_kind(CapturedOpKind::Sdpa);
187
188    // Dispatch main kernel: (1 query, num_heads, NWG workgroups).
189    let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
190    let threadgroup_size = MTLSize::new(32, 1, 1); // 1 simdgroup of 32 threads
191
192    encoder.encode_threadgroups_with_args_and_shared(
193        pipeline,
194        &[
195            (0, KernelArg::Bytes(as_bytes(&gpu_params))),
196            (1, KernelArg::Buffer(q)),
197            (2, KernelArg::Buffer(k_packed)),
198            (3, KernelArg::Buffer(k_norms)),
199            (4, KernelArg::Buffer(v_packed)),
200            (5, KernelArg::Buffer(v_norms)),
201            (6, KernelArg::Buffer(tmp)),
202        ],
203        &[(0, shmem_bytes as u64)],
204        threadgroups,
205        threadgroup_size,
206    );
207
208    // --- Reduce kernel ---
209    // Barrier: reduce reads `tmp` written by the main dispatch above.
210    encoder.memory_barrier();
211
212    let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
213
214    let reduce_kernel = match head_dim {
215        256 => "flash_attn_vec_reduce_dk256",
216        512 => "flash_attn_vec_reduce_dk512",
217        _ => unreachable!(),
218    };
219    let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
220
221    // Grid: (num_heads, 1, 1), Threadgroup: (32*NWG, 1, 1)
222    let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
223    let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
224
225    encoder.encode_threadgroups_with_args(
226        reduce_pipeline,
227        &[
228            (0, KernelArg::Bytes(as_bytes(&reduce_params))),
229            (1, KernelArg::Buffer(tmp)),
230            (2, KernelArg::Buffer(output)),
231            (3, KernelArg::Bytes(as_bytes(&nwg))),
232        ],
233        reduce_tg,
234        reduce_tg_size,
235    );
236
237    Ok(())
238}
239
240/// Compute the size in bytes of the temporary buffer needed for TQ SDPA.
241///
242/// Same formula as F16 SDPA: stores NWG partial output vectors + S/M values.
243pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
244    let nrows = num_heads as usize;
245    let nwg = NWG as usize;
246    let dv = head_dim as usize;
247    (nrows * nwg * (dv + 2)) * std::mem::size_of::<f32>()
248}
249
250/// Pad x up to next multiple of n (n must be power of 2).
251fn pad2(x: usize, n: usize) -> usize {
252    (x + n - 1) & !(n - 1)
253}
254
255#[cfg(test)]
256#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_validate_params_ok() {
262        let p = FlashAttnVecTqParams {
263            num_heads: 8,
264            num_kv_heads: 4,
265            head_dim: 256,
266            kv_seq_len: 64,
267            kv_capacity: 1024,
268            scale: 1.0,
269            mask_type: 1,
270            sliding_window: 0,
271            softcap: 0.0,
272        };
273        assert!(validate_params(&p).is_ok());
274    }
275
276    #[test]
277    fn test_validate_params_bad_head_dim() {
278        let p = FlashAttnVecTqParams {
279            num_heads: 8,
280            num_kv_heads: 4,
281            head_dim: 128,
282            kv_seq_len: 64,
283            kv_capacity: 1024,
284            scale: 1.0,
285            mask_type: 0,
286            sliding_window: 0,
287            softcap: 0.0,
288        };
289        assert!(validate_params(&p).is_err());
290    }
291
292    #[test]
293    fn test_gpu_params_layout() {
294        assert_eq!(
295            std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
296            40, // 10 x u32/f32 = 40 bytes
297        );
298    }
299}