1use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::device::MlxDevice;
18use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
19use crate::error::{MlxError, Result};
20use crate::kernel_registry::KernelRegistry;
21
22pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
24 include_str!("../shaders/flash_attn_vec_tq.metal");
25
26pub fn register(registry: &mut KernelRegistry) {
28 registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
29 registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
30}
31
32#[derive(Debug, Clone, Copy)]
34pub struct FlashAttnVecTqParams {
35 pub num_heads: u32,
37 pub num_kv_heads: u32,
39 pub head_dim: u32,
41 pub kv_seq_len: u32,
43 pub kv_capacity: u32,
45 pub scale: f32,
47 pub mask_type: u32,
49 pub sliding_window: u32,
51 pub softcap: f32,
53 pub ring_start: u32,
62 pub scale_factor_d512: f32,
68}
69
70#[repr(C)]
72#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
73struct FlashAttnVecReduceParamsGpu {
74 nrows: u32,
75}
76
77#[repr(C)]
79#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
80struct FlashAttnVecTqParamsGpu {
81 n_heads: u32,
82 n_kv_heads: u32,
83 head_dim: u32,
84 kv_seq_len: u32,
85 kv_capacity: u32,
86 scale: f32,
87 mask_type: u32,
88 sliding_window: u32,
89 softcap: f32,
90 nwg: u32,
91 ring_start: u32,
92 scale_factor_d512: f32,
94}
95
96fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
98 if params.head_dim != 256 && params.head_dim != 512 {
99 return Err(MlxError::InvalidArgument(format!(
100 "flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
101 params.head_dim
102 )));
103 }
104 if params.num_heads == 0 || params.num_kv_heads == 0 {
105 return Err(MlxError::InvalidArgument(
106 "flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
107 ));
108 }
109 if params.num_heads % params.num_kv_heads != 0 {
110 return Err(MlxError::InvalidArgument(format!(
111 "flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
112 params.num_heads, params.num_kv_heads
113 )));
114 }
115 if params.kv_seq_len == 0 {
116 return Err(MlxError::InvalidArgument(
117 "flash_attn_vec_tq: kv_seq_len must be > 0".into(),
118 ));
119 }
120 if params.kv_capacity < params.kv_seq_len {
121 return Err(MlxError::InvalidArgument(format!(
122 "flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
123 params.kv_capacity, params.kv_seq_len
124 )));
125 }
126 Ok(())
127}
128
129fn compute_nwg(_kv_seq_len: u32) -> u32 {
138 use std::sync::atomic::{AtomicI32, Ordering};
143 static CACHED_TQ_NWG_4BIT: AtomicI32 = AtomicI32::new(-1);
144 let mut v = CACHED_TQ_NWG_4BIT.load(Ordering::Relaxed);
145 if v < 0 {
146 let parsed = std::env::var("HF2Q_TQ_NWG")
147 .ok()
148 .and_then(|s| s.parse::<u32>().ok())
149 .filter(|&n| n >= 1 && n <= 32)
150 .unwrap_or(0);
151 CACHED_TQ_NWG_4BIT.store(parsed as i32, Ordering::Relaxed);
152 v = parsed as i32;
153 }
154 if v > 0 { v as u32 } else { 16 }
155}
156
157#[allow(clippy::too_many_arguments)]
174pub fn flash_attn_vec_tq(
175 encoder: &mut CommandEncoder,
176 registry: &mut KernelRegistry,
177 device: &MlxDevice,
178 q: &MlxBuffer,
179 k_packed: &MlxBuffer,
180 k_norms: &MlxBuffer,
181 v_packed: &MlxBuffer,
182 v_norms: &MlxBuffer,
183 output: &MlxBuffer,
184 tmp: &MlxBuffer,
185 params: &FlashAttnVecTqParams,
186) -> Result<()> {
187 validate_params(params)?;
188
189 let head_dim = params.head_dim;
190 let nwg = compute_nwg(params.kv_seq_len);
191
192 let effective_scale_d512 = if params.scale_factor_d512 < 1e-6 { 1.0_f32 } else { params.scale_factor_d512 };
194 let gpu_params = FlashAttnVecTqParamsGpu {
195 n_heads: params.num_heads,
196 n_kv_heads: params.num_kv_heads,
197 head_dim: params.head_dim,
198 kv_seq_len: params.kv_seq_len,
199 kv_capacity: params.kv_capacity,
200 scale: params.scale,
201 mask_type: params.mask_type,
202 sliding_window: params.sliding_window,
203 softcap: params.softcap,
204 nwg,
205 ring_start: params.ring_start,
206 scale_factor_d512: effective_scale_d512,
207 };
208
209 let kernel_name = match head_dim {
210 256 => "flash_attn_vec_tq_dk256",
211 512 => "flash_attn_vec_tq_dk512",
212 _ => return Err(MlxError::InvalidArgument(format!(
213 "flash_attn_vec_tq: unsupported head_dim {head_dim}"
214 ))),
215 };
216 let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
217
218 let pk = pad2(head_dim as usize, 128);
221 let pv = pad2(head_dim as usize, 128);
222 let sh = 4 * 32; let shmem_halfs = pk + sh + 2 * pv;
224 let shmem_bytes = shmem_halfs * 2; encoder.set_op_kind(CapturedOpKind::Sdpa);
228
229 let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
231 let threadgroup_size = MTLSize::new(32, 1, 1); let dst_buf = if nwg == 1 { output } else { tmp };
236
237 encoder.encode_threadgroups_with_args_and_shared(
238 pipeline,
239 &[
240 (0, KernelArg::Bytes(as_bytes(&gpu_params))),
241 (1, KernelArg::Buffer(q)),
242 (2, KernelArg::Buffer(k_packed)),
243 (3, KernelArg::Buffer(k_norms)),
244 (4, KernelArg::Buffer(v_packed)),
245 (5, KernelArg::Buffer(v_norms)),
246 (6, KernelArg::Buffer(dst_buf)),
247 ],
248 &[(0, shmem_bytes as u64)],
249 threadgroups,
250 threadgroup_size,
251 );
252
253 if nwg > 1 {
255 encoder.memory_barrier();
256
257 let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
258
259 let reduce_kernel = match head_dim {
260 256 => "flash_attn_vec_reduce_dk256",
261 512 => "flash_attn_vec_reduce_dk512",
262 _ => unreachable!(),
263 };
264 let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
265
266 let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
267 let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
268
269 encoder.encode_threadgroups_with_args(
270 reduce_pipeline,
271 &[
272 (0, KernelArg::Bytes(as_bytes(&reduce_params))),
273 (1, KernelArg::Buffer(tmp)),
274 (2, KernelArg::Buffer(output)),
275 (3, KernelArg::Bytes(as_bytes(&nwg))),
276 ],
277 reduce_tg,
278 reduce_tg_size,
279 );
280 }
281
282 Ok(())
283}
284
285pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
290 let nrows = num_heads as usize;
291 let max_nwg = 32usize;
292 let dv = head_dim as usize;
293 (nrows * max_nwg * (dv + 2)) * std::mem::size_of::<f32>()
294}
295
296fn pad2(x: usize, n: usize) -> usize {
298 (x + n - 1) & !(n - 1)
299}
300
301#[cfg(test)]
302#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_validate_params_ok() {
308 let p = FlashAttnVecTqParams {
309 num_heads: 8,
310 num_kv_heads: 4,
311 head_dim: 256,
312 kv_seq_len: 64,
313 kv_capacity: 1024,
314 scale: 1.0,
315 mask_type: 1,
316 sliding_window: 0,
317 softcap: 0.0,
318 ring_start: 0,
319 scale_factor_d512: 1.0,
320 };
321 assert!(validate_params(&p).is_ok());
322 }
323
324 #[test]
325 fn test_validate_params_bad_head_dim() {
326 let p = FlashAttnVecTqParams {
327 num_heads: 8,
328 num_kv_heads: 4,
329 head_dim: 128,
330 kv_seq_len: 64,
331 kv_capacity: 1024,
332 scale: 1.0,
333 mask_type: 0,
334 sliding_window: 0,
335 softcap: 0.0,
336 ring_start: 0,
337 scale_factor_d512: 1.0,
338 };
339 assert!(validate_params(&p).is_err());
340 }
341
342 #[test]
343 fn test_gpu_params_layout() {
344 assert_eq!(
345 std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
346 48, );
348 }
349}