mlx-native 0.7.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
//! Multi-section Rotary Position Embedding with optional interleaved mode.
//!
//! Used by Qwen3.5 / Qwen3.6 full-attention layers (ADR-013 Decision 10).
//! Both MROPE (`mode = 8`) and IMROPE (`mode = 40`) share a kernel; only the
//! sector-to-axis mapping differs.
//!
//! # Spec (summary)
//!
//! For every pair `p ∈ [0, rope_dim/2)`:
//! 1. `sector = p mod (s0 + s1 + s2 + s3)`
//! 2. Pick axis based on `mode`:
//!    * `Mrope`: contiguous sections — `sector ∈ [0, s0)` → axis 0, etc.
//!    * `Imrope`: `sector % 3` cycling — `sector % 3 == 0 && sector < 3*s0`
//!      → axis 0; `== 1 && sector < 3*s1` → axis 1; `== 2 && sector < 3*s2`
//!      → axis 2; else axis 3.
//! 3. `theta = position[axis] * freq_base^(-2p/rope_dim)`
//! 4. Rotate pair `(x[p], x[p + head_dim/2])` by that angle (NeoX indexing).
//!
//! Pairs `p ≥ rope_dim/2` pass through unchanged (partial-rotary-factor).
//!
//! # Positions layout
//!
//! The `positions` buffer is an `int32` array of length `4 * seq_len`:
//! first `seq_len` entries are the time-axis positions, next `seq_len` are
//! the height-axis, then width, then extra. For Qwen3.5 text, all four
//! axes are set to the token's 1D position.

use std::cell::RefCell;
use std::collections::HashMap;

use metal::MTLSize;

use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;

pub static ROPE_MULTI_SHADER_SOURCE: &str = include_str!("../shaders/rope_multi.metal");

pub fn register(registry: &mut KernelRegistry) {
    registry.register_source("rope_multi_f32", ROPE_MULTI_SHADER_SOURCE);
    registry.register_source("rope_multi_bf16", ROPE_MULTI_SHADER_SOURCE);
}

/// MROPE variant. Wire-level values match the ggml `GGML_ROPE_TYPE_*` enum.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum RopeMultiMode {
    /// Standard multi-section RoPE; contiguous sections.
    Mrope = 8,
    /// Interleaved multi-section RoPE; `sector % 3` cycles through 3 axes.
    /// Used by Qwen3.5 / Qwen3.6.
    Imrope = 40,
    /// Vision multi-section RoPE for ViT 2-D positions (Qwen3-VL ViT block).
    ///
    /// Mode value `24` matches `GGML_ROPE_TYPE_VISION` in
    /// `/opt/llama.cpp/ggml/include/ggml.h:253` and the per-section
    /// `[yyyyxxxx]` layout described at `ggml.h:1840-1846`.
    ///
    /// # Layout
    ///
    /// Only the first two section counts (`s0 = y`, `s1 = x`) are used; the
    /// last two are ignored. With `n_dims = head_dim / 2` and `sect_dims =
    /// s0 + s1 = n_dims`, the rotated pairs partition as:
    ///
    /// ```text
    /// pair_idx in [0,    s0)        -> axis 0 (y), local_p = pair_idx
    /// pair_idx in [s0,   s0 + s1)   -> axis 1 (x), local_p = pair_idx - s0
    /// ```
    ///
    /// # Per-section theta
    ///
    /// Unlike `Mrope` / `Imrope` which use a unified theta sequence across
    /// all sections (`theta = pos[axis] * freq_base^(-2*pair_idx/rope_dim)`),
    /// vision rope **restarts the theta exponent at every section boundary**:
    ///
    /// ```text
    /// theta_scale = freq_base^(-2 / n_dims)             where n_dims = head_dim/2
    /// theta       = pos[axis] * theta_scale^local_p
    /// ```
    ///
    /// `local_p` is the index of the pair *within its section*, not the
    /// global `pair_idx`. This per-section restart is what produces the
    /// `[0123][0123]` exponent pattern documented at `ggml.h:1845-1846`.
    ///
    /// # No partial-rotary tail
    ///
    /// The CPU reference at `ggml-cpu/ops.cpp:5860` calls
    /// `rotate_pairs(ne0, n_dims, ...)` (rotating *all* `head_dim/2` pairs)
    /// and at `:5866` skips the partial-rotary fill loop when `is_vision`,
    /// so the caller MUST supply `rope_dim == head_dim`.
    ///
    /// # Caller-side requirements
    ///
    /// - `rope_dim == head_dim` (validated; no partial-rotary support).
    /// - `sections[0] + sections[1] == head_dim / 2` (validated; the last
    ///   two sections are ignored but must be present in the buffer for
    ///   binary-layout uniformity with `Mrope` / `Imrope`).
    /// - `positions` buffer length still `4 * seq_len`; only axes 0 and 1
    ///   are read.
    Vision = 24,
}

