use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::DType;
pub static FLASH_ATTN_PREFILL_BLK_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_prefill_blk.metal");
pub const K_BLK_BF16: &str = "flash_attn_prefill_blk_bf16";
pub const FC_IDX_BQ: usize = 400;
pub const FC_IDX_BK: usize = 401;
pub const FC_IDX_HAS_BLK: usize = 303;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(K_BLK_BF16, FLASH_ATTN_PREFILL_BLK_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct BlkParamsGpu {
seq_len_q: i32,
seq_len_k: i32,
mask_row_stride: i32,
_pad: i32,
}
#[derive(Debug, Clone, Copy)]
pub struct BlkParams {
pub seq_len_q: u32,
pub seq_len_k: u32,
pub bq: u32,
pub bk: u32,
}
pub fn blk_buffer_byte_len(params: &BlkParams) -> Result<usize> {
if params.seq_len_q == 0 || params.seq_len_k == 0 {
return Err(MlxError::InvalidArgument(
"blk_buffer_byte_len: seq lengths must be > 0".into(),
));
}
if params.bq == 0 || params.bk == 0 {
return Err(MlxError::InvalidArgument(
"blk_buffer_byte_len: tile dimensions (bq, bk) must be > 0".into(),
));
}
let nq = params.seq_len_q.div_ceil(params.bq) as usize;
let nk = params.seq_len_k.div_ceil(params.bk) as usize;
let raw = nq.checked_mul(nk).ok_or_else(|| {
MlxError::InvalidArgument(format!(
"blk_buffer_byte_len: nq ({}) * nk ({}) overflows usize",
nq, nk,
))
})?;
let aligned = (raw + 31) & !31_usize;
Ok(aligned.max(32))
}
pub fn alloc_blk_buffer(device: &MlxDevice, params: &BlkParams) -> Result<MlxBuffer> {
let byte_len = blk_buffer_byte_len(params)?;
let nq = params.seq_len_q.div_ceil(params.bq) as usize;
let nk = params.seq_len_k.div_ceil(params.bk) as usize;
device.alloc_buffer(byte_len, DType::U8, vec![nq, nk])
}
fn validate_params(params: &BlkParams) -> Result<()> {
if params.seq_len_q == 0 {
return Err(MlxError::InvalidArgument(
"dispatch_flash_attn_prefill_blk: seq_len_q must be > 0".into(),
));
}
if params.seq_len_k == 0 {
return Err(MlxError::InvalidArgument(
"dispatch_flash_attn_prefill_blk: seq_len_k must be > 0".into(),
));
}
if params.bq == 0 {
return Err(MlxError::InvalidArgument(
"dispatch_flash_attn_prefill_blk: bq must be > 0".into(),
));
}
if params.bk == 0 {
return Err(MlxError::InvalidArgument(
"dispatch_flash_attn_prefill_blk: bk must be > 0".into(),
));
}
Ok(())
}
pub fn dispatch_flash_attn_prefill_blk(
encoder: &mut CommandEncoder,
device: &MlxDevice,
registry: &mut KernelRegistry,
mask: &MlxBuffer,
blk_out: &MlxBuffer,
params: &BlkParams,
) -> Result<()> {
validate_params(params)?;
if mask.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_blk: mask buffer must be BF16, got {:?}",
mask.dtype()
)));
}
let ql = params.seq_len_q as usize;
let kl = params.seq_len_k as usize;
let mask_bytes_needed = ql
.checked_mul(kl)
.and_then(|n| n.checked_mul(2))
.ok_or_else(|| {
MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_blk: qL ({}) * kL ({}) * 2 overflows usize",
ql, kl,
))
})?;
if mask.byte_len() < mask_bytes_needed {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_blk: mask buffer too small: \
expected at least {mask_bytes_needed} bytes, got {}",
mask.byte_len()
)));
}
let nq = params.seq_len_q.div_ceil(params.bq) as usize;
let nk = params.seq_len_k.div_ceil(params.bk) as usize;
let blk_bytes_needed = nq.checked_mul(nk).ok_or_else(|| {
MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_blk: nq ({}) * nk ({}) overflows usize",
nq, nk,
))
})?;
if blk_out.byte_len() < blk_bytes_needed {
return Err(MlxError::InvalidArgument(format!(
"dispatch_flash_attn_prefill_blk: blk_out buffer too small: \
expected at least {blk_bytes_needed} bytes, got {}",
blk_out.byte_len()
)));
}
let gpu_params = BlkParamsGpu {
seq_len_q: params.seq_len_q as i32,
seq_len_k: params.seq_len_k as i32,
mask_row_stride: params.seq_len_k as i32,
_pad: 0,
};
let pipeline = registry.get_pipeline_with_constants(
K_BLK_BF16,
device.metal_device(),
&[],
&[
(FC_IDX_BQ, params.bq as i32),
(FC_IDX_BK, params.bk as i32),
],
)?;
let grid = MTLSize::new(nk as u64, nq as u64, 1);
let tg_size = MTLSize::new(32, 1, 1);
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(mask)),
(1, KernelArg::Buffer(blk_out)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
grid,
tg_size,
);
Ok(())
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_blk_params_gpu_size() {
assert_eq!(std::mem::size_of::<BlkParamsGpu>(), 16);
}
#[test]
fn test_fc_indices_match_shader() {
assert_eq!(FC_IDX_BQ, 400);
assert_eq!(FC_IDX_BK, 401);
assert_eq!(FC_IDX_HAS_BLK, 303);
}
#[test]
fn test_blk_buffer_byte_len_d256_gemma4() {
let p = BlkParams {
seq_len_q: 2455,
seq_len_k: 2455,
bq: 32,
bk: 16,
};
let bytes = blk_buffer_byte_len(&p).unwrap();
assert!(bytes >= 11858, "must cover all 11858 tiles, got {bytes}");
assert_eq!(bytes % 32, 0, "must be 32-byte aligned");
}
#[test]
fn test_blk_buffer_byte_len_d512_gemma4() {
let p = BlkParams {
seq_len_q: 2455,
seq_len_k: 2455,
bq: 8,
bk: 8,
};
let bytes = blk_buffer_byte_len(&p).unwrap();
assert!(bytes >= 94249, "must cover all 94249 tiles, got {bytes}");
assert_eq!(bytes % 32, 0, "must be 32-byte aligned");
}
#[test]
fn test_blk_buffer_byte_len_minimum() {
let p = BlkParams {
seq_len_q: 1,
seq_len_k: 1,
bq: 32,
bk: 16,
};
assert_eq!(blk_buffer_byte_len(&p).unwrap(), 32);
}
#[test]
fn test_blk_buffer_byte_len_zero_rejected() {
assert!(blk_buffer_byte_len(&BlkParams {
seq_len_q: 0,
seq_len_k: 8,
bq: 32,
bk: 16,
}).is_err());
assert!(blk_buffer_byte_len(&BlkParams {
seq_len_q: 8,
seq_len_k: 0,
bq: 32,
bk: 16,
}).is_err());
assert!(blk_buffer_byte_len(&BlkParams {
seq_len_q: 8,
seq_len_k: 8,
bq: 0,
bk: 16,
}).is_err());
}
#[test]
fn test_validate_params_zero_rejected() {
assert!(validate_params(&BlkParams {
seq_len_q: 0,
seq_len_k: 8,
bq: 32,
bk: 16,
}).is_err());
}
#[test]
fn test_kernel_name_stable() {
assert_eq!(K_BLK_BF16, "flash_attn_prefill_blk_bf16");
}
#[test]
fn test_register_does_not_panic() {
let mut registry = KernelRegistry::new();
register(&mut registry);
}
}