Skip to main content

mlx_native/ops/
rope_multi.rs

1//! Multi-section Rotary Position Embedding with optional interleaved mode.
2//!
3//! Used by Qwen3.5 / Qwen3.6 full-attention layers (ADR-013 Decision 10).
4//! Both MROPE (`mode = 8`) and IMROPE (`mode = 40`) share a kernel; only the
5//! sector-to-axis mapping differs.
6//!
7//! # Spec (summary)
8//!
9//! For every pair `p ∈ [0, rope_dim/2)`:
10//! 1. `sector = p mod (s0 + s1 + s2 + s3)`
11//! 2. Pick axis based on `mode`:
12//!    * `Mrope`: contiguous sections — `sector ∈ [0, s0)` → axis 0, etc.
13//!    * `Imrope`: `sector % 3` cycling — `sector % 3 == 0 && sector < 3*s0`
14//!      → axis 0; `== 1 && sector < 3*s1` → axis 1; `== 2 && sector < 3*s2`
15//!      → axis 2; else axis 3.
16//! 3. `theta = position[axis] * freq_base^(-2p/rope_dim)`
17//! 4. Rotate pair `(x[p], x[p + head_dim/2])` by that angle (NeoX indexing).
18//!
19//! Pairs `p ≥ rope_dim/2` pass through unchanged (partial-rotary-factor).
20//!
21//! # Positions layout
22//!
23//! The `positions` buffer is an `int32` array of length `4 * seq_len`:
24//! first `seq_len` entries are the time-axis positions, next `seq_len` are
25//! the height-axis, then width, then extra. For Qwen3.5 text, all four
26//! axes are set to the token's 1D position.
27
28use std::cell::RefCell;
29use std::collections::HashMap;
30
31use metal::MTLSize;
32
33use crate::buffer::MlxBuffer;
34use crate::dtypes::DType;
35use crate::encoder::CommandEncoder;
36use crate::error::{MlxError, Result};
37use crate::kernel_registry::KernelRegistry;
38
39pub static ROPE_MULTI_SHADER_SOURCE: &str = include_str!("../shaders/rope_multi.metal");
40
41pub fn register(registry: &mut KernelRegistry) {
42    registry.register_source("rope_multi_f32", ROPE_MULTI_SHADER_SOURCE);
43    registry.register_source("rope_multi_bf16", ROPE_MULTI_SHADER_SOURCE);
44}
45
46/// MROPE variant. Wire-level values match the ggml `GGML_ROPE_TYPE_*` enum.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48#[repr(u32)]
49pub enum RopeMultiMode {
50    /// Standard multi-section RoPE; contiguous sections.
51    Mrope = 8,
52    /// Interleaved multi-section RoPE; `sector % 3` cycles through 3 axes.
53    /// Used by Qwen3.5 / Qwen3.6.
54    Imrope = 40,
55    /// Vision multi-section RoPE for ViT 2-D positions (Qwen3-VL ViT block).
56    ///
57    /// Mode value `24` matches `GGML_ROPE_TYPE_VISION` in
58    /// `/opt/llama.cpp/ggml/include/ggml.h:253` and the per-section
59    /// `[yyyyxxxx]` layout described at `ggml.h:1840-1846`.
60    ///
61    /// # Layout
62    ///
63    /// Only the first two section counts (`s0 = y`, `s1 = x`) are used; the
64    /// last two are ignored. With `n_dims = head_dim / 2` and `sect_dims =
65    /// s0 + s1 = n_dims`, the rotated pairs partition as:
66    ///
67    /// ```text
68    /// pair_idx in [0,    s0)        -> axis 0 (y), local_p = pair_idx
69    /// pair_idx in [s0,   s0 + s1)   -> axis 1 (x), local_p = pair_idx - s0
70    /// ```
71    ///
72    /// # Per-section theta
73    ///
74    /// Unlike `Mrope` / `Imrope` which use a unified theta sequence across
75    /// all sections (`theta = pos[axis] * freq_base^(-2*pair_idx/rope_dim)`),
76    /// vision rope **restarts the theta exponent at every section boundary**:
77    ///
78    /// ```text
79    /// theta_scale = freq_base^(-2 / n_dims)             where n_dims = head_dim/2
80    /// theta       = pos[axis] * theta_scale^local_p
81    /// ```
82    ///
83    /// `local_p` is the index of the pair *within its section*, not the
84    /// global `pair_idx`. This per-section restart is what produces the
85    /// `[0123][0123]` exponent pattern documented at `ggml.h:1845-1846`.
86    ///
87    /// # No partial-rotary tail
88    ///
89    /// The CPU reference at `ggml-cpu/ops.cpp:5860` calls
90    /// `rotate_pairs(ne0, n_dims, ...)` (rotating *all* `head_dim/2` pairs)
91    /// and at `:5866` skips the partial-rotary fill loop when `is_vision`,
92    /// so the caller MUST supply `rope_dim == head_dim`.
93    ///
94    /// # Caller-side requirements
95    ///
96    /// - `rope_dim == head_dim` (validated; no partial-rotary support).
97    /// - `sections[0] + sections[1] == head_dim / 2` (validated; the last
98    ///   two sections are ignored but must be present in the buffer for
99    ///   binary-layout uniformity with `Mrope` / `Imrope`).
100    /// - `positions` buffer length still `4 * seq_len`; only axes 0 and 1
101    ///   are read.
102    Vision = 24,
103}
104
105/// Shape + config for a rope_multi dispatch.
106#[derive(Debug, Clone, Copy)]
107pub struct RopeMultiParams {
108    pub head_dim: u32,
109    pub rope_dim: u32, // must be <= head_dim; must be even
110    pub n_heads: u32,
111    pub seq_len: u32,
112    pub freq_base: f32,
113    pub mode: RopeMultiMode,
114    /// Section counts `[s0, s1, s2, s3]`. Sum should be `rope_dim / 2` for
115    /// full coverage; the kernel tolerates smaller sums (sector wraps).
116    pub sections: [u32; 4],
117}
118
119fn validate(
120    p: &RopeMultiParams,
121    input: &MlxBuffer,
122    output: &MlxBuffer,
123    positions: &MlxBuffer,
124) -> Result<()> {
125    if p.head_dim == 0 || p.rope_dim == 0 || p.n_heads == 0 || p.seq_len == 0 {
126        return Err(MlxError::InvalidArgument(
127            "rope_multi: head_dim, rope_dim, n_heads, seq_len must all be > 0".into(),
128        ));
129    }
130    if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
131        return Err(MlxError::InvalidArgument(
132            "rope_multi: head_dim and rope_dim must be even".into(),
133        ));
134    }
135    if p.rope_dim > p.head_dim {
136        return Err(MlxError::InvalidArgument(
137            "rope_multi: rope_dim must be <= head_dim".into(),
138        ));
139    }
140    if !p.freq_base.is_finite() || p.freq_base <= 0.0 {
141        return Err(MlxError::InvalidArgument(format!(
142            "rope_multi: freq_base must be finite and positive, got {}",
143            p.freq_base
144        )));
145    }
146    // Vision-mode requires every pair to rotate (no partial-rotary tail —
147    // see /opt/llama.cpp/ggml/src/ggml-cpu/ops.cpp:5803,5866) and the first
148    // two section counts must sum to n_dims = head_dim / 2 (last 2 ignored
149    // per ggml.h:1843-1846).
150    if p.mode == RopeMultiMode::Vision {
151        if p.rope_dim != p.head_dim {
152            return Err(MlxError::InvalidArgument(format!(
153                "rope_multi(Vision): rope_dim must equal head_dim (no partial-rotary tail in vision mode), got rope_dim={}, head_dim={}",
154                p.rope_dim, p.head_dim
155            )));
156        }
157        let n_dims = p.head_dim / 2;
158        let sect_sum = p.sections[0] + p.sections[1];
159        if sect_sum != n_dims {
160            return Err(MlxError::InvalidArgument(format!(
161                "rope_multi(Vision): sections[0] + sections[1] must equal head_dim/2 ({}), got {} + {} = {}",
162                n_dims, p.sections[0], p.sections[1], sect_sum
163            )));
164        }
165    }
166
167    let n_rows = (p.seq_len as usize) * (p.n_heads as usize);
168    let elements = n_rows * (p.head_dim as usize);
169    if input.element_count() != elements {
170        return Err(MlxError::InvalidArgument(format!(
171            "rope_multi: input element count {} != seq_len({}) * n_heads({}) * head_dim({}) = {}",
172            input.element_count(),
173            p.seq_len,
174            p.n_heads,
175            p.head_dim,
176            elements
177        )));
178    }
179    if output.element_count() != elements {
180        return Err(MlxError::InvalidArgument(format!(
181            "rope_multi: output element count {} != {}",
182            output.element_count(),
183            elements
184        )));
185    }
186    if input.dtype() != output.dtype() {
187        return Err(MlxError::InvalidArgument(format!(
188            "rope_multi: input/output dtype mismatch {} vs {}",
189            input.dtype(),
190            output.dtype()
191        )));
192    }
193
194    let expected_positions = 4 * (p.seq_len as usize);
195    if positions.element_count() != expected_positions {
196        return Err(MlxError::InvalidArgument(format!(
197            "rope_multi: positions length {} != 4 * seq_len({}) = {}",
198            positions.element_count(),
199            p.seq_len,
200            expected_positions
201        )));
202    }
203    match positions.dtype() {
204        DType::I32 | DType::U32 => {}
205        other => {
206            return Err(MlxError::InvalidArgument(format!(
207                "rope_multi: positions must be i32 or u32 (got {})",
208                other
209            )));
210        }
211    }
212
213    Ok(())
214}
215
216/// Dispatch a rope_multi operation.
217///
218/// The caller must upload:
219/// - `params_buf`: float4 `[freq_base, head_dim, rope_dim, 0]`.
220/// - `rope_params_buf`: uint4 `[n_heads, mode_code, seq_len, 0]`. The
221///   `mode_code` is the `u32` underlying [`RopeMultiMode`].
222/// - `sections_buf`: uint4 `[s0, s1, s2, s3]`.
223/// - `positions`: int32 array of length `4 * seq_len`.
224///
225/// The helper [`build_rope_multi_buffers`] constructs all three small buffers
226/// in one call for callers that do not already keep them pooled.
227#[allow(clippy::too_many_arguments)]
228pub fn dispatch_rope_multi(
229    encoder: &mut CommandEncoder,
230    registry: &mut KernelRegistry,
231    device: &metal::DeviceRef,
232    input: &MlxBuffer,
233    output: &MlxBuffer,
234    positions: &MlxBuffer,
235    params_buf: &MlxBuffer,
236    rope_params_buf: &MlxBuffer,
237    sections_buf: &MlxBuffer,
238    p: RopeMultiParams,
239) -> Result<()> {
240    validate(&p, input, output, positions)?;
241
242    let kernel_name = match input.dtype() {
243        DType::F32 => "rope_multi_f32",
244        DType::BF16 => "rope_multi_bf16",
245        other => {
246            return Err(MlxError::InvalidArgument(format!(
247                "rope_multi: unsupported dtype {}",
248                other
249            )));
250        }
251    };
252
253    let pipeline = registry.get_pipeline(kernel_name, device)?;
254
255    let half_dim = p.head_dim / 2;
256    let n_rows = p.seq_len * p.n_heads;
257
258    // Grid: (half_dim, n_rows). Every thread writes a NeoX pair.
259    let grid = MTLSize::new(half_dim as u64, n_rows as u64, 1);
260
261    let tg_x = std::cmp::min(half_dim, 256).max(1);
262    let remain = (256u32 / tg_x).max(1);
263    let tg_y = std::cmp::min(n_rows, remain).max(1);
264    let tg = MTLSize::new(tg_x as u64, tg_y as u64, 1);
265
266    encoder.encode(
267        pipeline,
268        &[
269            (0, input),
270            (1, output),
271            (2, params_buf),
272            (3, positions),
273            (4, rope_params_buf),
274            (5, sections_buf),
275        ],
276        grid,
277        tg,
278    );
279
280    Ok(())
281}
282
283/// Pre-built triple of small parameter buffers for a `rope_multi` dispatch.
284///
285/// Held in the per-thread [`ROPE_PACK_CACHE`] so callers that issue
286/// repeated dispatches with stable shape (the qwen35 / qwen36 decode hot
287/// path: identical `head_dim`, `rope_dim`, `n_heads`, `seq_len=1`,
288/// `freq_base`, `mode`, `sections` every step) skip the per-call
289/// allocation triplet (~208 µs/token measured on M5 Max in
290/// `cfa-20260426-adr015-wave2a-p3aprime`).  Decode-out-of-scope cases
291/// (variable `seq_len`) populate one entry per `seq_len` value seen,
292/// then reuse on re-encounter.
293pub struct RopeMultiBufferPack {
294    pub params_buf: MlxBuffer,
295    pub rope_params_buf: MlxBuffer,
296    pub sections_buf: MlxBuffer,
297}
298
299/// Cache key for [`ROPE_PACK_CACHE`].  Includes the [`MlxDevice`] pointer
300/// so two consecutive sessions with different devices (e.g. model swap)
301/// never share entries — required for correctness, not just isolation.
302#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
303struct RopeMultiCacheKey {
304    device_ptr: usize,
305    head_dim: u32,
306    rope_dim: u32,
307    n_heads: u32,
308    seq_len: u32,
309    freq_base_bits: u32,
310    mode: u32,
311    sections: [u32; 4],
312}
313
314impl RopeMultiCacheKey {
315    fn from_params(device: &crate::MlxDevice, p: &RopeMultiParams) -> Self {
316        Self {
317            device_ptr: device as *const _ as usize,
318            head_dim: p.head_dim,
319            rope_dim: p.rope_dim,
320            n_heads: p.n_heads,
321            seq_len: p.seq_len,
322            freq_base_bits: p.freq_base.to_bits(),
323            mode: p.mode as u32,
324            sections: p.sections,
325        }
326    }
327}
328
329thread_local! {
330    /// Per-thread cache of pre-built `rope_multi` parameter buffers,
331    /// keyed by [`RopeMultiCacheKey`].  Built lazily on first
332    /// [`dispatch_rope_multi_cached`] call for a given key, retained for
333    /// the thread's lifetime (cleared by [`clear_rope_pack_cache`] in
334    /// tests / on explicit model unload).
335    static ROPE_PACK_CACHE: RefCell<HashMap<RopeMultiCacheKey, RopeMultiBufferPack>> =
336        RefCell::new(HashMap::new());
337}
338
339/// Clear the thread-local rope-multi pack cache.
340///
341/// Useful between model loads (the cache key includes the device pointer
342/// so old entries can never be returned for a new device, but they leak
343/// memory until cleared) and in test suites that swap mocked devices.
344pub fn clear_rope_pack_cache() {
345    ROPE_PACK_CACHE.with(|cell| cell.borrow_mut().clear());
346}
347
348/// Inspect the current pack-cache size — diagnostic only.
349pub fn rope_pack_cache_len() -> usize {
350    ROPE_PACK_CACHE.with(|cell| cell.borrow().len())
351}
352
353/// Dispatch a `rope_multi` operation, reusing pre-built parameter
354/// buffers from the per-thread cache.
355///
356/// Functionally equivalent to [`dispatch_rope_multi`] preceded by
357/// [`build_rope_multi_buffers`], but the small param/rope_params/sections
358/// buffers (3 × 16 bytes each) are built once per
359/// `(device, head_dim, rope_dim, n_heads, seq_len, freq_base, mode,
360/// sections)` tuple and reused on every subsequent call.  See
361/// `docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live profile pass"
362/// rank-4 finding — `build_rope_multi_buffers` per-call alloc was
363/// measured at 208 µs/token on the qwen3.6-27b-dwq46 dense-FFN-Q hot
364/// path (16 FullAttn layers × 2 calls/layer = 32 calls/token, each
365/// allocating 3 fresh `MlxBuffer`s via Mach IPC).  This helper closes
366/// that residual.
367///
368/// Bit-exact to the per-call form: identical kernel, identical inputs,
369/// only the small parameter triplet is sourced from the cache.
370#[allow(clippy::too_many_arguments)]
371pub fn dispatch_rope_multi_cached(
372    encoder: &mut CommandEncoder,
373    registry: &mut KernelRegistry,
374    device: &crate::MlxDevice,
375    input: &MlxBuffer,
376    output: &MlxBuffer,
377    positions: &MlxBuffer,
378    p: RopeMultiParams,
379) -> Result<()> {
380    let key = RopeMultiCacheKey::from_params(device, &p);
381    ROPE_PACK_CACHE.with(|cell| {
382        let mut map = cell.borrow_mut();
383        if !map.contains_key(&key) {
384            let (params_buf, rope_params_buf, sections_buf) =
385                build_rope_multi_buffers(device, p)?;
386            map.insert(
387                key,
388                RopeMultiBufferPack {
389                    params_buf,
390                    rope_params_buf,
391                    sections_buf,
392                },
393            );
394        }
395        let pack = map
396            .get(&key)
397            .expect("inserted above if missing; cache is single-threaded");
398        dispatch_rope_multi(
399            encoder,
400            registry,
401            device.metal_device(),
402            input,
403            output,
404            positions,
405            &pack.params_buf,
406            &pack.rope_params_buf,
407            &pack.sections_buf,
408            p,
409        )
410    })
411}
412
413/// Convenience: build all three small parameter buffers given a [`RopeMultiParams`].
414///
415/// Returns `(params_buf, rope_params_buf, sections_buf)`.
416pub fn build_rope_multi_buffers(
417    device: &crate::MlxDevice,
418    p: RopeMultiParams,
419) -> Result<(MlxBuffer, MlxBuffer, MlxBuffer)> {
420    let mut params = device.alloc_buffer(4 * 4, DType::F32, vec![4])?;
421    {
422        let s = params.as_mut_slice::<f32>()?;
423        s[0] = p.freq_base;
424        s[1] = p.head_dim as f32;
425        s[2] = p.rope_dim as f32;
426        s[3] = 0.0;
427    }
428    let mut rope_params = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
429    {
430        let s = rope_params.as_mut_slice::<u32>()?;
431        s[0] = p.n_heads;
432        s[1] = p.mode as u32;
433        s[2] = p.seq_len;
434        s[3] = 0;
435    }
436    let mut sections = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
437    {
438        let s = sections.as_mut_slice::<u32>()?;
439        s[0] = p.sections[0];
440        s[1] = p.sections[1];
441        s[2] = p.sections[2];
442        s[3] = p.sections[3];
443    }
444    Ok((params, rope_params, sections))
445}