mlx_native/ops/
sdpa_decode.rs1use metal::MTLSize;
15
16use crate::buffer::MlxBuffer;
17use crate::device::MlxDevice;
18use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
19use crate::error::{MlxError, Result};
20use crate::kernel_registry::KernelRegistry;
21
22pub static SDPA_DECODE_SHADER_SOURCE: &str =
24 include_str!("../shaders/sdpa_decode.metal");
25
26pub fn register(registry: &mut KernelRegistry) {
28 registry.register_source("sdpa_decode", SDPA_DECODE_SHADER_SOURCE);
29}
30
31#[repr(C)]
33#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
34struct SdpaDecodeParamsGpu {
35 n_heads: u32,
36 n_kv_heads: u32,
37 head_dim: u32,
38 kv_seq_len: u32,
39 kv_capacity: u32,
40 scale: f32,
41 n_sg: u32,
42}
43
44fn select_n_sg(kv_seq_len: u32) -> u32 {
54 if kv_seq_len < 32 {
55 1
56 } else if kv_seq_len < 128 {
57 2
58 } else {
59 4
60 }
61}
62
63#[allow(clippy::too_many_arguments)]
72pub fn dispatch_sdpa_decode(
73 encoder: &mut CommandEncoder,
74 registry: &mut KernelRegistry,
75 device: &MlxDevice,
76 q: &MlxBuffer,
77 k: &MlxBuffer,
78 v: &MlxBuffer,
79 output: &MlxBuffer,
80 n_heads: u32,
81 n_kv_heads: u32,
82 head_dim: u32,
83 kv_seq_len: u32,
84 kv_capacity: u32,
85 scale: f32,
86) -> Result<()> {
87 if head_dim % 32 != 0 {
88 return Err(MlxError::InvalidArgument(format!(
89 "sdpa_decode: head_dim ({}) must be a multiple of 32", head_dim
90 )));
91 }
92 if kv_seq_len == 0 {
93 return Err(MlxError::InvalidArgument(
94 "sdpa_decode: kv_seq_len must be > 0".into(),
95 ));
96 }
97
98 let q_elems = (n_heads * head_dim) as usize;
99 let kv_elems = (n_kv_heads * kv_capacity * head_dim) as usize;
100 let o_elems = q_elems;
101
102 macro_rules! chk {
103 ($buf:expr, $exp:expr, $name:literal) => {
104 if $buf.element_count() < $exp {
105 return Err(MlxError::InvalidArgument(format!(
106 "sdpa_decode: {} too small ({} < {})", $name,
107 $buf.element_count(), $exp
108 )));
109 }
110 };
111 }
112 chk!(q, q_elems, "Q");
113 chk!(k, kv_elems, "K");
114 chk!(v, kv_elems, "V");
115 chk!(output, o_elems, "output");
116
117 let n_sg = select_n_sg(kv_seq_len);
118
119 let gpu_params = SdpaDecodeParamsGpu {
120 n_heads,
121 n_kv_heads,
122 head_dim,
123 kv_seq_len,
124 kv_capacity,
125 scale,
126 n_sg,
127 };
128
129 let shmem_bytes: u64 = 4 * n_sg as u64 * (head_dim as u64 + 2);
135
136 let pipeline = registry.get_pipeline("sdpa_decode", device.metal_device())?;
137
138 let threadgroups = MTLSize::new(n_heads as u64, 1, 1);
140 let threadgroup_sz = MTLSize::new(n_sg as u64 * 32, 1, 1);
141
142 encoder.encode_threadgroups_with_args_and_shared(
143 pipeline,
144 &[
145 (0, KernelArg::Buffer(q)),
146 (1, KernelArg::Buffer(k)),
147 (2, KernelArg::Buffer(v)),
148 (3, KernelArg::Buffer(output)),
149 (4, KernelArg::Bytes(as_bytes(&gpu_params))),
150 ],
151 &[(0, shmem_bytes)],
152 threadgroups,
153 threadgroup_sz,
154 );
155
156 Ok(())
157}