Skip to main content

mlx_native/ops/
rope_multi.rs

1//! Multi-section Rotary Position Embedding with optional interleaved mode.
2//!
3//! Used by Qwen3.5 / Qwen3.6 full-attention layers (ADR-013 Decision 10).
4//! Both MROPE (`mode = 8`) and IMROPE (`mode = 40`) share a kernel; only the
5//! sector-to-axis mapping differs.
6//!
7//! # Spec (summary)
8//!
9//! For every pair `p ∈ [0, rope_dim/2)`:
10//! 1. `sector = p mod (s0 + s1 + s2 + s3)`
11//! 2. Pick axis based on `mode`:
12//!    * `Mrope`: contiguous sections — `sector ∈ [0, s0)` → axis 0, etc.
13//!    * `Imrope`: `sector % 3` cycling — `sector % 3 == 0 && sector < 3*s0`
14//!      → axis 0; `== 1 && sector < 3*s1` → axis 1; `== 2 && sector < 3*s2`
15//!      → axis 2; else axis 3.
16//! 3. `theta = position[axis] * freq_base^(-2p/rope_dim)`
17//! 4. Rotate pair `(x[p], x[p + head_dim/2])` by that angle (NeoX indexing).
18//!
19//! Pairs `p ≥ rope_dim/2` pass through unchanged (partial-rotary-factor).
20//!
21//! # Positions layout
22//!
23//! The `positions` buffer is an `int32` array of length `4 * seq_len`:
24//! first `seq_len` entries are the time-axis positions, next `seq_len` are
25//! the height-axis, then width, then extra. For Qwen3.5 text, all four
26//! axes are set to the token's 1D position.
27
28use std::cell::RefCell;
29use std::collections::HashMap;
30
31use metal::MTLSize;
32
33use crate::buffer::MlxBuffer;
34use crate::dtypes::DType;
35use crate::encoder::CommandEncoder;
36use crate::error::{MlxError, Result};
37use crate::kernel_registry::KernelRegistry;
38
39pub static ROPE_MULTI_SHADER_SOURCE: &str = include_str!("../shaders/rope_multi.metal");
40
41pub fn register(registry: &mut KernelRegistry) {
42    registry.register_source("rope_multi_f32", ROPE_MULTI_SHADER_SOURCE);
43    registry.register_source("rope_multi_bf16", ROPE_MULTI_SHADER_SOURCE);
44}
45
46/// MROPE variant. Wire-level values match the ggml `GGML_ROPE_TYPE_*` enum.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48#[repr(u32)]
49pub enum RopeMultiMode {
50    /// Standard multi-section RoPE; contiguous sections.
51    Mrope = 8,
52    /// Interleaved multi-section RoPE; `sector % 3` cycles through 3 axes.
53    /// Used by Qwen3.5 / Qwen3.6.
54    Imrope = 40,
55}
56
57/// Shape + config for a rope_multi dispatch.
58#[derive(Debug, Clone, Copy)]
59pub struct RopeMultiParams {
60    pub head_dim: u32,
61    pub rope_dim: u32, // must be <= head_dim; must be even
62    pub n_heads: u32,
63    pub seq_len: u32,
64    pub freq_base: f32,
65    pub mode: RopeMultiMode,
66    /// Section counts `[s0, s1, s2, s3]`. Sum should be `rope_dim / 2` for
67    /// full coverage; the kernel tolerates smaller sums (sector wraps).
68    pub sections: [u32; 4],
69}
70
71fn validate(
72    p: &RopeMultiParams,
73    input: &MlxBuffer,
74    output: &MlxBuffer,
75    positions: &MlxBuffer,
76) -> Result<()> {
77    if p.head_dim == 0 || p.rope_dim == 0 || p.n_heads == 0 || p.seq_len == 0 {
78        return Err(MlxError::InvalidArgument(
79            "rope_multi: head_dim, rope_dim, n_heads, seq_len must all be > 0".into(),
80        ));
81    }
82    if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
83        return Err(MlxError::InvalidArgument(
84            "rope_multi: head_dim and rope_dim must be even".into(),
85        ));
86    }
87    if p.rope_dim > p.head_dim {
88        return Err(MlxError::InvalidArgument(
89            "rope_multi: rope_dim must be <= head_dim".into(),
90        ));
91    }
92    if !p.freq_base.is_finite() || p.freq_base <= 0.0 {
93        return Err(MlxError::InvalidArgument(format!(
94            "rope_multi: freq_base must be finite and positive, got {}",
95            p.freq_base
96        )));
97    }
98
99    let n_rows = (p.seq_len as usize) * (p.n_heads as usize);
100    let elements = n_rows * (p.head_dim as usize);
101    if input.element_count() != elements {
102        return Err(MlxError::InvalidArgument(format!(
103            "rope_multi: input element count {} != seq_len({}) * n_heads({}) * head_dim({}) = {}",
104            input.element_count(),
105            p.seq_len,
106            p.n_heads,
107            p.head_dim,
108            elements
109        )));
110    }
111    if output.element_count() != elements {
112        return Err(MlxError::InvalidArgument(format!(
113            "rope_multi: output element count {} != {}",
114            output.element_count(),
115            elements
116        )));
117    }
118    if input.dtype() != output.dtype() {
119        return Err(MlxError::InvalidArgument(format!(
120            "rope_multi: input/output dtype mismatch {} vs {}",
121            input.dtype(),
122            output.dtype()
123        )));
124    }
125
126    let expected_positions = 4 * (p.seq_len as usize);
127    if positions.element_count() != expected_positions {
128        return Err(MlxError::InvalidArgument(format!(
129            "rope_multi: positions length {} != 4 * seq_len({}) = {}",
130            positions.element_count(),
131            p.seq_len,
132            expected_positions
133        )));
134    }
135    match positions.dtype() {
136        DType::I32 | DType::U32 => {}
137        other => {
138            return Err(MlxError::InvalidArgument(format!(
139                "rope_multi: positions must be i32 or u32 (got {})",
140                other
141            )));
142        }
143    }
144
145    Ok(())
146}
147
148/// Dispatch a rope_multi operation.
149///
150/// The caller must upload:
151/// - `params_buf`: float4 `[freq_base, head_dim, rope_dim, 0]`.
152/// - `rope_params_buf`: uint4 `[n_heads, mode_code, seq_len, 0]`. The
153///   `mode_code` is the `u32` underlying [`RopeMultiMode`].
154/// - `sections_buf`: uint4 `[s0, s1, s2, s3]`.
155/// - `positions`: int32 array of length `4 * seq_len`.
156///
157/// The helper [`build_rope_multi_buffers`] constructs all three small buffers
158/// in one call for callers that do not already keep them pooled.
159#[allow(clippy::too_many_arguments)]
160pub fn dispatch_rope_multi(
161    encoder: &mut CommandEncoder,
162    registry: &mut KernelRegistry,
163    device: &metal::DeviceRef,
164    input: &MlxBuffer,
165    output: &MlxBuffer,
166    positions: &MlxBuffer,
167    params_buf: &MlxBuffer,
168    rope_params_buf: &MlxBuffer,
169    sections_buf: &MlxBuffer,
170    p: RopeMultiParams,
171) -> Result<()> {
172    validate(&p, input, output, positions)?;
173
174    let kernel_name = match input.dtype() {
175        DType::F32 => "rope_multi_f32",
176        DType::BF16 => "rope_multi_bf16",
177        other => {
178            return Err(MlxError::InvalidArgument(format!(
179                "rope_multi: unsupported dtype {}",
180                other
181            )));
182        }
183    };
184
185    let pipeline = registry.get_pipeline(kernel_name, device)?;
186
187    let half_dim = p.head_dim / 2;
188    let n_rows = p.seq_len * p.n_heads;
189
190    // Grid: (half_dim, n_rows). Every thread writes a NeoX pair.
191    let grid = MTLSize::new(half_dim as u64, n_rows as u64, 1);
192
193    let tg_x = std::cmp::min(half_dim, 256).max(1);
194    let remain = (256u32 / tg_x).max(1);
195    let tg_y = std::cmp::min(n_rows, remain).max(1);
196    let tg = MTLSize::new(tg_x as u64, tg_y as u64, 1);
197
198    encoder.encode(
199        pipeline,
200        &[
201            (0, input),
202            (1, output),
203            (2, params_buf),
204            (3, positions),
205            (4, rope_params_buf),
206            (5, sections_buf),
207        ],
208        grid,
209        tg,
210    );
211
212    Ok(())
213}
214
215/// Pre-built triple of small parameter buffers for a `rope_multi` dispatch.
216///
217/// Held in the per-thread [`ROPE_PACK_CACHE`] so callers that issue
218/// repeated dispatches with stable shape (the qwen35 / qwen36 decode hot
219/// path: identical `head_dim`, `rope_dim`, `n_heads`, `seq_len=1`,
220/// `freq_base`, `mode`, `sections` every step) skip the per-call
221/// allocation triplet (~208 µs/token measured on M5 Max in
222/// `cfa-20260426-adr015-wave2a-p3aprime`).  Decode-out-of-scope cases
223/// (variable `seq_len`) populate one entry per `seq_len` value seen,
224/// then reuse on re-encounter.
225pub struct RopeMultiBufferPack {
226    pub params_buf: MlxBuffer,
227    pub rope_params_buf: MlxBuffer,
228    pub sections_buf: MlxBuffer,
229}
230
231/// Cache key for [`ROPE_PACK_CACHE`].  Includes the [`MlxDevice`] pointer
232/// so two consecutive sessions with different devices (e.g. model swap)
233/// never share entries — required for correctness, not just isolation.
234#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
235struct RopeMultiCacheKey {
236    device_ptr: usize,
237    head_dim: u32,
238    rope_dim: u32,
239    n_heads: u32,
240    seq_len: u32,
241    freq_base_bits: u32,
242    mode: u32,
243    sections: [u32; 4],
244}
245
246impl RopeMultiCacheKey {
247    fn from_params(device: &crate::MlxDevice, p: &RopeMultiParams) -> Self {
248        Self {
249            device_ptr: device as *const _ as usize,
250            head_dim: p.head_dim,
251            rope_dim: p.rope_dim,
252            n_heads: p.n_heads,
253            seq_len: p.seq_len,
254            freq_base_bits: p.freq_base.to_bits(),
255            mode: p.mode as u32,
256            sections: p.sections,
257        }
258    }
259}
260
261thread_local! {
262    /// Per-thread cache of pre-built `rope_multi` parameter buffers,
263    /// keyed by [`RopeMultiCacheKey`].  Built lazily on first
264    /// [`dispatch_rope_multi_cached`] call for a given key, retained for
265    /// the thread's lifetime (cleared by [`clear_rope_pack_cache`] in
266    /// tests / on explicit model unload).
267    static ROPE_PACK_CACHE: RefCell<HashMap<RopeMultiCacheKey, RopeMultiBufferPack>> =
268        RefCell::new(HashMap::new());
269}
270
271/// Clear the thread-local rope-multi pack cache.
272///
273/// Useful between model loads (the cache key includes the device pointer
274/// so old entries can never be returned for a new device, but they leak
275/// memory until cleared) and in test suites that swap mocked devices.
276pub fn clear_rope_pack_cache() {
277    ROPE_PACK_CACHE.with(|cell| cell.borrow_mut().clear());
278}
279
280/// Inspect the current pack-cache size — diagnostic only.
281pub fn rope_pack_cache_len() -> usize {
282    ROPE_PACK_CACHE.with(|cell| cell.borrow().len())
283}
284
285/// Dispatch a `rope_multi` operation, reusing pre-built parameter
286/// buffers from the per-thread cache.
287///
288/// Functionally equivalent to [`dispatch_rope_multi`] preceded by
289/// [`build_rope_multi_buffers`], but the small param/rope_params/sections
290/// buffers (3 × 16 bytes each) are built once per
291/// `(device, head_dim, rope_dim, n_heads, seq_len, freq_base, mode,
292/// sections)` tuple and reused on every subsequent call.  See
293/// `docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live profile pass"
294/// rank-4 finding — `build_rope_multi_buffers` per-call alloc was
295/// measured at 208 µs/token on the qwen3.6-27b-dwq46 dense-FFN-Q hot
296/// path (16 FullAttn layers × 2 calls/layer = 32 calls/token, each
297/// allocating 3 fresh `MlxBuffer`s via Mach IPC).  This helper closes
298/// that residual.
299///
300/// Bit-exact to the per-call form: identical kernel, identical inputs,
301/// only the small parameter triplet is sourced from the cache.
302#[allow(clippy::too_many_arguments)]
303pub fn dispatch_rope_multi_cached(
304    encoder: &mut CommandEncoder,
305    registry: &mut KernelRegistry,
306    device: &crate::MlxDevice,
307    input: &MlxBuffer,
308    output: &MlxBuffer,
309    positions: &MlxBuffer,
310    p: RopeMultiParams,
311) -> Result<()> {
312    let key = RopeMultiCacheKey::from_params(device, &p);
313    ROPE_PACK_CACHE.with(|cell| {
314        let mut map = cell.borrow_mut();
315        if !map.contains_key(&key) {
316            let (params_buf, rope_params_buf, sections_buf) =
317                build_rope_multi_buffers(device, p)?;
318            map.insert(
319                key,
320                RopeMultiBufferPack {
321                    params_buf,
322                    rope_params_buf,
323                    sections_buf,
324                },
325            );
326        }
327        let pack = map
328            .get(&key)
329            .expect("inserted above if missing; cache is single-threaded");
330        dispatch_rope_multi(
331            encoder,
332            registry,
333            device.metal_device(),
334            input,
335            output,
336            positions,
337            &pack.params_buf,
338            &pack.rope_params_buf,
339            &pack.sections_buf,
340            p,
341        )
342    })
343}
344
345/// Convenience: build all three small parameter buffers given a [`RopeMultiParams`].
346///
347/// Returns `(params_buf, rope_params_buf, sections_buf)`.
348pub fn build_rope_multi_buffers(
349    device: &crate::MlxDevice,
350    p: RopeMultiParams,
351) -> Result<(MlxBuffer, MlxBuffer, MlxBuffer)> {
352    let mut params = device.alloc_buffer(4 * 4, DType::F32, vec![4])?;
353    {
354        let s = params.as_mut_slice::<f32>()?;
355        s[0] = p.freq_base;
356        s[1] = p.head_dim as f32;
357        s[2] = p.rope_dim as f32;
358        s[3] = 0.0;
359    }
360    let mut rope_params = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
361    {
362        let s = rope_params.as_mut_slice::<u32>()?;
363        s[0] = p.n_heads;
364        s[1] = p.mode as u32;
365        s[2] = p.seq_len;
366        s[3] = 0;
367    }
368    let mut sections = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
369    {
370        let s = sections.as_mut_slice::<u32>()?;
371        s[0] = p.sections[0];
372        s[1] = p.sections[1];
373        s[2] = p.sections[2];
374        s[3] = p.sections[3];
375    }
376    Ok((params, rope_params, sections))
377}