/// Shape + config for a rope_multi dispatch.
#[derive(Debug, Clone, Copy)]
pub struct RopeMultiParams {
    pub head_dim: u32,
    pub rope_dim: u32, // must be <= head_dim; must be even
    pub n_heads: u32,
    pub seq_len: u32,
    pub freq_base: f32,
    pub mode: RopeMultiMode,
    /// Section counts `[s0, s1, s2, s3]`. Sum should be `rope_dim / 2` for
    /// full coverage; the kernel tolerates smaller sums (sector wraps).
    pub sections: [u32; 4],
}

fn validate(
    p: &RopeMultiParams,
    input: &MlxBuffer,
    output: &MlxBuffer,
    positions: &MlxBuffer,
) -> Result<()> {
    if p.head_dim == 0 || p.rope_dim == 0 || p.n_heads == 0 || p.seq_len == 0 {
        return Err(MlxError::InvalidArgument(
            "rope_multi: head_dim, rope_dim, n_heads, seq_len must all be > 0".into(),
        ));
    }
    if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
        return Err(MlxError::InvalidArgument(
            "rope_multi: head_dim and rope_dim must be even".into(),
        ));
    }
    if p.rope_dim > p.head_dim {
        return Err(MlxError::InvalidArgument(
            "rope_multi: rope_dim must be <= head_dim".into(),
        ));
    }
    if !p.freq_base.is_finite() || p.freq_base <= 0.0 {
        return Err(MlxError::InvalidArgument(format!(
            "rope_multi: freq_base must be finite and positive, got {}",
            p.freq_base
        )));
    }
    // Vision-mode requires every pair to rotate (no partial-rotary tail —
    // see /opt/llama.cpp/ggml/src/ggml-cpu/ops.cpp:5803,5866) and the first
    // two section counts must sum to n_dims = head_dim / 2 (last 2 ignored
    // per ggml.h:1843-1846).
    if p.mode == RopeMultiMode::Vision {
        if p.rope_dim != p.head_dim {
            return Err(MlxError::InvalidArgument(format!(
                "rope_multi(Vision): rope_dim must equal head_dim (no partial-rotary tail in vision mode), got rope_dim={}, head_dim={}",
                p.rope_dim, p.head_dim
            )));
        }
        let n_dims = p.head_dim / 2;
        let sect_sum = p.sections[0] + p.sections[1];
        if sect_sum != n_dims {
            return Err(MlxError::InvalidArgument(format!(
                "rope_multi(Vision): sections[0] + sections[1] must equal head_dim/2 ({}), got {} + {} = {}",
                n_dims, p.sections[0], p.sections[1], sect_sum
            )));
        }
    }

    let n_rows = (p.seq_len as usize) * (p.n_heads as usize);
    let elements = n_rows * (p.head_dim as usize);
    if input.element_count() != elements {
        return Err(MlxError::InvalidArgument(format!(
            "rope_multi: input element count {} != seq_len({}) * n_heads({}) * head_dim({}) = {}",
            input.element_count(),
            p.seq_len,
            p.n_heads,
            p.head_dim,
            elements
        )));
    }
    if output.element_count() != elements {
        return Err(MlxError::InvalidArgument(format!(
            "rope_multi: output element count {} != {}",
            output.element_count(),
            elements
        )));
    }
    if input.dtype() != output.dtype() {
        return Err(MlxError::InvalidArgument(format!(
            "rope_multi: input/output dtype mismatch {} vs {}",
            input.dtype(),
            output.dtype()
        )));
    }

    let expected_positions = 4 * (p.seq_len as usize);
    if positions.element_count() != expected_positions {
        return Err(MlxError::InvalidArgument(format!(
            "rope_multi: positions length {} != 4 * seq_len({}) = {}",
            positions.element_count(),
            p.seq_len,
            expected_positions
        )));
    }
    match positions.dtype() {
        DType::I32 | DType::U32 => {}
        other => {
            return Err(MlxError::InvalidArgument(format!(
                "rope_multi: positions must be i32 or u32 (got {})",
                other
            )));
        }
    }

    Ok(())
}

