1use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::device::MlxDevice;
18use crate::encoder::{as_bytes, CapturedOpKind, CommandEncoder, KernelArg};
19use crate::error::{MlxError, Result};
20use crate::kernel_registry::KernelRegistry;
21
22pub static FLASH_ATTN_VEC_TQ_SHADER_SOURCE: &str =
24 include_str!("../shaders/flash_attn_vec_tq.metal");
25
26pub fn register(registry: &mut KernelRegistry) {
28 registry.register_source("flash_attn_vec_tq_dk256", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
29 registry.register_source("flash_attn_vec_tq_dk512", FLASH_ATTN_VEC_TQ_SHADER_SOURCE);
30}
31
32#[derive(Debug, Clone, Copy)]
34pub struct FlashAttnVecTqParams {
35 pub num_heads: u32,
37 pub num_kv_heads: u32,
39 pub head_dim: u32,
41 pub kv_seq_len: u32,
43 pub kv_capacity: u32,
45 pub scale: f32,
47 pub mask_type: u32,
49 pub sliding_window: u32,
51 pub softcap: f32,
53}
54
55#[repr(C)]
57#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
58struct FlashAttnVecReduceParamsGpu {
59 nrows: u32,
60}
61
62#[repr(C)]
64#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
65struct FlashAttnVecTqParamsGpu {
66 n_heads: u32,
67 n_kv_heads: u32,
68 head_dim: u32,
69 kv_seq_len: u32,
70 kv_capacity: u32,
71 scale: f32,
72 mask_type: u32,
73 sliding_window: u32,
74 softcap: f32,
75 nwg: u32,
76}
77
78fn validate_params(params: &FlashAttnVecTqParams) -> Result<()> {
80 if params.head_dim != 256 && params.head_dim != 512 {
81 return Err(MlxError::InvalidArgument(format!(
82 "flash_attn_vec_tq: head_dim must be 256 or 512, got {}",
83 params.head_dim
84 )));
85 }
86 if params.num_heads == 0 || params.num_kv_heads == 0 {
87 return Err(MlxError::InvalidArgument(
88 "flash_attn_vec_tq: num_heads and num_kv_heads must be > 0".into(),
89 ));
90 }
91 if params.num_heads % params.num_kv_heads != 0 {
92 return Err(MlxError::InvalidArgument(format!(
93 "flash_attn_vec_tq: num_heads ({}) must be divisible by num_kv_heads ({})",
94 params.num_heads, params.num_kv_heads
95 )));
96 }
97 if params.kv_seq_len == 0 {
98 return Err(MlxError::InvalidArgument(
99 "flash_attn_vec_tq: kv_seq_len must be > 0".into(),
100 ));
101 }
102 if params.kv_capacity < params.kv_seq_len {
103 return Err(MlxError::InvalidArgument(format!(
104 "flash_attn_vec_tq: kv_capacity ({}) must be >= kv_seq_len ({})",
105 params.kv_capacity, params.kv_seq_len
106 )));
107 }
108 Ok(())
109}
110
111fn compute_nwg(_kv_seq_len: u32) -> u32 {
120 if let Ok(v) = std::env::var("HF2Q_TQ_NWG") {
121 if let Ok(n) = v.parse::<u32>() {
122 if n >= 1 && n <= 32 {
123 return n;
124 }
125 }
126 }
127 16
128}
129
130#[allow(clippy::too_many_arguments)]
147pub fn flash_attn_vec_tq(
148 encoder: &mut CommandEncoder,
149 registry: &mut KernelRegistry,
150 device: &MlxDevice,
151 q: &MlxBuffer,
152 k_packed: &MlxBuffer,
153 k_norms: &MlxBuffer,
154 v_packed: &MlxBuffer,
155 v_norms: &MlxBuffer,
156 output: &MlxBuffer,
157 tmp: &MlxBuffer,
158 params: &FlashAttnVecTqParams,
159) -> Result<()> {
160 validate_params(params)?;
161
162 let head_dim = params.head_dim;
163 let nwg = compute_nwg(params.kv_seq_len);
164
165 let gpu_params = FlashAttnVecTqParamsGpu {
166 n_heads: params.num_heads,
167 n_kv_heads: params.num_kv_heads,
168 head_dim: params.head_dim,
169 kv_seq_len: params.kv_seq_len,
170 kv_capacity: params.kv_capacity,
171 scale: params.scale,
172 mask_type: params.mask_type,
173 sliding_window: params.sliding_window,
174 softcap: params.softcap,
175 nwg,
176 };
177
178 let kernel_name = match head_dim {
179 256 => "flash_attn_vec_tq_dk256",
180 512 => "flash_attn_vec_tq_dk512",
181 _ => return Err(MlxError::InvalidArgument(format!(
182 "flash_attn_vec_tq: unsupported head_dim {head_dim}"
183 ))),
184 };
185 let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
186
187 let pk = pad2(head_dim as usize, 128);
190 let pv = pad2(head_dim as usize, 128);
191 let sh = 4 * 32; let shmem_halfs = pk + sh + 2 * pv;
193 let shmem_bytes = shmem_halfs * 2; encoder.set_op_kind(CapturedOpKind::Sdpa);
197
198 let threadgroups = MTLSize::new(1, params.num_heads as u64, nwg as u64);
200 let threadgroup_size = MTLSize::new(32, 1, 1); let dst_buf = if nwg == 1 { output } else { tmp };
205
206 encoder.encode_threadgroups_with_args_and_shared(
207 pipeline,
208 &[
209 (0, KernelArg::Bytes(as_bytes(&gpu_params))),
210 (1, KernelArg::Buffer(q)),
211 (2, KernelArg::Buffer(k_packed)),
212 (3, KernelArg::Buffer(k_norms)),
213 (4, KernelArg::Buffer(v_packed)),
214 (5, KernelArg::Buffer(v_norms)),
215 (6, KernelArg::Buffer(dst_buf)),
216 ],
217 &[(0, shmem_bytes as u64)],
218 threadgroups,
219 threadgroup_size,
220 );
221
222 if nwg > 1 {
224 encoder.memory_barrier();
225
226 let reduce_params = FlashAttnVecReduceParamsGpu { nrows: params.num_heads };
227
228 let reduce_kernel = match head_dim {
229 256 => "flash_attn_vec_reduce_dk256",
230 512 => "flash_attn_vec_reduce_dk512",
231 _ => unreachable!(),
232 };
233 let reduce_pipeline = registry.get_pipeline(reduce_kernel, device.metal_device())?;
234
235 let reduce_tg = MTLSize::new(params.num_heads as u64, 1, 1);
236 let reduce_tg_size = MTLSize::new(32 * nwg as u64, 1, 1);
237
238 encoder.encode_threadgroups_with_args(
239 reduce_pipeline,
240 &[
241 (0, KernelArg::Bytes(as_bytes(&reduce_params))),
242 (1, KernelArg::Buffer(tmp)),
243 (2, KernelArg::Buffer(output)),
244 (3, KernelArg::Bytes(as_bytes(&nwg))),
245 ],
246 reduce_tg,
247 reduce_tg_size,
248 );
249 }
250
251 Ok(())
252}
253
254pub fn tmp_buffer_bytes(num_heads: u32, head_dim: u32) -> usize {
259 let nrows = num_heads as usize;
260 let max_nwg = 32usize;
261 let dv = head_dim as usize;
262 (nrows * max_nwg * (dv + 2)) * std::mem::size_of::<f32>()
263}
264
265fn pad2(x: usize, n: usize) -> usize {
267 (x + n - 1) & !(n - 1)
268}
269
270#[cfg(test)]
271#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_validate_params_ok() {
277 let p = FlashAttnVecTqParams {
278 num_heads: 8,
279 num_kv_heads: 4,
280 head_dim: 256,
281 kv_seq_len: 64,
282 kv_capacity: 1024,
283 scale: 1.0,
284 mask_type: 1,
285 sliding_window: 0,
286 softcap: 0.0,
287 };
288 assert!(validate_params(&p).is_ok());
289 }
290
291 #[test]
292 fn test_validate_params_bad_head_dim() {
293 let p = FlashAttnVecTqParams {
294 num_heads: 8,
295 num_kv_heads: 4,
296 head_dim: 128,
297 kv_seq_len: 64,
298 kv_capacity: 1024,
299 scale: 1.0,
300 mask_type: 0,
301 sliding_window: 0,
302 softcap: 0.0,
303 };
304 assert!(validate_params(&p).is_err());
305 }
306
307 #[test]
308 fn test_gpu_params_layout() {
309 assert_eq!(
310 std::mem::size_of::<FlashAttnVecTqParamsGpu>(),
311 40, );
313 }
314}