Skip to main content

mlx_native/ops/
rope_multi.rs

1//! Multi-section Rotary Position Embedding with optional interleaved mode.
2//!
3//! Used by Qwen3.5 / Qwen3.6 full-attention layers (ADR-013 Decision 10).
4//! Both MROPE (`mode = 8`) and IMROPE (`mode = 40`) share a kernel; only the
5//! sector-to-axis mapping differs.
6//!
7//! # Spec (summary)
8//!
9//! For every pair `p ∈ [0, rope_dim/2)`:
10//! 1. `sector = p mod (s0 + s1 + s2 + s3)`
11//! 2. Pick axis based on `mode`:
12//!    * `Mrope`: contiguous sections — `sector ∈ [0, s0)` → axis 0, etc.
13//!    * `Imrope`: `sector % 3` cycling — `sector % 3 == 0 && sector < 3*s0`
14//!      → axis 0; `== 1 && sector < 3*s1` → axis 1; `== 2 && sector < 3*s2`
15//!      → axis 2; else axis 3.
16//! 3. `theta = position[axis] * freq_base^(-2p/rope_dim)`
17//! 4. Rotate pair `(x[p], x[p + head_dim/2])` by that angle (NeoX indexing).
18//!
19//! Pairs `p ≥ rope_dim/2` pass through unchanged (partial-rotary-factor).
20//!
21//! # Positions layout
22//!
23//! The `positions` buffer is an `int32` array of length `4 * seq_len`:
24//! first `seq_len` entries are the time-axis positions, next `seq_len` are
25//! the height-axis, then width, then extra. For Qwen3.5 text, all four
26//! axes are set to the token's 1D position.
27
28use metal::MTLSize;
29
30use crate::buffer::MlxBuffer;
31use crate::dtypes::DType;
32use crate::encoder::CommandEncoder;
33use crate::error::{MlxError, Result};
34use crate::kernel_registry::KernelRegistry;
35
36pub static ROPE_MULTI_SHADER_SOURCE: &str = include_str!("../shaders/rope_multi.metal");
37
38pub fn register(registry: &mut KernelRegistry) {
39    registry.register_source("rope_multi_f32", ROPE_MULTI_SHADER_SOURCE);
40    registry.register_source("rope_multi_bf16", ROPE_MULTI_SHADER_SOURCE);
41}
42
43/// MROPE variant. Wire-level values match the ggml `GGML_ROPE_TYPE_*` enum.
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45#[repr(u32)]
46pub enum RopeMultiMode {
47    /// Standard multi-section RoPE; contiguous sections.
48    Mrope = 8,
49    /// Interleaved multi-section RoPE; `sector % 3` cycles through 3 axes.
50    /// Used by Qwen3.5 / Qwen3.6.
51    Imrope = 40,
52}
53
54/// Shape + config for a rope_multi dispatch.
55#[derive(Debug, Clone, Copy)]
56pub struct RopeMultiParams {
57    pub head_dim: u32,
58    pub rope_dim: u32, // must be <= head_dim; must be even
59    pub n_heads: u32,
60    pub seq_len: u32,
61    pub freq_base: f32,
62    pub mode: RopeMultiMode,
63    /// Section counts `[s0, s1, s2, s3]`. Sum should be `rope_dim / 2` for
64    /// full coverage; the kernel tolerates smaller sums (sector wraps).
65    pub sections: [u32; 4],
66}
67
68fn validate(
69    p: &RopeMultiParams,
70    input: &MlxBuffer,
71    output: &MlxBuffer,
72    positions: &MlxBuffer,
73) -> Result<()> {
74    if p.head_dim == 0 || p.rope_dim == 0 || p.n_heads == 0 || p.seq_len == 0 {
75        return Err(MlxError::InvalidArgument(
76            "rope_multi: head_dim, rope_dim, n_heads, seq_len must all be > 0".into(),
77        ));
78    }
79    if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
80        return Err(MlxError::InvalidArgument(
81            "rope_multi: head_dim and rope_dim must be even".into(),
82        ));
83    }
84    if p.rope_dim > p.head_dim {
85        return Err(MlxError::InvalidArgument(
86            "rope_multi: rope_dim must be <= head_dim".into(),
87        ));
88    }
89    if !p.freq_base.is_finite() || p.freq_base <= 0.0 {
90        return Err(MlxError::InvalidArgument(format!(
91            "rope_multi: freq_base must be finite and positive, got {}",
92            p.freq_base
93        )));
94    }
95
96    let n_rows = (p.seq_len as usize) * (p.n_heads as usize);
97    let elements = n_rows * (p.head_dim as usize);
98    if input.element_count() != elements {
99        return Err(MlxError::InvalidArgument(format!(
100            "rope_multi: input element count {} != seq_len({}) * n_heads({}) * head_dim({}) = {}",
101            input.element_count(),
102            p.seq_len,
103            p.n_heads,
104            p.head_dim,
105            elements
106        )));
107    }
108    if output.element_count() != elements {
109        return Err(MlxError::InvalidArgument(format!(
110            "rope_multi: output element count {} != {}",
111            output.element_count(),
112            elements
113        )));
114    }
115    if input.dtype() != output.dtype() {
116        return Err(MlxError::InvalidArgument(format!(
117            "rope_multi: input/output dtype mismatch {} vs {}",
118            input.dtype(),
119            output.dtype()
120        )));
121    }
122
123    let expected_positions = 4 * (p.seq_len as usize);
124    if positions.element_count() != expected_positions {
125        return Err(MlxError::InvalidArgument(format!(
126            "rope_multi: positions length {} != 4 * seq_len({}) = {}",
127            positions.element_count(),
128            p.seq_len,
129            expected_positions
130        )));
131    }
132    match positions.dtype() {
133        DType::I32 | DType::U32 => {}
134        other => {
135            return Err(MlxError::InvalidArgument(format!(
136                "rope_multi: positions must be i32 or u32 (got {})",
137                other
138            )));
139        }
140    }
141
142    Ok(())
143}
144
145/// Dispatch a rope_multi operation.
146///
147/// The caller must upload:
148/// - `params_buf`: float4 `[freq_base, head_dim, rope_dim, 0]`.
149/// - `rope_params_buf`: uint4 `[n_heads, mode_code, seq_len, 0]`. The
150///   `mode_code` is the `u32` underlying [`RopeMultiMode`].
151/// - `sections_buf`: uint4 `[s0, s1, s2, s3]`.
152/// - `positions`: int32 array of length `4 * seq_len`.
153///
154/// The helper [`build_rope_multi_buffers`] constructs all three small buffers
155/// in one call for callers that do not already keep them pooled.
156#[allow(clippy::too_many_arguments)]
157pub fn dispatch_rope_multi(
158    encoder: &mut CommandEncoder,
159    registry: &mut KernelRegistry,
160    device: &metal::DeviceRef,
161    input: &MlxBuffer,
162    output: &MlxBuffer,
163    positions: &MlxBuffer,
164    params_buf: &MlxBuffer,
165    rope_params_buf: &MlxBuffer,
166    sections_buf: &MlxBuffer,
167    p: RopeMultiParams,
168) -> Result<()> {
169    validate(&p, input, output, positions)?;
170
171    let kernel_name = match input.dtype() {
172        DType::F32 => "rope_multi_f32",
173        DType::BF16 => "rope_multi_bf16",
174        other => {
175            return Err(MlxError::InvalidArgument(format!(
176                "rope_multi: unsupported dtype {}",
177                other
178            )));
179        }
180    };
181
182    let pipeline = registry.get_pipeline(kernel_name, device)?;
183
184    let half_dim = p.head_dim / 2;
185    let n_rows = p.seq_len * p.n_heads;
186
187    // Grid: (half_dim, n_rows). Every thread writes a NeoX pair.
188    let grid = MTLSize::new(half_dim as u64, n_rows as u64, 1);
189
190    let tg_x = std::cmp::min(half_dim, 256).max(1);
191    let remain = (256u32 / tg_x).max(1);
192    let tg_y = std::cmp::min(n_rows, remain).max(1);
193    let tg = MTLSize::new(tg_x as u64, tg_y as u64, 1);
194
195    encoder.encode(
196        pipeline,
197        &[
198            (0, input),
199            (1, output),
200            (2, params_buf),
201            (3, positions),
202            (4, rope_params_buf),
203            (5, sections_buf),
204        ],
205        grid,
206        tg,
207    );
208
209    Ok(())
210}
211
212/// Convenience: build all three small parameter buffers given a [`RopeMultiParams`].
213///
214/// Returns `(params_buf, rope_params_buf, sections_buf)`.
215pub fn build_rope_multi_buffers(
216    device: &crate::MlxDevice,
217    p: RopeMultiParams,
218) -> Result<(MlxBuffer, MlxBuffer, MlxBuffer)> {
219    let mut params = device.alloc_buffer(4 * 4, DType::F32, vec![4])?;
220    {
221        let s = params.as_mut_slice::<f32>()?;
222        s[0] = p.freq_base;
223        s[1] = p.head_dim as f32;
224        s[2] = p.rope_dim as f32;
225        s[3] = 0.0;
226    }
227    let mut rope_params = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
228    {
229        let s = rope_params.as_mut_slice::<u32>()?;
230        s[0] = p.n_heads;
231        s[1] = p.mode as u32;
232        s[2] = p.seq_len;
233        s[3] = 0;
234    }
235    let mut sections = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
236    {
237        let s = sections.as_mut_slice::<u32>()?;
238        s[0] = p.sections[0];
239        s[1] = p.sections[1];
240        s[2] = p.sections[2];
241        s[3] = p.sections[3];
242    }
243    Ok((params, rope_params, sections))
244}