Skip to main content

mlx_native/
mem_ranges.rs

1//! Dataflow-driven barrier inference (port of llama.cpp `mem_ranges`).
2//!
3//! ADR-015 iter37 — framework-side complement to iter21's hand-audited
4//! barrier fix at `gpu_full_attn.rs:1856`.
5//!
6//! # Purpose
7//!
8//! When a Metal `MTLComputeCommandEncoder` is created with
9//! `MTLDispatchTypeConcurrent` (mlx-native's default since iter8e —
10//! [`encoder::CommandEncoder::get_or_create_encoder`](crate::encoder)),
11//! every dispatch can execute in parallel with every other dispatch in
12//! the same encoder unless separated by a memory barrier.  The encoder
13//! does not infer dataflow on its own — the caller must hand-place
14//! `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]` between
15//! every read-after-write (RAW), write-after-read (WAR), or
16//! write-after-write (WAW) pair.
17//!
18//! Hand-audited barrier placement is correct but fragile: iter21 found
19//! one missing producer→consumer edge (`sigmoid_gate_multiply` →
20//! `linear_projection wo`) that had escaped review for months because
21//! the diverged-output bug it caused only surfaced under specific
22//! sequence-length × sample-count combinations.  Any future kernel
23//! sequence built without rigorous review is subject to the same
24//! class of bug.
25//!
26//! `MemRanges` ports llama.cpp's mem_ranges algorithm
27//! (`/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp`) so
28//! callers describe each dispatch's read and write buffer regions and
29//! the framework auto-emits a barrier exactly when the new dispatch's
30//! ranges overlap a previously-recorded range.  This makes
31//! iter21-class bugs structurally impossible at the framework boundary.
32//!
33//! # Algorithm
34//!
35//! Verbatim port of `ggml_mem_ranges_check` + `ggml_mem_ranges_add`
36//! (lines 124-185 of `ggml-metal-common.cpp`):
37//!
38//! * A range is `(buffer_id, p0, p1, role∈{Src,Dst})`.
39//! * Two ranges in different buffers can never conflict.
40//! * Two `Src` ranges in the same buffer never conflict (read-read OK).
41//! * A new `Src` overlapping an existing `Dst` is a RAW conflict.
42//! * A new `Dst` overlapping any existing range (Src or Dst) is a
43//!   WAR/WAW conflict.
44//! * Overlap test: `new.p0 < existing.p1 && new.p1 >= existing.p0`
45//!   (matches llama.cpp byte-for-byte at line 138).
46//! * On conflict, the caller emits a `memoryBarrier` and `reset()`s
47//!   the cumulative state, then records the new dispatch's ranges.
48//!
49//! # mlx-native specifics
50//!
51//! llama.cpp keys ranges by `tensor->buffer` (the backend buffer
52//! handle) plus `tensor->data` (the element pointer inside that
53//! buffer). mlx-native uses
54//! [`MlxBuffer::metal_buffer`](crate::buffer::MlxBuffer::metal_buffer)
55//! as the `(MTLBuffer*) -> usize` buffer-id and
56//! [`MlxBuffer::contents_ptr`](crate::buffer::MlxBuffer::contents_ptr)
57//! as the start address.  These are stable across the encoder lifetime
58//! because hf2q's per-decode-token `MlxBufferPool` keeps ARC clones
59//! alive for the entire CB.  Different `slice_view`s of the same
60//! parent buffer share `metal_buffer()` (intentional: a write to
61//! `parent[0..N]` must barrier against a read of `parent[N..2N]`
62//! only when the two slices alias).
63//!
64//! # Why same-buffer-only
65//!
66//! Different `MTLBuffer`s never alias — Metal's address space is
67//! per-buffer.  Skipping the overlap check on cross-buffer pairs is
68//! both correct and a major perf win: a typical decode token has
69//! ~1500 dispatches against ~30-50 distinct buffers, so the
70//! same-buffer filter keeps the per-dispatch check at O(N) over the
71//! short list of ranges in *one* buffer rather than O(N) over all
72//! ranges.
73//!
74//! # Per iter37 envelope: env-gated, opt-in
75//!
76//! `MemRanges` is dormant unless the caller explicitly threads it
77//! through a dispatch via [`CommandEncoder::dispatch_tracked`].  The
78//! existing `encode*`/`memory_barrier()` API is unchanged, so iter37
79//! ships with **zero behavioral diff in production** until callers
80//! migrate.  Migration of the qwen35 forward path is iter38+ scope.
81
82use crate::buffer::MlxBuffer;
83use metal::foreign_types::ForeignType;
84
85/// Whether a recorded range was read by a dispatch (`Src`) or written
86/// by a dispatch (`Dst`).  Mirrors `ggml_mem_range_type` in
87/// `ggml-metal-common.h:14-17`.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum MemRangeRole {
90    /// Dispatch reads this range.
91    Src,
92    /// Dispatch writes this range.
93    Dst,
94}
95
96/// A buffer region recorded for dataflow tracking.
97///
98/// Mirrors `struct ggml_mem_range` in `ggml-metal-common.cpp:10-17`.
99#[derive(Clone, Copy, Debug)]
100pub struct BufferRange {
101    /// Backing `metal::Buffer` pointer cast to `usize`.  Stable across
102    /// the encoder lifetime as long as the `MlxBuffer`'s ARC clone
103    /// outlives the CB (see `CommandEncoder::new_with_residency`
104    /// caller contract).
105    pub buf_id: usize,
106    /// Start byte address (`contents_ptr() + byte_offset` for
107    /// `MlxBuffer`).  Used for overlap arithmetic.
108    pub p0: u64,
109    /// End byte address (start + element-extent).  llama.cpp uses
110    /// `tensor->data + ggml_backend_buft_get_alloc_size(tensor)` —
111    /// for mlx-native we use the buffer's `byte_len()` minus
112    /// `byte_offset()`, which equals the slice extent.
113    pub p1: u64,
114    /// Whether this range is read or written by the recording dispatch.
115    pub role: MemRangeRole,
116}
117
118impl BufferRange {
119    /// Build a [`BufferRange`] from an [`MlxBuffer`] and a role.
120    ///
121    /// Uses `metal_buffer().as_ptr() as usize` as the buffer-id (so two
122    /// `slice_view`s of the same parent share a `buf_id`, which is
123    /// the intended behavior — a slice write must barrier against a
124    /// sibling-slice read of the same parent).
125    ///
126    /// The `(p0, p1)` range covers the addressable extent the kernel
127    /// can reach: `[contents_ptr + byte_offset,
128    ///   contents_ptr + byte_offset + (byte_len - byte_offset))`.
129    /// For non-slice buffers `byte_offset == 0` and the range covers
130    /// the full allocation.  For slices the range covers only the
131    /// slice region — matching llama.cpp's `tensor->data ..
132    /// tensor->data + alloc_size`.
133    #[inline]
134    pub fn from_buffer(buf: &MlxBuffer, role: MemRangeRole) -> Self {
135        let buf_id = buf.metal_buffer().as_ptr() as usize;
136        // `contents_ptr` already points at the buffer's base; mlx-native
137        // applies `byte_offset` only at bind-site (`set_buffer`).  The
138        // overlap arithmetic must use the slice's *kernel-visible*
139        // address window, so we add the offset explicitly here.
140        let base = buf.contents_ptr() as u64;
141        let p0 = base + buf.byte_offset();
142        // `byte_len()` returns the underlying allocation length, so
143        // the slice extent is `(allocation_len - offset)`.
144        let extent = (buf.byte_len() as u64).saturating_sub(buf.byte_offset());
145        let p1 = p0 + extent;
146        Self {
147            buf_id,
148            p0,
149            p1,
150            role,
151        }
152    }
153
154    /// Whether `self` and `other` overlap by the same arithmetic
155    /// llama.cpp uses at `ggml-metal-common.cpp:138`.
156    ///
157    /// Returns `false` for cross-buffer pairs (different `buf_id`) and
158    /// for src-vs-src pairs (read-read is always concurrent-safe).
159    #[inline]
160    pub fn conflicts_with(&self, other: &BufferRange) -> bool {
161        if self.buf_id != other.buf_id {
162            return false;
163        }
164        if self.role == MemRangeRole::Src && other.role == MemRangeRole::Src {
165            return false;
166        }
167        // Llama.cpp: `mr.p0 < cmp.p1 && mr.p1 >= cmp.p0`
168        self.p0 < other.p1 && self.p1 >= other.p0
169    }
170}
171
172/// Cumulative dataflow state for a sequence of concurrent dispatches.
173///
174/// Direct port of `struct ggml_mem_ranges` in
175/// `ggml-metal-common.cpp:19-23`.  The state is reset every time a
176/// barrier is emitted; between barriers, all recorded dispatches are
177/// considered to run concurrently and their R/W ranges accumulate.
178pub struct MemRanges {
179    ranges: Vec<BufferRange>,
180    /// Total checks performed (diagnostic).
181    checks: u64,
182    /// Number of `check()` calls that returned `false` (i.e. forced a
183    /// barrier).  `total_dispatches - barriers_forced` == elided
184    /// barriers (would-have-been-emitted by an unconditional pattern).
185    barriers_forced: u64,
186}
187
188impl Default for MemRanges {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194impl MemRanges {
195    /// New empty state.  Pre-allocates capacity matching llama.cpp's
196    /// `reserve(256)` (line 28).
197    pub fn new() -> Self {
198        Self {
199            ranges: Vec::with_capacity(256),
200            checks: 0,
201            barriers_forced: 0,
202        }
203    }
204
205    /// Drop all recorded ranges (called after emitting a barrier).
206    /// Mirrors `ggml_mem_ranges_reset`.
207    #[inline]
208    pub fn reset(&mut self) {
209        self.ranges.clear();
210    }
211
212    /// Number of currently-recorded ranges (diagnostic).
213    #[inline]
214    pub fn len(&self) -> usize {
215        self.ranges.len()
216    }
217
218    /// Whether the cumulative state is empty.
219    #[inline]
220    pub fn is_empty(&self) -> bool {
221        self.ranges.is_empty()
222    }
223
224    /// Number of `check()` calls performed since construction
225    /// (diagnostic, monotone).
226    #[inline]
227    pub fn checks(&self) -> u64 {
228        self.checks
229    }
230
231    /// Number of `check()` calls that returned `false`, forcing a
232    /// barrier (diagnostic, monotone).  When tracking is enabled at
233    /// every dispatch, `total_dispatches - barriers_forced` ==
234    /// barriers elided versus the unconditional-barrier baseline.
235    #[inline]
236    pub fn barriers_forced(&self) -> u64 {
237        self.barriers_forced
238    }
239
240    /// Push a single range onto the cumulative state without checking.
241    /// Used internally by [`Self::add`] and [`Self::add_dispatch`].
242    /// Public so unit tests can construct adversarial states.
243    #[inline]
244    pub fn push(&mut self, range: BufferRange) {
245        self.ranges.push(range);
246    }
247
248    /// Record a dispatch's read-buffer ranges + write-buffer ranges.
249    ///
250    /// Mirrors `ggml_mem_ranges_add(tensor)` at
251    /// `ggml-metal-common.cpp:114-122`: pushes one Src range per
252    /// `tensor->src[i]` and one Dst range for `tensor` itself.
253    ///
254    /// Caller is expected to have already invoked
255    /// [`Self::check_dispatch`] and emitted a barrier on conflict; the
256    /// barrier-emit + `reset()` is the responsibility of the
257    /// integration site (typically `CommandEncoder`).
258    pub fn add_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
259        for r in reads {
260            self.ranges
261                .push(BufferRange::from_buffer(r, MemRangeRole::Src));
262        }
263        for w in writes {
264            self.ranges
265                .push(BufferRange::from_buffer(w, MemRangeRole::Dst));
266        }
267    }
268
269    /// Check whether a candidate dispatch can run concurrently with
270    /// the recorded state.
271    ///
272    /// Returns `true` iff none of the candidate's reads or writes
273    /// conflict with any recorded range.  Exactly mirrors
274    /// `ggml_mem_ranges_check(tensor)` at `ggml-metal-common.cpp:175-185`:
275    /// each src is checked against existing ranges, then the dst is
276    /// checked against existing ranges.
277    ///
278    /// Increments [`Self::checks`].  On `false` return, also
279    /// increments [`Self::barriers_forced`] — so the diagnostic
280    /// counter is accurate even when callers ignore the return value.
281    pub fn check_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) -> bool {
282        self.checks += 1;
283        for r in reads {
284            let candidate = BufferRange::from_buffer(r, MemRangeRole::Src);
285            for existing in &self.ranges {
286                if candidate.conflicts_with(existing) {
287                    self.barriers_forced += 1;
288                    return false;
289                }
290            }
291        }
292        for w in writes {
293            let candidate = BufferRange::from_buffer(w, MemRangeRole::Dst);
294            for existing in &self.ranges {
295                if candidate.conflicts_with(existing) {
296                    self.barriers_forced += 1;
297                    return false;
298                }
299            }
300        }
301        true
302    }
303
304    /// Combined check + add.  Returns `true` if the dispatch was added
305    /// concurrent (no conflict, no barrier needed); returns `false`
306    /// if the caller must emit a barrier and `reset()` before adding
307    /// the dispatch's ranges.
308    ///
309    /// On `false` return, the caller's responsibility is:
310    /// 1. Emit the underlying `memoryBarrierWithScope:` on the live
311    ///    encoder.
312    /// 2. Call [`Self::reset`].
313    /// 3. Call [`Self::add_dispatch`] with the same `reads`/`writes`
314    ///    to seed the new concurrent group.
315    ///
316    /// This mirrors the call pattern at `ggml-metal-ops.cpp:220-225`.
317    pub fn check_and_record(
318        &mut self,
319        reads: &[&MlxBuffer],
320        writes: &[&MlxBuffer],
321    ) -> bool {
322        let ok = self.check_dispatch(reads, writes);
323        if ok {
324            self.add_dispatch(reads, writes);
325        }
326        // On !ok the caller will reset+add per the contract above.
327        ok
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    //! Unit tests for [`MemRanges`].
334    //!
335    //! These are pure-CPU tests that exercise the address arithmetic
336    //! and overlap-detection logic without touching Metal — they
337    //! construct `MlxBuffer`s via `MlxDevice::alloc_buffer`, which
338    //! does allocate real Metal buffers but does not require any GPU
339    //! commands to be encoded or executed.  Each test is bounded to a
340    //! handful of small allocations.
341    use super::*;
342    use crate::{DType, MlxDevice};
343
344    fn dev() -> MlxDevice {
345        MlxDevice::new().expect("MlxDevice::new failed")
346    }
347
348    /// Two reads of the same buffer must NOT conflict (RAR concurrent).
349    #[test]
350    fn read_read_same_buffer_no_conflict() {
351        let d = dev();
352        let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
353        let mut mr = MemRanges::new();
354        // First dispatch: read a, write nothing.
355        let ok1 = mr.check_and_record(&[&a], &[]);
356        assert!(ok1, "first dispatch always ok");
357        // Second dispatch: read a again — must be concurrent.
358        let ok2 = mr.check_and_record(&[&a], &[]);
359        assert!(ok2, "RAR same-buffer must not conflict");
360        assert_eq!(mr.barriers_forced(), 0);
361    }
362
363    /// Read-after-write same buffer MUST conflict (RAW barrier needed).
364    #[test]
365    fn raw_same_buffer_conflicts() {
366        let d = dev();
367        let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
368        let mut mr = MemRanges::new();
369        // First dispatch writes a.
370        assert!(mr.check_and_record(&[], &[&a]));
371        // Second dispatch reads a — must conflict.
372        let ok = mr.check_and_record(&[&a], &[]);
373        assert!(!ok, "RAW same-buffer must force barrier");
374        assert_eq!(mr.barriers_forced(), 1);
375    }
376
377    /// Write-after-read same buffer MUST conflict (WAR barrier needed).
378    #[test]
379    fn war_same_buffer_conflicts() {
380        let d = dev();
381        let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
382        let mut mr = MemRanges::new();
383        assert!(mr.check_and_record(&[&a], &[]));
384        let ok = mr.check_and_record(&[], &[&a]);
385        assert!(!ok, "WAR same-buffer must force barrier");
386        assert_eq!(mr.barriers_forced(), 1);
387    }
388
389    /// Write-after-write same buffer MUST conflict (WAW barrier needed).
390    #[test]
391    fn waw_same_buffer_conflicts() {
392        let d = dev();
393        let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
394        let mut mr = MemRanges::new();
395        assert!(mr.check_and_record(&[], &[&a]));
396        let ok = mr.check_and_record(&[], &[&a]);
397        assert!(!ok, "WAW same-buffer must force barrier");
398        assert_eq!(mr.barriers_forced(), 1);
399    }
400
401    /// Cross-buffer reads/writes never conflict regardless of role.
402    /// The candidate dispatch's ranges are checked only against
403    /// recorded ranges in the SAME buffer; ranges in disjoint
404    /// buffers are skipped early in `BufferRange::conflicts_with`.
405    #[test]
406    fn different_buffers_never_conflict() {
407        let d = dev();
408        let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
409        let b = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
410        let c = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
411        let mut mr = MemRanges::new();
412        // dispatch1: write a — records (a, Dst).
413        assert!(mr.check_and_record(&[], &[&a]));
414        // dispatch2: read+write b — disjoint from a, ok.
415        assert!(mr.check_and_record(&[&b], &[&b]));
416        // dispatch3: read c — disjoint from a and b, ok.  Critically,
417        // we do NOT read `a` here because that would be RAW against
418        // dispatch1's write — a real conflict, not a same-buffer
419        // false positive.
420        assert!(mr.check_and_record(&[&c], &[]));
421        assert_eq!(mr.barriers_forced(), 0);
422    }
423
424    /// Reset clears state and lets a previously-conflicting dispatch
425    /// be recorded.  Mirrors the post-barrier flow.
426    #[test]
427    fn reset_clears_state() {
428        let d = dev();
429        let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
430        let mut mr = MemRanges::new();
431        assert!(mr.check_and_record(&[], &[&a]));
432        // Would conflict with the recorded write…
433        assert!(!mr.check_and_record(&[&a], &[]));
434        // …unless we reset first (simulating a barrier emission).
435        mr.reset();
436        assert!(mr.check_and_record(&[&a], &[]));
437        // After reset, two reads in a row are still non-conflicting.
438        assert!(mr.check_and_record(&[&a], &[]));
439        assert_eq!(mr.barriers_forced(), 1);
440    }
441
442    /// Disjoint slices of the same parent: today the algorithm is
443    /// conservative (treats slice writes as touching the full
444    /// addressable extent of the parent), matching llama.cpp's
445    /// `alloc_size` upper bound.  This documents the behavior so
446    /// future iterations can tighten it intentionally.
447    #[test]
448    fn slices_of_same_parent_conservative() {
449        let d = dev();
450        // 256 floats; carve into two halves.
451        let parent = d.alloc_buffer(1024, DType::F32, vec![256]).unwrap();
452        let lo = parent.slice_view(0, 128);
453        let hi = parent.slice_view(512, 128);
454        let mut mr = MemRanges::new();
455        assert!(mr.check_and_record(&[], &[&lo]));
456        // hi is a disjoint half but conservatively conflicts because
457        // the lo write's recorded range covers
458        //   [parent + 0, parent + parent.byte_len()) and the hi range
459        //   starts at `parent + 512` which falls inside that window.
460        // The conservative answer is *correct* (a barrier is safe even
461        // if not necessary).  Tightening the slice arithmetic to use
462        // the slice's own extent only is a future iteration.
463        let ok = mr.check_and_record(&[], &[&hi]);
464        assert!(!ok, "slice WAW currently conservative — see docstring");
465    }
466
467    /// Sequential pattern: A=write x, B=read x, C=write y, D=read y.
468    /// Expect exactly 2 forced barriers (B vs A, D vs C).
469    #[test]
470    fn sequential_pattern_two_barriers() {
471        let d = dev();
472        let x = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
473        let y = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
474        let mut mr = MemRanges::new();
475        // A: write x.
476        assert!(mr.check_and_record(&[], &[&x]));
477        // B: read x — conflict.
478        assert!(!mr.check_dispatch(&[&x], &[]));
479        mr.reset();
480        mr.add_dispatch(&[&x], &[]);
481        // C: write y — different buffer, concurrent OK.
482        assert!(mr.check_and_record(&[], &[&y]));
483        // D: read y — conflict (against C's write).
484        assert!(!mr.check_dispatch(&[&y], &[]));
485        mr.reset();
486        mr.add_dispatch(&[&y], &[]);
487        assert_eq!(mr.barriers_forced(), 2);
488    }
489
490    /// `BufferRange::conflicts_with` is symmetric.
491    #[test]
492    fn conflict_is_symmetric() {
493        let d = dev();
494        let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
495        let r_src = BufferRange::from_buffer(&a, MemRangeRole::Src);
496        let r_dst = BufferRange::from_buffer(&a, MemRangeRole::Dst);
497        assert!(r_src.conflicts_with(&r_dst));
498        assert!(r_dst.conflicts_with(&r_src));
499    }
500}