1use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15pub static ROPE_SHADER_SOURCE: &str = include_str!("../shaders/rope.metal");
17
18pub fn register(registry: &mut KernelRegistry) {
20 registry.register_source("rope_f32", ROPE_SHADER_SOURCE);
21 registry.register_source("rope_f16", ROPE_SHADER_SOURCE);
22 registry.register_source("rope_bf16", ROPE_SHADER_SOURCE);
23 registry.register_source("rope_neox_bf16", ROPE_SHADER_SOURCE);
24 registry.register_source("rope_neox_f32", ROPE_SHADER_SOURCE);
25}
26
27pub fn dispatch_rope(
48 encoder: &mut CommandEncoder,
49 registry: &mut KernelRegistry,
50 device: &metal::DeviceRef,
51 input: &MlxBuffer,
52 output: &MlxBuffer,
53 params_buf: &MlxBuffer,
54 positions_buf: &MlxBuffer,
55 seq_len: u32,
56 head_dim: u32,
57) -> Result<()> {
58 if head_dim % 2 != 0 {
59 return Err(MlxError::InvalidArgument(format!(
60 "RoPE head_dim must be even, got {}",
61 head_dim
62 )));
63 }
64 if head_dim == 0 || seq_len == 0 {
65 return Err(MlxError::InvalidArgument(
66 "RoPE head_dim and seq_len must be > 0".into(),
67 ));
68 }
69
70 let expected_elements = (seq_len as usize) * (head_dim as usize);
71 if input.element_count() != expected_elements {
72 return Err(MlxError::InvalidArgument(format!(
73 "RoPE input element count {} != seq_len({}) * head_dim({})",
74 input.element_count(),
75 seq_len,
76 head_dim
77 )));
78 }
79 if output.element_count() != expected_elements {
80 return Err(MlxError::InvalidArgument(format!(
81 "RoPE output element count {} != seq_len({}) * head_dim({})",
82 output.element_count(),
83 seq_len,
84 head_dim
85 )));
86 }
87
88 let kernel_name = match input.dtype() {
89 DType::F32 => "rope_f32",
90 DType::F16 => "rope_f16",
91 DType::BF16 => "rope_bf16",
92 _ => {
93 return Err(MlxError::InvalidArgument(format!(
94 "RoPE unsupported dtype: {}",
95 input.dtype()
96 )));
97 }
98 };
99
100 let pipeline = registry.get_pipeline(kernel_name, device)?;
101 let half_dim = head_dim / 2;
102
103 let tg_x = std::cmp::min(64, half_dim as u64);
106 let tg_y = std::cmp::min(4, seq_len as u64);
107
108 encoder.encode(
109 pipeline,
110 &[
111 (0, input),
112 (1, output),
113 (2, params_buf),
114 (3, positions_buf),
115 ],
116 MTLSize::new(half_dim as u64, seq_len as u64, 1),
117 MTLSize::new(tg_x, tg_y, 1),
118 );
119
120 Ok(())
121}
122
123#[repr(C)]
127#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
128struct GpuRopeNeoxParams {
129 n_heads: u32,
130 _pad: u32,
131}
132
133#[allow(clippy::too_many_arguments)]
157pub fn dispatch_rope_neox_bf16(
158 encoder: &mut CommandEncoder,
159 registry: &mut KernelRegistry,
160 device: &metal::DeviceRef,
161 input: &MlxBuffer,
162 output: &MlxBuffer,
163 params_buf: &MlxBuffer,
164 positions_buf: &MlxBuffer,
165 seq_len: u32,
166 n_heads: u32,
167 head_dim: u32,
168 rope_dim: u32,
169) -> Result<()> {
170 use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
171
172 if rope_dim % 2 != 0 {
173 return Err(MlxError::InvalidArgument(format!(
174 "RoPE neox rope_dim must be even, got {}",
175 rope_dim
176 )));
177 }
178 if rope_dim > head_dim {
179 return Err(MlxError::InvalidArgument(format!(
180 "RoPE neox rope_dim ({}) must be <= head_dim ({})",
181 rope_dim, head_dim
182 )));
183 }
184 if head_dim == 0 || seq_len == 0 || n_heads == 0 {
185 return Err(MlxError::InvalidArgument(
186 "RoPE neox head_dim, seq_len, and n_heads must be > 0".into(),
187 ));
188 }
189
190 let n_rows = (seq_len as usize) * (n_heads as usize);
191 let expected_elements = n_rows * (head_dim as usize);
192 if input.element_count() != expected_elements {
193 return Err(MlxError::InvalidArgument(format!(
194 "RoPE neox input element count {} != seq_len({}) * n_heads({}) * head_dim({})",
195 input.element_count(),
196 seq_len,
197 n_heads,
198 head_dim
199 )));
200 }
201 if output.element_count() != expected_elements {
202 return Err(MlxError::InvalidArgument(format!(
203 "RoPE neox output element count {} != seq_len({}) * n_heads({}) * head_dim({})",
204 output.element_count(),
205 seq_len,
206 n_heads,
207 head_dim
208 )));
209 }
210
211 let pipeline = registry.get_pipeline("rope_neox_bf16", device)?;
212 let half_rope = rope_dim / 2;
213
214 let gpu_rope_params = GpuRopeNeoxParams {
215 n_heads,
216 _pad: 0,
217 };
218
219 let tg_x = std::cmp::min(64, half_rope as u64);
221 let tg_y = std::cmp::min(4, n_rows as u64);
222
223 encode_with_args(
224 encoder,
225 pipeline,
226 &[
227 (0, KernelArg::Buffer(input)),
228 (1, KernelArg::Buffer(output)),
229 (2, KernelArg::Buffer(params_buf)),
230 (3, KernelArg::Buffer(positions_buf)),
231 (4, KernelArg::Bytes(as_bytes(&gpu_rope_params))),
232 ],
233 MTLSize::new(half_rope as u64, n_rows as u64, 1),
234 MTLSize::new(tg_x, tg_y, 1),
235 );
236
237 Ok(())
238}
239
240#[repr(C)]
244#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
245struct GpuRopeNeoxF32Params {
246 n_heads: u32,
247 has_freq_factors: u32,
248}
249
250#[allow(clippy::too_many_arguments)]
281pub fn dispatch_rope_neox_f32(
282 encoder: &mut CommandEncoder,
283 registry: &mut KernelRegistry,
284 device: &metal::DeviceRef,
285 input: &MlxBuffer,
286 output: &MlxBuffer,
287 params_buf: &MlxBuffer,
288 positions_buf: &MlxBuffer,
289 freq_factors: Option<&MlxBuffer>,
290 seq_len: u32,
291 n_heads: u32,
292 head_dim: u32,
293 rope_dim: u32,
294) -> Result<()> {
295 use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
296
297 if rope_dim % 2 != 0 {
298 return Err(MlxError::InvalidArgument(format!(
299 "RoPE neox f32 rope_dim must be even, got {}",
300 rope_dim
301 )));
302 }
303 if rope_dim > head_dim {
304 return Err(MlxError::InvalidArgument(format!(
305 "RoPE neox f32 rope_dim ({}) must be <= head_dim ({})",
306 rope_dim, head_dim
307 )));
308 }
309 if head_dim == 0 || seq_len == 0 || n_heads == 0 {
310 return Err(MlxError::InvalidArgument(
311 "RoPE neox f32 head_dim, seq_len, and n_heads must be > 0".into(),
312 ));
313 }
314
315 let n_rows = (seq_len as usize) * (n_heads as usize);
316 let expected_elements = n_rows * (head_dim as usize);
317 if input.element_count() != expected_elements {
318 return Err(MlxError::InvalidArgument(format!(
319 "RoPE neox f32 input element count {} != seq_len({}) * n_heads({}) * head_dim({})",
320 input.element_count(),
321 seq_len,
322 n_heads,
323 head_dim
324 )));
325 }
326 if output.element_count() != expected_elements {
327 return Err(MlxError::InvalidArgument(format!(
328 "RoPE neox f32 output element count {} != seq_len({}) * n_heads({}) * head_dim({})",
329 output.element_count(),
330 seq_len,
331 n_heads,
332 head_dim
333 )));
334 }
335
336 let pipeline = registry.get_pipeline("rope_neox_f32", device)?;
337 let half_rope = rope_dim / 2;
338
339 let has_ff = freq_factors.is_some();
340 let gpu_rope_params = GpuRopeNeoxF32Params {
341 n_heads,
342 has_freq_factors: u32::from(has_ff),
343 };
344
345 let ff_buf = freq_factors.unwrap_or(input);
349
350 let tg_x = std::cmp::min(64, half_rope as u64);
352 let tg_y = std::cmp::min(4, n_rows as u64);
353
354 encode_with_args(
355 encoder,
356 pipeline,
357 &[
358 (0, KernelArg::Buffer(input)),
359 (1, KernelArg::Buffer(output)),
360 (2, KernelArg::Buffer(params_buf)),
361 (3, KernelArg::Buffer(positions_buf)),
362 (4, KernelArg::Bytes(as_bytes(&gpu_rope_params))),
363 (5, KernelArg::Buffer(ff_buf)),
364 ],
365 MTLSize::new(half_rope as u64, n_rows as u64, 1),
366 MTLSize::new(tg_x, tg_y, 1),
367 );
368
369 Ok(())
370}