/// Dispatch a rope_multi operation.
///
/// The caller must upload:
/// - `params_buf`: float4 `[freq_base, head_dim, rope_dim, 0]`.
/// - `rope_params_buf`: uint4 `[n_heads, mode_code, seq_len, 0]`. The
///   `mode_code` is the `u32` underlying [`RopeMultiMode`].
/// - `sections_buf`: uint4 `[s0, s1, s2, s3]`.
/// - `positions`: int32 array of length `4 * seq_len`.
///
/// The helper [`build_rope_multi_buffers`] constructs all three small buffers
/// in one call for callers that do not already keep them pooled.
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_multi(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &metal::DeviceRef,
    input: &MlxBuffer,
    output: &MlxBuffer,
    positions: &MlxBuffer,
    params_buf: &MlxBuffer,
    rope_params_buf: &MlxBuffer,
    sections_buf: &MlxBuffer,
    p: RopeMultiParams,
) -> Result<()> {
    validate(&p, input, output, positions)?;

    let kernel_name = match input.dtype() {
        DType::F32 => "rope_multi_f32",
        DType::BF16 => "rope_multi_bf16",
        other => {
            return Err(MlxError::InvalidArgument(format!(
                "rope_multi: unsupported dtype {}",
                other
            )));
        }
    };

    let pipeline = registry.get_pipeline(kernel_name, device)?;

    let half_dim = p.head_dim / 2;
    let n_rows = p.seq_len * p.n_heads;

    // Grid: (half_dim, n_rows). Every thread writes a NeoX pair.
    let grid = MTLSize::new(half_dim as u64, n_rows as u64, 1);

    let tg_x = std::cmp::min(half_dim, 256).max(1);
    let remain = (256u32 / tg_x).max(1);
    let tg_y = std::cmp::min(n_rows, remain).max(1);
    let tg = MTLSize::new(tg_x as u64, tg_y as u64, 1);

    encoder.encode(
        pipeline,
        &[
            (0, input),
            (1, output),
            (2, params_buf),
            (3, positions),
            (4, rope_params_buf),
            (5, sections_buf),
        ],
        grid,
        tg,
    );

    Ok(())
}

/// Pre-built triple of small parameter buffers for a `rope_multi` dispatch.
///
/// Held in the per-thread [`ROPE_PACK_CACHE`] so callers that issue
/// repeated dispatches with stable shape (the qwen35 / qwen36 decode hot
/// path: identical `head_dim`, `rope_dim`, `n_heads`, `seq_len=1`,
/// `freq_base`, `mode`, `sections` every step) skip the per-call
/// allocation triplet (~208 µs/token measured on M5 Max in
/// `cfa-20260426-adr015-wave2a-p3aprime`).  Decode-out-of-scope cases
/// (variable `seq_len`) populate one entry per `seq_len` value seen,
/// then reuse on re-encounter.
pub struct RopeMultiBufferPack {
    pub params_buf: MlxBuffer,
    pub rope_params_buf: MlxBuffer,
    pub sections_buf: MlxBuffer,
}

/// Cache key for [`ROPE_PACK_CACHE`].  Includes the [`MlxDevice`] pointer
/// so two consecutive sessions with different devices (e.g. model swap)
/// never share entries — required for correctness, not just isolation.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct RopeMultiCacheKey {
    device_ptr: usize,
    head_dim: u32,
    rope_dim: u32,
    n_heads: u32,
    seq_len: u32,
    freq_base_bits: u32,
    mode: u32,
    sections: [u32; 4],
}

impl RopeMultiCacheKey {
    fn from_params(device: &crate::MlxDevice, p: &RopeMultiParams) -> Self {
        Self {
            device_ptr: device as *const _ as usize,
            head_dim: p.head_dim,
            rope_dim: p.rope_dim,
            n_heads: p.n_heads,
            seq_len: p.seq_len,
            freq_base_bits: p.freq_base.to_bits(),
            mode: p.mode as u32,
            sections: p.sections,
        }
    }
}

