Skip to main content

mlx_native/ops/
sdpa_decode.rs

1//! GPU SDPA decode kernel — F32 Q/K/V, multi-simdgroup tiled, single-token decode.
2//!
3//! The kernel divides the KV sequence across N_SG simdgroups that each scan an
4//! independent KV chunk and produce a local (max, sum, unnorm_acc) triple.
5//! Simdgroup 0 then merges all N_SG partial results using the log-sum-exp
6//! combination rule and writes the final F32 output.
7//!
8//! # Constraints
9//! - seq_len must be 1 (decode path only)
10//! - head_dim must be a multiple of 32 (128, 256, 512 are supported)
11//! - Q/K/V must be F32
12//! - n_sg must be 1, 2, or 4
13
14use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::device::MlxDevice;
18use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
19use crate::error::{MlxError, Result};
20use crate::kernel_registry::KernelRegistry;
21
22/// Metal shader source.
23pub static SDPA_DECODE_SHADER_SOURCE: &str =
24    include_str!("../shaders/sdpa_decode.metal");
25
26/// Register `sdpa_decode` pipeline.
27pub fn register(registry: &mut KernelRegistry) {
28    registry.register_source("sdpa_decode", SDPA_DECODE_SHADER_SOURCE);
29}
30
31/// GPU-side params struct (must match `SdpaDecodeParams` in MSL exactly).
32#[repr(C)]
33#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
34struct SdpaDecodeParamsGpu {
35    n_heads:     u32,
36    n_kv_heads:  u32,
37    head_dim:    u32,
38    kv_seq_len:  u32,
39    kv_capacity: u32,
40    scale:       f32,
41    n_sg:        u32,
42}
43
44/// Select the number of simdgroups based on kv_seq_len.
45///
46/// The threadgroup overhead of writing/reading shared memory and executing
47/// the barrier dominates when kv_seq_len is small, so we start with n_sg=1
48/// and ramp up as the KV cache grows:
49///
50/// - kv_seq_len < 32   → 1  (overhead dominates; single sg is faster)
51/// - kv_seq_len < 128  → 2  (2× speedup, barrier cost amortized)
52/// - otherwise         → 4  (4× speedup at long context)
53fn select_n_sg(kv_seq_len: u32) -> u32 {
54    if kv_seq_len < 32 {
55        1
56    } else if kv_seq_len < 128 {
57        2
58    } else {
59        4
60    }
61}
62
63/// Dispatch the tiled decode SDPA kernel.
64///
65/// # Constraints
66/// - `head_dim` must be a multiple of 32 (128, 256, 512 are supported).
67/// - Q layout: `[n_heads, head_dim]` F32 (seq_len=1, seq dim elided).
68/// - K/V layout: `[n_kv_heads, kv_capacity, head_dim]` F32.
69/// - Output layout: `[n_heads, head_dim]` F32.
70/// - kv_seq_len: number of valid positions in the KV cache (must be > 0).
71#[allow(clippy::too_many_arguments)]
72pub fn dispatch_sdpa_decode(
73    encoder:     &mut CommandEncoder,
74    registry:    &mut KernelRegistry,
75    device:      &MlxDevice,
76    q:           &MlxBuffer,
77    k:           &MlxBuffer,
78    v:           &MlxBuffer,
79    output:      &MlxBuffer,
80    n_heads:     u32,
81    n_kv_heads:  u32,
82    head_dim:    u32,
83    kv_seq_len:  u32,
84    kv_capacity: u32,
85    scale:       f32,
86) -> Result<()> {
87    if head_dim % 32 != 0 {
88        return Err(MlxError::InvalidArgument(format!(
89            "sdpa_decode: head_dim ({}) must be a multiple of 32", head_dim
90        )));
91    }
92    if kv_seq_len == 0 {
93        return Err(MlxError::InvalidArgument(
94            "sdpa_decode: kv_seq_len must be > 0".into(),
95        ));
96    }
97
98    let q_elems  = (n_heads    * head_dim) as usize;
99    let kv_elems = (n_kv_heads * kv_capacity * head_dim) as usize;
100    let o_elems  = q_elems;
101
102    macro_rules! chk {
103        ($buf:expr, $exp:expr, $name:literal) => {
104            if $buf.element_count() < $exp {
105                return Err(MlxError::InvalidArgument(format!(
106                    "sdpa_decode: {} too small ({} < {})", $name,
107                    $buf.element_count(), $exp
108                )));
109            }
110        };
111    }
112    chk!(q,      q_elems,  "Q");
113    chk!(k,      kv_elems, "K");
114    chk!(v,      kv_elems, "V");
115    chk!(output, o_elems,  "output");
116
117    let n_sg = select_n_sg(kv_seq_len);
118
119    let gpu_params = SdpaDecodeParamsGpu {
120        n_heads,
121        n_kv_heads,
122        head_dim,
123        kv_seq_len,
124        kv_capacity,
125        scale,
126        n_sg,
127    };
128
129    // Shared memory layout:
130    //   sg_max : [n_sg]          floats
131    //   sg_sum : [n_sg]          floats
132    //   sg_acc : [n_sg*head_dim] floats
133    // Total: 4 * n_sg * (head_dim + 2) bytes
134    let shmem_bytes: u64 = 4 * n_sg as u64 * (head_dim as u64 + 2);
135
136    let pipeline = registry.get_pipeline("sdpa_decode", device.metal_device())?;
137
138    // Grid: one TG per query head; TG has n_sg * 32 threads.
139    let threadgroups   = MTLSize::new(n_heads as u64, 1, 1);
140    let threadgroup_sz = MTLSize::new(n_sg as u64 * 32, 1, 1);
141
142    encoder.encode_threadgroups_with_args_and_shared(
143        pipeline,
144        &[
145            (0, KernelArg::Buffer(q)),
146            (1, KernelArg::Buffer(k)),
147            (2, KernelArg::Buffer(v)),
148            (3, KernelArg::Buffer(output)),
149            (4, KernelArg::Bytes(as_bytes(&gpu_params))),
150        ],
151        &[(0, shmem_bytes)],
152        threadgroups,
153        threadgroup_sz,
154    );
155
156    Ok(())
157}