1use metal::MTLSize;
11
12use crate::buffer::MlxBuffer;
13use crate::device::MlxDevice;
14use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17use crate::DType;
18
19pub static FLASH_ATTN_VEC_SHADER_SOURCE: &str =
21 include_str!("../shaders/flash_attn_vec.metal");
22
23pub fn register(registry: &mut KernelRegistry) {
25 registry.register_source("flash_attn_vec_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
26 registry.register_source("flash_attn_vec_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
27 registry.register_source("flash_attn_vec_reduce_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
28 registry.register_source("flash_attn_vec_reduce_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
29 registry.register_source("flash_attn_vec_f16kv_dk256", FLASH_ATTN_VEC_SHADER_SOURCE);
31 registry.register_source("flash_attn_vec_f16kv_dk512", FLASH_ATTN_VEC_SHADER_SOURCE);
32}
33
34#[derive(Debug, Clone, Copy)]
36pub struct FlashAttnVecParams {
37 pub num_heads: u32,
39 pub num_kv_heads: u32,
41 pub head_dim: u32,
43 pub kv_seq_len: u32,
45 pub kv_capacity: u32,
47 pub scale: f32,
49 pub mask_type: u32,
51 pub sliding_window: u32,
53 pub softcap: f32,
55}
56
57#[repr(C)]
59#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
60struct FlashAttnVecParamsGpu {
61 n_heads: u32,
62 n_kv_heads: u32,
63 head_dim: u32,
64 kv_seq_len: u32,
65 kv_capacity: u32,
66 scale: f32,
67 mask_type: u32,
68 sliding_window: u32,
69 softcap: f32,
70 nwg: u32,
71}
72
73#[repr(C)]
75#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
76struct FlashAttnVecReduceParamsGpu {
77 nrows: u32,
78}
79
80const NWG: u32 = 32;
83
84fn validate_params(params: &FlashAttnVecParams) -> Result<()> {
86 if params.head_dim != 256 && params.head_dim != 512 {
87 return Err(MlxError::InvalidArgument(format!(
88 "flash_attn_vec: head_dim must be 256 or 512, got {}",
89 params.head_dim
90 )));
91 }
92 if params.num_heads == 0 || params.num_kv_heads == 0 {
93 return Err(MlxError::InvalidArgument(
94 "flash_attn_vec: num_heads and num_kv_heads must be > 0".into(),
95 ));
96 }
97 if params.num_heads % params.num_kv_heads != 0 {
98 return Err(MlxError::InvalidArgument(format!(
99 "flash_attn_vec: num_heads ({}) must be divisible by num_kv_heads ({})",
100 params.num_heads, params.num_kv_heads
101 )));
102 }
103 if params.kv_seq_len == 0 {
104 return Err(MlxError::InvalidArgument(
105 "flash_attn_vec: kv_seq_len must be > 0".into(),
106 ));
107 }
108 if params.kv_capacity < params.kv_seq_len {
109 return Err(MlxError::InvalidArgument(format!(
110 "flash_attn_vec: kv_capacity ({}) must be >= kv_seq_len ({})",
111 params.kv_capacity, params.kv_seq_len
112 )));
113 }
114 Ok(())
115}
116
117pub fn flash_attn_vec(
136 encoder: &mut CommandEncoder,
137 registry: &mut KernelRegistry,
138 device: &MlxDevice,
139 q: &MlxBuffer,
140 k: &MlxBuffer,
141 v: &MlxBuffer,
142 output: &MlxBuffer,
143 tmp: &MlxBuffer,
144 params: &FlashAttnVecParams,
145) -> Result<()> {
146 validate_params(params)?;
147
148 let head_dim = params.head_dim;
149 let nwg = NWG;
150
151 let gpu_params = FlashAttnVecParamsGpu {
153 n_heads: params.num_heads,
154 n_kv_heads: params.num_kv_heads,
155 head_dim: params.head_dim,
156 kv_seq_len: params.kv_seq_len,
157 kv_capacity: params.kv_capacity,
158 scale: params.scale,
159 mask_type: params.mask_type,
160 sliding_window: params.sliding_window,
161 softcap: params.softcap,
162 nwg,
163 };
164 let kv_is_f16 = k.dtype() == DType::F16;
167 let kernel_name = match (head_dim, kv_is_f16) {
168 (256, false) => "flash_attn_vec_dk256",
169 (512, false) => "flash_attn_vec_dk512",
170 (256, true) => "flash_attn_vec_f16kv_dk256",
171 (512, true) => "flash_attn_vec_f16kv_dk512",
172 _ => unreachable!(), };
174 let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
175
176 let pk = pad2(head_dim as usize, 128);
181 let pv = pad2(head_dim as usize, 128);
182 let sh = 4 * 32; let shmem_halfs = pk + sh + 2 * pv;
184 let shmem_bytes = shmem_halfs * 2; encoder.set_op_kind(CapturedOpKind::Sdpa);
188
189 let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
192 let threadgroup_size = MTLSize::new(32, 1, 1); encoder.encode_threadgroups_with_args_and_shared(
197 pipeline,
198 &[
199 (0, KernelArg::Bytes(as_bytes(&gpu_params))),
200 (1, KernelArg::Buffer(q)),
201 (2, KernelArg::Buffer(k)),
202 (3, KernelArg::Buffer(v)),
203 (4, KernelArg::Buffer(tmp)),
204 ],
205 &[(0, shmem_bytes as u64)],
206 threadgroups,
207 threadgroup_size,
208 );
209
210 if nwg > 1 {
216 encoder.memory_barrier();
217 let reduce_params = FlashAttnVecReduceParamsGpu {
218 nrows: params.num_heads,
219 };
220
221 let reduce_kernel = match head_dim {
222 256 => "flash_attn_vec_reduce_dk256",
223 512 => "flash_attn_vec_reduce_dk512",
224 _ => unreachable!(),
225 };
226 let reduce_pipeline =
227 registry.get_pipeline(reduce_kernel, device.metal_device())?;
228
229 let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
231 let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
232
233 {
236 let read_ranges = vec![
237 {
238 let s = tmp.contents_ptr() as usize;
239 (s, s + tmp.byte_len())
240 },
241 ];
242 let write_ranges = vec![
243 {
244 let s = output.contents_ptr() as usize;
245 (s, s + output.byte_len())
246 },
247 ];
248 encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
249 }
250
251 encoder.encode_threadgroups_with_args(
254 reduce_pipeline,
255 &[
256 (0, KernelArg::Bytes(as_bytes(&reduce_params))),
257 (1, KernelArg::Buffer(tmp)),
258 (2, KernelArg::Buffer(output)),
259 (3, KernelArg::Bytes(as_bytes(&nwg))),
260 ],
261 reduce_tg,
262 reduce_tg_size,
263 );
264 }
265
266 Ok(())
267}
268
269pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
275 let nrows = num_heads as usize;
276 let nwg = NWG as usize;
277 let dv = head_dim as usize;
278 (nrows * nwg * (dv + 2)) * std::mem::size_of::<f32>()
280}
281
282fn pad2(x: usize, n: usize) -> usize {
284 (x + n - 1) & !(n - 1)
285}
286
287#[cfg(test)]
288#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_validate_params_ok() {
294 let p = FlashAttnVecParams {
295 num_heads: 16,
296 num_kv_heads: 8,
297 head_dim: 256,
298 kv_seq_len: 100,
299 kv_capacity: 1024,
300 scale: 1.0,
301 mask_type: 1,
302 sliding_window: 0,
303 softcap: 0.0,
304 };
305 assert!(validate_params(&p).is_ok());
306 }
307
308 #[test]
309 fn test_validate_params_bad_head_dim() {
310 let p = FlashAttnVecParams {
311 num_heads: 16,
312 num_kv_heads: 8,
313 head_dim: 128,
314 kv_seq_len: 100,
315 kv_capacity: 1024,
316 scale: 1.0,
317 mask_type: 0,
318 sliding_window: 0,
319 softcap: 0.0,
320 };
321 assert!(validate_params(&p).is_err());
322 }
323
324 #[test]
325 fn test_gpu_params_layout() {
326 assert_eq!(
327 std::mem::size_of::<FlashAttnVecParamsGpu>(),
328 40, );
330 }
331
332 #[test]
333 fn test_tmp_buffer_size() {
334 let bytes = tmp_buffer_bytes(16, 256);
336 assert_eq!(bytes, 16 * 32 * 258 * 4);
338 }
339}