Skip to main content

mlx_native/
encoder_session.rs

1//! [`EncoderSession`] — D3 Per-Stage Fence encoder abstraction (ADR-019 Phase 0b).
2//!
3//! `EncoderSession` lifts [`CommandEncoder`] into a session-aware shell that
4//! carries semantic *stage* metadata across the lifetime of one or more
5//! logical transformer stages (e.g. `"layer.full_attn.stage1"` →
6//! `"layer.full_attn.stage2"`). Phase 0b-A delivered the bare struct +
7//! single-stage lifecycle methods. Phase 0b-B (this file) adds the
8//! [`MTLSharedEvent`](metal::SharedEvent) inter-CB ordering primitives D3
9//! needs:
10//!
11//! - [`Self::fence_stage`] — encode signal-event(N+1) on the current CB,
12//!   commit non-blocking, increment the per-session monotonic counter.
13//! - [`Self::reset_for_next_stage`] — open a fresh CB on the same queue
14//!   and (when a fence is active) encode wait-event(N) on the new CB so
15//!   its GPU work blocks until the prior fenced CB completes.
16//! - [`Self::add_to_residency_set`] / [`Self::remove_from_residency_set`]
17//!   — public delegation surface for the single residency set owned by
18//!   [`MlxDevice`](crate::MlxDevice). EncoderSession does NOT own a
19//!   separate set; it routes calls through the Arc clone the inner
20//!   [`CommandEncoder`] already holds.
21//!
22//! Phase 0b-C will broaden label propagation (per-substage labels +
23//! xctrace MST round-trip). iter89e2-B leaves Phase 0b-A's existing
24//! per-session label semantics intact.
25//!
26//! Production callers MUST stay on [`crate::CommandEncoder`] until Phase
27//! 2 (FA-path D3 stage migration) wires `forward_gpu.rs` to consume
28//! `EncoderSession`. The struct is feature-flagged behind
29//! `HF2Q_ENCODER_SESSION=1` (default OFF) and is constructed only via
30//! [`MlxDevice::encoder_session`](crate::MlxDevice::encoder_session).
31//!
32//! # Lifecycle (iter89e2-B — multi-stage chaining)
33//!
34//! ```text
35//!                   MlxDevice::encoder_session()
36//!                                |
37//!                                v
38//!                          +-----------+
39//!                          | Empty     |  no CB or encoder open yet
40//!                          +-----------+
41//!                                |
42//!                          first dispatch
43//!                          (via inner CommandEncoder)
44//!                                |
45//!                                v
46//!                          +-----------+
47//!                          | Encoding  |  CB open, persistent compute encoder open
48//!                          +-----------+
49//!                            |       |
50//!                            |       +---fence_stage(label)----+
51//!                            |                                 |
52//!                            |                                 v
53//!                            |                          +-----------+
54//!                            |                          | Fenced    |  signal encoded; CB submitted
55//!                            |                          +-----------+
56//!                            |                                 |
57//!                  commit_stage()                  reset_for_next_stage()
58//!                  commit_and_wait()                            |
59//!                            |                                 v
60//!                            v                          (loop back to Encoding
61//!                      +-----------+                    on next dispatch — wait
62//!                      | Drained   |                    is encoded automatically
63//!                      +-----------+                    on the new CB)
64//!                            |
65//!                          Drop
66//! ```
67//!
68//! `fence_stage` collapses the design-doc's separate Encoding→Fenced→
69//! Committed transitions into a single "submit-with-fence" call: the
70//! signal is encoded on the current CB at `event_value+1`, the encoder
71//! is ended, the CB is committed non-blocking, and the per-session
72//! monotonic counter is incremented. The session is `drained` until
73//! [`Self::reset_for_next_stage`] rotates the inner [`CommandEncoder`]'s
74//! command buffer to a fresh CB on the same queue and (when the event is
75//! present) encodes the matching wait at `event_value` on the new CB.
76//!
77//! # Risk register fence preservation (F1-F12 from ADR-019)
78//!
79//! - **F1 — persistent compute encoder per CB**: ADOPTED unchanged.
80//!   `EncoderSession` borrows `&mut CommandEncoder` via [`Self::encoder`];
81//!   every dispatch reuses the same lazy-opened encoder per CB. Each
82//!   stage CB still has exactly one persistent compute encoder.
83//! - **F2 — iter58b residency-rescission**: PRESERVED. `commit_stage`,
84//!   `commit_and_wait`, and `fence_stage` all delegate to the inner
85//!   encoder, which calls `flush_residency_pending()` at every commit
86//!   boundary (`encoder.rs:1842, 2004`). `reset_for_next_stage` does NOT
87//!   re-flush — staged add/remove operations between stages flush at
88//!   the next commit on the new CB. The single residency set is owned
89//!   by [`MlxDevice`](crate::MlxDevice) (single-set invariant per
90//!   ADR-019:467). Multi-stage chaining DOES widen the in-flight CB
91//!   window — dropping a buffer between stage 1's `fence_stage` and
92//!   stage 2's `commit_*` stages a remove-allocation that flushes at
93//!   stage 2's commit, while stage 1's CB may still be GPU-pipelined.
94//!   Under retained-refs (default), the prior CB's ARC retains keep the
95//!   underlying Metal buffer alive across the residency-set demotion;
96//!   the GPU completes safely. Under `MLX_UNRETAINED_REFS=1` (NOT
97//!   enabled in Phase 0b), caller-owned arenas remain the only
98//!   structural mitigation — same contract as the existing async-commit
99//!   path. The adversarial F2 test (see
100//!   `/opt/mlx-native/tests/encoder_session_multistage.rs`) explicitly
101//!   exercises this window.
102//! - **F11 — zero-init alloc_buffer**: INVARIANT. `EncoderSession` does
103//!   not allocate buffers; the zero-init contract on
104//!   `MlxDevice::alloc_buffer` is unchanged.
105//! - **F12 — `HF2Q_FORCE_SERIAL_DISPATCH` falsification probe**: PRESERVED.
106//!   The probe lives in `CommandEncoder::get_or_create_encoder` and is
107//!   re-read every time a fresh CB lazily opens its compute encoder
108//!   (every `reset_for_next_stage` rotation). Both pre- and post-fence
109//!   CBs honor the env var.
110//! - F3, F4, F5, F6, F7, F8, F9, F10 are out of scope for iter89e2-B
111//!   (forward-path phases 1-4 territory) — `EncoderSession` is purely
112//!   structural and does not touch any forward path.
113//!
114//! # Feature gate
115//!
116//! [`MlxDevice::encoder_session`] returns `Ok(None)` when the
117//! `HF2Q_ENCODER_SESSION` env var is unset (default). When set to `"1"`
118//! it returns `Ok(Some(EncoderSession))`. Production code paths in hf2q
119//! consume `device.command_encoder()` (returns plain `CommandEncoder`) so
120//! the gate is a no-op in default builds — zero behavior change.
121
122use crate::buffer::MlxBuffer;
123use crate::encoder::CommandEncoder;
124use crate::error::Result;
125use crate::residency::ResidencySet;
126
127/// Cached `HF2Q_ENCODER_SESSION` decision.
128///
129/// Identical pattern to `auto_barrier_enabled` / `unretained_refs_enabled`
130/// in `encoder.rs` — `OnceLock` so the env-read happens exactly once per
131/// process, and the per-call cost is a single atomic load. Declared at
132/// module scope so the gate is observable from both
133/// [`EncoderSession::env_enabled`] (the public introspection helper) and
134/// [`MlxDevice::encoder_session`] (the factory site).
135fn encoder_session_enabled() -> bool {
136    use std::sync::OnceLock;
137    static FLAG: OnceLock<bool> = OnceLock::new();
138    *FLAG.get_or_init(|| {
139        std::env::var("HF2Q_ENCODER_SESSION")
140            .map(|v| v == "1")
141            .unwrap_or(false)
142    })
143}
144
145/// Session-level wrapper around a [`CommandEncoder`] for one or more
146/// logical transformer stages.
147///
148/// See module docs for lifecycle and fence preservation. iter89e2-B scope:
149/// multi-stage chaining via [`MTLSharedEvent`](metal::SharedEvent), residency
150/// delegation surface, and the matching test cohort. Phase 0b-C will
151/// broaden label propagation; Phase 2+ will wire this struct into the
152/// production forward path.
153///
154/// # Thread safety
155///
156/// `EncoderSession` is `Send` because [`CommandEncoder`] is `Send` (the
157/// existing unsafe impl at `encoder.rs:613-619`), `String`/`u64`/`bool`
158/// are `Send`, [`metal::Device`] is `Send + Sync` (foreign_obj_type! at
159/// metal-rs 0.33 lib.rs:179), and [`metal::SharedEvent`] is `Send + Sync`
160/// for the same reason. It is NOT `Sync` — exclusive ownership during
161/// dispatch encoding is the same contract as the inner [`CommandEncoder`].
162pub struct EncoderSession {
163    /// Inner command encoder. Carries `cmd_buf`, the persistent
164    /// `active_encoder`, the `queue` clone (read by
165    /// [`CommandEncoder::reset_command_buffer`]), the residency-set
166    /// flush hook, capture-mode IR, the auto-barrier `MemRanges`
167    /// tracker, the iter63 sample buffer, and the iter16 `last_label`
168    /// history. All dispatch operations flow through here.
169    ///
170    /// INVARIANT: `inner` is in a consistent state at every public API
171    /// boundary. Drops cleanly via `CommandEncoder::Drop` which calls
172    /// `end_active_encoder()` (Metal-asserts on a CB dropped with an
173    /// unended encoder).
174    inner: CommandEncoder,
175
176    /// Owned clone of the originating [`metal::Device`].
177    ///
178    /// iter89e2-B: held so [`Self::fence_stage`] can lazily allocate an
179    /// [`metal::SharedEvent`] on first call without threading a
180    /// `&MlxDevice` through every call site. metal-rs 0.33's `Device`
181    /// is `Send + Sync` (foreign_obj_type! lib.rs:179), so adding this
182    /// field preserves the existing unsafe `Send` impl on
183    /// [`EncoderSession`] declared below.
184    device: metal::Device,
185
186    /// Lazily-allocated [`MTLSharedEvent`](metal::SharedEvent) backing
187    /// the per-session monotonic stage fence.
188    ///
189    /// `None` until the first [`Self::fence_stage`] call. Once
190    /// allocated, the same event is reused across every fence in this
191    /// session — the value half of the (event, value) pair carries the
192    /// monotonic identity. Cost is one ObjC alloc + autorelease per
193    /// session lifetime; subsequent fences reuse the same event.
194    event: Option<metal::SharedEvent>,
195
196    /// Per-session monotonic fence counter.
197    ///
198    /// Mirrors `ggml_metal_event::value` at
199    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m:941`.
200    /// [`Self::fence_stage`] post-increments (signal = current+1, then
201    /// store current+1); [`Self::reset_for_next_stage`] reads (wait =
202    /// current). Starts at 0; bumps to 1 on first fence; CB N waits on
203    /// value N to gate after CB N's signal lands.
204    event_value: u64,
205
206    /// Human-readable stage label for xctrace MST attribution.
207    ///
208    /// Set by [`Self::begin_stage`] and by the `Some` arm of
209    /// [`Self::fence_stage`]'s `label` parameter. Empty by default.
210    /// When non-empty, [`Self::commit_stage`], [`Self::commit_and_wait`],
211    /// and [`Self::fence_stage`] all delegate to the inner encoder's
212    /// `commit_labeled` / `commit_and_wait_labeled` path, which
213    /// propagates the label to `MTLCommandBuffer.label` and
214    /// `MTLComputeCommandEncoder.label` via `apply_labels` at
215    /// `encoder.rs:1968-1986`.
216    ///
217    /// Cleared by [`Self::reset_for_next_stage`] so each chained stage
218    /// starts with a fresh label slot — the caller calls `begin_stage`
219    /// (or passes `Some(label)` to the next `fence_stage`) per stage.
220    stage_label: String,
221
222    /// Latch flipped to `true` after a `commit_stage` / `commit_and_wait`
223    /// / `fence_stage` call.
224    ///
225    /// Used to enforce the one-CB-per-state contract: a `EncoderSession`
226    /// in the `Drained` (or `Fenced`) state must call
227    /// [`Self::reset_for_next_stage`] before further dispatches encode
228    /// onto a new CB. Calling `commit_*` twice without an intervening
229    /// reset is a logic error — we surface it as a no-op rather than a
230    /// panic so the session remains drop-safe.
231    drained: bool,
232
233    /// Whether the most recent commit was a [`Self::fence_stage`] call.
234    ///
235    /// When `true`, [`Self::reset_for_next_stage`] encodes an
236    /// `encodeWaitForEvent` on the new CB at `event_value`. Cleared by
237    /// `reset_for_next_stage` so a subsequent `commit_stage` (no fence)
238    /// does not spuriously emit a wait on the next reset.
239    fence_pending: bool,
240
241    /// Per-session count of `encodeWaitForEvent` calls actually emitted
242    /// inside [`Self::reset_for_next_stage`].
243    ///
244    /// Symmetric counterpart to `event_value` (the signal-side high-water
245    /// mark) — `wait_count` is the wait-side scoreboard. Bumped exactly
246    /// once each time `reset_for_next_stage` finds `fence_pending == true`
247    /// and routes through `inner.encode_wait_for_event`. Read-only via
248    /// [`Self::wait_count`]; never mutated by control flow (introspection
249    /// only — does NOT widen F1/F2/F11/F12 windows).
250    ///
251    /// iter90b §2 H1b proof: the multi-stage chain test asserts this
252    /// equals `(num_stages - 1)` for an N-stage chain (one wait per
253    /// reset; the first stage's CB never had a prior signal to wait on).
254    wait_count: u64,
255
256    /// Value of the most recent `encodeWaitForEvent` actually emitted
257    /// inside [`Self::reset_for_next_stage`].
258    ///
259    /// Mirrors the relationship between `event_value` (signal-side) and
260    /// the value passed to `inner.encode_wait_for_event`. Starts at 0
261    /// (no wait yet emitted); each successful wait sets this to the
262    /// `value` argument. Read-only via [`Self::wait_value`]; pure
263    /// introspection (does NOT widen F1/F2/F11/F12 windows).
264    ///
265    /// iter90b §2 H1b proof: after a `fence_stage(N)` followed by
266    /// `reset_for_next_stage()`, this MUST equal N (the wait-side
267    /// matches the signal we just signaled).
268    last_wait_value: u64,
269}
270
271// SAFETY: `EncoderSession` is `Send` provided that:
272// 1. `CommandEncoder` is `Send` (existing unsafe impl at encoder.rs:606,
273//    Apple documents that command buffers / encoders may be encoded
274//    from any thread provided exclusive ownership).
275// 2. `metal::Device` is `Send + Sync` via foreign_obj_type!
276//    (metal-0.33.0/src/lib.rs:179).
277// 3. `metal::SharedEvent` is `Send + Sync` via foreign_obj_type!
278//    (same site — the macro emits `unsafe type ...: Sync + Send`
279//    for every type, including SharedEvent in sync.rs:36-40).
280// 4. `String`, `u64`, `bool` are `Send`.
281// All five hold. `EncoderSession` does NOT add any non-Send fields in
282// iter89e2-B beyond `metal::Device` + `Option<metal::SharedEvent>` +
283// `u64` + `bool`, all already validated.
284unsafe impl Send for EncoderSession {}
285
286impl EncoderSession {
287    /// Construct a new session over a fresh `CommandEncoder`.
288    ///
289    /// Returns `Err` if the underlying `CommandEncoder::new_with_residency`
290    /// fails (currently impossible past metal-rs 0.33's
291    /// `new_command_buffer`, but the `Result` is preserved for
292    /// future-proofing against driver-side allocation failures).
293    ///
294    /// # Crate-internal
295    ///
296    /// `pub(crate)` because the public construction surface is
297    /// [`MlxDevice::encoder_session`](crate::MlxDevice::encoder_session),
298    /// which threads the env-gate. Direct construction from outside
299    /// `mlx-native` would bypass the `HF2Q_ENCODER_SESSION` flag, which
300    /// is the wrong layering.
301    pub(crate) fn new(
302        device: &metal::DeviceRef,
303        queue: &metal::CommandQueue,
304        residency_set: Option<ResidencySet>,
305    ) -> Result<Self> {
306        Ok(Self {
307            inner: CommandEncoder::new_with_residency(queue, residency_set)?,
308            device: device.to_owned(),
309            event: None,
310            event_value: 0,
311            stage_label: String::new(),
312            drained: false,
313            fence_pending: false,
314            wait_count: 0,
315            last_wait_value: 0,
316        })
317    }
318
319    /// Whether `HF2Q_ENCODER_SESSION=1` is set in the process environment.
320    ///
321    /// Public introspection helper for hf2q-side dispatch wrappers that
322    /// need to choose between the legacy `command_encoder()` path and the
323    /// new `encoder_session()` path. Cached on first read via `OnceLock`
324    /// so the per-call cost is a single atomic load.
325    #[inline]
326    pub fn env_enabled() -> bool {
327        encoder_session_enabled()
328    }
329
330    /// Set the semantic stage label.
331    ///
332    /// The label propagates to `MTLCommandBuffer.label` and (when an
333    /// encoder is active) `MTLComputeCommandEncoder.label` at the next
334    /// `commit_stage` / `commit_and_wait` / `fence_stage` call, enabling
335    /// xctrace MST attribution per ADR-015 iter16. Calling `begin_stage`
336    /// does NOT itself touch any Metal object — it only stores the
337    /// string.
338    ///
339    /// Idempotent: calling `begin_stage` multiple times before commit
340    /// overwrites the previous label with the latest value, matching
341    /// the existing `apply_labels` semantic at `encoder.rs:1980-1985`
342    /// (the `last_label` field is overwritten on every labeled commit).
343    pub fn begin_stage(&mut self, label: &str) {
344        self.stage_label.clear();
345        self.stage_label.push_str(label);
346    }
347
348    /// Borrow the inner [`CommandEncoder`] for dispatch encoding.
349    ///
350    /// All dispatch APIs (`encode`, `encode_threadgroups`,
351    /// `encode_with_args`, `dispatch_tracked_*`, `memory_barrier`,
352    /// `start_capture` / `take_capture`, etc.) live on
353    /// [`CommandEncoder`]; `EncoderSession` adds a stage-aware commit
354    /// surface on top of them. Use this accessor inside the dispatch
355    /// loop, then call one of [`Self::commit_stage`] /
356    /// [`Self::commit_and_wait`] / [`Self::fence_stage`] at the stage
357    /// boundary.
358    ///
359    /// # Caller contract
360    ///
361    /// Do NOT call `inner.commit*` methods directly through this
362    /// borrow. Use the session's commit surface so the stage label
363    /// propagates and the drained-latch / fence state stay consistent.
364    /// Calling the inner commit bypasses these — it is not unsafe (no
365    /// UB risk) but it makes the session state inconsistent with what
366    /// it has actually committed.
367    #[inline]
368    pub fn encoder(&mut self) -> &mut CommandEncoder {
369        &mut self.inner
370    }
371
372    /// Commit the stage's command buffer non-blocking (no fence).
373    ///
374    /// Delegates to `CommandEncoder::commit_labeled` (when a label is
375    /// set) or `CommandEncoder::commit` (when not). Both end the
376    /// persistent compute encoder, flush the residency-set pending
377    /// staging (`flush_residency_pending` at `encoder.rs:2004`), and
378    /// hand the CB to the GPU without blocking the CPU.
379    ///
380    /// The session enters the `Drained` state. To chain into another
381    /// stage on the same session, call [`Self::reset_for_next_stage`]
382    /// — that opens a fresh CB and (if a fence was pending from a prior
383    /// `fence_stage`) encodes the matching wait. After
384    /// `commit_stage` (no fence), `reset_for_next_stage` does NOT emit
385    /// a wait — the CBs are merely sequenced by the Metal queue's FIFO
386    /// dispatch order.
387    ///
388    /// # Errors
389    ///
390    /// Returns `Ok(())` unconditionally — `CommandEncoder::commit` and
391    /// `CommandEncoder::commit_labeled` are infallible (they hand the
392    /// CB to Metal without waiting for completion; errors surface only
393    /// at `wait_until_completed`). The `Result` is preserved for
394    /// symmetry with [`Self::commit_and_wait`] and for future-proofing.
395    pub fn commit_stage(&mut self) -> Result<()> {
396        if self.drained {
397            return Ok(());
398        }
399        self.drained = true;
400        self.fence_pending = false;
401        if self.stage_label.is_empty() {
402            self.inner.commit();
403        } else {
404            // Take a snapshot of the label so we don't borrow `self`
405            // both immutably (for the label) and mutably (for inner)
406            // — clones a small String, fine for stage-boundary cost.
407            let label = self.stage_label.clone();
408            self.inner.commit_labeled(&label);
409        }
410        Ok(())
411    }
412
413    /// Commit the stage's command buffer and block until GPU completion.
414    ///
415    /// Delegates to `CommandEncoder::commit_and_wait_labeled` (when a
416    /// label is set) or `CommandEncoder::commit_and_wait` (when not).
417    /// Required at K-batch boundaries (F7) and at output-head CPU reads
418    /// (F6). Increments `SYNC_COUNT` exactly once per call (matches
419    /// `encoder.rs:1845`).
420    ///
421    /// The session enters the `Drained` state with NO fence pending —
422    /// blocking commit fully drains the GPU, so the next stage (after
423    /// [`Self::reset_for_next_stage`]) needs no wait-event.
424    ///
425    /// # Errors
426    ///
427    /// Returns `MlxError::CommandBufferError` if the GPU reports an
428    /// error after wait — propagated from `CommandEncoder`.
429    ///
430    /// ADR-015 iter94 Task #2 — fail-loud contract.  iter93 final-report
431    /// §"Root-cause hypothesis" point 5 noted that under
432    /// `MLX_UNRETAINED_REFS=1` + `HF2Q_ENCODER_SESSION=1` + `K>1`, the
433    /// session appeared to silently absorb a `MTLCommandBufferStatus::
434    /// Error` and produce deterministic-but-wrong tokens.  By code
435    /// reading, the tail-expression `self.inner.commit_and_wait()`
436    /// already returns the inner error (commit_and_wait at
437    /// encoder.rs:1852 explicitly matches on `cmd_buf.status()`).  This
438    /// re-shape converts the implicit propagation into an explicit `?`
439    /// chain so future maintainers cannot accidentally swallow the
440    /// error by inserting a `let _ = inner.commit_and_wait();` or
441    /// adding fall-through logic between the inner call and the
442    /// function return.  Latched `drained = true` happens BEFORE the
443    /// inner call so a panicking unwind through Drop sees the same
444    /// drained-state contract.
445    pub fn commit_and_wait(&mut self) -> Result<()> {
446        if self.drained {
447            return Ok(());
448        }
449        self.drained = true;
450        self.fence_pending = false;
451        let result = if self.stage_label.is_empty() {
452            self.inner.commit_and_wait()
453        } else {
454            let label = self.stage_label.clone();
455            self.inner.commit_and_wait_labeled(&label)
456        };
457        // Explicit `?`-style propagation: any `Err` from the inner
458        // commit_and_wait MUST surface to the caller.  This is the
459        // iter94 Task #2 fail-loud guarantee — silent absorption here
460        // would replicate the iter93 §"Root-cause hypothesis" point 5
461        // failure mode (deterministic-but-wrong outputs at the triple
462        // combo).  The extra `?` is a no-op codegen-wise vs the prior
463        // tail-expression form but documents intent and is unit-tested
464        // by `test_commit_and_wait_propagates_inner_cb_error`.
465        result?;
466        Ok(())
467    }
468
469    /// Encode a stage-fence signal on the current CB and commit non-blocking.
470    ///
471    /// This is the D3 multi-stage building block: the prior stage's
472    /// final CB-level op is `encodeSignalEvent:value:value+1`, where
473    /// `value+1` is then both stored in `event_value` (so the next
474    /// stage's `encodeWaitForEvent:value:` blocks on it) and committed.
475    /// The session enters the `Fenced` (drained-with-fence-pending)
476    /// state; [`Self::reset_for_next_stage`] rotates the inner CB and
477    /// emits the matching wait.
478    ///
479    /// # Lazy event allocation
480    ///
481    /// On the first call, allocates the per-session
482    /// [`MTLSharedEvent`](metal::SharedEvent) via
483    /// [`metal::DeviceRef::new_shared_event`]
484    /// (`/Users/robert/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metal-0.33.0/src/device.rs:2063`).
485    /// Subsequent calls reuse the same event — the monotonic
486    /// `event_value` carries the per-fence identity. This matches the
487    /// llama.cpp pattern at
488    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m:944-958`.
489    ///
490    /// # Label
491    ///
492    /// `label`'s `Some(value)` arm overwrites `stage_label` and
493    /// propagates via `commit_labeled`'s `apply_labels` chain — same as
494    /// calling [`Self::begin_stage`] before this. `None` keeps any
495    /// previously-set `begin_stage` label intact.
496    ///
497    /// # Counter semantics
498    ///
499    /// Bumps `SYNC_COUNT` zero times (non-blocking). Bumps
500    /// `CMD_BUF_COUNT` zero times (no new CB allocated here —
501    /// `reset_for_next_stage` does that). Increments `event_value` by
502    /// exactly 1.
503    ///
504    /// # Errors
505    ///
506    /// Returns `Ok(())` unconditionally for the same reason
507    /// [`Self::commit_stage`] does.
508    pub fn fence_stage(&mut self, label: Option<&str>) -> Result<()> {
509        if self.drained {
510            return Ok(());
511        }
512        // Apply the label argument before committing so commit_labeled
513        // (called below) propagates the latest value to the CB. Note
514        // that the encoder.rs:1968 apply_labels writes to the active
515        // compute encoder iff one is open — at this point one IS open
516        // (we have not yet ended it), so the encoder picks up the label
517        // before end_encoding fires. After end_encoding the CB still
518        // has its label set (set on the CB itself, not the encoder).
519        if let Some(l) = label {
520            self.stage_label.clear();
521            self.stage_label.push_str(l);
522        }
523
524        // Lazy-alloc the SharedEvent on first fence in this session.
525        // metal::DeviceRef::new_shared_event lives at
526        // /Users/robert/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metal-0.33.0/src/device.rs:2063.
527        if self.event.is_none() {
528            self.event = Some(self.device.new_shared_event());
529        }
530
531        // Sequence: end-active-encoder + encodeSignalEvent (CB-level) +
532        // residency-flush + cmd_buf.commit, all inside the inner
533        // helper. This preserves F1 (encoder is ended exactly once per
534        // CB), F2 (residency-flush still fires at the commit boundary),
535        // and matches llama.cpp's pattern at
536        // `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m:944-950`.
537        let new_value = self.event_value + 1;
538        let event_ref: &metal::SharedEventRef = self
539            .event
540            .as_ref()
541            .expect("event allocated immediately above this borrow")
542            .as_ref();
543        // Deref-coerce SharedEventRef -> EventRef via the
544        // ParentType = Event chain in metal-0.33.0/src/sync.rs:36-40.
545        let label_opt: Option<&str> = if self.stage_label.is_empty() {
546            None
547        } else {
548            Some(self.stage_label.as_str())
549        };
550        self.inner
551            .fence_signal_and_commit(event_ref, new_value, label_opt);
552
553        self.event_value = new_value;
554        self.drained = true;
555        self.fence_pending = true;
556        Ok(())
557    }
558
559    /// Open a fresh command buffer on the same queue and (when a fence
560    /// is pending) encode the matching wait on the new CB.
561    ///
562    /// This is the second half of the multi-stage chaining primitive.
563    /// After [`Self::fence_stage`] (or [`Self::commit_stage`] /
564    /// [`Self::commit_and_wait`]) has put the session in the `Drained`
565    /// state, callers invoke this to start the next stage's CB. The
566    /// session transitions back to `Encoding` (no CB or compute encoder
567    /// open until the next dispatch lazy-opens them).
568    ///
569    /// # Wait-event encoding
570    ///
571    /// If [`Self::fence_stage`] was the most recent commit, this
572    /// method encodes `encodeWaitForEvent:value:event_value` on the
573    /// freshly-allocated CB before returning. The new CB's GPU work
574    /// blocks until the prior CB's signal lands at the same value.
575    /// After [`Self::commit_stage`] / [`Self::commit_and_wait`] (no
576    /// fence), no wait is encoded — Metal's queue-FIFO sequencing is
577    /// the implicit ordering primitive.
578    ///
579    /// # State machine
580    ///
581    /// | Before | After |
582    /// |---|---|
583    /// | Drained (no fence) | Encoding (new CB, no wait) |
584    /// | Fenced (fence pending) | Encoding (new CB, wait encoded) |
585    /// | Encoding (not drained) | no-op (returns Ok) |
586    ///
587    /// The not-drained case is intentionally a no-op rather than a
588    /// panic: it keeps the session drop-safe under unusual call
589    /// sequences (e.g. test scaffolding that calls reset speculatively).
590    ///
591    /// # Counter semantics
592    ///
593    /// Bumps `CMD_BUF_COUNT` by exactly 1 (the new CB). Does NOT bump
594    /// `SYNC_COUNT` (no commit/wait happens here).
595    ///
596    /// # Errors
597    ///
598    /// Returns `Ok(())` unconditionally. Future error paths (e.g.
599    /// queue-side allocation failure on `new_command_buffer`) would
600    /// surface here.
601    pub fn reset_for_next_stage(&mut self) -> Result<()> {
602        if !self.drained {
603            return Ok(());
604        }
605
606        // Snapshot the wait-event metadata BEFORE rotating cmd_buf so
607        // we encode the wait on the NEW CB.
608        let wait_metadata = if self.fence_pending {
609            self.event
610                .as_ref()
611                .map(|ev| (ev.clone(), self.event_value))
612        } else {
613            None
614        };
615
616        self.inner.reset_command_buffer();
617
618        if let Some((event, value)) = wait_metadata {
619            // Deref-coerce SharedEventRef → EventRef via the
620            // ParentType = Event chain in metal-0.33.0/src/sync.rs:36-40.
621            let event_ref: &metal::EventRef = event.as_ref();
622            self.inner.encode_wait_for_event(event_ref, value);
623            // iter90b §2 H1b — track the wait-event for introspection.
624            // Bump scoreboard ONLY after the wait actually encoded
625            // (mirrors the signal-side discipline: `event_value` is
626            // updated AFTER `fence_signal_and_commit` returns). These
627            // fields are pure read-only observability — they do NOT
628            // alter F1 (encoder lazy-open), F2 (residency-flush), F11
629            // (alloc_buffer zero-init), or F12 (force-serial-dispatch).
630            self.wait_count += 1;
631            self.last_wait_value = value;
632        }
633
634        self.drained = false;
635        self.fence_pending = false;
636        self.stage_label.clear();
637        Ok(())
638    }
639
640    /// Add a buffer to the device-level residency set.
641    ///
642    /// Delegates to the inner encoder's [`ResidencySet::add_allocation`]
643    /// (the same Arc clone the device, the encoder, and every other
644    /// concurrent encoder shares — single-set invariant per ADR-019:467).
645    /// The actual `[set commit]` is deferred until the next
646    /// `commit_stage` / `commit_and_wait` / `fence_stage`, which all
647    /// route through `flush_residency_pending`.
648    ///
649    /// Returns `false` and is a no-op when the device booted without
650    /// a residency set (HF2Q_NO_RESIDENCY=1, macOS<15, or
651    /// `MlxError::DeviceNotFound` test paths).
652    ///
653    /// # Use case
654    ///
655    /// Caller holds an [`MlxBuffer`] not previously registered (e.g.
656    /// from a pool, slice_view, or external interop) and wants the GPU
657    /// pages hinted as resident before the stage's first dispatch.
658    /// `MlxDevice::alloc_buffer` already auto-registers — this method
659    /// is the explicit hook for the residual cases.
660    pub fn add_to_residency_set(&self, buffer: &MlxBuffer) -> bool {
661        match self.inner.residency_set() {
662            Some(set) => {
663                set.add_allocation(buffer.metal_buffer());
664                true
665            }
666            None => false,
667        }
668    }
669
670    /// Remove a buffer from the device-level residency set.
671    ///
672    /// Mirror of [`Self::add_to_residency_set`]. Stages a deferred
673    /// `removeAllocation:` that flushes at the next commit boundary.
674    /// Returns `false` and no-ops when no residency set is active.
675    ///
676    /// # F2 caveat
677    ///
678    /// Removing a buffer that the in-flight CB still references is the
679    /// iter58b residency-rescission class. Under retained-refs
680    /// (default), the CB's ARC retain keeps the underlying Metal page
681    /// alive; the residency-set demotion only affects the resident-hint
682    /// (a perf knob, not a safety knob). Under `MLX_UNRETAINED_REFS=1`
683    /// (NOT enabled in Phase 0b), the caller-owned arena contract is
684    /// the only structural mitigation.
685    pub fn remove_from_residency_set(&self, buffer: &MlxBuffer) -> bool {
686        match self.inner.residency_set() {
687            Some(set) => {
688                set.remove_allocation(buffer.metal_buffer());
689                true
690            }
691            None => false,
692        }
693    }
694
695    /// Whether the session has been committed (any commit path).
696    ///
697    /// Test-and-introspection helper. Production code should use the
698    /// explicit `reset_for_next_stage` cycle to chain stages rather
699    /// than polling this field.
700    #[inline]
701    pub fn is_drained(&self) -> bool {
702        self.drained
703    }
704
705    /// Whether a fence is pending (most recent commit was `fence_stage`).
706    ///
707    /// Test-and-introspection helper for verifying the multi-stage
708    /// state machine. Cleared by the next `reset_for_next_stage` /
709    /// `commit_stage` / `commit_and_wait`.
710    #[inline]
711    pub fn is_fence_pending(&self) -> bool {
712        self.fence_pending
713    }
714
715    /// The current monotonic fence value.
716    ///
717    /// Returns 0 before the first `fence_stage`; otherwise returns the
718    /// most recently signaled value. Mirrors the semantics of
719    /// `ggml_metal_event::value` — a fence at value `N` means signal
720    /// `N` is in flight (or completed) and any subsequent waiters at
721    /// `N` will be unblocked.
722    #[inline]
723    pub fn fence_value(&self) -> u64 {
724        self.event_value
725    }
726
727    /// Whether a [`MTLSharedEvent`](metal::SharedEvent) has been allocated
728    /// in this session.
729    ///
730    /// Returns `false` until the first `fence_stage`; `true` afterwards.
731    /// Test helper for verifying lazy-allocation behavior.
732    #[inline]
733    pub fn has_event(&self) -> bool {
734        self.event.is_some()
735    }
736
737    /// The most recent value passed to `encode_wait_for_event` inside
738    /// [`Self::reset_for_next_stage`].
739    ///
740    /// Returns 0 until the first `reset_for_next_stage` actually emits
741    /// a wait (i.e. the prior commit was [`Self::fence_stage`], not
742    /// [`Self::commit_stage`] / [`Self::commit_and_wait`]). After a
743    /// `fence_stage(N)` followed by `reset_for_next_stage()`, this MUST
744    /// equal `N` — the wait-side scoreboard mirrors the signal-side
745    /// [`Self::fence_value`].
746    ///
747    /// iter90b §2 H1b proof helper: makes the wait-event encoding
748    /// observable from a Rust test without xctrace.
749    ///
750    /// # Risk register
751    ///
752    /// Pure read-only introspection. Reads a `u64` field updated under
753    /// `&mut self` exclusively (no concurrent mutation possible —
754    /// `EncoderSession` is `!Sync`). Does NOT widen F1/F2/F11/F12.
755    #[inline]
756    pub fn wait_value(&self) -> u64 {
757        self.last_wait_value
758    }
759
760    /// Cumulative count of `encode_wait_for_event` calls actually
761    /// emitted inside [`Self::reset_for_next_stage`] in this session.
762    ///
763    /// Bumped exactly once per `reset_for_next_stage` call that finds
764    /// `fence_pending == true` — i.e. once per "fence + reset" pair.
765    /// `commit_stage` / `commit_and_wait` followed by
766    /// `reset_for_next_stage` does NOT bump this (no wait emitted —
767    /// Metal queue FIFO is the implicit ordering primitive in that
768    /// case).
769    ///
770    /// For an N-stage chain (N fences + (N-1) resets), this returns
771    /// `N - 1` after the last reset. The Nth (terminal) fence is
772    /// drained by the caller via `metal_command_buffer().wait_until_completed()`
773    /// or by a subsequent `commit_and_wait`, neither of which emits an
774    /// additional wait.
775    ///
776    /// iter90b §2 H1b proof helper: paired with [`Self::wait_value`] to
777    /// make the wait-event side of the multi-stage chain observable.
778    ///
779    /// # Risk register
780    ///
781    /// Same as [`Self::wait_value`] — pure read-only introspection over
782    /// a `u64` field updated under `&mut self` exclusively. Does NOT
783    /// widen F1/F2/F11/F12.
784    #[inline]
785    pub fn wait_count(&self) -> u64 {
786        self.wait_count
787    }
788
789    /// Borrow the underlying Metal command buffer.
790    ///
791    /// Mirrors [`CommandEncoder::metal_command_buffer`]. Used by
792    /// label-propagation tests and by callers that need to call
793    /// `wait_until_completed` after a non-blocking `commit_stage` /
794    /// `fence_stage`.
795    #[inline]
796    pub fn metal_command_buffer(&self) -> &metal::CommandBuffer {
797        self.inner.metal_command_buffer()
798    }
799}
800
801impl Drop for EncoderSession {
802    /// Drain the inner [`CommandEncoder`] safely on drop.
803    ///
804    /// # F2 residency-rescission preservation (load-bearing)
805    ///
806    /// Drop scenarios across the multi-stage state machine:
807    ///
808    /// 1. **Drained (no fence)** — `commit_stage` / `commit_and_wait`
809    ///    already ran. `inner.flush_residency_pending()` was already
810    ///    called; the GPU has the CB (and may already have completed
811    ///    it under `commit_and_wait`). `CommandEncoder::Drop` runs and
812    ///    calls `end_active_encoder()`, which is a no-op because
813    ///    `commit*` already ended the encoder. Safe.
814    ///
815    /// 2. **Fenced (fence pending)** — `fence_stage` already ran. The
816    ///    signal-event has been encoded onto the prior CB and the CB
817    ///    has been submitted non-blocking. The session never opened a
818    ///    new CB (no `reset_for_next_stage` call), so `cmd_buf` still
819    ///    points at the FENCED CB. `CommandEncoder::Drop` runs and
820    ///    end_active_encoder is a no-op (encoder was ended inside
821    ///    `fence_signal_and_commit`). The submitted CB executes on the
822    ///    GPU normally — the signal lands, the value is observable to
823    ///    any external `waitUntilSignaledValue:` consumer (none in
824    ///    iter89e2-B), and the next allocation/CB on the same residency
825    ///    set will see the bumped pending flag flushed at its commit
826    ///    boundary. The fence event itself is dropped with `event` (an
827    ///    Option<SharedEvent>); ARC drop releases it.
828    ///
829    /// 3. **Encoding (uncommitted)** — caller created the session,
830    ///    optionally encoded dispatches, then dropped without calling
831    ///    any `commit_*`. `CommandEncoder::Drop` ends the active
832    ///    compute encoder cleanly (`encoder.rs:2057-2063`). The
833    ///    `cmd_buf` is dropped without ever being committed — Metal
834    ///    discards the encoded work. **No residency-remove is staged**
835    ///    because no buffers were registered as freed during this
836    ///    session (the F2 race requires a buffer drop staging a remove
837    ///    that a later `flush_pending` commits before the in-flight CB
838    ///    finishes; here no commit ever happens). The residency-set's
839    ///    pending state persists into the next encoder; correct.
840    ///
841    /// 4. **Empty** — no dispatches encoded. `active_encoder` is null;
842    ///    `CommandEncoder::Drop`'s `end_active_encoder` is a no-op.
843    ///    Safe.
844    ///
845    /// We deliberately do NOT call `wait_until_completed` here for the
846    /// committed-but-not-waited case (scenarios 1 with `commit_stage`
847    /// or 2 with `fence_stage`). Under retained-refs mode (default —
848    /// `MLX_UNRETAINED_REFS=0`), the in-flight CB holds ARC retains on
849    /// every bound buffer, so the GPU completes safely after the
850    /// session drops. Under `MLX_UNRETAINED_REFS=1` (NOT enabled in
851    /// Phase 0b), the caller-owned-arena contract is the only
852    /// structural mitigation — same as the existing async-`commit()`
853    /// path at `encoder.rs:2014-2022`.
854    ///
855    /// In short: `Drop` does no extra work; the inner `CommandEncoder`'s
856    /// own Drop is the entire safety story. `metal::SharedEvent` drops
857    /// via its foreign_obj_type! ARC release.
858    fn drop(&mut self) {
859        // The actual end-encoder call lives in `CommandEncoder::Drop`,
860        // which fires automatically when `self.inner` goes out of scope
861        // here. The `event` field's ObjC release fires via
862        // foreign_obj_type! ARC. No additional work needed — see this
863        // docstring's case analysis above for the F2 fence preservation
864        // argument.
865    }
866}