mlx_native/mem_ranges.rs
1//! Dataflow-driven barrier inference (port of llama.cpp `mem_ranges`).
2//!
3//! ADR-015 iter37 — framework-side complement to iter21's hand-audited
4//! barrier fix at `gpu_full_attn.rs:1856`.
5//!
6//! # Purpose
7//!
8//! When a Metal `MTLComputeCommandEncoder` is created with
9//! `MTLDispatchTypeConcurrent` (mlx-native's default since iter8e —
10//! [`encoder::CommandEncoder::get_or_create_encoder`](crate::encoder)),
11//! every dispatch can execute in parallel with every other dispatch in
12//! the same encoder unless separated by a memory barrier. The encoder
13//! does not infer dataflow on its own — the caller must hand-place
14//! `[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]` between
15//! every read-after-write (RAW), write-after-read (WAR), or
16//! write-after-write (WAW) pair.
17//!
18//! Hand-audited barrier placement is correct but fragile: iter21 found
19//! one missing producer→consumer edge (`sigmoid_gate_multiply` →
20//! `linear_projection wo`) that had escaped review for months because
21//! the diverged-output bug it caused only surfaced under specific
22//! sequence-length × sample-count combinations. Any future kernel
23//! sequence built without rigorous review is subject to the same
24//! class of bug.
25//!
26//! `MemRanges` ports llama.cpp's mem_ranges algorithm
27//! (`/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp`) so
28//! callers describe each dispatch's read and write buffer regions and
29//! the framework auto-emits a barrier exactly when the new dispatch's
30//! ranges overlap a previously-recorded range. This makes
31//! iter21-class bugs structurally impossible at the framework boundary.
32//!
33//! # Algorithm
34//!
35//! Verbatim port of `ggml_mem_ranges_check` + `ggml_mem_ranges_add`
36//! (lines 124-185 of `ggml-metal-common.cpp`):
37//!
38//! * A range is `(buffer_id, p0, p1, role∈{Src,Dst})`.
39//! * Two ranges in different buffers can never conflict.
40//! * Two `Src` ranges in the same buffer never conflict (read-read OK).
41//! * A new `Src` overlapping an existing `Dst` is a RAW conflict.
42//! * A new `Dst` overlapping any existing range (Src or Dst) is a
43//! WAR/WAW conflict.
44//! * Overlap test: `new.p0 < existing.p1 && new.p1 >= existing.p0`
45//! (matches llama.cpp byte-for-byte at line 138).
46//! * On conflict, the caller emits a `memoryBarrier` and `reset()`s
47//! the cumulative state, then records the new dispatch's ranges.
48//!
49//! # mlx-native specifics
50//!
51//! llama.cpp keys ranges by `tensor->buffer` (the backend buffer
52//! handle) plus `tensor->data` (the element pointer inside that
53//! buffer). mlx-native uses
54//! [`MlxBuffer::metal_buffer`](crate::buffer::MlxBuffer::metal_buffer)
55//! as the `(MTLBuffer*) -> usize` buffer-id and
56//! [`MlxBuffer::contents_ptr`](crate::buffer::MlxBuffer::contents_ptr)
57//! as the start address. These are stable across the encoder lifetime
58//! because hf2q's per-decode-token `MlxBufferPool` keeps ARC clones
59//! alive for the entire CB. Different `slice_view`s of the same
60//! parent buffer share `metal_buffer()` (intentional: a write to
61//! `parent[0..N]` must barrier against a read of `parent[N..2N]`
62//! only when the two slices alias).
63//!
64//! # Why same-buffer-only
65//!
66//! Different `MTLBuffer`s never alias — Metal's address space is
67//! per-buffer. Skipping the overlap check on cross-buffer pairs is
68//! both correct and a major perf win: a typical decode token has
69//! ~1500 dispatches against ~30-50 distinct buffers, so the
70//! same-buffer filter keeps the per-dispatch check at O(N) over the
71//! short list of ranges in *one* buffer rather than O(N) over all
72//! ranges.
73//!
74//! # Per iter37 envelope: env-gated, opt-in
75//!
76//! `MemRanges` is dormant unless the caller explicitly threads it
77//! through a dispatch via [`CommandEncoder::dispatch_tracked`]. The
78//! existing `encode*`/`memory_barrier()` API is unchanged, so iter37
79//! ships with **zero behavioral diff in production** until callers
80//! migrate. Migration of the qwen35 forward path is iter38+ scope.
81
82use crate::buffer::MlxBuffer;
83use metal::foreign_types::ForeignType;
84
85/// Whether a recorded range was read by a dispatch (`Src`) or written
86/// by a dispatch (`Dst`). Mirrors `ggml_mem_range_type` in
87/// `ggml-metal-common.h:14-17`.
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89pub enum MemRangeRole {
90 /// Dispatch reads this range.
91 Src,
92 /// Dispatch writes this range.
93 Dst,
94}
95
96/// A buffer region recorded for dataflow tracking.
97///
98/// Mirrors `struct ggml_mem_range` in `ggml-metal-common.cpp:10-17`.
99#[derive(Clone, Copy, Debug)]
100pub struct BufferRange {
101 /// Backing `metal::Buffer` pointer cast to `usize`. Stable across
102 /// the encoder lifetime as long as the `MlxBuffer`'s ARC clone
103 /// outlives the CB (see `CommandEncoder::new_with_residency`
104 /// caller contract).
105 pub buf_id: usize,
106 /// Start byte address (`contents_ptr() + byte_offset` for
107 /// `MlxBuffer`). Used for overlap arithmetic.
108 pub p0: u64,
109 /// End byte address (start + element-extent). llama.cpp uses
110 /// `tensor->data + ggml_backend_buft_get_alloc_size(tensor)` —
111 /// for mlx-native we use the buffer's `byte_len()` minus
112 /// `byte_offset()`, which equals the slice extent.
113 pub p1: u64,
114 /// Whether this range is read or written by the recording dispatch.
115 pub role: MemRangeRole,
116}
117
118impl BufferRange {
119 /// Build a [`BufferRange`] from an [`MlxBuffer`] and a role.
120 ///
121 /// Uses `metal_buffer().as_ptr() as usize` as the buffer-id (so two
122 /// `slice_view`s of the same parent share a `buf_id`, which is
123 /// the intended behavior — a slice write must barrier against a
124 /// sibling-slice read of the same parent).
125 ///
126 /// The `(p0, p1)` range covers the addressable extent the kernel
127 /// can reach: `[contents_ptr + byte_offset,
128 /// contents_ptr + byte_offset + (byte_len - byte_offset))`.
129 /// For non-slice buffers `byte_offset == 0` and the range covers
130 /// the full allocation. For slices the range covers only the
131 /// slice region — matching llama.cpp's `tensor->data ..
132 /// tensor->data + alloc_size`.
133 #[inline]
134 pub fn from_buffer(buf: &MlxBuffer, role: MemRangeRole) -> Self {
135 let buf_id = buf.metal_buffer().as_ptr() as usize;
136 // `contents_ptr` already points at the buffer's base; mlx-native
137 // applies `byte_offset` only at bind-site (`set_buffer`). The
138 // overlap arithmetic must use the slice's *kernel-visible*
139 // address window, so we add the offset explicitly here.
140 let base = buf.contents_ptr() as u64;
141 let p0 = base + buf.byte_offset();
142 // `byte_len()` returns the underlying allocation length, so
143 // the slice extent is `(allocation_len - offset)`.
144 let extent = (buf.byte_len() as u64).saturating_sub(buf.byte_offset());
145 let p1 = p0 + extent;
146 Self {
147 buf_id,
148 p0,
149 p1,
150 role,
151 }
152 }
153
154 /// Whether `self` and `other` overlap by the same arithmetic
155 /// llama.cpp uses at `ggml-metal-common.cpp:138`.
156 ///
157 /// Returns `false` for cross-buffer pairs (different `buf_id`) and
158 /// for src-vs-src pairs (read-read is always concurrent-safe).
159 #[inline]
160 pub fn conflicts_with(&self, other: &BufferRange) -> bool {
161 if self.buf_id != other.buf_id {
162 return false;
163 }
164 if self.role == MemRangeRole::Src && other.role == MemRangeRole::Src {
165 return false;
166 }
167 // Llama.cpp: `mr.p0 < cmp.p1 && mr.p1 >= cmp.p0`
168 self.p0 < other.p1 && self.p1 >= other.p0
169 }
170}
171
172/// Cumulative dataflow state for a sequence of concurrent dispatches.
173///
174/// Direct port of `struct ggml_mem_ranges` in
175/// `ggml-metal-common.cpp:19-23`. The state is reset every time a
176/// barrier is emitted; between barriers, all recorded dispatches are
177/// considered to run concurrently and their R/W ranges accumulate.
178pub struct MemRanges {
179 ranges: Vec<BufferRange>,
180 /// Total checks performed (diagnostic).
181 checks: u64,
182 /// Number of `check()` calls that returned `false` (i.e. forced a
183 /// barrier). `total_dispatches - barriers_forced` == elided
184 /// barriers (would-have-been-emitted by an unconditional pattern).
185 barriers_forced: u64,
186}
187
188impl Default for MemRanges {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194impl MemRanges {
195 /// New empty state. Pre-allocates capacity matching llama.cpp's
196 /// `reserve(256)` (line 28).
197 pub fn new() -> Self {
198 Self {
199 ranges: Vec::with_capacity(256),
200 checks: 0,
201 barriers_forced: 0,
202 }
203 }
204
205 /// Drop all recorded ranges (called after emitting a barrier).
206 /// Mirrors `ggml_mem_ranges_reset`.
207 #[inline]
208 pub fn reset(&mut self) {
209 self.ranges.clear();
210 }
211
212 /// Number of currently-recorded ranges (diagnostic).
213 #[inline]
214 pub fn len(&self) -> usize {
215 self.ranges.len()
216 }
217
218 /// Whether the cumulative state is empty.
219 #[inline]
220 pub fn is_empty(&self) -> bool {
221 self.ranges.is_empty()
222 }
223
224 /// Number of `check()` calls performed since construction
225 /// (diagnostic, monotone).
226 #[inline]
227 pub fn checks(&self) -> u64 {
228 self.checks
229 }
230
231 /// Number of `check()` calls that returned `false`, forcing a
232 /// barrier (diagnostic, monotone). When tracking is enabled at
233 /// every dispatch, `total_dispatches - barriers_forced` ==
234 /// barriers elided versus the unconditional-barrier baseline.
235 #[inline]
236 pub fn barriers_forced(&self) -> u64 {
237 self.barriers_forced
238 }
239
240 /// Push a single range onto the cumulative state without checking.
241 /// Used internally by [`Self::add`] and [`Self::add_dispatch`].
242 /// Public so unit tests can construct adversarial states.
243 #[inline]
244 pub fn push(&mut self, range: BufferRange) {
245 self.ranges.push(range);
246 }
247
248 /// Record a dispatch's read-buffer ranges + write-buffer ranges.
249 ///
250 /// Mirrors `ggml_mem_ranges_add(tensor)` at
251 /// `ggml-metal-common.cpp:114-122`: pushes one Src range per
252 /// `tensor->src[i]` and one Dst range for `tensor` itself.
253 ///
254 /// Caller is expected to have already invoked
255 /// [`Self::check_dispatch`] and emitted a barrier on conflict; the
256 /// barrier-emit + `reset()` is the responsibility of the
257 /// integration site (typically `CommandEncoder`).
258 pub fn add_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
259 for r in reads {
260 self.ranges
261 .push(BufferRange::from_buffer(r, MemRangeRole::Src));
262 }
263 for w in writes {
264 self.ranges
265 .push(BufferRange::from_buffer(w, MemRangeRole::Dst));
266 }
267 }
268
269 /// Check whether a candidate dispatch can run concurrently with
270 /// the recorded state.
271 ///
272 /// Returns `true` iff none of the candidate's reads or writes
273 /// conflict with any recorded range. Exactly mirrors
274 /// `ggml_mem_ranges_check(tensor)` at `ggml-metal-common.cpp:175-185`:
275 /// each src is checked against existing ranges, then the dst is
276 /// checked against existing ranges.
277 ///
278 /// Increments [`Self::checks`]. On `false` return, also
279 /// increments [`Self::barriers_forced`] — so the diagnostic
280 /// counter is accurate even when callers ignore the return value.
281 pub fn check_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) -> bool {
282 self.checks += 1;
283 for r in reads {
284 let candidate = BufferRange::from_buffer(r, MemRangeRole::Src);
285 for existing in &self.ranges {
286 if candidate.conflicts_with(existing) {
287 self.barriers_forced += 1;
288 return false;
289 }
290 }
291 }
292 for w in writes {
293 let candidate = BufferRange::from_buffer(w, MemRangeRole::Dst);
294 for existing in &self.ranges {
295 if candidate.conflicts_with(existing) {
296 self.barriers_forced += 1;
297 return false;
298 }
299 }
300 }
301 true
302 }
303
304 /// Combined check + add. Returns `true` if the dispatch was added
305 /// concurrent (no conflict, no barrier needed); returns `false`
306 /// if the caller must emit a barrier and `reset()` before adding
307 /// the dispatch's ranges.
308 ///
309 /// On `false` return, the caller's responsibility is:
310 /// 1. Emit the underlying `memoryBarrierWithScope:` on the live
311 /// encoder.
312 /// 2. Call [`Self::reset`].
313 /// 3. Call [`Self::add_dispatch`] with the same `reads`/`writes`
314 /// to seed the new concurrent group.
315 ///
316 /// This mirrors the call pattern at `ggml-metal-ops.cpp:220-225`.
317 pub fn check_and_record(
318 &mut self,
319 reads: &[&MlxBuffer],
320 writes: &[&MlxBuffer],
321 ) -> bool {
322 let ok = self.check_dispatch(reads, writes);
323 if ok {
324 self.add_dispatch(reads, writes);
325 }
326 // On !ok the caller will reset+add per the contract above.
327 ok
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 //! Unit tests for [`MemRanges`].
334 //!
335 //! These are pure-CPU tests that exercise the address arithmetic
336 //! and overlap-detection logic without touching Metal — they
337 //! construct `MlxBuffer`s via `MlxDevice::alloc_buffer`, which
338 //! does allocate real Metal buffers but does not require any GPU
339 //! commands to be encoded or executed. Each test is bounded to a
340 //! handful of small allocations.
341 use super::*;
342 use crate::{DType, MlxDevice};
343
344 fn dev() -> MlxDevice {
345 MlxDevice::new().expect("MlxDevice::new failed")
346 }
347
348 /// Two reads of the same buffer must NOT conflict (RAR concurrent).
349 #[test]
350 fn read_read_same_buffer_no_conflict() {
351 let d = dev();
352 let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
353 let mut mr = MemRanges::new();
354 // First dispatch: read a, write nothing.
355 let ok1 = mr.check_and_record(&[&a], &[]);
356 assert!(ok1, "first dispatch always ok");
357 // Second dispatch: read a again — must be concurrent.
358 let ok2 = mr.check_and_record(&[&a], &[]);
359 assert!(ok2, "RAR same-buffer must not conflict");
360 assert_eq!(mr.barriers_forced(), 0);
361 }
362
363 /// Read-after-write same buffer MUST conflict (RAW barrier needed).
364 #[test]
365 fn raw_same_buffer_conflicts() {
366 let d = dev();
367 let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
368 let mut mr = MemRanges::new();
369 // First dispatch writes a.
370 assert!(mr.check_and_record(&[], &[&a]));
371 // Second dispatch reads a — must conflict.
372 let ok = mr.check_and_record(&[&a], &[]);
373 assert!(!ok, "RAW same-buffer must force barrier");
374 assert_eq!(mr.barriers_forced(), 1);
375 }
376
377 /// Write-after-read same buffer MUST conflict (WAR barrier needed).
378 #[test]
379 fn war_same_buffer_conflicts() {
380 let d = dev();
381 let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
382 let mut mr = MemRanges::new();
383 assert!(mr.check_and_record(&[&a], &[]));
384 let ok = mr.check_and_record(&[], &[&a]);
385 assert!(!ok, "WAR same-buffer must force barrier");
386 assert_eq!(mr.barriers_forced(), 1);
387 }
388
389 /// Write-after-write same buffer MUST conflict (WAW barrier needed).
390 #[test]
391 fn waw_same_buffer_conflicts() {
392 let d = dev();
393 let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
394 let mut mr = MemRanges::new();
395 assert!(mr.check_and_record(&[], &[&a]));
396 let ok = mr.check_and_record(&[], &[&a]);
397 assert!(!ok, "WAW same-buffer must force barrier");
398 assert_eq!(mr.barriers_forced(), 1);
399 }
400
401 /// Cross-buffer reads/writes never conflict regardless of role.
402 /// The candidate dispatch's ranges are checked only against
403 /// recorded ranges in the SAME buffer; ranges in disjoint
404 /// buffers are skipped early in `BufferRange::conflicts_with`.
405 #[test]
406 fn different_buffers_never_conflict() {
407 let d = dev();
408 let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
409 let b = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
410 let c = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
411 let mut mr = MemRanges::new();
412 // dispatch1: write a — records (a, Dst).
413 assert!(mr.check_and_record(&[], &[&a]));
414 // dispatch2: read+write b — disjoint from a, ok.
415 assert!(mr.check_and_record(&[&b], &[&b]));
416 // dispatch3: read c — disjoint from a and b, ok. Critically,
417 // we do NOT read `a` here because that would be RAW against
418 // dispatch1's write — a real conflict, not a same-buffer
419 // false positive.
420 assert!(mr.check_and_record(&[&c], &[]));
421 assert_eq!(mr.barriers_forced(), 0);
422 }
423
424 /// Reset clears state and lets a previously-conflicting dispatch
425 /// be recorded. Mirrors the post-barrier flow.
426 #[test]
427 fn reset_clears_state() {
428 let d = dev();
429 let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
430 let mut mr = MemRanges::new();
431 assert!(mr.check_and_record(&[], &[&a]));
432 // Would conflict with the recorded write…
433 assert!(!mr.check_and_record(&[&a], &[]));
434 // …unless we reset first (simulating a barrier emission).
435 mr.reset();
436 assert!(mr.check_and_record(&[&a], &[]));
437 // After reset, two reads in a row are still non-conflicting.
438 assert!(mr.check_and_record(&[&a], &[]));
439 assert_eq!(mr.barriers_forced(), 1);
440 }
441
442 /// Disjoint slices of the same parent: today the algorithm is
443 /// conservative (treats slice writes as touching the full
444 /// addressable extent of the parent), matching llama.cpp's
445 /// `alloc_size` upper bound. This documents the behavior so
446 /// future iterations can tighten it intentionally.
447 #[test]
448 fn slices_of_same_parent_conservative() {
449 let d = dev();
450 // 256 floats; carve into two halves.
451 let parent = d.alloc_buffer(1024, DType::F32, vec![256]).unwrap();
452 let lo = parent.slice_view(0, 128);
453 let hi = parent.slice_view(512, 128);
454 let mut mr = MemRanges::new();
455 assert!(mr.check_and_record(&[], &[&lo]));
456 // hi is a disjoint half but conservatively conflicts because
457 // the lo write's recorded range covers
458 // [parent + 0, parent + parent.byte_len()) and the hi range
459 // starts at `parent + 512` which falls inside that window.
460 // The conservative answer is *correct* (a barrier is safe even
461 // if not necessary). Tightening the slice arithmetic to use
462 // the slice's own extent only is a future iteration.
463 let ok = mr.check_and_record(&[], &[&hi]);
464 assert!(!ok, "slice WAW currently conservative — see docstring");
465 }
466
467 /// Sequential pattern: A=write x, B=read x, C=write y, D=read y.
468 /// Expect exactly 2 forced barriers (B vs A, D vs C).
469 #[test]
470 fn sequential_pattern_two_barriers() {
471 let d = dev();
472 let x = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
473 let y = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
474 let mut mr = MemRanges::new();
475 // A: write x.
476 assert!(mr.check_and_record(&[], &[&x]));
477 // B: read x — conflict.
478 assert!(!mr.check_dispatch(&[&x], &[]));
479 mr.reset();
480 mr.add_dispatch(&[&x], &[]);
481 // C: write y — different buffer, concurrent OK.
482 assert!(mr.check_and_record(&[], &[&y]));
483 // D: read y — conflict (against C's write).
484 assert!(!mr.check_dispatch(&[&y], &[]));
485 mr.reset();
486 mr.add_dispatch(&[&y], &[]);
487 assert_eq!(mr.barriers_forced(), 2);
488 }
489
490 /// `BufferRange::conflicts_with` is symmetric.
491 #[test]
492 fn conflict_is_symmetric() {
493 let d = dev();
494 let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
495 let r_src = BufferRange::from_buffer(&a, MemRangeRole::Src);
496 let r_dst = BufferRange::from_buffer(&a, MemRangeRole::Dst);
497 assert!(r_src.conflicts_with(&r_dst));
498 assert!(r_dst.conflicts_with(&r_src));
499 }
500}