1use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::device::MlxDevice;
18use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
19use crate::error::{MlxError, Result};
20use crate::kernel_registry::KernelRegistry;
21
22pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
24 include_str!("../shaders/flash_attn_vec_tq.metal");
25
26pub fn register(registry: &mut KernelRegistry) {
28 registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
29 registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
30}
31
32#[derive(Debug, Clone, Copy)]
34pub struct FlashAttnVecTqParams {
35 pub num_heads: u32,
37 pub num_kv_heads: u32,
39 pub head_dim: u32,
41 pub kv_seq_len: u32,
43 pub kv_capacity: u32,
45 pub scale: f32,
47 pub mask_type: u32,
49 pub sliding_window: u32,
51 pub softcap: f32,
53 pub ring_start: u32,
62}
63
64#[repr(C)]
66#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
67struct FlashAttnVecReduceParamsGpu {
68 nrows: u32,
69}
70
71#[repr(C)]
73#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
74struct FlashAttnVecTqParamsGpu {
75 n_heads: u32,
76 n_kv_heads: u32,
77 head_dim: u32,
78 kv_seq_len: u32,
79 kv_capacity: u32,
80 scale: f32,
81 mask_type: u32,
82 sliding_window: u32,
83 softcap: f32,
84 nwg: u32,
85 ring_start: u32,
86}
87
88fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
90 if params.head_dim != 256 && params.head_dim != 512 {
91 return Err(MlxError::InvalidArgument(format!(
92 "flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
93 params.head_dim
94 )));
95 }
96 if params.num_heads == 0 || params.num_kv_heads == 0 {
97 return Err(MlxError::InvalidArgument(
98 "flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
99 ));
100 }
101 if params.num_heads % params.num_kv_heads != 0 {
102 return Err(MlxError::InvalidArgument(format!(
103 "flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
104 params.num_heads, params.num_kv_heads
105 )));
106 }
107 if params.kv_seq_len == 0 {
108 return Err(MlxError::InvalidArgument(
109 "flash_attn_vec_tq: kv_seq_len must be > 0".into(),
110 ));
111 }
112 if params.kv_capacity < params.kv_seq_len {
113 return Err(MlxError::InvalidArgument(format!(
114 "flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
115 params.kv_capacity, params.kv_seq_len
116 )));
117 }
118 Ok(())
119}
120
121fn compute_nwg(_kv_seq_len: u32) -> u32 {
130 if let Ok(v) = std::env::var("HF2Q_TQ_NWG") {
131 if let Ok(n) = v.parse::<u32>() {
132 if n >= 1 && n <= 32 {
133 return n;
134 }
135 }
136 }
137 16
138}
139
140#[allow(clippy::too_many_arguments)]
157pub fn flash_attn_vec_tq(
158 encoder: &mut CommandEncoder,
159 registry: &mut KernelRegistry,
160 device: &MlxDevice,
161 q: &MlxBuffer,
162 k_packed: &MlxBuffer,
163 k_norms: &MlxBuffer,
164 v_packed: &MlxBuffer,
165 v_norms: &MlxBuffer,
166 output: &MlxBuffer,
167 tmp: &MlxBuffer,
168 params: &FlashAttnVecTqParams,
169) -> Result<()> {
170 validate_params(params)?;
171
172 let head_dim = params.head_dim;
173 let nwg = compute_nwg(params.kv_seq_len);
174
175 let gpu_params = FlashAttnVecTqParamsGpu {
176 n_heads: params.num_heads,
177 n_kv_heads: params.num_kv_heads,
178 head_dim: params.head_dim,
179 kv_seq_len: params.kv_seq_len,
180 kv_capacity: params.kv_capacity,
181 scale: params.scale,
182 mask_type: params.mask_type,
183 sliding_window: params.sliding_window,
184 softcap: params.softcap,
185 nwg,
186 ring_start: params.ring_start,
187 };
188
189 let kernel_name = match head_dim {
190 256 => "flash_attn_vec_tq_dk256",
191 512 => "flash_attn_vec_tq_dk512",
192 _ => return Err(MlxError::InvalidArgument(format!(
193 "flash_attn_vec_tq: unsupported head_dim {head_dim}"
194 ))),
195 };
196 let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
197
198 let pk = pad2(head_dim as usize, 128);
201 let pv = pad2(head_dim as usize, 128);
202 let sh = 4 * 32; let shmem_halfs = pk + sh + 2 * pv;
204 let shmem_bytes = shmem_halfs * 2; encoder.set_op_kind(CapturedOpKind::Sdpa);
208
209 let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
211 let threadgroup_size = MTLSize::new(32, 1, 1); let dst_buf = if nwg == 1 { output } else { tmp };
216
217 encoder.encode_threadgroups_with_args_and_shared(
218 pipeline,
219 &[
220 (0, KernelArg::Bytes(as_bytes(&gpu_params))),
221 (1, KernelArg::Buffer(q)),
222 (2, KernelArg::Buffer(k_packed)),
223 (3, KernelArg::Buffer(k_norms)),
224 (4, KernelArg::Buffer(v_packed)),
225 (5, KernelArg::Buffer(v_norms)),
226 (6, KernelArg::Buffer(dst_buf)),
227 ],
228 &[(0, shmem_bytes as u64)],
229 threadgroups,
230 threadgroup_size,
231 );
232
233 if nwg > 1 {
235 encoder.memory_barrier();
236
237 let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
238
239 let reduce_kernel = match head_dim {
240 256 => "flash_attn_vec_reduce_dk256",
241 512 => "flash_attn_vec_reduce_dk512",
242 _ => unreachable!(),
243 };
244 let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
245
246 let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
247 let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
248
249 encoder.encode_threadgroups_with_args(
250 reduce_pipeline,
251 &[
252 (0, KernelArg::Bytes(as_bytes(&reduce_params))),
253 (1, KernelArg::Buffer(tmp)),
254 (2, KernelArg::Buffer(output)),
255 (3, KernelArg::Bytes(as_bytes(&nwg))),
256 ],
257 reduce_tg,
258 reduce_tg_size,
259 );
260 }
261
262 Ok(())
263}
264
265pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
270 let nrows = num_heads as usize;
271 let max_nwg = 32usize;
272 let dv = head_dim as usize;
273 (nrows * max_nwg * (dv + 2)) * std::mem::size_of::<f32>()
274}
275
276fn pad2(x: usize, n: usize) -> usize {
278 (x + n - 1) & !(n - 1)
279}
280
281#[cfg(test)]
282#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_validate_params_ok() {
288 let p = FlashAttnVecTqParams {
289 num_heads: 8,
290 num_kv_heads: 4,
291 head_dim: 256,
292 kv_seq_len: 64,
293 kv_capacity: 1024,
294 scale: 1.0,
295 mask_type: 1,
296 sliding_window: 0,
297 softcap: 0.0,
298 ring_start: 0,
299 };
300 assert!(validate_params(&p).is_ok());
301 }
302
303 #[test]
304 fn test_validate_params_bad_head_dim() {
305 let p = FlashAttnVecTqParams {
306 num_heads: 8,
307 num_kv_heads: 4,
308 head_dim: 128,
309 kv_seq_len: 64,
310 kv_capacity: 1024,
311 scale: 1.0,
312 mask_type: 0,
313 sliding_window: 0,
314 softcap: 0.0,
315 ring_start: 0,
316 };
317 assert!(validate_params(&p).is_err());
318 }
319
320 #[test]
321 fn test_gpu_params_layout() {
322 assert_eq!(
323 std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
324 44, );
326 }
327}