thread_local! {
    /// Per-thread cache of pre-built `rope_multi` parameter buffers,
    /// keyed by [`RopeMultiCacheKey`].  Built lazily on first
    /// [`dispatch_rope_multi_cached`] call for a given key, retained for
    /// the thread's lifetime (cleared by [`clear_rope_pack_cache`] in
    /// tests / on explicit model unload).
    static ROPE_PACK_CACHE: RefCell<HashMap<RopeMultiCacheKey, RopeMultiBufferPack>> =
        RefCell::new(HashMap::new());
}

/// Clear the thread-local rope-multi pack cache.
///
/// Useful between model loads (the cache key includes the device pointer
/// so old entries can never be returned for a new device, but they leak
/// memory until cleared) and in test suites that swap mocked devices.
pub fn clear_rope_pack_cache() {
    ROPE_PACK_CACHE.with(|cell| cell.borrow_mut().clear());
}

/// Inspect the current pack-cache size — diagnostic only.
pub fn rope_pack_cache_len() -> usize {
    ROPE_PACK_CACHE.with(|cell| cell.borrow().len())
}

/// Dispatch a `rope_multi` operation, reusing pre-built parameter
/// buffers from the per-thread cache.
///
/// Functionally equivalent to [`dispatch_rope_multi`] preceded by
/// [`build_rope_multi_buffers`], but the small param/rope_params/sections
/// buffers (3 × 16 bytes each) are built once per
/// `(device, head_dim, rope_dim, n_heads, seq_len, freq_base, mode,
/// sections)` tuple and reused on every subsequent call.  See
/// `docs/ADR-015-mlx-native-single-cb-decode.md` §"P3a' live profile pass"
/// rank-4 finding — `build_rope_multi_buffers` per-call alloc was
/// measured at 208 µs/token on the qwen3.6-27b-dwq46 dense-FFN-Q hot
/// path (16 FullAttn layers × 2 calls/layer = 32 calls/token, each
/// allocating 3 fresh `MlxBuffer`s via Mach IPC).  This helper closes
/// that residual.
///
/// Bit-exact to the per-call form: identical kernel, identical inputs,
/// only the small parameter triplet is sourced from the cache.
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_multi_cached(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &crate::MlxDevice,
    input: &MlxBuffer,
    output: &MlxBuffer,
    positions: &MlxBuffer,
    p: RopeMultiParams,
) -> Result<()> {
    let key = RopeMultiCacheKey::from_params(device, &p);
    ROPE_PACK_CACHE.with(|cell| {
        let mut map = cell.borrow_mut();
        if !map.contains_key(&key) {
            let (params_buf, rope_params_buf, sections_buf) =
                build_rope_multi_buffers(device, p)?;
            map.insert(
                key,
                RopeMultiBufferPack {
                    params_buf,
                    rope_params_buf,
                    sections_buf,
                },
            );
        }
        let pack = map
            .get(&key)
            .expect("inserted above if missing; cache is single-threaded");
        dispatch_rope_multi(
            encoder,
            registry,
            device.metal_device(),
            input,
            output,
            positions,
            &pack.params_buf,
            &pack.rope_params_buf,
            &pack.sections_buf,
            p,
        )
    })
}

/// Convenience: build all three small parameter buffers given a [`RopeMultiParams`].
///
/// Returns `(params_buf, rope_params_buf, sections_buf)`.
pub fn build_rope_multi_buffers(
    device: &crate::MlxDevice,
    p: RopeMultiParams,
) -> Result<(MlxBuffer, MlxBuffer, MlxBuffer)> {
    let mut params = device.alloc_buffer(4 * 4, DType::F32, vec![4])?;
    {
        let s = params.as_mut_slice::<f32>()?;
        s[0] = p.freq_base;
        s[1] = p.head_dim as f32;
        s[2] = p.rope_dim as f32;
        s[3] = 0.0;
    }
    let mut rope_params = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
    {
        let s = rope_params.as_mut_slice::<u32>()?;
        s[0] = p.n_heads;
        s[1] = p.mode as u32;
        s[2] = p.seq_len;
        s[3] = 0;
    }
    let mut sections = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
    {
        let s = sections.as_mut_slice::<u32>()?;
        s[0] = p.sections[0];
        s[1] = p.sections[1];
        s[2] = p.sections[2];
        s[3] = p.sections[3];
    }
    Ok((params, rope_params, sections))
}