use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::encoder::{CapturedOpKind, CommandEncoder};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::DType;
pub static SDPA_SHADER_SOURCE: &str = include_str!("../shaders/sdpa.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("sdpa", SDPA_SHADER_SOURCE);
}
#[derive(Debug, Clone, Copy)]
pub struct SdpaParams {
pub n_heads: u32,
pub n_kv_heads: u32,
pub head_dim: u32,
pub seq_len: u32,
pub kv_seq_len: u32,
pub scale: f32,
pub kv_capacity: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct SdpaParamsGpu {
n_heads: u32,
n_kv_heads: u32,
head_dim: u32,
seq_len: u32,
kv_seq_len: u32,
scale: f32,
kv_capacity: u32,
}
const TILE_Q: u32 = 32;
fn validate_params(params: &SdpaParams) -> Result<()> {
if params.head_dim == 0 {
return Err(MlxError::InvalidArgument(
"head_dim must be > 0".into(),
));
}
if params.n_heads == 0 {
return Err(MlxError::InvalidArgument(
"n_heads must be > 0".into(),
));
}
if params.n_kv_heads == 0 {
return Err(MlxError::InvalidArgument(
"n_kv_heads must be > 0".into(),
));
}
if params.n_heads % params.n_kv_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"n_heads ({}) must be divisible by n_kv_heads ({})",
params.n_heads, params.n_kv_heads
)));
}
if params.seq_len == 0 {
return Err(MlxError::InvalidArgument(
"seq_len must be > 0".into(),
));
}
if params.kv_seq_len == 0 {
return Err(MlxError::InvalidArgument(
"kv_seq_len must be > 0".into(),
));
}
Ok(())
}
fn validate_buffer(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
let expected_bytes = expected_elements * buf.dtype().size_of();
if buf.byte_len() < expected_bytes {
return Err(MlxError::InvalidArgument(format!(
"{name} buffer too small: expected at least {expected_bytes} bytes, got {}",
buf.byte_len()
)));
}
Ok(())
}
pub fn sdpa(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
output: &MlxBuffer,
params: &SdpaParams,
batch_size: u32,
) -> Result<()> {
validate_params(params)?;
let kv_cap = if params.kv_capacity == 0 { params.kv_seq_len } else { params.kv_capacity };
let q_elements = batch_size as usize
* params.n_heads as usize
* params.seq_len as usize
* params.head_dim as usize;
let kv_elements = batch_size as usize
* params.n_kv_heads as usize
* kv_cap as usize
* params.head_dim as usize;
validate_buffer(q, "Q", q_elements)?;
validate_buffer(k, "K", kv_elements)?;
validate_buffer(v, "V", kv_elements)?;
validate_buffer(output, "output", q_elements)?;
let params_gpu = SdpaParamsGpu {
n_heads: params.n_heads,
n_kv_heads: params.n_kv_heads,
head_dim: params.head_dim,
seq_len: params.seq_len,
kv_seq_len: params.kv_seq_len,
scale: params.scale,
kv_capacity: kv_cap,
};
let params_bytes = bytemuck::bytes_of(¶ms_gpu);
let mut params_buf = device.alloc_buffer(
params_bytes.len(),
DType::U8,
vec![params_bytes.len()],
)?;
{
let dst: &mut [u8] = params_buf.as_mut_slice()?;
dst[..params_bytes.len()].copy_from_slice(params_bytes);
}
let kernel_name = if q.dtype() == DType::BF16 { "sdpa_bf16" } else { "sdpa" };
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let n_tiles = (params.seq_len + TILE_Q - 1) / TILE_Q;
let threadgroups = MTLSize::new(
batch_size as u64,
params.n_heads as u64,
n_tiles as u64,
);
let threadgroup_size = MTLSize::new(TILE_Q as u64, 1, 1);
encoder.set_op_kind(CapturedOpKind::Sdpa);
encoder.encode_threadgroups(
pipeline,
&[
(0, q),
(1, k),
(2, v),
(3, output),
(4, ¶ms_buf),
],
threadgroups,
threadgroup_size,
);
Ok(())
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_validate_params_ok() {
let p = SdpaParams {
n_heads: 16,
n_kv_heads: 8,
head_dim: 256,
seq_len: 128,
kv_seq_len: 128,
scale: 1.0 / (256.0_f32).sqrt(),
kv_capacity: 128,
};
assert!(validate_params(&p).is_ok());
}
#[test]
fn test_validate_params_zero_head_dim() {
let p = SdpaParams {
n_heads: 16,
n_kv_heads: 8,
head_dim: 0,
seq_len: 128,
kv_seq_len: 128,
scale: 1.0,
kv_capacity: 128,
};
assert!(matches!(
validate_params(&p),
Err(MlxError::InvalidArgument(_))
));
}
#[test]
fn test_validate_params_bad_ratio() {
let p = SdpaParams {
n_heads: 16,
n_kv_heads: 7,
head_dim: 256,
seq_len: 128,
kv_seq_len: 128,
scale: 1.0,
kv_capacity: 128,
};
assert!(matches!(
validate_params(&p),
Err(MlxError::InvalidArgument(_))
));
}
#[test]
fn test_gpu_params_layout() {
assert_eq!(std::mem::size_of::<SdpaParamsGpu>(), 28);
}
}