use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CapturedOpKind, CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::ops::flash_attn_prefill::{AttnMaskParamsGpu, AttnParamsGpu};
pub static FLASH_ATTN_TRAIN_FWD_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_train_fwd.metal");
const K_BF16_D64: &str = "flash_attn_train_fwd_bf16_d64";
const K_BF16_D64_BOOLMASK: &str = "flash_attn_train_fwd_bf16_d64_boolmask";
const K_BF16_D256: &str = "flash_attn_train_fwd_bf16_d256";
const K_BF16_D256_BOOLMASK: &str = "flash_attn_train_fwd_bf16_d256_boolmask";
const ALL_KERNEL_NAMES: &[&str] = &[
K_BF16_D64,
K_BF16_D64_BOOLMASK,
K_BF16_D256,
K_BF16_D256_BOOLMASK,
];
pub fn register(registry: &mut KernelRegistry) {
for &name in ALL_KERNEL_NAMES {
registry.register_source(name, FLASH_ATTN_TRAIN_FWD_SHADER_SOURCE);
}
}
const BQ: u32 = 32;
const BK: u32 = 16;
const WM: u32 = 4;
const WN: u32 = 1;
#[derive(Debug, Clone, Copy)]
pub struct FlashAttnTrainParams {
pub batch: u32,
pub n_q_heads: u32,
pub n_kv_heads: u32,
pub head_dim: u32,
pub q_seq_len: u32,
pub k_seq_len: u32,
pub scale: f32,
pub causal: bool,
}
fn validate_params(p: &FlashAttnTrainParams) -> Result<()> {
if p.n_q_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_train: n_q_heads must be > 0".into(),
));
}
if p.n_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_train: n_kv_heads must be > 0".into(),
));
}
if p.n_q_heads % p.n_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train: n_q_heads ({}) must be divisible by n_kv_heads ({})",
p.n_q_heads, p.n_kv_heads
)));
}
if p.q_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_train: q_seq_len must be > 0".into(),
));
}
if p.k_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_train: k_seq_len must be > 0".into(),
));
}
if p.batch == 0 {
return Err(MlxError::InvalidArgument(
"flash_attn_train: batch must be > 0".into(),
));
}
Ok(())
}
fn validate_buffer_size(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
let expected_bytes = expected_elements * buf.dtype().size_of();
if buf.byte_len() < expected_bytes {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train: {name} buffer too small: expected at least \
{expected_bytes} bytes, got {}",
buf.byte_len()
)));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn dispatch_inner(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q_buf: &MlxBuffer,
k_buf: &MlxBuffer,
v_buf: &MlxBuffer,
mask: Option<&MlxBuffer>,
o_buf: &mut MlxBuffer,
l_buf: &mut MlxBuffer,
params: &FlashAttnTrainParams,
kernel_name: &str,
head_dim_expected: u32,
) -> Result<()> {
if params.head_dim != head_dim_expected {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train ({}): head_dim must be {head_dim_expected}, got {}",
kernel_name, params.head_dim
)));
}
validate_params(params)?;
for (buf, name) in &[(q_buf, "Q"), (k_buf, "K"), (v_buf, "V"), (o_buf as &MlxBuffer, "O")] {
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train ({kernel_name}): {name} buffer must be BF16, got {:?}",
buf.dtype()
)));
}
}
if l_buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train ({kernel_name}): L_out buffer must be F32, got {:?}",
l_buf.dtype()
)));
}
if let Some(m) = mask {
if m.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train ({kernel_name}): mask buffer must be BF16, got {:?}",
m.dtype()
)));
}
}
let batch = params.batch as usize;
let h = params.n_q_heads as usize;
let h_kv = params.n_kv_heads as usize;
let ql = params.q_seq_len as usize;
let kl = params.k_seq_len as usize;
let d = params.head_dim as usize;
validate_buffer_size(q_buf, "Q", batch * h * ql * d)?;
validate_buffer_size(k_buf, "K", batch * h_kv * kl * d)?;
validate_buffer_size(v_buf, "V", batch * h_kv * kl * d)?;
validate_buffer_size(o_buf, "O", batch * h * ql * d)?;
validate_buffer_size(l_buf, "L_out", batch * h * ql)?;
if let Some(m) = mask {
validate_buffer_size(m, "mask", batch * h * ql * kl)?;
}
let nq = params.q_seq_len.div_ceil(BQ);
let nk = params.k_seq_len.div_ceil(BK);
let nq_aligned = params.q_seq_len / BQ;
let nk_aligned = params.k_seq_len / BK;
let ql_rem = params.q_seq_len % BQ;
let kl_rem = params.k_seq_len % BK;
let align_q = ql_rem == 0;
let align_k = kl_rem == 0;
let has_mask = mask.is_some();
let do_causal = params.causal;
let pipeline = registry.get_pipeline_with_bool_constants(
kernel_name,
device.metal_device(),
&[
(200, align_q),
(201, align_k),
(300, has_mask),
(301, do_causal),
],
)?;
let q_seq_stride = d as i64;
let q_head_stride = (ql * d) as i64;
let q_batch_stride = (h * ql * d) as i64;
let kv_seq_stride = d as i64;
let kv_head_stride = (kl * d) as i64;
let kv_batch_stride = (h_kv * kl * d) as i64;
let gqa_factor = (params.n_q_heads / params.n_kv_heads) as i32;
let attn_params = AttnParamsGpu {
b: params.batch as i32,
h: params.n_q_heads as i32,
d: params.head_dim as i32,
ql: params.q_seq_len as i32,
kl: params.k_seq_len as i32,
gqa_factor,
scale: params.scale,
softcapping: 1.0_f32,
nq: nq as i32,
nk: nk as i32,
nq_aligned: nq_aligned as i32,
nk_aligned: nk_aligned as i32,
ql_rem: ql_rem as i32,
kl_rem: kl_rem as i32,
ql_off: 0,
_pad: 0,
q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
};
let grid = MTLSize::new(nq as u64, params.n_q_heads as u64, params.batch as u64);
let tg_size = MTLSize::new(32, WM as u64, WN as u64);
encoder.set_op_kind(CapturedOpKind::Sdpa);
if let Some(mask_buf) = mask {
let m_batch_stride = (h * ql * kl) as i64;
let m_head_stride = (ql * kl) as i64;
let m_ql_stride = kl as i64;
let mask_params = AttnMaskParamsGpu {
m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
};
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(q_buf)),
(1, KernelArg::Buffer(k_buf)),
(2, KernelArg::Buffer(v_buf)),
(3, KernelArg::Buffer(o_buf)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
(5, KernelArg::Bytes(as_bytes(&mask_params))),
(6, KernelArg::Buffer(mask_buf)),
(8, KernelArg::Buffer(l_buf)),
],
grid,
tg_size,
);
} else {
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(q_buf)),
(1, KernelArg::Buffer(k_buf)),
(2, KernelArg::Buffer(v_buf)),
(3, KernelArg::Buffer(o_buf)),
(4, KernelArg::Bytes(as_bytes(&attn_params))),
(8, KernelArg::Buffer(l_buf)),
],
grid,
tg_size,
);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_train_fwd_bf16_d64(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q_buf: &MlxBuffer,
k_buf: &MlxBuffer,
v_buf: &MlxBuffer,
mask: Option<&MlxBuffer>,
o_buf: &mut MlxBuffer,
l_buf: &mut MlxBuffer,
params: &FlashAttnTrainParams,
) -> Result<()> {
dispatch_inner(
encoder, device, registry,
q_buf, k_buf, v_buf, mask, o_buf, l_buf,
params, K_BF16_D64, 64,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_train_fwd_bf16_d256(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q_buf: &MlxBuffer,
k_buf: &MlxBuffer,
v_buf: &MlxBuffer,
mask: Option<&MlxBuffer>,
o_buf: &mut MlxBuffer,
l_buf: &mut MlxBuffer,
params: &FlashAttnTrainParams,
) -> Result<()> {
dispatch_inner(
encoder, device, registry,
q_buf, k_buf, v_buf, mask, o_buf, l_buf,
params, K_BF16_D256, 256,
)
}
#[doc(hidden)]
pub fn all_kernel_names_for_test() -> &'static [&'static str] {
ALL_KERNEL_NAMES
}
pub static FLASH_ATTN_TRAIN_BWD_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_train_bwd.metal");
pub static FLASH_ATTN_TRAIN_BWD_COMPUTE_D_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_train_bwd_compute_d.metal");
const K_BWD_COMPUTE_D: &str = "flash_attn_train_bwd_compute_d_bf16";
const K_BWD_D64: &str = "flash_attn_train_bwd_bf16_d64";
const K_BWD_D256: &str = "flash_attn_train_bwd_bf16_d256";
const K_F32_TO_BF16: &str = "f32_to_bf16_cast";
const ALL_BWD_KERNEL_NAMES: &[&str] = &[
K_BWD_COMPUTE_D,
K_BWD_D64,
K_BWD_D256,
K_F32_TO_BF16,
];
pub fn register_bwd(registry: &mut KernelRegistry) {
registry.register_source(K_BWD_COMPUTE_D, FLASH_ATTN_TRAIN_BWD_COMPUTE_D_SHADER_SOURCE);
for &name in &[K_BWD_D64, K_BWD_D256, K_F32_TO_BF16] {
registry.register_source(name, FLASH_ATTN_TRAIN_BWD_SHADER_SOURCE);
}
}
#[doc(hidden)]
pub fn all_bwd_kernel_names_for_test() -> &'static [&'static str] {
ALL_BWD_KERNEL_NAMES
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct ComputeDParams {
batch: u32,
n_q_heads: u32,
q_seq_len: u32,
head_dim: u32,
}
fn dispatch_compute_d(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
o_buf: &MlxBuffer,
do_buf: &MlxBuffer,
d_out_buf: &MlxBuffer,
params: &FlashAttnTrainParams,
) -> Result<()> {
let p = ComputeDParams {
batch: params.batch,
n_q_heads: params.n_q_heads,
q_seq_len: params.q_seq_len,
head_dim: params.head_dim,
};
let pipeline = registry.get_pipeline(K_BWD_COMPUTE_D, device)?;
let tg_x = std::cmp::min(256, params.head_dim.next_power_of_two()) as u64;
let grid = MTLSize::new(
params.q_seq_len as u64,
1,
(params.batch * params.n_q_heads) as u64,
);
let tg_size = MTLSize::new(tg_x, 1, 1);
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(o_buf)),
(1, KernelArg::Buffer(do_buf)),
(2, KernelArg::Buffer(d_out_buf)),
(3, KernelArg::Bytes(as_bytes(&p))),
],
grid,
tg_size,
);
Ok(())
}
fn dispatch_f32_to_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
dst: &MlxBuffer,
n_elems: usize,
) -> Result<()> {
let pipeline = registry.get_pipeline(K_F32_TO_BF16, device)?;
let tg_x = std::cmp::min(256u64, n_elems as u64);
let n_groups = (n_elems as u64).div_ceil(tg_x);
let n_u32 = n_elems as u32;
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(dst)),
(2, KernelArg::Bytes(as_bytes(&n_u32))),
],
MTLSize::new(n_groups, 1, 1),
MTLSize::new(tg_x, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn dispatch_bwd_inner(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q_buf: &MlxBuffer,
k_buf: &MlxBuffer,
v_buf: &MlxBuffer,
o_buf: &MlxBuffer,
l_buf: &MlxBuffer,
do_buf: &MlxBuffer,
mask: Option<&MlxBuffer>,
dq_buf: &mut MlxBuffer,
dk_buf: &mut MlxBuffer,
dv_buf: &mut MlxBuffer,
params: &FlashAttnTrainParams,
bwd_kernel_name: &str,
head_dim_expected: u32,
) -> Result<()> {
if params.head_dim != head_dim_expected {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train_bwd ({bwd_kernel_name}): head_dim must be \
{head_dim_expected}, got {}",
params.head_dim
)));
}
validate_params(params)?;
for (buf, name) in &[
(q_buf, "Q"),
(k_buf, "K"),
(v_buf, "V"),
(o_buf, "O"),
(do_buf, "dO"),
] {
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train_bwd ({bwd_kernel_name}): {name} buffer must be BF16, \
got {:?}",
buf.dtype()
)));
}
}
for (buf, name) in &[(l_buf, "L")] {
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train_bwd ({bwd_kernel_name}): {name} buffer must be F32, \
got {:?}",
buf.dtype()
)));
}
}
for (buf, name) in &[
(dq_buf as &MlxBuffer, "dQ"),
(dk_buf as &MlxBuffer, "dK"),
(dv_buf as &MlxBuffer, "dV"),
] {
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train_bwd ({bwd_kernel_name}): {name} output buffer must be \
BF16, got {:?}",
buf.dtype()
)));
}
}
if let Some(m) = mask {
if m.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"flash_attn_train_bwd ({bwd_kernel_name}): mask buffer must be BF16, \
got {:?}",
m.dtype()
)));
}
}
let batch = params.batch as usize;
let h_q = params.n_q_heads as usize;
let h_kv = params.n_kv_heads as usize;
let ql = params.q_seq_len as usize;
let kl = params.k_seq_len as usize;
let d = params.head_dim as usize;
let q_elems = batch * h_q * ql * d;
let kv_elems = batch * h_kv * kl * d;
let l_elems = batch * h_q * ql;
validate_buffer_size(q_buf, "Q", q_elems)?;
validate_buffer_size(k_buf, "K", kv_elems)?;
validate_buffer_size(v_buf, "V", kv_elems)?;
validate_buffer_size(o_buf, "O", q_elems)?;
validate_buffer_size(l_buf, "L", l_elems)?;
validate_buffer_size(do_buf, "dO", q_elems)?;
validate_buffer_size(dq_buf, "dQ", q_elems)?;
validate_buffer_size(dk_buf, "dK", kv_elems)?;
validate_buffer_size(dv_buf, "dV", kv_elems)?;
if let Some(m) = mask {
validate_buffer_size(m, "mask", batch * h_q * ql * kl)?;
}
let d_vec_buf = device
.alloc_buffer(l_elems * 4, DType::F32, vec![l_elems])
.map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc D_vec: {e}")))?;
let dq_f32_buf = device
.alloc_buffer(q_elems * 4, DType::F32, vec![q_elems])
.map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dQ_f32: {e}")))?;
let dk_f32_buf = device
.alloc_buffer(kv_elems * 4, DType::F32, vec![kv_elems])
.map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dK_f32: {e}")))?;
let dv_f32_buf = device
.alloc_buffer(kv_elems * 4, DType::F32, vec![kv_elems])
.map_err(|e| MlxError::InvalidArgument(format!("flash_attn_train_bwd: alloc dV_f32: {e}")))?;
let nq = params.q_seq_len.div_ceil(BQ);
let nk = params.k_seq_len.div_ceil(BK);
let nq_aligned = params.q_seq_len / BQ;
let nk_aligned = params.k_seq_len / BK;
let ql_rem = params.q_seq_len % BQ;
let kl_rem = params.k_seq_len % BK;
let align_q = ql_rem == 0;
let align_k = kl_rem == 0;
let has_mask = mask.is_some();
let do_causal = params.causal;
let q_seq_stride = d as i64;
let q_head_stride = (ql * d) as i64;
let q_batch_stride = (h_q * ql * d) as i64;
let kv_seq_stride = d as i64;
let kv_head_stride = (kl * d) as i64;
let kv_batch_stride = (h_kv * kl * d) as i64;
let gqa_factor = (params.n_q_heads / params.n_kv_heads) as i32;
let attn_params = AttnParamsGpu {
b: params.batch as i32,
h: params.n_q_heads as i32,
d: params.head_dim as i32,
ql: params.q_seq_len as i32,
kl: params.k_seq_len as i32,
gqa_factor,
scale: params.scale,
softcapping: 1.0_f32,
nq: nq as i32,
nk: nk as i32,
nq_aligned: nq_aligned as i32,
nk_aligned: nk_aligned as i32,
ql_rem: ql_rem as i32,
kl_rem: kl_rem as i32,
ql_off: 0,
_pad: 0,
q_strides: [q_batch_stride, q_head_stride, q_seq_stride],
k_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
v_strides: [kv_batch_stride, kv_head_stride, kv_seq_stride],
o_strides: [q_batch_stride, q_head_stride, q_seq_stride],
};
dispatch_compute_d(
encoder, registry, device.metal_device(),
o_buf, do_buf, &d_vec_buf, params,
)?;
encoder.memory_barrier();
let bwd_pipeline = registry.get_pipeline_with_bool_constants(
bwd_kernel_name,
device.metal_device(),
&[
(200, align_q),
(201, align_k),
(300, has_mask),
(301, do_causal),
],
)?;
let grid = MTLSize::new(nq as u64, params.n_q_heads as u64, params.batch as u64);
let tg_size = MTLSize::new(32, WM as u64, WN as u64);
encoder.set_op_kind(CapturedOpKind::Sdpa);
if let Some(mask_buf) = mask {
let m_batch_stride = (h_q * ql * kl) as i64;
let m_head_stride = (ql * kl) as i64;
let m_ql_stride = kl as i64;
let mask_params = AttnMaskParamsGpu {
m_strides: [m_batch_stride, m_head_stride, m_ql_stride],
};
encoder.encode_threadgroups_with_args(
bwd_pipeline,
&[
(0, KernelArg::Buffer(q_buf)),
(1, KernelArg::Buffer(k_buf)),
(2, KernelArg::Buffer(v_buf)),
(4, KernelArg::Buffer(l_buf)),
(5, KernelArg::Buffer(do_buf)),
(6, KernelArg::Buffer(&d_vec_buf)),
(7, KernelArg::Buffer(&dq_f32_buf)),
(8, KernelArg::Buffer(&dk_f32_buf)),
(9, KernelArg::Buffer(&dv_f32_buf)),
(10, KernelArg::Bytes(as_bytes(&attn_params))),
(11, KernelArg::Bytes(as_bytes(&mask_params))),
(12, KernelArg::Buffer(mask_buf)),
],
grid,
tg_size,
);
} else {
encoder.encode_threadgroups_with_args(
bwd_pipeline,
&[
(0, KernelArg::Buffer(q_buf)),
(1, KernelArg::Buffer(k_buf)),
(2, KernelArg::Buffer(v_buf)),
(4, KernelArg::Buffer(l_buf)),
(5, KernelArg::Buffer(do_buf)),
(6, KernelArg::Buffer(&d_vec_buf)),
(7, KernelArg::Buffer(&dq_f32_buf)),
(8, KernelArg::Buffer(&dk_f32_buf)),
(9, KernelArg::Buffer(&dv_f32_buf)),
(10, KernelArg::Bytes(as_bytes(&attn_params))),
],
grid,
tg_size,
);
}
encoder.memory_barrier();
dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dq_f32_buf, dq_buf, q_elems)?;
encoder.memory_barrier();
dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dk_f32_buf, dk_buf, kv_elems)?;
encoder.memory_barrier();
dispatch_f32_to_bf16(encoder, registry, device.metal_device(), &dv_f32_buf, dv_buf, kv_elems)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_train_bwd_bf16_d64(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q_buf: &MlxBuffer,
k_buf: &MlxBuffer,
v_buf: &MlxBuffer,
o_buf: &MlxBuffer,
l_buf: &MlxBuffer,
do_buf: &MlxBuffer,
mask: Option<&MlxBuffer>,
dq_buf: &mut MlxBuffer,
dk_buf: &mut MlxBuffer,
dv_buf: &mut MlxBuffer,
params: &FlashAttnTrainParams,
) -> Result<()> {
dispatch_bwd_inner(
encoder, device, registry,
q_buf, k_buf, v_buf, o_buf, l_buf, do_buf, mask,
dq_buf, dk_buf, dv_buf,
params, K_BWD_D64, 64,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_flash_attn_train_bwd_bf16_d256(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
q_buf: &MlxBuffer,
k_buf: &MlxBuffer,
v_buf: &MlxBuffer,
o_buf: &MlxBuffer,
l_buf: &MlxBuffer,
do_buf: &MlxBuffer,
mask: Option<&MlxBuffer>,
dq_buf: &mut MlxBuffer,
dk_buf: &mut MlxBuffer,
dv_buf: &mut MlxBuffer,
params: &FlashAttnTrainParams,
) -> Result<()> {
dispatch_bwd_inner(
encoder, device, registry,
q_buf, k_buf, v_buf, o_buf, l_buf, do_buf, mask,
dq_buf, dk_buf, dv_buf,
params, K_BWD_D256, 256,
)
}