Skip to main content

mlx_native/
encoder_session.rs

1//! [`EncoderSession`] — D3 Per-Stage Fence encoder abstraction (ADR-019 Phase 0b).
2//!
3//! `EncoderSession` lifts [`CommandEncoder`] into a session-aware shell that
4//! carries semantic *stage* metadata across the lifetime of one or more
5//! logical transformer stages (e.g. `"layer.full_attn.stage1"` →
6//! `"layer.full_attn.stage2"`). Phase 0b-A delivered the bare struct +
7//! single-stage lifecycle methods. Phase 0b-B (this file) adds the
8//! [`MTLSharedEvent`](metal::SharedEvent) inter-CB ordering primitives D3
9//! needs:
10//!
11//! - [`Self::fence_stage`] — encode signal-event(N+1) on the current CB,
12//!   commit non-blocking, increment the per-session monotonic counter.
13//! - [`Self::reset_for_next_stage`] — open a fresh CB on the same queue
14//!   and (when a fence is active) encode wait-event(N) on the new CB so
15//!   its GPU work blocks until the prior fenced CB completes.
16//! - [`Self::add_to_residency_set`] / [`Self::remove_from_residency_set`]
17//!   — public delegation surface for the single residency set owned by
18//!   [`MlxDevice`](crate::MlxDevice). EncoderSession does NOT own a
19//!   separate set; it routes calls through the Arc clone the inner
20//!   [`CommandEncoder`] already holds.
21//!
22//! Phase 0b-C will broaden label propagation (per-substage labels +
23//! xctrace MST round-trip). iter89e2-B leaves Phase 0b-A's existing
24//! per-session label semantics intact.
25//!
26//! Production callers MUST stay on [`crate::CommandEncoder`] until Phase
27//! 2 (FA-path D3 stage migration) wires `forward_gpu.rs` to consume
28//! `EncoderSession`. The struct is feature-flagged behind
29//! `HF2Q_ENCODER_SESSION=1` (default OFF) and is constructed only via
30//! [`MlxDevice::encoder_session`](crate::MlxDevice::encoder_session).
31//!
32//! # Lifecycle (iter89e2-B — multi-stage chaining)
33//!
34//! ```text
35//!                   MlxDevice::encoder_session()
36//!                                |
37//!                                v
38//!                          +-----------+
39//!                          | Empty     |  no CB or encoder open yet
40//!                          +-----------+
41//!                                |
42//!                          first dispatch
43//!                          (via inner CommandEncoder)
44//!                                |
45//!                                v
46//!                          +-----------+
47//!                          | Encoding  |  CB open, persistent compute encoder open
48//!                          +-----------+
49//!                            |       |
50//!                            |       +---fence_stage(label)----+
51//!                            |                                 |
52//!                            |                                 v
53//!                            |                          +-----------+
54//!                            |                          | Fenced    |  signal encoded; CB submitted
55//!                            |                          +-----------+
56//!                            |                                 |
57//!                  commit_stage()                  reset_for_next_stage()
58//!                  commit_and_wait()                            |
59//!                            |                                 v
60//!                            v                          (loop back to Encoding
61//!                      +-----------+                    on next dispatch — wait
62//!                      | Drained   |                    is encoded automatically
63//!                      +-----------+                    on the new CB)
64//!                            |
65//!                          Drop
66//! ```
67//!
68//! `fence_stage` collapses the design-doc's separate Encoding→Fenced→
69//! Committed transitions into a single "submit-with-fence" call: the
70//! signal is encoded on the current CB at `event_value+1`, the encoder
71//! is ended, the CB is committed non-blocking, and the per-session
72//! monotonic counter is incremented. The session is `drained` until
73//! [`Self::reset_for_next_stage`] rotates the inner [`CommandEncoder`]'s
74//! command buffer to a fresh CB on the same queue and (when the event is
75//! present) encodes the matching wait at `event_value` on the new CB.
76//!
77//! # Risk register fence preservation (F1-F12 from ADR-019)
78//!
79//! - **F1 — persistent compute encoder per CB**: ADOPTED unchanged.
80//!   `EncoderSession` borrows `&mut CommandEncoder` via [`Self::encoder`];
81//!   every dispatch reuses the same lazy-opened encoder per CB. Each
82//!   stage CB still has exactly one persistent compute encoder.
83//! - **F2 — iter58b residency-rescission**: PRESERVED. `commit_stage`,
84//!   `commit_and_wait`, and `fence_stage` all delegate to the inner
85//!   encoder, which calls `flush_residency_pending()` at every commit
86//!   boundary (`encoder.rs:1842, 2004`). `reset_for_next_stage` does NOT
87//!   re-flush — staged add/remove operations between stages flush at
88//!   the next commit on the new CB. The single residency set is owned
89//!   by [`MlxDevice`](crate::MlxDevice) (single-set invariant per
90//!   ADR-019:467). Multi-stage chaining DOES widen the in-flight CB
91//!   window — dropping a buffer between stage 1's `fence_stage` and
92//!   stage 2's `commit_*` stages a remove-allocation that flushes at
93//!   stage 2's commit, while stage 1's CB may still be GPU-pipelined.
94//!   Under retained-refs (default), the prior CB's ARC retains keep the
95//!   underlying Metal buffer alive across the residency-set demotion;
96//!   the GPU completes safely. Under `MLX_UNRETAINED_REFS=1` (NOT
97//!   enabled in Phase 0b), caller-owned arenas remain the only
98//!   structural mitigation — same contract as the existing async-commit
99//!   path. The adversarial F2 test (see
100//!   `/opt/mlx-native/tests/encoder_session_multistage.rs`) explicitly
101//!   exercises this window.
102//! - **F11 — zero-init alloc_buffer**: INVARIANT. `EncoderSession` does
103//!   not allocate buffers; the zero-init contract on
104//!   `MlxDevice::alloc_buffer` is unchanged.
105//! - **F12 — `HF2Q_FORCE_SERIAL_DISPATCH` falsification probe**: PRESERVED.
106//!   The probe lives in `CommandEncoder::get_or_create_encoder` and is
107//!   re-read every time a fresh CB lazily opens its compute encoder
108//!   (every `reset_for_next_stage` rotation). Both pre- and post-fence
109//!   CBs honor the env var.
110//! - F3, F4, F5, F6, F7, F8, F9, F10 are out of scope for iter89e2-B
111//!   (forward-path phases 1-4 territory) — `EncoderSession` is purely
112//!   structural and does not touch any forward path.
113//!
114//! # Feature gate
115//!
116//! [`MlxDevice::encoder_session`] returns `Ok(None)` when the
117//! `HF2Q_ENCODER_SESSION` env var is unset (default). When set to `"1"`
118//! it returns `Ok(Some(EncoderSession))`. Production code paths in hf2q
119//! consume `device.command_encoder()` (returns plain `CommandEncoder`) so
120//! the gate is a no-op in default builds — zero behavior change.
121
122use crate::buffer::MlxBuffer;
123use crate::encoder::CommandEncoder;
124use crate::error::Result;
125use crate::residency::ResidencySet;
126
127/// Cached `HF2Q_ENCODER_SESSION` decision.
128///
129/// Identical pattern to `auto_barrier_enabled` / `unretained_refs_enabled`
130/// in `encoder.rs` — `OnceLock` so the env-read happens exactly once per
131/// process, and the per-call cost is a single atomic load. Declared at
132/// module scope so the gate is observable from both
133/// [`EncoderSession::env_enabled`] (the public introspection helper) and
134/// [`MlxDevice::encoder_session`] (the factory site).
135fn encoder_session_enabled() -> bool {
136    use std::sync::OnceLock;
137    static FLAG: OnceLock<bool> = OnceLock::new();
138    *FLAG.get_or_init(|| {
139        std::env::var("HF2Q_ENCODER_SESSION")
140            .map(|v| v == "1")
141            .unwrap_or(false)
142    })
143}
144
145/// Session-level wrapper around a [`CommandEncoder`] for one or more
146/// logical transformer stages.
147///
148/// See module docs for lifecycle and fence preservation. iter89e2-B scope:
149/// multi-stage chaining via [`MTLSharedEvent`](metal::SharedEvent), residency
150/// delegation surface, and the matching test cohort. Phase 0b-C will
151/// broaden label propagation; Phase 2+ will wire this struct into the
152/// production forward path.
153///
154/// # Thread safety
155///
156/// `EncoderSession` is `Send` because [`CommandEncoder`] is `Send` (the
157/// existing unsafe impl at `encoder.rs:613-619`), `String`/`u64`/`bool`
158/// are `Send`, [`metal::Device`] is `Send + Sync` (foreign_obj_type! at
159/// metal-rs 0.33 lib.rs:179), and [`metal::SharedEvent`] is `Send + Sync`
160/// for the same reason. It is NOT `Sync` — exclusive ownership during
161/// dispatch encoding is the same contract as the inner [`CommandEncoder`].
162pub struct EncoderSession {
163    /// Inner command encoder. Carries `cmd_buf`, the persistent
164    /// `active_encoder`, the `queue` clone (read by
165    /// [`CommandEncoder::reset_command_buffer`]), the residency-set
166    /// flush hook, capture-mode IR, the auto-barrier `MemRanges`
167    /// tracker, the iter63 sample buffer, and the iter16 `last_label`
168    /// history. All dispatch operations flow through here.
169    ///
170    /// INVARIANT: `inner` is in a consistent state at every public API
171    /// boundary. Drops cleanly via `CommandEncoder::Drop` which calls
172    /// `end_active_encoder()` (Metal-asserts on a CB dropped with an
173    /// unended encoder).
174    inner: CommandEncoder,
175
176    /// Owned clone of the originating [`metal::Device`].
177    ///
178    /// iter89e2-B: held so [`Self::fence_stage`] can lazily allocate an
179    /// [`metal::SharedEvent`] on first call without threading a
180    /// `&MlxDevice` through every call site. metal-rs 0.33's `Device`
181    /// is `Send + Sync` (foreign_obj_type! lib.rs:179), so adding this
182    /// field preserves the existing unsafe `Send` impl on
183    /// [`EncoderSession`] declared below.
184    device: metal::Device,
185
186    /// Lazily-allocated [`MTLSharedEvent`](metal::SharedEvent) backing
187    /// the per-session monotonic stage fence.
188    ///
189    /// `None` until the first [`Self::fence_stage`] call. Once
190    /// allocated, the same event is reused across every fence in this
191    /// session — the value half of the (event, value) pair carries the
192    /// monotonic identity. Cost is one ObjC alloc + autorelease per
193    /// session lifetime; subsequent fences reuse the same event.
194    event: Option<metal::SharedEvent>,
195
196    /// Per-session monotonic fence counter.
197    ///
198    /// Mirrors `ggml_metal_event::value` at
199    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m:941`.
200    /// [`Self::fence_stage`] post-increments (signal = current+1, then
201    /// store current+1); [`Self::reset_for_next_stage`] reads (wait =
202    /// current). Starts at 0; bumps to 1 on first fence; CB N waits on
203    /// value N to gate after CB N's signal lands.
204    event_value: u64,
205
206    /// Human-readable stage label for xctrace MST attribution.
207    ///
208    /// Set by [`Self::begin_stage`] and by the `Some` arm of
209    /// [`Self::fence_stage`]'s `label` parameter. Empty by default.
210    /// When non-empty, [`Self::commit_stage`], [`Self::commit_and_wait`],
211    /// and [`Self::fence_stage`] all delegate to the inner encoder's
212    /// `commit_labeled` / `commit_and_wait_labeled` path, which
213    /// propagates the label to `MTLCommandBuffer.label` and
214    /// `MTLComputeCommandEncoder.label` via `apply_labels` at
215    /// `encoder.rs:1968-1986`.
216    ///
217    /// Cleared by [`Self::reset_for_next_stage`] so each chained stage
218    /// starts with a fresh label slot — the caller calls `begin_stage`
219    /// (or passes `Some(label)` to the next `fence_stage`) per stage.
220    stage_label: String,
221
222    /// Latch flipped to `true` after a `commit_stage` / `commit_and_wait`
223    /// / `fence_stage` call.
224    ///
225    /// Used to enforce the one-CB-per-state contract: a `EncoderSession`
226    /// in the `Drained` (or `Fenced`) state must call
227    /// [`Self::reset_for_next_stage`] before further dispatches encode
228    /// onto a new CB. Calling `commit_*` twice without an intervening
229    /// reset is a logic error — we surface it as a no-op rather than a
230    /// panic so the session remains drop-safe.
231    drained: bool,
232
233    /// Whether the most recent commit was a [`Self::fence_stage`] call.
234    ///
235    /// When `true`, [`Self::reset_for_next_stage`] encodes an
236    /// `encodeWaitForEvent` on the new CB at `event_value`. Cleared by
237    /// `reset_for_next_stage` so a subsequent `commit_stage` (no fence)
238    /// does not spuriously emit a wait on the next reset.
239    fence_pending: bool,
240}
241
242// SAFETY: `EncoderSession` is `Send` provided that:
243// 1. `CommandEncoder` is `Send` (existing unsafe impl at encoder.rs:606,
244//    Apple documents that command buffers / encoders may be encoded
245//    from any thread provided exclusive ownership).
246// 2. `metal::Device` is `Send + Sync` via foreign_obj_type!
247//    (metal-0.33.0/src/lib.rs:179).
248// 3. `metal::SharedEvent` is `Send + Sync` via foreign_obj_type!
249//    (same site — the macro emits `unsafe type ...: Sync + Send`
250//    for every type, including SharedEvent in sync.rs:36-40).
251// 4. `String`, `u64`, `bool` are `Send`.
252// All five hold. `EncoderSession` does NOT add any non-Send fields in
253// iter89e2-B beyond `metal::Device` + `Option<metal::SharedEvent>` +
254// `u64` + `bool`, all already validated.
255unsafe impl Send for EncoderSession {}
256
257impl EncoderSession {
258    /// Construct a new session over a fresh `CommandEncoder`.
259    ///
260    /// Returns `Err` if the underlying `CommandEncoder::new_with_residency`
261    /// fails (currently impossible past metal-rs 0.33's
262    /// `new_command_buffer`, but the `Result` is preserved for
263    /// future-proofing against driver-side allocation failures).
264    ///
265    /// # Crate-internal
266    ///
267    /// `pub(crate)` because the public construction surface is
268    /// [`MlxDevice::encoder_session`](crate::MlxDevice::encoder_session),
269    /// which threads the env-gate. Direct construction from outside
270    /// `mlx-native` would bypass the `HF2Q_ENCODER_SESSION` flag, which
271    /// is the wrong layering.
272    pub(crate) fn new(
273        device: &metal::DeviceRef,
274        queue: &metal::CommandQueue,
275        residency_set: Option<ResidencySet>,
276    ) -> Result<Self> {
277        Ok(Self {
278            inner: CommandEncoder::new_with_residency(queue, residency_set)?,
279            device: device.to_owned(),
280            event: None,
281            event_value: 0,
282            stage_label: String::new(),
283            drained: false,
284            fence_pending: false,
285        })
286    }
287
288    /// Whether `HF2Q_ENCODER_SESSION=1` is set in the process environment.
289    ///
290    /// Public introspection helper for hf2q-side dispatch wrappers that
291    /// need to choose between the legacy `command_encoder()` path and the
292    /// new `encoder_session()` path. Cached on first read via `OnceLock`
293    /// so the per-call cost is a single atomic load.
294    #[inline]
295    pub fn env_enabled() -> bool {
296        encoder_session_enabled()
297    }
298
299    /// Set the semantic stage label.
300    ///
301    /// The label propagates to `MTLCommandBuffer.label` and (when an
302    /// encoder is active) `MTLComputeCommandEncoder.label` at the next
303    /// `commit_stage` / `commit_and_wait` / `fence_stage` call, enabling
304    /// xctrace MST attribution per ADR-015 iter16. Calling `begin_stage`
305    /// does NOT itself touch any Metal object — it only stores the
306    /// string.
307    ///
308    /// Idempotent: calling `begin_stage` multiple times before commit
309    /// overwrites the previous label with the latest value, matching
310    /// the existing `apply_labels` semantic at `encoder.rs:1980-1985`
311    /// (the `last_label` field is overwritten on every labeled commit).
312    pub fn begin_stage(&mut self, label: &str) {
313        self.stage_label.clear();
314        self.stage_label.push_str(label);
315    }
316
317    /// Borrow the inner [`CommandEncoder`] for dispatch encoding.
318    ///
319    /// All dispatch APIs (`encode`, `encode_threadgroups`,
320    /// `encode_with_args`, `dispatch_tracked_*`, `memory_barrier`,
321    /// `start_capture` / `take_capture`, etc.) live on
322    /// [`CommandEncoder`]; `EncoderSession` adds a stage-aware commit
323    /// surface on top of them. Use this accessor inside the dispatch
324    /// loop, then call one of [`Self::commit_stage`] /
325    /// [`Self::commit_and_wait`] / [`Self::fence_stage`] at the stage
326    /// boundary.
327    ///
328    /// # Caller contract
329    ///
330    /// Do NOT call `inner.commit*` methods directly through this
331    /// borrow. Use the session's commit surface so the stage label
332    /// propagates and the drained-latch / fence state stay consistent.
333    /// Calling the inner commit bypasses these — it is not unsafe (no
334    /// UB risk) but it makes the session state inconsistent with what
335    /// it has actually committed.
336    #[inline]
337    pub fn encoder(&mut self) -> &mut CommandEncoder {
338        &mut self.inner
339    }
340
341    /// Commit the stage's command buffer non-blocking (no fence).
342    ///
343    /// Delegates to `CommandEncoder::commit_labeled` (when a label is
344    /// set) or `CommandEncoder::commit` (when not). Both end the
345    /// persistent compute encoder, flush the residency-set pending
346    /// staging (`flush_residency_pending` at `encoder.rs:2004`), and
347    /// hand the CB to the GPU without blocking the CPU.
348    ///
349    /// The session enters the `Drained` state. To chain into another
350    /// stage on the same session, call [`Self::reset_for_next_stage`]
351    /// — that opens a fresh CB and (if a fence was pending from a prior
352    /// `fence_stage`) encodes the matching wait. After
353    /// `commit_stage` (no fence), `reset_for_next_stage` does NOT emit
354    /// a wait — the CBs are merely sequenced by the Metal queue's FIFO
355    /// dispatch order.
356    ///
357    /// # Errors
358    ///
359    /// Returns `Ok(())` unconditionally — `CommandEncoder::commit` and
360    /// `CommandEncoder::commit_labeled` are infallible (they hand the
361    /// CB to Metal without waiting for completion; errors surface only
362    /// at `wait_until_completed`). The `Result` is preserved for
363    /// symmetry with [`Self::commit_and_wait`] and for future-proofing.
364    pub fn commit_stage(&mut self) -> Result<()> {
365        if self.drained {
366            return Ok(());
367        }
368        self.drained = true;
369        self.fence_pending = false;
370        if self.stage_label.is_empty() {
371            self.inner.commit();
372        } else {
373            // Take a snapshot of the label so we don't borrow `self`
374            // both immutably (for the label) and mutably (for inner)
375            // — clones a small String, fine for stage-boundary cost.
376            let label = self.stage_label.clone();
377            self.inner.commit_labeled(&label);
378        }
379        Ok(())
380    }
381
382    /// Commit the stage's command buffer and block until GPU completion.
383    ///
384    /// Delegates to `CommandEncoder::commit_and_wait_labeled` (when a
385    /// label is set) or `CommandEncoder::commit_and_wait` (when not).
386    /// Required at K-batch boundaries (F7) and at output-head CPU reads
387    /// (F6). Increments `SYNC_COUNT` exactly once per call (matches
388    /// `encoder.rs:1845`).
389    ///
390    /// The session enters the `Drained` state with NO fence pending —
391    /// blocking commit fully drains the GPU, so the next stage (after
392    /// [`Self::reset_for_next_stage`]) needs no wait-event.
393    ///
394    /// # Errors
395    ///
396    /// Returns `MlxError::CommandBufferError` if the GPU reports an
397    /// error after wait — propagated from `CommandEncoder`.
398    pub fn commit_and_wait(&mut self) -> Result<()> {
399        if self.drained {
400            return Ok(());
401        }
402        self.drained = true;
403        self.fence_pending = false;
404        if self.stage_label.is_empty() {
405            self.inner.commit_and_wait()
406        } else {
407            let label = self.stage_label.clone();
408            self.inner.commit_and_wait_labeled(&label)
409        }
410    }
411
412    /// Encode a stage-fence signal on the current CB and commit non-blocking.
413    ///
414    /// This is the D3 multi-stage building block: the prior stage's
415    /// final CB-level op is `encodeSignalEvent:value:value+1`, where
416    /// `value+1` is then both stored in `event_value` (so the next
417    /// stage's `encodeWaitForEvent:value:` blocks on it) and committed.
418    /// The session enters the `Fenced` (drained-with-fence-pending)
419    /// state; [`Self::reset_for_next_stage`] rotates the inner CB and
420    /// emits the matching wait.
421    ///
422    /// # Lazy event allocation
423    ///
424    /// On the first call, allocates the per-session
425    /// [`MTLSharedEvent`](metal::SharedEvent) via
426    /// [`metal::DeviceRef::new_shared_event`]
427    /// (`/Users/robert/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metal-0.33.0/src/device.rs:2063`).
428    /// Subsequent calls reuse the same event — the monotonic
429    /// `event_value` carries the per-fence identity. This matches the
430    /// llama.cpp pattern at
431    /// `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m:944-958`.
432    ///
433    /// # Label
434    ///
435    /// `label`'s `Some(value)` arm overwrites `stage_label` and
436    /// propagates via `commit_labeled`'s `apply_labels` chain — same as
437    /// calling [`Self::begin_stage`] before this. `None` keeps any
438    /// previously-set `begin_stage` label intact.
439    ///
440    /// # Counter semantics
441    ///
442    /// Bumps `SYNC_COUNT` zero times (non-blocking). Bumps
443    /// `CMD_BUF_COUNT` zero times (no new CB allocated here —
444    /// `reset_for_next_stage` does that). Increments `event_value` by
445    /// exactly 1.
446    ///
447    /// # Errors
448    ///
449    /// Returns `Ok(())` unconditionally for the same reason
450    /// [`Self::commit_stage`] does.
451    pub fn fence_stage(&mut self, label: Option<&str>) -> Result<()> {
452        if self.drained {
453            return Ok(());
454        }
455        // Apply the label argument before committing so commit_labeled
456        // (called below) propagates the latest value to the CB. Note
457        // that the encoder.rs:1968 apply_labels writes to the active
458        // compute encoder iff one is open — at this point one IS open
459        // (we have not yet ended it), so the encoder picks up the label
460        // before end_encoding fires. After end_encoding the CB still
461        // has its label set (set on the CB itself, not the encoder).
462        if let Some(l) = label {
463            self.stage_label.clear();
464            self.stage_label.push_str(l);
465        }
466
467        // Lazy-alloc the SharedEvent on first fence in this session.
468        // metal::DeviceRef::new_shared_event lives at
469        // /Users/robert/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metal-0.33.0/src/device.rs:2063.
470        if self.event.is_none() {
471            self.event = Some(self.device.new_shared_event());
472        }
473
474        // Sequence: end-active-encoder + encodeSignalEvent (CB-level) +
475        // residency-flush + cmd_buf.commit, all inside the inner
476        // helper. This preserves F1 (encoder is ended exactly once per
477        // CB), F2 (residency-flush still fires at the commit boundary),
478        // and matches llama.cpp's pattern at
479        // `/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m:944-950`.
480        let new_value = self.event_value + 1;
481        let event_ref: &metal::SharedEventRef = self
482            .event
483            .as_ref()
484            .expect("event allocated immediately above this borrow")
485            .as_ref();
486        // Deref-coerce SharedEventRef -> EventRef via the
487        // ParentType = Event chain in metal-0.33.0/src/sync.rs:36-40.
488        let label_opt: Option<&str> = if self.stage_label.is_empty() {
489            None
490        } else {
491            Some(self.stage_label.as_str())
492        };
493        self.inner
494            .fence_signal_and_commit(event_ref, new_value, label_opt);
495
496        self.event_value = new_value;
497        self.drained = true;
498        self.fence_pending = true;
499        Ok(())
500    }
501
502    /// Open a fresh command buffer on the same queue and (when a fence
503    /// is pending) encode the matching wait on the new CB.
504    ///
505    /// This is the second half of the multi-stage chaining primitive.
506    /// After [`Self::fence_stage`] (or [`Self::commit_stage`] /
507    /// [`Self::commit_and_wait`]) has put the session in the `Drained`
508    /// state, callers invoke this to start the next stage's CB. The
509    /// session transitions back to `Encoding` (no CB or compute encoder
510    /// open until the next dispatch lazy-opens them).
511    ///
512    /// # Wait-event encoding
513    ///
514    /// If [`Self::fence_stage`] was the most recent commit, this
515    /// method encodes `encodeWaitForEvent:value:event_value` on the
516    /// freshly-allocated CB before returning. The new CB's GPU work
517    /// blocks until the prior CB's signal lands at the same value.
518    /// After [`Self::commit_stage`] / [`Self::commit_and_wait`] (no
519    /// fence), no wait is encoded — Metal's queue-FIFO sequencing is
520    /// the implicit ordering primitive.
521    ///
522    /// # State machine
523    ///
524    /// | Before | After |
525    /// |---|---|
526    /// | Drained (no fence) | Encoding (new CB, no wait) |
527    /// | Fenced (fence pending) | Encoding (new CB, wait encoded) |
528    /// | Encoding (not drained) | no-op (returns Ok) |
529    ///
530    /// The not-drained case is intentionally a no-op rather than a
531    /// panic: it keeps the session drop-safe under unusual call
532    /// sequences (e.g. test scaffolding that calls reset speculatively).
533    ///
534    /// # Counter semantics
535    ///
536    /// Bumps `CMD_BUF_COUNT` by exactly 1 (the new CB). Does NOT bump
537    /// `SYNC_COUNT` (no commit/wait happens here).
538    ///
539    /// # Errors
540    ///
541    /// Returns `Ok(())` unconditionally. Future error paths (e.g.
542    /// queue-side allocation failure on `new_command_buffer`) would
543    /// surface here.
544    pub fn reset_for_next_stage(&mut self) -> Result<()> {
545        if !self.drained {
546            return Ok(());
547        }
548
549        // Snapshot the wait-event metadata BEFORE rotating cmd_buf so
550        // we encode the wait on the NEW CB.
551        let wait_metadata = if self.fence_pending {
552            self.event
553                .as_ref()
554                .map(|ev| (ev.clone(), self.event_value))
555        } else {
556            None
557        };
558
559        self.inner.reset_command_buffer();
560
561        if let Some((event, value)) = wait_metadata {
562            // Deref-coerce SharedEventRef → EventRef via the
563            // ParentType = Event chain in metal-0.33.0/src/sync.rs:36-40.
564            let event_ref: &metal::EventRef = event.as_ref();
565            self.inner.encode_wait_for_event(event_ref, value);
566        }
567
568        self.drained = false;
569        self.fence_pending = false;
570        self.stage_label.clear();
571        Ok(())
572    }
573
574    /// Add a buffer to the device-level residency set.
575    ///
576    /// Delegates to the inner encoder's [`ResidencySet::add_allocation`]
577    /// (the same Arc clone the device, the encoder, and every other
578    /// concurrent encoder shares — single-set invariant per ADR-019:467).
579    /// The actual `[set commit]` is deferred until the next
580    /// `commit_stage` / `commit_and_wait` / `fence_stage`, which all
581    /// route through `flush_residency_pending`.
582    ///
583    /// Returns `false` and is a no-op when the device booted without
584    /// a residency set (HF2Q_NO_RESIDENCY=1, macOS<15, or
585    /// `MlxError::DeviceNotFound` test paths).
586    ///
587    /// # Use case
588    ///
589    /// Caller holds an [`MlxBuffer`] not previously registered (e.g.
590    /// from a pool, slice_view, or external interop) and wants the GPU
591    /// pages hinted as resident before the stage's first dispatch.
592    /// `MlxDevice::alloc_buffer` already auto-registers — this method
593    /// is the explicit hook for the residual cases.
594    pub fn add_to_residency_set(&self, buffer: &MlxBuffer) -> bool {
595        match self.inner.residency_set() {
596            Some(set) => {
597                set.add_allocation(buffer.metal_buffer());
598                true
599            }
600            None => false,
601        }
602    }
603
604    /// Remove a buffer from the device-level residency set.
605    ///
606    /// Mirror of [`Self::add_to_residency_set`]. Stages a deferred
607    /// `removeAllocation:` that flushes at the next commit boundary.
608    /// Returns `false` and no-ops when no residency set is active.
609    ///
610    /// # F2 caveat
611    ///
612    /// Removing a buffer that the in-flight CB still references is the
613    /// iter58b residency-rescission class. Under retained-refs
614    /// (default), the CB's ARC retain keeps the underlying Metal page
615    /// alive; the residency-set demotion only affects the resident-hint
616    /// (a perf knob, not a safety knob). Under `MLX_UNRETAINED_REFS=1`
617    /// (NOT enabled in Phase 0b), the caller-owned arena contract is
618    /// the only structural mitigation.
619    pub fn remove_from_residency_set(&self, buffer: &MlxBuffer) -> bool {
620        match self.inner.residency_set() {
621            Some(set) => {
622                set.remove_allocation(buffer.metal_buffer());
623                true
624            }
625            None => false,
626        }
627    }
628
629    /// Whether the session has been committed (any commit path).
630    ///
631    /// Test-and-introspection helper. Production code should use the
632    /// explicit `reset_for_next_stage` cycle to chain stages rather
633    /// than polling this field.
634    #[inline]
635    pub fn is_drained(&self) -> bool {
636        self.drained
637    }
638
639    /// Whether a fence is pending (most recent commit was `fence_stage`).
640    ///
641    /// Test-and-introspection helper for verifying the multi-stage
642    /// state machine. Cleared by the next `reset_for_next_stage` /
643    /// `commit_stage` / `commit_and_wait`.
644    #[inline]
645    pub fn is_fence_pending(&self) -> bool {
646        self.fence_pending
647    }
648
649    /// The current monotonic fence value.
650    ///
651    /// Returns 0 before the first `fence_stage`; otherwise returns the
652    /// most recently signaled value. Mirrors the semantics of
653    /// `ggml_metal_event::value` — a fence at value `N` means signal
654    /// `N` is in flight (or completed) and any subsequent waiters at
655    /// `N` will be unblocked.
656    #[inline]
657    pub fn fence_value(&self) -> u64 {
658        self.event_value
659    }
660
661    /// Whether a [`MTLSharedEvent`](metal::SharedEvent) has been allocated
662    /// in this session.
663    ///
664    /// Returns `false` until the first `fence_stage`; `true` afterwards.
665    /// Test helper for verifying lazy-allocation behavior.
666    #[inline]
667    pub fn has_event(&self) -> bool {
668        self.event.is_some()
669    }
670
671    /// Borrow the underlying Metal command buffer.
672    ///
673    /// Mirrors [`CommandEncoder::metal_command_buffer`]. Used by
674    /// label-propagation tests and by callers that need to call
675    /// `wait_until_completed` after a non-blocking `commit_stage` /
676    /// `fence_stage`.
677    #[inline]
678    pub fn metal_command_buffer(&self) -> &metal::CommandBuffer {
679        self.inner.metal_command_buffer()
680    }
681}
682
683impl Drop for EncoderSession {
684    /// Drain the inner [`CommandEncoder`] safely on drop.
685    ///
686    /// # F2 residency-rescission preservation (load-bearing)
687    ///
688    /// Drop scenarios across the multi-stage state machine:
689    ///
690    /// 1. **Drained (no fence)** — `commit_stage` / `commit_and_wait`
691    ///    already ran. `inner.flush_residency_pending()` was already
692    ///    called; the GPU has the CB (and may already have completed
693    ///    it under `commit_and_wait`). `CommandEncoder::Drop` runs and
694    ///    calls `end_active_encoder()`, which is a no-op because
695    ///    `commit*` already ended the encoder. Safe.
696    ///
697    /// 2. **Fenced (fence pending)** — `fence_stage` already ran. The
698    ///    signal-event has been encoded onto the prior CB and the CB
699    ///    has been submitted non-blocking. The session never opened a
700    ///    new CB (no `reset_for_next_stage` call), so `cmd_buf` still
701    ///    points at the FENCED CB. `CommandEncoder::Drop` runs and
702    ///    end_active_encoder is a no-op (encoder was ended inside
703    ///    `fence_signal_and_commit`). The submitted CB executes on the
704    ///    GPU normally — the signal lands, the value is observable to
705    ///    any external `waitUntilSignaledValue:` consumer (none in
706    ///    iter89e2-B), and the next allocation/CB on the same residency
707    ///    set will see the bumped pending flag flushed at its commit
708    ///    boundary. The fence event itself is dropped with `event` (an
709    ///    Option<SharedEvent>); ARC drop releases it.
710    ///
711    /// 3. **Encoding (uncommitted)** — caller created the session,
712    ///    optionally encoded dispatches, then dropped without calling
713    ///    any `commit_*`. `CommandEncoder::Drop` ends the active
714    ///    compute encoder cleanly (`encoder.rs:2057-2063`). The
715    ///    `cmd_buf` is dropped without ever being committed — Metal
716    ///    discards the encoded work. **No residency-remove is staged**
717    ///    because no buffers were registered as freed during this
718    ///    session (the F2 race requires a buffer drop staging a remove
719    ///    that a later `flush_pending` commits before the in-flight CB
720    ///    finishes; here no commit ever happens). The residency-set's
721    ///    pending state persists into the next encoder; correct.
722    ///
723    /// 4. **Empty** — no dispatches encoded. `active_encoder` is null;
724    ///    `CommandEncoder::Drop`'s `end_active_encoder` is a no-op.
725    ///    Safe.
726    ///
727    /// We deliberately do NOT call `wait_until_completed` here for the
728    /// committed-but-not-waited case (scenarios 1 with `commit_stage`
729    /// or 2 with `fence_stage`). Under retained-refs mode (default —
730    /// `MLX_UNRETAINED_REFS=0`), the in-flight CB holds ARC retains on
731    /// every bound buffer, so the GPU completes safely after the
732    /// session drops. Under `MLX_UNRETAINED_REFS=1` (NOT enabled in
733    /// Phase 0b), the caller-owned-arena contract is the only
734    /// structural mitigation — same as the existing async-`commit()`
735    /// path at `encoder.rs:2014-2022`.
736    ///
737    /// In short: `Drop` does no extra work; the inner `CommandEncoder`'s
738    /// own Drop is the entire safety story. `metal::SharedEvent` drops
739    /// via its foreign_obj_type! ARC release.
740    fn drop(&mut self) {
741        // The actual end-encoder call lives in `CommandEncoder::Drop`,
742        // which fires automatically when `self.inner` goes out of scope
743        // here. The `event` field's ObjC release fires via
744        // foreign_obj_type! ARC. No additional work needed — see this
745        // docstring's case analysis above for the F2 fence preservation
746        // argument.
747    }
748}