1use metal::MTLSize;
64
65use crate::buffer::MlxBuffer;
66use crate::device::MlxDevice;
67use crate::dtypes::DType;
68use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
69use crate::error::{MlxError, Result};
70use crate::kernel_registry::KernelRegistry;
71use crate::ops::flash_attn_prefill::{AttnMaskParamsGpu, AttnParamsGpu};
72
73pub static FLASH_ATTN_TRAIN_FWD_SHADER_SOURCE: &str =
77 include_str!("../shaders/flash_attn_train_fwd.metal");
78
79const K_BF16_D64: &str = "flash_attn_train_fwd_bf16_d64";
82const K_BF16_D64_BOOLMASK: &str = "flash_attn_train_fwd_bf16_d64_boolmask";
83const K_BF16_D256: &str = "flash_attn_train_fwd_bf16_d256";
84const K_BF16_D256_BOOLMASK: &str = "flash_attn_train_fwd_bf16_d256_boolmask";
85
86const ALL_KERNEL_NAMES: &[&str] = &[
87 K_BF16_D64,
88 K_BF16_D64_BOOLMASK,
89 K_BF16_D256,
90 K_BF16_D256_BOOLMASK,
91];
92
93pub fn register(registry: &mut KernelRegistry) {
99 for &name in ALL_KERNEL_NAMES {
100 registry.register_source(name, FLASH_ATTN_TRAIN_FWD_SHADER_SOURCE);
101 }
102}
103
104const BQ: u32 = 32;
108const BK: u32 = 16;
109const WM: u32 = 4;
110const WN: u32 = 1;
111
112#[derive(Debug, Clone, Copy)]
119pub struct FlashAttnTrainParams {
120 pub batch: u32,
122 pub n_q_heads: u32,
124 pub n_kv_heads: u32,
126 pub head_dim: u32,
128 pub q_seq_len: u32,
130 pub k_seq_len: u32,
132 pub scale: f32,
137 pub causal: bool,
139}
140
141fn validate_params(p: &FlashAttnTrainParams) -> Result<()> {
144 if p.n_q_heads == 0 {
145 return Err(MlxError::InvalidArgument(
146 "flash_attn_train: n_q_heads must be > 0".into(),
147 ));
148 }
149 if p.n_kv_heads == 0 {
150 return Err(MlxError::InvalidArgument(
151 "flash_attn_train: n_kv_heads must be > 0".into(),
152 ));
153 }
154 if p.n_q_heads % p.n_kv_heads != 0 {
155 return Err(MlxError::InvalidArgument(format!(
156 "flash_attn_train: n_q_heads ({}) must be divisible by n_kv_heads ({})",
157 p.n_q_heads, p.n_kv_heads
158 )));
159 }
160 if p.q_seq_len == 0 {
161 return Err(MlxError::InvalidArgument(
162 "flash_attn_train: q_seq_len must be > 0".into(),
163 ));
164 }
165 if p.k_seq_len == 0 {
166 return Err(MlxError::InvalidArgument(
167 "flash_attn_train: k_seq_len must be > 0".into(),
168 ));
169 }
170 if p.batch == 0 {
171 return Err(MlxError::InvalidArgument(
172 "flash_attn_train: batch must be > 0".into(),
173 ));
174 }
175 Ok(())
176}
177
178fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
179 let expected_bytes = expected_elements * buf.dtype().size_of();
180 if buf.byte_len() < expected_bytes {
181 return Err(MlxError::InvalidArgument(format!(
182 "flash_attn_train: {name} buffer too small: expected at least \
183 {expected_bytes} bytes, got {}",
184 buf.byte_len()
185 )));
186 }
187 Ok(())
188}
189
190#[allow(clippy::too_many_arguments)]
197fn dispatch_inner(
198 encoder: &mut CommandEncoder,
199 device: &MlxDevice,
200 registry: &mut KernelRegistry,
201 q_buf: &MlxBuffer,
202 k_buf: &MlxBuffer,
203 v_buf: &MlxBuffer,
204 mask: Option<&MlxBuffer>,
205 o_buf: &MlxBuffer,
206 l_buf: &MlxBuffer,
207 params: &FlashAttnTrainParams,
208 kernel_name: &str,
209 head_dim_expected: u32,
210) -> Result<()> {
211 if params.head_dim != head_dim_expected {
213 return Err(MlxError::InvalidArgument(format!(
214 "flash_attn_train ({}): head_dim must be {head_dim_expected}, got {}",
215 kernel_name, params.head_dim
216 )));
217 }
218
219 validate_params(params)?;
220
221 for (buf, name) in &[(q_buf, "Q"), (k_buf, "K"), (v_buf, "V"), (o_buf as &MlxBuffer, "O")] {
223 if buf.dtype() != DType::BF16 {
224 return Err(MlxError::InvalidArgument(format!(
225 "flash_attn_train ({kernel_name}): {name} buffer must be BF16, got {:?}",
226 buf.dtype()
227 )));
228 }
229 }
230 if l_buf.dtype() != DType::F32 {
231 return Err(MlxError::InvalidArgument(format!(
232 "flash_attn_train ({kernel_name}): L_out buffer must be F32, got {:?}",
233 l_buf.dtype()
234 )));
235 }
236 if let Some(m) = mask {
237 if m.dtype() != DType::BF16 {
238 return Err(MlxError::InvalidArgument(format!(
239 "flash_attn_train ({kernel_name}): mask buffer must be BF16, got {:?}",
240 m.dtype()
241 )));
242 }
243 }
244
245 let batch = params.batch as usize;
247 let h = params.n_q_heads as usize;
248 let h_kv = params.n_kv_heads as usize;
249 let ql = params.q_seq_len as usize;
250 let kl = params.k_seq_len as usize;
251 let d = params.head_dim as usize;
252
253 validate_buffer_size(q_buf, "Q", batch * h * ql * d)?;
254 validate_buffer_size(k_buf, "K", batch * h_kv * kl * d)?;
255 validate_buffer_size(v_buf, "V", batch * h_kv * kl * d)?;
256 validate_buffer_size(o_buf, "O", batch * h * ql * d)?;
257 validate_buffer_size(l_buf, "L_out", batch * h * ql)?;
258 if let Some(m) = mask {
259 validate_buffer_size(m, "mask", batch * h * ql * kl)?;
260 }
261
262 let nq = params.q_seq_len.div_ceil(BQ);
264 let nk = params.k_seq_len.div_ceil(BK);
265 let nq_aligned = params.q_seq_len / BQ;
266 let nk_aligned = params.k_seq_len / BK;
267 let ql_rem = params.q_seq_len % BQ;
268 let kl_rem = params.k_seq_len % BK;
269
270 let align_q = ql_rem == 0;
271 let align_k = kl_rem == 0;
272 let has_mask = mask.is_some();
273 let do_causal = params.causal;
274
275 let pipeline = registry.get_pipeline_with_bool_constants(
277 kernel_name,
278 device.metal_device(),
279 &[
280 (200, align_q),
281 (201, align_k),
282 (300, has_mask),
283 (301, do_causal),
284 ],
285 )?;
286
287 let q_seq_stride = d as i64;
289 let q_head_stride = (ql * d) as i64;
290 let q_batch_stride = (h * ql * d) as i64;
291
292 let kv_seq_stride = d as i64;
293 let kv_head_stride = (kl * d) as i64;
294 let kv_batch_stride = (h_kv * kl * d) as i64;
295
296 let gqa_factor = (params.n_q_heads / params.n_kv_heads) as i32;
297
298 let attn_params = AttnParamsGpu {
299 b: params.batch as i32,
300 h: params.n_q_heads as i32,
301 d: params.head_dim as i32,
302 ql: params.q_seq_len as i32,
303 kl: params.k_seq_len as i32,
304 gqa_factor,
305 scale: params.scale,
306 softcapping: 1.0_f32,
307 nq: nq as i32,
308 nk: nk as i32,
309 nq_aligned: nq_aligned as i32,
310 nk_aligned: nk_aligned as i32,
311 ql_rem: ql_rem as i32,
312 kl_rem: kl_rem as i32,
313 ql_off: 0,
314 _pad: 0,
315 q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
316 k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
317 v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
318 o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
319 };
320
321 let grid = MTLSize::new(nq as u64, params.n_q_heads as u64, params.batch as u64);
325 let tg_size = MTLSize::new(32, WM as u64, WN as u64);
326
327 encoder.set_op_kind(CapturedOpKind::Sdpa);
329
330 if let Some(mask_buf) = mask {
331 let m_batch_stride = (h * ql * kl) as i64;
333 let m_head_stride = (ql * kl) as i64;
334 let m_ql_stride = kl as i64;
335
336 let mask_params = AttnMaskParamsGpu {
337 m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
338 };
339
340 encoder.encode_threadgroups_with_args(
341 pipeline,
342 &[
343 (0, KernelArg::Buffer(q_buf)),
344 (1, KernelArg::Buffer(k_buf)),
345 (2, KernelArg::Buffer(v_buf)),
346 (3, KernelArg::Buffer(o_buf)),
347 (4, KernelArg::Bytes(as_bytes(&attn_params))),
348 (5, KernelArg::Bytes(as_bytes(&mask_params))),
349 (6, KernelArg::Buffer(mask_buf)),
350 (8, KernelArg::Buffer(l_buf)),
352 ],
353 grid,
354 tg_size,
355 );
356 } else {
357 encoder.encode_threadgroups_with_args(
358 pipeline,
359 &[
360 (0, KernelArg::Buffer(q_buf)),
361 (1, KernelArg::Buffer(k_buf)),
362 (2, KernelArg::Buffer(v_buf)),
363 (3, KernelArg::Buffer(o_buf)),
364 (4, KernelArg::Bytes(as_bytes(&attn_params))),
365 (8, KernelArg::Buffer(l_buf)),
367 ],
368 grid,
369 tg_size,
370 );
371 }
372
373 Ok(())
374}
375
376#[allow(clippy::too_many_arguments)]
396pub fn dispatch_flash_attn_train_fwd_bf16_d64(
397 encoder: &mut CommandEncoder,
398 device: &MlxDevice,
399 registry: &mut KernelRegistry,
400 q_buf: &MlxBuffer,
401 k_buf: &MlxBuffer,
402 v_buf: &MlxBuffer,
403 mask: Option<&MlxBuffer>,
404 o_buf: &MlxBuffer,
405 l_buf: &MlxBuffer,
406 params: &FlashAttnTrainParams,
407) -> Result<()> {
408 dispatch_inner(
409 encoder, device, registry,
410 q_buf, k_buf, v_buf, mask, o_buf, l_buf,
411 params, K_BF16_D64, 64,
412 )
413}
414
415#[allow(clippy::too_many_arguments)]
424pub fn dispatch_flash_attn_train_fwd_bf16_d256(
425 encoder: &mut CommandEncoder,
426 device: &MlxDevice,
427 registry: &mut KernelRegistry,
428 q_buf: &MlxBuffer,
429 k_buf: &MlxBuffer,
430 v_buf: &MlxBuffer,
431 mask: Option<&MlxBuffer>,
432 o_buf: &MlxBuffer,
433 l_buf: &MlxBuffer,
434 params: &FlashAttnTrainParams,
435) -> Result<()> {
436 dispatch_inner(
437 encoder, device, registry,
438 q_buf, k_buf, v_buf, mask, o_buf, l_buf,
439 params, K_BF16_D256, 256,
440 )
441}
442
443#[doc(hidden)]
451pub fn all_kernel_names_for_test() -> &'static [&'static str] {
452 ALL_KERNEL_NAMES
453}
454
455pub static FLASH_ATTN_TRAIN_BWD_SHADER_SOURCE: &str =
480 include_str!("../shaders/flash_attn_train_bwd.metal");
481
482pub static FLASH_ATTN_TRAIN_BWD_COMPUTE_D_SHADER_SOURCE: &str =
484 include_str!("../shaders/flash_attn_train_bwd_compute_d.metal");
485
486const K_BWD_COMPUTE_D: &str = "flash_attn_train_bwd_compute_d_bf16";
488const K_BWD_D64: &str = "flash_attn_train_bwd_bf16_d64";
489const K_BWD_D256: &str = "flash_attn_train_bwd_bf16_d256";
490const K_F32_TO_BF16: &str = "f32_to_bf16_cast";
491
492const ALL_BWD_KERNEL_NAMES: &[&str] = &[
493 K_BWD_COMPUTE_D,
494 K_BWD_D64,
495 K_BWD_D256,
496 K_F32_TO_BF16,
497];
498
499pub fn register_bwd(registry: &mut KernelRegistry) {
504 registry.register_source(K_BWD_COMPUTE_D, FLASH_ATTN_TRAIN_BWD_COMPUTE_D_SHADER_SOURCE);
505 for &name in &[K_BWD_D64, K_BWD_D256, K_F32_TO_BF16] {
506 registry.register_source(name, FLASH_ATTN_TRAIN_BWD_SHADER_SOURCE);
507 }
508}
509
510#[doc(hidden)]
512pub fn all_bwd_kernel_names_for_test() -> &'static [&'static str] {
513 ALL_BWD_KERNEL_NAMES
514}
515
516#[repr(C)]
520#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
521struct ComputeDParams {
522 batch: u32,
523 n_q_heads: u32,
524 q_seq_len: u32,
525 head_dim: u32,
526}
527
528fn dispatch_compute_d(
535 encoder: &mut CommandEncoder,
536 registry: &mut KernelRegistry,
537 device: &metal::DeviceRef,
538 o_buf: &MlxBuffer,
539 do_buf: &MlxBuffer,
540 d_out_buf: &MlxBuffer,
541 params: &FlashAttnTrainParams,
542) -> Result<()> {
543 let p = ComputeDParams {
544 batch: params.batch,
545 n_q_heads: params.n_q_heads,
546 q_seq_len: params.q_seq_len,
547 head_dim: params.head_dim,
548 };
549
550 let pipeline = registry.get_pipeline(K_BWD_COMPUTE_D, device)?;
551
552 let tg_x = std::cmp::min(256, params.head_dim.next_power_of_two()) as u64;
553 let grid = MTLSize::new(
554 params.q_seq_len as u64,
555 1,
556 (params.batch * params.n_q_heads) as u64,
557 );
558 let tg_size = MTLSize::new(tg_x, 1, 1);
559
560 encoder.encode_threadgroups_with_args(
561 pipeline,
562 &[
563 (0, KernelArg::Buffer(o_buf)),
564 (1, KernelArg::Buffer(do_buf)),
565 (2, KernelArg::Buffer(d_out_buf)),
566 (3, KernelArg::Bytes(as_bytes(&p))),
567 ],
568 grid,
569 tg_size,
570 );
571
572 Ok(())
573}
574
575fn dispatch_f32_to_bf16(
582 encoder: &mut CommandEncoder,
583 registry: &mut KernelRegistry,
584 device: &metal::DeviceRef,
585 src: &MlxBuffer,
586 dst: &MlxBuffer,
587 n_elems: usize,
588) -> Result<()> {
589 let pipeline = registry.get_pipeline(K_F32_TO_BF16, device)?;
590 let tg_x = std::cmp::min(256u64, n_elems as u64);
591 let n_groups = (n_elems as u64).div_ceil(tg_x);
592 let n_u32 = n_elems as u32;
593 encoder.encode_threadgroups_with_args(
594 pipeline,
595 &[
596 (0, KernelArg::Buffer(src)),
597 (1, KernelArg::Buffer(dst)),
598 (2, KernelArg::Bytes(as_bytes(&n_u32))),
599 ],
600 MTLSize::new(n_groups, 1, 1),
601 MTLSize::new(tg_x, 1, 1),
602 );
603 Ok(())
604}
605
606#[allow(clippy::too_many_arguments)]
615fn dispatch_bwd_inner(
616 encoder: &mut CommandEncoder,
617 device: &MlxDevice,
618 registry: &mut KernelRegistry,
619 q_buf: &MlxBuffer,
620 k_buf: &MlxBuffer,
621 v_buf: &MlxBuffer,
622 o_buf: &MlxBuffer,
623 l_buf: &MlxBuffer,
624 do_buf: &MlxBuffer,
625 mask: Option<&MlxBuffer>,
626 dq_buf: &MlxBuffer,
627 dk_buf: &MlxBuffer,
628 dv_buf: &MlxBuffer,
629 params: &FlashAttnTrainParams,
630 bwd_kernel_name: &str,
631 head_dim_expected: u32,
632) -> Result<()> {
633 if params.head_dim != head_dim_expected {
635 return Err(MlxError::InvalidArgument(format!(
636 "flash_attn_train_bwd ({bwd_kernel_name}): head_dim must be \
637 {head_dim_expected}, got {}",
638 params.head_dim
639 )));
640 }
641
642 validate_params(params)?;
643
644 for (buf, name) in &[
646 (q_buf, "Q"),
647 (k_buf, "K"),
648 (v_buf, "V"),
649 (o_buf, "O"),
650 (do_buf, "dO"),
651 ] {
652 if buf.dtype() != DType::BF16 {
653 return Err(MlxError::InvalidArgument(format!(
654 "flash_attn_train_bwd ({bwd_kernel_name}): {name} buffer must be BF16, \
655 got {:?}",
656 buf.dtype()
657 )));
658 }
659 }
660 for (buf, name) in &[(l_buf, "L")] {
661 if buf.dtype() != DType::F32 {
662 return Err(MlxError::InvalidArgument(format!(
663 "flash_attn_train_bwd ({bwd_kernel_name}): {name} buffer must be F32, \
664 got {:?}",
665 buf.dtype()
666 )));
667 }
668 }
669 for (buf, name) in &[
670 (dq_buf as &MlxBuffer, "dQ"),
671 (dk_buf as &MlxBuffer, "dK"),
672 (dv_buf as &MlxBuffer, "dV"),
673 ] {
674 if buf.dtype() != DType::BF16 {
675 return Err(MlxError::InvalidArgument(format!(
676 "flash_attn_train_bwd ({bwd_kernel_name}): {name} output buffer must be \
677 BF16, got {:?}",
678 buf.dtype()
679 )));
680 }
681 }
682 if let Some(m) = mask {
683 if m.dtype() != DType::BF16 {
684 return Err(MlxError::InvalidArgument(format!(
685 "flash_attn_train_bwd ({bwd_kernel_name}): mask buffer must be BF16, \
686 got {:?}",
687 m.dtype()
688 )));
689 }
690 }
691
692 let batch = params.batch as usize;
694 let h_q = params.n_q_heads as usize;
695 let h_kv = params.n_kv_heads as usize;
696 let ql = params.q_seq_len as usize;
697 let kl = params.k_seq_len as usize;
698 let d = params.head_dim as usize;
699
700 let q_elems = batch * h_q * ql * d;
701 let kv_elems = batch * h_kv * kl * d;
702 let l_elems = batch * h_q * ql;
703
704 validate_buffer_size(q_buf, "Q", q_elems)?;
705 validate_buffer_size(k_buf, "K", kv_elems)?;
706 validate_buffer_size(v_buf, "V", kv_elems)?;
707 validate_buffer_size(o_buf, "O", q_elems)?;
708 validate_buffer_size(l_buf, "L", l_elems)?;
709 validate_buffer_size(do_buf, "dO", q_elems)?;
710 validate_buffer_size(dq_buf, "dQ", q_elems)?;
711 validate_buffer_size(dk_buf, "dK", kv_elems)?;
712 validate_buffer_size(dv_buf, "dV", kv_elems)?;
713 if let Some(m) = mask {
714 validate_buffer_size(m, "mask", batch * h_q * ql * kl)?;
715 }
716
717 let d_vec_buf = device
722 .alloc_buffer(l_elems * 4, DType::F32, vec![l_elems])
723 .map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc D_vec: {e}")))?;
724 let dq_f32_buf = device
725 .alloc_buffer(q_elems * 4, DType::F32, vec![q_elems])
726 .map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dQ_f32: {e}")))?;
727 let dk_f32_buf = device
728 .alloc_buffer(kv_elems * 4, DType::F32, vec![kv_elems])
729 .map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dK_f32: {e}")))?;
730 let dv_f32_buf = device
731 .alloc_buffer(kv_elems * 4, DType::F32, vec![kv_elems])
732 .map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dV_f32: {e}")))?;
733
734 let nq = params.q_seq_len.div_ceil(BQ);
736 let nk = params.k_seq_len.div_ceil(BK);
737 let nq_aligned = params.q_seq_len / BQ;
738 let nk_aligned = params.k_seq_len / BK;
739 let ql_rem = params.q_seq_len % BQ;
740 let kl_rem = params.k_seq_len % BK;
741
742 let align_q = ql_rem == 0;
743 let align_k = kl_rem == 0;
744 let has_mask = mask.is_some();
745 let do_causal = params.causal;
746
747 let q_seq_stride = d as i64;
749 let q_head_stride = (ql * d) as i64;
750 let q_batch_stride = (h_q * ql * d) as i64;
751 let kv_seq_stride = d as i64;
752 let kv_head_stride = (kl * d) as i64;
753 let kv_batch_stride = (h_kv * kl * d) as i64;
754 let gqa_factor = (params.n_q_heads / params.n_kv_heads) as i32;
755
756 let attn_params = AttnParamsGpu {
757 b: params.batch as i32,
758 h: params.n_q_heads as i32,
759 d: params.head_dim as i32,
760 ql: params.q_seq_len as i32,
761 kl: params.k_seq_len as i32,
762 gqa_factor,
763 scale: params.scale,
764 softcapping: 1.0_f32,
765 nq: nq as i32,
766 nk: nk as i32,
767 nq_aligned: nq_aligned as i32,
768 nk_aligned: nk_aligned as i32,
769 ql_rem: ql_rem as i32,
770 kl_rem: kl_rem as i32,
771 ql_off: 0,
772 _pad: 0,
773 q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
774 k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
775 v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
776 o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
777 };
778
779 dispatch_compute_d(
781 encoder, registry, device.metal_device(),
782 o_buf, do_buf, &d_vec_buf, params,
783 )?;
784 encoder.memory_barrier();
785
786 let bwd_pipeline = registry.get_pipeline_with_bool_constants(
788 bwd_kernel_name,
789 device.metal_device(),
790 &[
791 (200, align_q),
792 (201, align_k),
793 (300, has_mask),
794 (301, do_causal),
795 ],
796 )?;
797
798 let grid = MTLSize::new(nq as u64, params.n_q_heads as u64, params.batch as u64);
800 let tg_size = MTLSize::new(32, WM as u64, WN as u64);
801
802 encoder.set_op_kind(CapturedOpKind::Sdpa);
803
804 if let Some(mask_buf) = mask {
805 let m_batch_stride = (h_q * ql * kl) as i64;
806 let m_head_stride = (ql * kl) as i64;
807 let m_ql_stride = kl as i64;
808 let mask_params = AttnMaskParamsGpu {
809 m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
810 };
811 encoder.encode_threadgroups_with_args(
812 bwd_pipeline,
813 &[
814 (0, KernelArg::Buffer(q_buf)),
815 (1, KernelArg::Buffer(k_buf)),
816 (2, KernelArg::Buffer(v_buf)),
817 (4, KernelArg::Buffer(l_buf)),
819 (5, KernelArg::Buffer(do_buf)),
820 (6, KernelArg::Buffer(&d_vec_buf)),
821 (7, KernelArg::Buffer(&dq_f32_buf)),
822 (8, KernelArg::Buffer(&dk_f32_buf)),
823 (9, KernelArg::Buffer(&dv_f32_buf)),
824 (10, KernelArg::Bytes(as_bytes(&attn_params))),
825 (11, KernelArg::Bytes(as_bytes(&mask_params))),
826 (12, KernelArg::Buffer(mask_buf)),
827 ],
828 grid,
829 tg_size,
830 );
831 } else {
832 encoder.encode_threadgroups_with_args(
833 bwd_pipeline,
834 &[
835 (0, KernelArg::Buffer(q_buf)),
836 (1, KernelArg::Buffer(k_buf)),
837 (2, KernelArg::Buffer(v_buf)),
838 (4, KernelArg::Buffer(l_buf)),
840 (5, KernelArg::Buffer(do_buf)),
841 (6, KernelArg::Buffer(&d_vec_buf)),
842 (7, KernelArg::Buffer(&dq_f32_buf)),
843 (8, KernelArg::Buffer(&dk_f32_buf)),
844 (9, KernelArg::Buffer(&dv_f32_buf)),
845 (10, KernelArg::Bytes(as_bytes(&attn_params))),
846 ],
847 grid,
848 tg_size,
849 );
850 }
851 encoder.memory_barrier();
852
853 dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dq_f32_buf, dq_buf, q_elems)?;
855 encoder.memory_barrier();
856 dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dk_f32_buf, dk_buf, kv_elems)?;
857 encoder.memory_barrier();
858 dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dv_f32_buf, dv_buf, kv_elems)?;
859
860 Ok(())
861}
862
863#[allow(clippy::too_many_arguments)]
890pub fn dispatch_flash_attn_train_bwd_bf16_d64(
891 encoder: &mut CommandEncoder,
892 device: &MlxDevice,
893 registry: &mut KernelRegistry,
894 q_buf: &MlxBuffer,
895 k_buf: &MlxBuffer,
896 v_buf: &MlxBuffer,
897 o_buf: &MlxBuffer,
898 l_buf: &MlxBuffer,
899 do_buf: &MlxBuffer,
900 mask: Option<&MlxBuffer>,
901 dq_buf: &MlxBuffer,
902 dk_buf: &MlxBuffer,
903 dv_buf: &MlxBuffer,
904 params: &FlashAttnTrainParams,
905) -> Result<()> {
906 dispatch_bwd_inner(
907 encoder, device, registry,
908 q_buf, k_buf, v_buf, o_buf, l_buf, do_buf, mask,
909 dq_buf, dk_buf, dv_buf,
910 params, K_BWD_D64, 64,
911 )
912}
913
914#[allow(clippy::too_many_arguments)]
923pub fn dispatch_flash_attn_train_bwd_bf16_d256(
924 encoder: &mut CommandEncoder,
925 device: &MlxDevice,
926 registry: &mut KernelRegistry,
927 q_buf: &MlxBuffer,
928 k_buf: &MlxBuffer,
929 v_buf: &MlxBuffer,
930 o_buf: &MlxBuffer,
931 l_buf: &MlxBuffer,
932 do_buf: &MlxBuffer,
933 mask: Option<&MlxBuffer>,
934 dq_buf: &MlxBuffer,
935 dk_buf: &MlxBuffer,
936 dv_buf: &MlxBuffer,
937 params: &FlashAttnTrainParams,
938) -> Result<()> {
939 dispatch_bwd_inner(
940 encoder, device, registry,
941 q_buf, k_buf, v_buf, o_buf, l_buf, do_buf, mask,
942 dq_buf, dk_buf, dv_buf,
943 params, K_BWD_D256, 256,
944 )
945}