1use metal::MTLSize;
16
17use crate::buffer::MlxBuffer;
18use crate::device::MlxDevice;
19use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
20use crate::error::{MlxError, Result};
21use crate::kernel_registry::KernelRegistry;
22
23pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
25 include_str!("../shaders/flash_attn_vec_tq.metal");
26
27pub fn register(registry: &mut KernelRegistry) {
29 registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
30 registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
31}
32
33#[derive(Debug, Clone, Copy)]
35pub struct FlashAttnVecTqParams {
36 pub num_heads: u32,
38 pub num_kv_heads: u32,
40 pub head_dim: u32,
42 pub kv_seq_len: u32,
44 pub kv_capacity: u32,
46 pub scale: f32,
48 pub mask_type: u32,
50 pub sliding_window: u32,
52 pub softcap: f32,
54}
55
56#[repr(C)]
58#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
59struct FlashAttnVecReduceParamsGpu {
60 nrows: u32,
61}
62
63#[repr(C)]
65#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
66struct FlashAttnVecTqParamsGpu {
67 n_heads: u32,
68 n_kv_heads: u32,
69 head_dim: u32,
70 kv_seq_len: u32,
71 kv_capacity: u32,
72 scale: f32,
73 mask_type: u32,
74 sliding_window: u32,
75 softcap: f32,
76 nwg: u32,
77}
78
79fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
81 if params.head_dim != 256 && params.head_dim != 512 {
82 return Err(MlxError::InvalidArgument(format!(
83 "flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
84 params.head_dim
85 )));
86 }
87 if params.num_heads == 0 || params.num_kv_heads == 0 {
88 return Err(MlxError::InvalidArgument(
89 "flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
90 ));
91 }
92 if params.num_heads % params.num_kv_heads != 0 {
93 return Err(MlxError::InvalidArgument(format!(
94 "flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
95 params.num_heads, params.num_kv_heads
96 )));
97 }
98 if params.kv_seq_len == 0 {
99 return Err(MlxError::InvalidArgument(
100 "flash_attn_vec_tq: kv_seq_len must be > 0".into(),
101 ));
102 }
103 if params.kv_capacity < params.kv_seq_len {
104 return Err(MlxError::InvalidArgument(format!(
105 "flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
106 params.kv_capacity, params.kv_seq_len
107 )));
108 }
109 Ok(())
110}
111
112const NWG: u32 = 32;
115
116#[allow(clippy::too_many_arguments)]
137pub fn flash_attn_vec_tq(
138 encoder: &mut CommandEncoder,
139 registry: &mut KernelRegistry,
140 device: &MlxDevice,
141 q: &MlxBuffer,
142 k_packed: &MlxBuffer,
143 k_norms: &MlxBuffer,
144 v_packed: &MlxBuffer,
145 v_norms: &MlxBuffer,
146 output: &MlxBuffer,
147 tmp: &MlxBuffer,
148 params: &FlashAttnVecTqParams,
149) -> Result<()> {
150 validate_params(params)?;
151
152 let head_dim = params.head_dim;
153 let nwg = NWG;
154
155 let gpu_params = FlashAttnVecTqParamsGpu {
156 n_heads: params.num_heads,
157 n_kv_heads: params.num_kv_heads,
158 head_dim: params.head_dim,
159 kv_seq_len: params.kv_seq_len,
160 kv_capacity: params.kv_capacity,
161 scale: params.scale,
162 mask_type: params.mask_type,
163 sliding_window: params.sliding_window,
164 softcap: params.softcap,
165 nwg,
166 };
167
168 let kernel_name = match head_dim {
169 256 => "flash_attn_vec_tq_dk256",
170 512 => "flash_attn_vec_tq_dk512",
171 _ => return Err(MlxError::InvalidArgument(format!(
172 "flash_attn_vec_tq: unsupported head_dim {head_dim}"
173 ))),
174 };
175 let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
176
177 let pk = pad2(head_dim as usize, 128);
180 let pv = pad2(head_dim as usize, 128);
181 let sh = 4 * 32; let shmem_halfs = pk + sh + 2 * pv;
183 let shmem_bytes = shmem_halfs * 2; encoder.set_op_kind(CapturedOpKind::Sdpa);
187
188 let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
190 let threadgroup_size = MTLSize::new(32, 1, 1); encoder.encode_threadgroups_with_args_and_shared(
193 pipeline,
194 &[
195 (0, KernelArg::Bytes(as_bytes(&gpu_params))),
196 (1, KernelArg::Buffer(q)),
197 (2, KernelArg::Buffer(k_packed)),
198 (3, KernelArg::Buffer(k_norms)),
199 (4, KernelArg::Buffer(v_packed)),
200 (5, KernelArg::Buffer(v_norms)),
201 (6, KernelArg::Buffer(tmp)),
202 ],
203 &[(0, shmem_bytes as u64)],
204 threadgroups,
205 threadgroup_size,
206 );
207
208 encoder.memory_barrier();
211
212 let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
213
214 let reduce_kernel = match head_dim {
215 256 => "flash_attn_vec_reduce_dk256",
216 512 => "flash_attn_vec_reduce_dk512",
217 _ => unreachable!(),
218 };
219 let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
220
221 let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
223 let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
224
225 encoder.encode_threadgroups_with_args(
226 reduce_pipeline,
227 &[
228 (0, KernelArg::Bytes(as_bytes(&reduce_params))),
229 (1, KernelArg::Buffer(tmp)),
230 (2, KernelArg::Buffer(output)),
231 (3, KernelArg::Bytes(as_bytes(&nwg))),
232 ],
233 reduce_tg,
234 reduce_tg_size,
235 );
236
237 Ok(())
238}
239
240pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
244 let nrows = num_heads as usize;
245 let nwg = NWG as usize;
246 let dv = head_dim as usize;
247 (nrows * nwg * (dv + 2)) * std::mem::size_of::<f32>()
248}
249
250fn pad2(x: usize, n: usize) -> usize {
252 (x + n - 1) & !(n - 1)
253}
254
255#[cfg(test)]
256#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_validate_params_ok() {
262 let p = FlashAttnVecTqParams {
263 num_heads: 8,
264 num_kv_heads: 4,
265 head_dim: 256,
266 kv_seq_len: 64,
267 kv_capacity: 1024,
268 scale: 1.0,
269 mask_type: 1,
270 sliding_window: 0,
271 softcap: 0.0,
272 };
273 assert!(validate_params(&p).is_ok());
274 }
275
276 #[test]
277 fn test_validate_params_bad_head_dim() {
278 let p = FlashAttnVecTqParams {
279 num_heads: 8,
280 num_kv_heads: 4,
281 head_dim: 128,
282 kv_seq_len: 64,
283 kv_capacity: 1024,
284 scale: 1.0,
285 mask_type: 0,
286 sliding_window: 0,
287 softcap: 0.0,
288 };
289 assert!(validate_params(&p).is_err());
290 }
291
292 #[test]
293 fn test_gpu_params_layout() {
294 assert_eq!(
295 std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
296 40, );
298 }
299}