Skip to main content

mlx_native/ops/
sdpa.rs

1//! Scaled dot-product attention (SDPA) host dispatch.
2//!
3//! Computes `softmax(Q * K^T / sqrt(head_dim)) * V` on the GPU using a fused
4//! Metal compute kernel with causal masking.
5//!
6//! Supports grouped-query attention (GQA) where `n_heads > n_kv_heads`.
7
8use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::device::MlxDevice;
12use crate::encoder::{CapturedOpKind, CommandEncoder};
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15use crate::DType;
16
17/// MSL source for the SDPA kernel (embedded at compile time).
18pub static SDPA_SHADER_SOURCE: &str = include_str!("../shaders/sdpa.metal");
19
20/// Register SDPA shader source with the given kernel registry.
21///
22/// This must be called before dispatching any SDPA operations.
23pub fn register(registry: &mut KernelRegistry) {
24    registry.register_source("sdpa", SDPA_SHADER_SOURCE);
25}
26
27/// Parameters for the SDPA kernel.
28///
29/// These describe the tensor shapes and head configuration for the attention
30/// computation.
31#[derive(Debug, Clone, Copy)]
32pub struct SdpaParams {
33    /// Number of query attention heads (e.g. 16 for Gemma 4).
34    pub n_heads: u32,
35    /// Number of key/value attention heads (may be less than `n_heads` for GQA).
36    pub n_kv_heads: u32,
37    /// Dimension of each attention head.
38    pub head_dim: u32,
39    /// Query sequence length.
40    pub seq_len: u32,
41    /// Key/value sequence length (may differ from `seq_len` in decode mode).
42    pub kv_seq_len: u32,
43    /// Attention score scaling factor. Typically `1.0 / sqrt(head_dim)`, but
44    /// models like Gemma 4 (which use QK norms) require `scale = 1.0`.
45    pub scale: f32,
46    /// KV cache capacity — the stride (in positions) between KV heads in the
47    /// cache buffer.  When the KV cache is pre-allocated to a fixed capacity
48    /// larger than `kv_seq_len`, set this to the capacity so the kernel reads
49    /// the correct memory offsets.  When KV buffers are tightly packed (no
50    /// extra capacity), set equal to `kv_seq_len`.  Default: 0 means "use
51    /// kv_seq_len as capacity" for backwards compatibility.
52    pub kv_capacity: u32,
53}
54
55/// GPU-side parameter struct layout.  Must match the MSL `SdpaParams` struct
56/// exactly (6 × u32 + 1 × f32 = 28 bytes, no padding).
57#[repr(C)]
58#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
59struct SdpaParamsGpu {
60    n_heads: u32,
61    n_kv_heads: u32,
62    head_dim: u32,
63    seq_len: u32,
64    kv_seq_len: u32,
65    scale: f32,
66    kv_capacity: u32,
67}
68
69/// Tile size for query positions per threadgroup.  Must match `TILE_Q` in the
70/// MSL shader.
71const TILE_Q: u32 = 32;
72
73/// Validate SDPA parameters and return a descriptive error if invalid.
74fn validate_params(params: &SdpaParams) -> Result<()> {
75    if params.head_dim == 0 {
76        return Err(MlxError::InvalidArgument(
77            "head_dim must be > 0".into(),
78        ));
79    }
80    if params.n_heads == 0 {
81        return Err(MlxError::InvalidArgument(
82            "n_heads must be > 0".into(),
83        ));
84    }
85    if params.n_kv_heads == 0 {
86        return Err(MlxError::InvalidArgument(
87            "n_kv_heads must be > 0".into(),
88        ));
89    }
90    if params.n_heads % params.n_kv_heads != 0 {
91        return Err(MlxError::InvalidArgument(format!(
92            "n_heads ({}) must be divisible by n_kv_heads ({})",
93            params.n_heads, params.n_kv_heads
94        )));
95    }
96    if params.seq_len == 0 {
97        return Err(MlxError::InvalidArgument(
98            "seq_len must be > 0".into(),
99        ));
100    }
101    if params.kv_seq_len == 0 {
102        return Err(MlxError::InvalidArgument(
103            "kv_seq_len must be > 0".into(),
104        ));
105    }
106    Ok(())
107}
108
109/// Validate that a buffer has the expected byte length for the given shape.
110fn validate_buffer(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
111    let expected_bytes = expected_elements * buf.dtype().size_of();
112    if buf.byte_len() < expected_bytes {
113        return Err(MlxError::InvalidArgument(format!(
114            "{name} buffer too small: expected at least {expected_bytes} bytes, got {}",
115            buf.byte_len()
116        )));
117    }
118    Ok(())
119}
120
121/// Dispatch scaled dot-product attention on the GPU.
122///
123/// Encodes a compute command into the provided `CommandEncoder` without
124/// committing.  The caller controls when to call `encoder.commit_and_wait()`.
125///
126/// # Arguments
127///
128/// * `encoder`  — Command encoder to record the dispatch into.
129/// * `registry` — Kernel registry for pipeline lookup/compilation.
130/// * `device`   — Metal device (needed for pipeline compilation and buffer allocation).
131/// * `q`        — Query buffer, shape `[batch, n_heads, seq_len, head_dim]`, dtype F32.
132/// * `k`        — Key buffer, shape `[batch, n_kv_heads, kv_seq_len, head_dim]`, dtype F32.
133/// * `v`        — Value buffer, shape `[batch, n_kv_heads, kv_seq_len, head_dim]`, dtype F32.
134/// * `output`   — Output buffer, same shape as Q, pre-allocated by caller.
135/// * `params`   — Attention parameters (head counts, dimensions, sequence lengths).
136/// * `batch_size` — Number of independent sequences in the batch.
137///
138/// # Errors
139///
140/// Returns `MlxError::InvalidArgument` for invalid parameters or mismatched
141/// buffer sizes.
142pub fn sdpa(
143    encoder: &mut CommandEncoder,
144    registry: &mut KernelRegistry,
145    device: &MlxDevice,
146    q: &MlxBuffer,
147    k: &MlxBuffer,
148    v: &MlxBuffer,
149    output: &MlxBuffer,
150    params: &SdpaParams,
151    batch_size: u32,
152) -> Result<()> {
153    validate_params(params)?;
154
155    // Resolve kv_capacity: 0 means "same as kv_seq_len" for backwards compat.
156    let kv_cap = if params.kv_capacity == 0 { params.kv_seq_len } else { params.kv_capacity };
157
158    // Validate buffer sizes.
159    let q_elements = batch_size as usize
160        * params.n_heads as usize
161        * params.seq_len as usize
162        * params.head_dim as usize;
163    // KV buffers are strided by kv_capacity, not kv_seq_len.
164    let kv_elements = batch_size as usize
165        * params.n_kv_heads as usize
166        * kv_cap as usize
167        * params.head_dim as usize;
168
169    validate_buffer(q, "Q", q_elements)?;
170    validate_buffer(k, "K", kv_elements)?;
171    validate_buffer(v, "V", kv_elements)?;
172    validate_buffer(output, "output", q_elements)?;
173
174    // Allocate a small buffer for the GPU-side params struct.
175    let params_gpu = SdpaParamsGpu {
176        n_heads: params.n_heads,
177        n_kv_heads: params.n_kv_heads,
178        head_dim: params.head_dim,
179        seq_len: params.seq_len,
180        kv_seq_len: params.kv_seq_len,
181        scale: params.scale,
182        kv_capacity: kv_cap,
183    };
184    let params_bytes = bytemuck::bytes_of(&params_gpu);
185    let mut params_buf = device.alloc_buffer(
186        params_bytes.len(),
187        DType::U8,
188        vec![params_bytes.len()],
189    )?;
190    {
191        let dst: &mut [u8] = params_buf.as_mut_slice()?;
192        dst[..params_bytes.len()].copy_from_slice(params_bytes);
193    }
194
195    // Get the compiled pipeline.
196    // Select kernel based on buffer dtype.
197    let kernel_name = if q.dtype() == DType::BF16 { "sdpa_bf16" } else { "sdpa" };
198    let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
199
200    // Calculate dispatch grid.
201    // Threadgroups: (batch, n_heads, ceil(seq_len / TILE_Q))
202    let n_tiles = (params.seq_len + TILE_Q - 1) / TILE_Q;
203    let threadgroups = MTLSize::new(
204        batch_size as u64,
205        params.n_heads as u64,
206        n_tiles as u64,
207    );
208    let threadgroup_size = MTLSize::new(TILE_Q as u64, 1, 1);
209
210    // Tag for the reorder pass (Phase 4e.3): SDPA is NOT reorderable.
211    encoder.set_op_kind(CapturedOpKind::Sdpa);
212
213    // Encode the dispatch.
214    encoder.encode_threadgroups(
215        pipeline,
216        &[
217            (0, q),
218            (1, k),
219            (2, v),
220            (3, output),
221            (4, &params_buf),
222        ],
223        threadgroups,
224        threadgroup_size,
225    );
226
227    Ok(())
228}
229
230#[cfg(test)]
231#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_validate_params_ok() {
237        let p = SdpaParams {
238            n_heads: 16,
239            n_kv_heads: 8,
240            head_dim: 256,
241            seq_len: 128,
242            kv_seq_len: 128,
243            scale: 1.0 / (256.0_f32).sqrt(),
244            kv_capacity: 128,
245        };
246        assert!(validate_params(&p).is_ok());
247    }
248
249    #[test]
250    fn test_validate_params_zero_head_dim() {
251        let p = SdpaParams {
252            n_heads: 16,
253            n_kv_heads: 8,
254            head_dim: 0,
255            seq_len: 128,
256            kv_seq_len: 128,
257            scale: 1.0,
258            kv_capacity: 128,
259        };
260        assert!(matches!(
261            validate_params(&p),
262            Err(MlxError::InvalidArgument(_))
263        ));
264    }
265
266    #[test]
267    fn test_validate_params_bad_ratio() {
268        let p = SdpaParams {
269            n_heads: 16,
270            n_kv_heads: 7,
271            head_dim: 256,
272            seq_len: 128,
273            kv_seq_len: 128,
274            scale: 1.0,
275            kv_capacity: 128,
276        };
277        assert!(matches!(
278            validate_params(&p),
279            Err(MlxError::InvalidArgument(_))
280        ));
281    }
282
283    #[test]
284    fn test_gpu_params_layout() {
285        // Ensure SdpaParamsGpu is exactly 28 bytes (6 x u32 + 1 x f32 + kv_capacity u32 = 28).
286        assert_eq!(std::mem::size_of::<SdpaParamsGpu>(), 28);
287    }
288}