use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::DType;
pub static FLASH_ATTN_PREFILL_MASK_SHADER_SOURCE: &str =
include_str!("../shaders/flash_attn_prefill_mask.metal");
pub const K_FILL_BF16: &str = "flash_attn_prefill_mask_fill_bf16";
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(K_FILL_BF16, FLASH_ATTN_PREFILL_MASK_SHADER_SOURCE);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SdpaMaskParams {
pub seq_len_q: u32,
pub seq_len_k: u32,
pub window_size: Option<u32>,
pub causal: bool,
pub q_abs_offset: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MaskFillParamsGpu {
seq_len_k: u32,
q_abs_offset: u32,
n_swa: i32,
causal: u32,
}
pub fn build_sdpa_mask_bf16(
device: &MlxDevice,
registry: &mut KernelRegistry,
encoder: &mut CommandEncoder,
params: &SdpaMaskParams,
) -> Result<MlxBuffer> {
if params.seq_len_q == 0 {
return Err(MlxError::InvalidArgument(
"build_sdpa_mask_bf16: seq_len_q must be > 0".into(),
));
}
if params.seq_len_k == 0 {
return Err(MlxError::InvalidArgument(
"build_sdpa_mask_bf16: seq_len_k must be > 0".into(),
));
}
if let Some(0) = params.window_size {
return Err(MlxError::InvalidArgument(
"build_sdpa_mask_bf16: window_size=Some(0) is not allowed \
(llama.cpp treats n_swa=0 as undefined; pass None for \
no-window / causal-only)".into(),
));
}
let total_elems = (params.seq_len_q as u64)
.checked_mul(params.seq_len_k as u64)
.ok_or_else(|| {
MlxError::InvalidArgument(format!(
"build_sdpa_mask_bf16: seq_len_q ({}) * seq_len_k ({}) overflows u64",
params.seq_len_q, params.seq_len_k
))
})?;
let byte_len = (total_elems as usize)
.checked_mul(2)
.ok_or_else(|| {
MlxError::InvalidArgument(format!(
"build_sdpa_mask_bf16: mask size ({} elems × 2 B) overflows usize",
total_elems
))
})?;
let mask = device.alloc_buffer(
byte_len,
DType::BF16,
vec![params.seq_len_q as usize, params.seq_len_k as usize],
)?;
let fill_params = MaskFillParamsGpu {
seq_len_k: params.seq_len_k,
q_abs_offset: params.q_abs_offset,
n_swa: match params.window_size {
None => -1,
Some(w) => w.min(i32::MAX as u32) as i32,
},
causal: if params.causal { 1 } else { 0 },
};
let pipeline = registry.get_pipeline(K_FILL_BF16, device.metal_device())?;
let tg_x = {
let want = params.seq_len_k.next_power_of_two().max(32);
want.min(256)
};
let threadgroups = MTLSize::new(params.seq_len_q as u64, 1, 1);
let tg_size = MTLSize::new(tg_x as u64, 1, 1);
encoder.encode_threadgroups_with_args(
pipeline,
&[
(0, KernelArg::Buffer(&mask)),
(1, KernelArg::Bytes(as_bytes(&fill_params))),
],
threadgroups,
tg_size,
);
Ok(mask)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_mask_fill_params_gpu_size() {
assert_eq!(std::mem::size_of::<MaskFillParamsGpu>(), 16);
}
#[test]
fn test_mask_fill_params_encoding_global() {
let p = MaskFillParamsGpu {
seq_len_k: 2048,
q_abs_offset: 0,
n_swa: -1,
causal: 1,
};
assert_eq!(p.n_swa, -1, "global mask encodes n_swa=-1");
assert_eq!(p.causal, 1);
}
#[test]
fn test_mask_fill_params_encoding_sliding() {
let p = MaskFillParamsGpu {
seq_len_k: 2048,
q_abs_offset: 0,
n_swa: 1024,
causal: 1,
};
assert_eq!(p.n_swa, 1024, "sliding mask encodes n_swa>0");
}
#[test]
fn test_reject_zero_seq_len_q() {
let p = SdpaMaskParams {
seq_len_q: 0,
seq_len_k: 8,
window_size: None,
causal: true,
q_abs_offset: 0,
};
assert_eq!(p.seq_len_q, 0);
}
#[test]
fn test_register_adds_kernel_name() {
let mut registry = KernelRegistry::new();
register(&mut registry);
assert_eq!(K_FILL_BF16, "flash_attn_prefill_mask_fill_bf16");
}
}