Skip to main content

mlx_native/ops/
rope.rs

1//! Rotary Position Embedding (RoPE) GPU dispatch.
2//!
3//! Applies rotation in pairs of elements using cos/sin of
4//! `position * theta^(-2i/d)`.  Gemma 4 uses theta=10000 for sliding
5//! attention layers and theta=1000000 for global attention layers.
6
7use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15/// MSL source for the RoPE kernels (embedded at compile time).
16pub static ROPE_SHADER_SOURCE: &str = include_str!("../shaders/rope.metal");
17
18/// Register RoPE shader sources with the given kernel registry.
19pub fn register(registry: &mut KernelRegistry) {
20    registry.register_source("rope_f32", ROPE_SHADER_SOURCE);
21    registry.register_source("rope_f16", ROPE_SHADER_SOURCE);
22    registry.register_source("rope_bf16", ROPE_SHADER_SOURCE);
23    registry.register_source("rope_neox_bf16", ROPE_SHADER_SOURCE);
24    registry.register_source("rope_neox_f32", ROPE_SHADER_SOURCE);
25}
26
27/// Dispatch a RoPE operation on the GPU.
28///
29/// # Arguments
30///
31/// * `encoder`      - Command encoder to record the dispatch into.
32/// * `registry`     - Kernel registry (must have RoPE sources registered).
33/// * `device`       - Metal device for pipeline compilation.
34/// * `input`        - Input buffer of shape `[seq_len, head_dim]` (f32 or f16).
35/// * `output`       - Output buffer (same dtype and shape as input).
36/// * `params_buf`   - Params buffer containing `[theta, head_dim, 0, 0]` as f32.
37/// * `positions_buf` - Positions buffer containing `[pos_0, pos_1, ...]` as u32.
38/// * `seq_len`      - Number of sequence positions.
39/// * `head_dim`     - Dimension of each head (must be even).
40///
41/// # Errors
42///
43/// Returns `MlxError::InvalidArgument` if:
44/// - Input dtype is not f32 or f16.
45/// - head_dim is not even.
46/// - Input and output element counts do not match.
47pub fn dispatch_rope(
48    encoder: &mut CommandEncoder,
49    registry: &mut KernelRegistry,
50    device: &metal::DeviceRef,
51    input: &MlxBuffer,
52    output: &MlxBuffer,
53    params_buf: &MlxBuffer,
54    positions_buf: &MlxBuffer,
55    seq_len: u32,
56    head_dim: u32,
57) -> Result<()> {
58    if head_dim % 2 != 0 {
59        return Err(MlxError::InvalidArgument(format!(
60            "RoPE head_dim must be even, got {}",
61            head_dim
62        )));
63    }
64    if head_dim == 0 || seq_len == 0 {
65        return Err(MlxError::InvalidArgument(
66            "RoPE head_dim and seq_len must be > 0".into(),
67        ));
68    }
69
70    let expected_elements = (seq_len as usize) * (head_dim as usize);
71    if input.element_count() != expected_elements {
72        return Err(MlxError::InvalidArgument(format!(
73            "RoPE input element count {} != seq_len({}) * head_dim({})",
74            input.element_count(),
75            seq_len,
76            head_dim
77        )));
78    }
79    if output.element_count() != expected_elements {
80        return Err(MlxError::InvalidArgument(format!(
81            "RoPE output element count {} != seq_len({}) * head_dim({})",
82            output.element_count(),
83            seq_len,
84            head_dim
85        )));
86    }
87
88    let kernel_name = match input.dtype() {
89        DType::F32 => "rope_f32",
90        DType::F16 => "rope_f16",
91        DType::BF16 => "rope_bf16",
92        _ => {
93            return Err(MlxError::InvalidArgument(format!(
94                "RoPE unsupported dtype: {}",
95                input.dtype()
96            )));
97        }
98    };
99
100    let pipeline = registry.get_pipeline(kernel_name, device)?;
101    let half_dim = head_dim / 2;
102
103    // Grid: (half_dim, seq_len) — one thread per pair per position
104    // Threadgroup: use a reasonable size for the pair dimension
105    let tg_x = std::cmp::min(64, half_dim as u64);
106    let tg_y = std::cmp::min(4, seq_len as u64);
107
108    encoder.encode(
109        pipeline,
110        &[
111            (0, input),
112            (1, output),
113            (2, params_buf),
114            (3, positions_buf),
115        ],
116        MTLSize::new(half_dim as u64, seq_len as u64, 1),
117        MTLSize::new(tg_x, tg_y, 1),
118    );
119
120    Ok(())
121}
122
123/// GPU params for the neox RoPE kernel's auxiliary params buffer.
124///
125/// Must match the uint array in `rope_neox_bf16` buffer(4).
126#[repr(C)]
127#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
128struct GpuRopeNeoxParams {
129    n_heads: u32,
130    _pad: u32,
131}
132
133/// Dispatch a Neox/split-convention RoPE operation on the GPU (bf16 only).
134///
135/// The Neox convention pairs `(d[i], d[i + half_rope_dim])` instead of
136/// `(d[2i], d[2i+1])`.  Supports partial rotary where only the first
137/// `rope_dim` dimensions are rotated.
138///
139/// # Arguments
140///
141/// * `encoder`       - Command encoder to record the dispatch into.
142/// * `registry`      - Kernel registry (must have rope_neox_bf16 registered).
143/// * `device`        - Metal device for pipeline compilation.
144/// * `input`         - Input buffer of shape `[seq_len * n_heads, head_dim]` (bf16).
145/// * `output`        - Output buffer (same shape and dtype as input).
146/// * `params_buf`    - Params buffer containing `[theta, head_dim, rope_dim, 0]` as f32.
147/// * `positions_buf` - Positions buffer containing `[pos_0, pos_1, ...]` as u32 (length = seq_len).
148/// * `seq_len`       - Number of sequence positions.
149/// * `n_heads`       - Number of attention heads.
150/// * `head_dim`      - Dimension of each head.
151/// * `rope_dim`      - Number of dimensions to rotate (must be even, <= head_dim).
152///
153/// # Errors
154///
155/// Returns `MlxError::InvalidArgument` if parameters are invalid.
156#[allow(clippy::too_many_arguments)]
157pub fn dispatch_rope_neox_bf16(
158    encoder: &mut CommandEncoder,
159    registry: &mut KernelRegistry,
160    device: &metal::DeviceRef,
161    input: &MlxBuffer,
162    output: &MlxBuffer,
163    params_buf: &MlxBuffer,
164    positions_buf: &MlxBuffer,
165    seq_len: u32,
166    n_heads: u32,
167    head_dim: u32,
168    rope_dim: u32,
169) -> Result<()> {
170    use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
171
172    if rope_dim % 2 != 0 {
173        return Err(MlxError::InvalidArgument(format!(
174            "RoPE neox rope_dim must be even, got {}",
175            rope_dim
176        )));
177    }
178    if rope_dim > head_dim {
179        return Err(MlxError::InvalidArgument(format!(
180            "RoPE neox rope_dim ({}) must be <= head_dim ({})",
181            rope_dim, head_dim
182        )));
183    }
184    if head_dim == 0 || seq_len == 0 || n_heads == 0 {
185        return Err(MlxError::InvalidArgument(
186            "RoPE neox head_dim, seq_len, and n_heads must be > 0".into(),
187        ));
188    }
189
190    let n_rows = (seq_len as usize) * (n_heads as usize);
191    let expected_elements = n_rows * (head_dim as usize);
192    if input.element_count() != expected_elements {
193        return Err(MlxError::InvalidArgument(format!(
194            "RoPE neox input element count {} != seq_len({}) * n_heads({}) * head_dim({})",
195            input.element_count(),
196            seq_len,
197            n_heads,
198            head_dim
199        )));
200    }
201    if output.element_count() != expected_elements {
202        return Err(MlxError::InvalidArgument(format!(
203            "RoPE neox output element count {} != seq_len({}) * n_heads({}) * head_dim({})",
204            output.element_count(),
205            seq_len,
206            n_heads,
207            head_dim
208        )));
209    }
210
211    let pipeline = registry.get_pipeline("rope_neox_bf16", device)?;
212    let half_rope = rope_dim / 2;
213
214    let gpu_rope_params = GpuRopeNeoxParams {
215        n_heads,
216        _pad: 0,
217    };
218
219    // Grid: (half_rope, n_rows) — one thread per pair per row
220    let tg_x = std::cmp::min(64, half_rope as u64);
221    let tg_y = std::cmp::min(4, n_rows as u64);
222
223    encode_with_args(
224        encoder,
225        pipeline,
226        &[
227            (0, KernelArg::Buffer(input)),
228            (1, KernelArg::Buffer(output)),
229            (2, KernelArg::Buffer(params_buf)),
230            (3, KernelArg::Buffer(positions_buf)),
231            (4, KernelArg::Bytes(as_bytes(&gpu_rope_params))),
232        ],
233        MTLSize::new(half_rope as u64, n_rows as u64, 1),
234        MTLSize::new(tg_x, tg_y, 1),
235    );
236
237    Ok(())
238}
239
240/// GPU params for the neox f32 RoPE kernel with freq_factors support.
241///
242/// Must match the uint array in `rope_neox_f32` buffer(4).
243#[repr(C)]
244#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
245struct GpuRopeNeoxF32Params {
246    n_heads: u32,
247    has_freq_factors: u32,
248}
249
250/// Dispatch a Neox/split-convention RoPE operation on the GPU (f32) with
251/// optional freq_factors support.
252///
253/// The Neox convention pairs `(d[i], d[i + half_rope_dim])` instead of
254/// `(d[2i], d[2i+1])`.  Supports partial rotary where only the first
255/// `rope_dim` dimensions are rotated.
256///
257/// When `freq_factors` is `Some`, each pair's base frequency is divided by
258/// `freq_factors[pair_idx]`.  Gemma 4's global attention layers use this to
259/// mask out rotation for certain dimensions (freq_factor=1e30 -> identity).
260///
261/// # Arguments
262///
263/// * `encoder`       - Command encoder to record the dispatch into.
264/// * `registry`      - Kernel registry (must have rope_neox_f32 registered).
265/// * `device`        - Metal device for pipeline compilation.
266/// * `input`         - Input buffer of shape `[seq_len * n_heads, head_dim]` (f32).
267/// * `output`        - Output buffer (same shape and dtype as input).
268/// * `params_buf`    - Params buffer containing `[theta, head_dim, rope_dim, 0]` as f32.
269/// * `positions_buf` - Positions buffer containing `[pos_0, pos_1, ...]` as u32 (length = seq_len).
270/// * `freq_factors`  - Optional freq_factors buffer of shape `[rope_dim/2]` (f32).
271///                     Pass `None` for standard RoPE (equivalent to all-ones).
272/// * `seq_len`       - Number of sequence positions.
273/// * `n_heads`       - Number of attention heads.
274/// * `head_dim`      - Dimension of each head.
275/// * `rope_dim`      - Number of dimensions to rotate (must be even, <= head_dim).
276///
277/// # Errors
278///
279/// Returns `MlxError::InvalidArgument` if parameters are invalid.
280#[allow(clippy::too_many_arguments)]
281pub fn dispatch_rope_neox_f32(
282    encoder: &mut CommandEncoder,
283    registry: &mut KernelRegistry,
284    device: &metal::DeviceRef,
285    input: &MlxBuffer,
286    output: &MlxBuffer,
287    params_buf: &MlxBuffer,
288    positions_buf: &MlxBuffer,
289    freq_factors: Option<&MlxBuffer>,
290    seq_len: u32,
291    n_heads: u32,
292    head_dim: u32,
293    rope_dim: u32,
294) -> Result<()> {
295    use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
296
297    if rope_dim % 2 != 0 {
298        return Err(MlxError::InvalidArgument(format!(
299            "RoPE neox f32 rope_dim must be even, got {}",
300            rope_dim
301        )));
302    }
303    if rope_dim > head_dim {
304        return Err(MlxError::InvalidArgument(format!(
305            "RoPE neox f32 rope_dim ({}) must be <= head_dim ({})",
306            rope_dim, head_dim
307        )));
308    }
309    if head_dim == 0 || seq_len == 0 || n_heads == 0 {
310        return Err(MlxError::InvalidArgument(
311            "RoPE neox f32 head_dim, seq_len, and n_heads must be > 0".into(),
312        ));
313    }
314
315    let n_rows = (seq_len as usize) * (n_heads as usize);
316    let expected_elements = n_rows * (head_dim as usize);
317    if input.element_count() != expected_elements {
318        return Err(MlxError::InvalidArgument(format!(
319            "RoPE neox f32 input element count {} != seq_len({}) * n_heads({}) * head_dim({})",
320            input.element_count(),
321            seq_len,
322            n_heads,
323            head_dim
324        )));
325    }
326    if output.element_count() != expected_elements {
327        return Err(MlxError::InvalidArgument(format!(
328            "RoPE neox f32 output element count {} != seq_len({}) * n_heads({}) * head_dim({})",
329            output.element_count(),
330            seq_len,
331            n_heads,
332            head_dim
333        )));
334    }
335
336    let pipeline = registry.get_pipeline("rope_neox_f32", device)?;
337    let half_rope = rope_dim / 2;
338
339    let has_ff = freq_factors.is_some();
340    let gpu_rope_params = GpuRopeNeoxF32Params {
341        n_heads,
342        has_freq_factors: u32::from(has_ff),
343    };
344
345    // When no freq_factors buffer is provided, bind the input buffer as a
346    // harmless dummy at buffer(5) — Metal requires all declared buffers to
347    // be bound. The shader checks has_freq_factors before reading.
348    let ff_buf = freq_factors.unwrap_or(input);
349
350    // Grid: (half_rope, n_rows) — one thread per pair per row
351    let tg_x = std::cmp::min(64, half_rope as u64);
352    let tg_y = std::cmp::min(4, n_rows as u64);
353
354    encode_with_args(
355        encoder,
356        pipeline,
357        &[
358            (0, KernelArg::Buffer(input)),
359            (1, KernelArg::Buffer(output)),
360            (2, KernelArg::Buffer(params_buf)),
361            (3, KernelArg::Buffer(positions_buf)),
362            (4, KernelArg::Bytes(as_bytes(&gpu_rope_params))),
363            (5, KernelArg::Buffer(ff_buf)),
364        ],
365        MTLSize::new(half_rope as u64, n_rows as u64, 1),
366        MTLSize::new(tg_x, tg_y, 1),
367    );
368
369    Ok(())
370}