Skip to main content

mlx_native/ops/
flash_attn_vec.rs

1//! Flash attention vector kernel dispatch — SIMD-vectorized decode-path SDPA.
2//!
3//! Ported from llama.cpp's `flash_attn_ext_vec` kernel. This replaces the naive
4//! SDPA kernel with a workgroup-parallel implementation that splits the KV cache
5//! across `nwg` workgroups, each computing partial softmax results, then a
6//! reduce kernel combines them.
7//!
8//! This kernel is optimized for the decode path (seq_len=1) with F32 Q/K/V.
9
10use metal::MTLSize;
11
12use crate::buffer::MlxBuffer;
13use crate::device::MlxDevice;
14use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17use crate::DType;
18
19/// MSL source for the flash attention vector kernel (embedded at compile time).
20pub static FLASH_ATTN_VEC_SHADER_SOURCE: &str =
21    include_str!("../shaders/flash_attn_vec.metal");
22
23/// Register flash attention vector shader source with the given kernel registry.
24pub fn register(registry: &mut KernelRegistry) {
25    registry.register_source("flash_attn_vec_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
26    registry.register_source("flash_attn_vec_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
27    registry.register_source("flash_attn_vec_reduce_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
28    registry.register_source("flash_attn_vec_reduce_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
29    // F16 KV variants (Phase 4a)
30    registry.register_source("flash_attn_vec_f16kv_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
31    registry.register_source("flash_attn_vec_f16kv_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
32}
33
34/// Parameters for the flash attention vector kernel.
35#[derive(Debug, Clone, Copy)]
36pub struct FlashAttnVecParams {
37    /// Number of query attention heads.
38    pub num_heads: u32,
39    /// Number of key/value attention heads (GQA: may be < num_heads).
40    pub num_kv_heads: u32,
41    /// Dimension of each attention head (256 or 512).
42    pub head_dim: u32,
43    /// Current KV sequence length (number of valid positions).
44    pub kv_seq_len: u32,
45    /// KV cache capacity (stride between KV heads in positions).
46    pub kv_capacity: u32,
47    /// Attention score scaling factor (e.g. 1/sqrt(head_dim) or 1.0).
48    pub scale: f32,
49    /// Mask type: 0=none, 1=causal, 2=sliding_window.
50    pub mask_type: u32,
51    /// Sliding window size (only used when mask_type == 2).
52    pub sliding_window: u32,
53    /// Logit softcapping (0 = disabled).
54    pub softcap: f32,
55}
56
57/// GPU-side parameter struct. Must match the MSL `FlashAttnVecParams` exactly.
58#[repr(C)]
59#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
60struct FlashAttnVecParamsGpu {
61    n_heads: u32,
62    n_kv_heads: u32,
63    head_dim: u32,
64    kv_seq_len: u32,
65    kv_capacity: u32,
66    scale: f32,
67    mask_type: u32,
68    sliding_window: u32,
69    softcap: f32,
70    nwg: u32,
71}
72
73/// GPU-side reduce params. Must match MSL `FlashAttnVecReduceParams`.
74#[repr(C)]
75#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
76struct FlashAttnVecReduceParamsGpu {
77    nrows: u32,
78}
79
80/// Number of workgroups to split the KV cache across.
81/// llama.cpp uses 32 as default. Must be <= 32 (one per SIMD lane in reduce).
82const NWG: u32 = 32;
83
84/// Validate flash attention parameters.
85fn validate_params(params: &FlashAttnVecParams) -> Result<()> {
86    if params.head_dim != 256 && params.head_dim != 512 {
87        return Err(MlxError::InvalidArgument(format!(
88            "flash_attn_vec: head_dim must be 256 or 512, got {}",
89            params.head_dim
90        )));
91    }
92    if params.num_heads == 0 || params.num_kv_heads == 0 {
93        return Err(MlxError::InvalidArgument(
94            "flash_attn_vec: num_heads and num_kv_heads must be > 0".into(),
95        ));
96    }
97    if params.num_heads % params.num_kv_heads != 0 {
98        return Err(MlxError::InvalidArgument(format!(
99            "flash_attn_vec: num_heads ({}) must be divisible by num_kv_heads ({})",
100            params.num_heads, params.num_kv_heads
101        )));
102    }
103    if params.kv_seq_len == 0 {
104        return Err(MlxError::InvalidArgument(
105            "flash_attn_vec: kv_seq_len must be > 0".into(),
106        ));
107    }
108    if params.kv_capacity < params.kv_seq_len {
109        return Err(MlxError::InvalidArgument(format!(
110            "flash_attn_vec: kv_capacity ({}) must be >= kv_seq_len ({})",
111            params.kv_capacity, params.kv_seq_len
112        )));
113    }
114    Ok(())
115}
116
117/// Dispatch flash attention vector kernel on the GPU.
118///
119/// This dispatches two Metal compute passes:
120/// 1. The main kernel with NWG workgroups per head computing partial results
121/// 2. The reduce kernel combining results from all workgroups
122///
123/// # Arguments
124///
125/// * `encoder`  — Command encoder to record dispatches into.
126/// * `registry` — Kernel registry for pipeline lookup/compilation.
127/// * `device`   — Metal device for buffer allocation.
128/// * `q`        — Query buffer `[num_heads, 1, head_dim]`, F32.
129/// * `k`        — Key buffer `[num_kv_heads, kv_capacity, head_dim]`, F32.
130/// * `v`        — Value buffer `[num_kv_heads, kv_capacity, head_dim]`, F32.
131/// * `output`   — Output buffer `[num_heads, 1, head_dim]`, F32, pre-allocated.
132/// * `tmp`      — Temporary buffer for workgroup partial results, pre-allocated.
133///                Size: `num_heads * nwg * (head_dim + 2) * sizeof(f32)` bytes.
134/// * `params`   — Flash attention parameters.
135pub fn flash_attn_vec(
136    encoder: &mut CommandEncoder,
137    registry: &mut KernelRegistry,
138    device: &MlxDevice,
139    q: &MlxBuffer,
140    k: &MlxBuffer,
141    v: &MlxBuffer,
142    output: &MlxBuffer,
143    tmp: &MlxBuffer,
144    params: &FlashAttnVecParams,
145) -> Result<()> {
146    validate_params(params)?;
147
148    let head_dim = params.head_dim;
149    let nwg = NWG;
150
151    // Build GPU params.
152    let gpu_params = FlashAttnVecParamsGpu {
153        n_heads: params.num_heads,
154        n_kv_heads: params.num_kv_heads,
155        head_dim: params.head_dim,
156        kv_seq_len: params.kv_seq_len,
157        kv_capacity: params.kv_capacity,
158        scale: params.scale,
159        mask_type: params.mask_type,
160        sliding_window: params.sliding_window,
161        softcap: params.softcap,
162        nwg,
163    };
164    // Select kernel by head dimension and KV dtype.
165    // F16 KV: K/V buffers are half-precision, halving bandwidth (Phase 4a).
166    let kv_is_f16 = k.dtype() == DType::F16;
167    let kernel_name = match (head_dim, kv_is_f16) {
168        (256, false) => "flash_attn_vec_dk256",
169        (512, false) => "flash_attn_vec_dk512",
170        (256, true)  => "flash_attn_vec_f16kv_dk256",
171        (512, true)  => "flash_attn_vec_f16kv_dk512",
172        _ => unreachable!(), // validated above
173    };
174    let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
175
176    // Shared memory size.
177    // Layout: PK halfs (Q) + SH halfs (scratch) + 2*PV halfs (output as float4)
178    // PK = PAD2(head_dim, 128), PV = PAD2(head_dim, 128)
179    // SH = 4 * C = 128 halfs
180    let pk = pad2(head_dim as usize, 128);
181    let pv = pad2(head_dim as usize, 128);
182    let sh = 4 * 32; // 4 * C = 128 halfs
183    let shmem_halfs = pk + sh + 2 * pv;
184    let shmem_bytes = shmem_halfs * 2; // 2 bytes per half
185
186    // Tag for the reorder pass (Phase 4e.3): SDPA is NOT reorderable.
187    encoder.set_op_kind(CapturedOpKind::Sdpa);
188
189    // Dispatch main kernel.
190    // Grid: (1 query, num_heads, nwg)
191    let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
192    let threadgroup_size = MTLSize::new(32, 1, 1); // 1 simdgroup of 32 threads
193
194    // Pass params as inline bytes — no Metal buffer allocation.
195    // This eliminates 1 [MTLDevice newBufferWithLength:] per SDPA call (30/token).
196    encoder.encode_threadgroups_with_args_and_shared(
197        pipeline,
198        &[
199            (0, KernelArg::Bytes(as_bytes(&gpu_params))),
200            (1, KernelArg::Buffer(q)),
201            (2, KernelArg::Buffer(k)),
202            (3, KernelArg::Buffer(v)),
203            (4, KernelArg::Buffer(tmp)),
204        ],
205        &[(0, shmem_bytes as u64)],
206        threadgroups,
207        threadgroup_size,
208    );
209
210    // --- Reduce kernel ---
211    // Only needed when NWG > 1.
212    // Barrier: reduce reads `tmp` written by the main dispatch above.
213    // With MTLDispatchTypeConcurrent, both dispatches would run simultaneously
214    // without this barrier, causing the reduce to read stale/partial `tmp` data.
215    if nwg > 1 {
216        encoder.memory_barrier();
217        let reduce_params = FlashAttnVecReduceParamsGpu {
218            nrows: params.num_heads,
219        };
220
221        let reduce_kernel = match head_dim {
222            256 => "flash_attn_vec_reduce_dk256",
223            512 => "flash_attn_vec_reduce_dk512",
224            _ => unreachable!(),
225        };
226        let reduce_pipeline =
227            registry.get_pipeline(reduce_kernel, device.metal_device())?;
228
229        // Grid: (num_heads, 1, 1), Threadgroup: (32*NWG, 1, 1)
230        let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
231        let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
232
233        // Annotate reads/writes for the capture graph so the reorder pass
234        // (Phase 4e.3) can track data dependencies on the reduce kernel.
235        {
236            let read_ranges = vec![
237                {
238                    let s = tmp.contents_ptr() as usize;
239                    (s, s + tmp.byte_len())
240                },
241            ];
242            let write_ranges = vec![
243                {
244                    let s = output.contents_ptr() as usize;
245                    (s, s + output.byte_len())
246                },
247            ];
248            encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
249        }
250
251        // Pass params as inline bytes — no Metal buffer allocation.
252        // This eliminates 2 [MTLDevice newBufferWithLength:] per SDPA call (60/token).
253        encoder.encode_threadgroups_with_args(
254            reduce_pipeline,
255            &[
256                (0, KernelArg::Bytes(as_bytes(&reduce_params))),
257                (1, KernelArg::Buffer(tmp)),
258                (2, KernelArg::Buffer(output)),
259                (3, KernelArg::Bytes(as_bytes(&nwg))),
260            ],
261            reduce_tg,
262            reduce_tg_size,
263        );
264    }
265
266    Ok(())
267}
268
269/// Compute the size in bytes of the temporary buffer needed for flash_attn_vec.
270///
271/// The temp buffer stores partial results from NWG workgroups:
272/// - `nrows * head_dim * NWG` floats for the partial output vectors
273/// - `nrows * 2 * NWG` floats for the S and M values
274pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
275    let nrows = num_heads as usize;
276    let nwg = NWG as usize;
277    let dv = head_dim as usize;
278    // DV * NWG floats per row for output, plus 2 * NWG floats per row for S/M.
279    (nrows * nwg * (dv + 2)) * std::mem::size_of::<f32>()
280}
281
282/// Pad x up to next multiple of n (n must be power of 2).
283fn pad2(x: usize, n: usize) -> usize {
284    (x + n - 1) & !(n - 1)
285}
286
287#[cfg(test)]
288#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_validate_params_ok() {
294        let p = FlashAttnVecParams {
295            num_heads: 16,
296            num_kv_heads: 8,
297            head_dim: 256,
298            kv_seq_len: 100,
299            kv_capacity: 1024,
300            scale: 1.0,
301            mask_type: 1,
302            sliding_window: 0,
303            softcap: 0.0,
304        };
305        assert!(validate_params(&p).is_ok());
306    }
307
308    #[test]
309    fn test_validate_params_bad_head_dim() {
310        let p = FlashAttnVecParams {
311            num_heads: 16,
312            num_kv_heads: 8,
313            head_dim: 128,
314            kv_seq_len: 100,
315            kv_capacity: 1024,
316            scale: 1.0,
317            mask_type: 0,
318            sliding_window: 0,
319            softcap: 0.0,
320        };
321        assert!(validate_params(&p).is_err());
322    }
323
324    #[test]
325    fn test_gpu_params_layout() {
326        assert_eq!(
327            std::mem::size_of::<FlashAttnVecParamsGpu>(),
328            40, // 10 x u32/f32 = 40 bytes
329        );
330    }
331
332    #[test]
333    fn test_tmp_buffer_size() {
334        // 16 heads, dk256, nwg=32
335        let bytes = tmp_buffer_bytes(16, 256);
336        // 16 * 32 * (256 + 2) * 4 = 16 * 32 * 258 * 4 = 528384
337        assert_eq!(bytes, 16 * 32 * 258 * 4);
338    }
339}