1use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::device::MlxDevice;
12use crate::encoder::{CapturedOpKind, CommandEncoder};
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15use crate::DType;
16
17pub static SDPA_SHADER_SOURCE: &str = include_str!("../shaders/sdpa.metal");
19
20pub fn register(registry: &mut KernelRegistry) {
24 registry.register_source("sdpa", SDPA_SHADER_SOURCE);
25}
26
27#[derive(Debug, Clone, Copy)]
32pub struct SdpaParams {
33 pub n_heads: u32,
35 pub n_kv_heads: u32,
37 pub head_dim: u32,
39 pub seq_len: u32,
41 pub kv_seq_len: u32,
43 pub scale: f32,
46 pub kv_capacity: u32,
53}
54
55#[repr(C)]
58#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
59struct SdpaParamsGpu {
60 n_heads: u32,
61 n_kv_heads: u32,
62 head_dim: u32,
63 seq_len: u32,
64 kv_seq_len: u32,
65 scale: f32,
66 kv_capacity: u32,
67}
68
69const TILE_Q: u32 = 32;
72
73fn validate_params(params: &SdpaParams) -> Result<()> {
75 if params.head_dim == 0 {
76 return Err(MlxError::InvalidArgument(
77 "head_dim must be > 0".into(),
78 ));
79 }
80 if params.n_heads == 0 {
81 return Err(MlxError::InvalidArgument(
82 "n_heads must be > 0".into(),
83 ));
84 }
85 if params.n_kv_heads == 0 {
86 return Err(MlxError::InvalidArgument(
87 "n_kv_heads must be > 0".into(),
88 ));
89 }
90 if params.n_heads % params.n_kv_heads != 0 {
91 return Err(MlxError::InvalidArgument(format!(
92 "n_heads ({}) must be divisible by n_kv_heads ({})",
93 params.n_heads, params.n_kv_heads
94 )));
95 }
96 if params.seq_len == 0 {
97 return Err(MlxError::InvalidArgument(
98 "seq_len must be > 0".into(),
99 ));
100 }
101 if params.kv_seq_len == 0 {
102 return Err(MlxError::InvalidArgument(
103 "kv_seq_len must be > 0".into(),
104 ));
105 }
106 Ok(())
107}
108
109fn validate_buffer(buf: &MlxBuffer, name: &str, expected_elements: usize) -> Result<()> {
111 let expected_bytes = expected_elements * buf.dtype().size_of();
112 if buf.byte_len() < expected_bytes {
113 return Err(MlxError::InvalidArgument(format!(
114 "{name} buffer too small: expected at least {expected_bytes} bytes, got {}",
115 buf.byte_len()
116 )));
117 }
118 Ok(())
119}
120
121pub fn sdpa(
143 encoder: &mut CommandEncoder,
144 registry: &mut KernelRegistry,
145 device: &MlxDevice,
146 q: &MlxBuffer,
147 k: &MlxBuffer,
148 v: &MlxBuffer,
149 output: &MlxBuffer,
150 params: &SdpaParams,
151 batch_size: u32,
152) -> Result<()> {
153 validate_params(params)?;
154
155 let kv_cap = if params.kv_capacity == 0 { params.kv_seq_len } else { params.kv_capacity };
157
158 let q_elements = batch_size as usize
160 * params.n_heads as usize
161 * params.seq_len as usize
162 * params.head_dim as usize;
163 let kv_elements = batch_size as usize
165 * params.n_kv_heads as usize
166 * kv_cap as usize
167 * params.head_dim as usize;
168
169 validate_buffer(q, "Q", q_elements)?;
170 validate_buffer(k, "K", kv_elements)?;
171 validate_buffer(v, "V", kv_elements)?;
172 validate_buffer(output, "output", q_elements)?;
173
174 let params_gpu = SdpaParamsGpu {
176 n_heads: params.n_heads,
177 n_kv_heads: params.n_kv_heads,
178 head_dim: params.head_dim,
179 seq_len: params.seq_len,
180 kv_seq_len: params.kv_seq_len,
181 scale: params.scale,
182 kv_capacity: kv_cap,
183 };
184 let params_bytes = bytemuck::bytes_of(¶ms_gpu);
185 let mut params_buf = device.alloc_buffer(
186 params_bytes.len(),
187 DType::U8,
188 vec![params_bytes.len()],
189 )?;
190 {
191 let dst: &mut [u8] = params_buf.as_mut_slice()?;
192 dst[..params_bytes.len()].copy_from_slice(params_bytes);
193 }
194
195 let kernel_name = if q.dtype() == DType::BF16 { "sdpa_bf16" } else { "sdpa" };
198 let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
199
200 let n_tiles = (params.seq_len + TILE_Q - 1) / TILE_Q;
203 let threadgroups = MTLSize::new(
204 batch_size as u64,
205 params.n_heads as u64,
206 n_tiles as u64,
207 );
208 let threadgroup_size = MTLSize::new(TILE_Q as u64, 1, 1);
209
210 encoder.set_op_kind(CapturedOpKind::Sdpa);
212
213 encoder.encode_threadgroups(
215 pipeline,
216 &[
217 (0, q),
218 (1, k),
219 (2, v),
220 (3, output),
221 (4, ¶ms_buf),
222 ],
223 threadgroups,
224 threadgroup_size,
225 );
226
227 Ok(())
228}
229
230#[cfg(test)]
231#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_validate_params_ok() {
237 let p = SdpaParams {
238 n_heads: 16,
239 n_kv_heads: 8,
240 head_dim: 256,
241 seq_len: 128,
242 kv_seq_len: 128,
243 scale: 1.0 / (256.0_f32).sqrt(),
244 kv_capacity: 128,
245 };
246 assert!(validate_params(&p).is_ok());
247 }
248
249 #[test]
250 fn test_validate_params_zero_head_dim() {
251 let p = SdpaParams {
252 n_heads: 16,
253 n_kv_heads: 8,
254 head_dim: 0,
255 seq_len: 128,
256 kv_seq_len: 128,
257 scale: 1.0,
258 kv_capacity: 128,
259 };
260 assert!(matches!(
261 validate_params(&p),
262 Err(MlxError::InvalidArgument(_))
263 ));
264 }
265
266 #[test]
267 fn test_validate_params_bad_ratio() {
268 let p = SdpaParams {
269 n_heads: 16,
270 n_kv_heads: 7,
271 head_dim: 256,
272 seq_len: 128,
273 kv_seq_len: 128,
274 scale: 1.0,
275 kv_capacity: 128,
276 };
277 assert!(matches!(
278 validate_params(&p),
279 Err(MlxError::InvalidArgument(_))
280 ));
281 }
282
283 #[test]
284 fn test_gpu_params_layout() {
285 assert_eq!(std::mem::size_of::<SdpaParamsGpu>(), 28);
287 }
288}