Skip to main content

mlx_native/
metal_capture.rs

1//! Programmatic Metal Frame Capture wrapping (ADR-015 iter63 Part B).
2//!
3//! Mirrors llama.cpp's `GGML_METAL_CAPTURE_COMPUTE` env-driven capture
4//! pattern (`/opt/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m`
5//! lines 161-170 + 488-608).  Triggered by:
6//!
7//! ```bash
8//! METAL_CAPTURE_ENABLED=1 \
9//! MLX_METAL_CAPTURE=/path/to/output.gputrace \
10//!     cargo run --release --bin hf2q -- generate ...
11//! ```
12//!
13//! The resulting `.gputrace` document opens in Xcode → Performance →
14//! GPU for full GPU-timeline analysis (idle gaps, scheduling stalls,
15//! memory pressure visualization).
16//!
17//! ## One-shot semantics
18//!
19//! Capture is one-shot per process: the FIRST [`MetalCapture::from_env`]
20//! call after the env var is set returns `Some(MetalCapture)`; every
21//! subsequent call returns `None` (a process-global `AtomicBool`
22//! latches after the first consume).  This mirrors llama.cpp's
23//! `capture_compute = 1; capture_compute--` countdown semantics —
24//! the very first decode/prefill forward pass is captured, and the
25//! kit goes quiet thereafter.  Re-running capture requires a fresh
26//! process.
27//!
28//! ## Permissions / entitlements
29//!
30//! From metal-rs 0.33 capturemanager.rs:75-79:
31//! > *Capture can be enabled by either:*
32//! > *1. Running from Xcode*
33//! > *2. Setting the environment variable `METAL_CAPTURE_ENABLED=1`*
34//! > *3. Adding an info.plist file containing the `MetalCaptureEnabled`
35//! >    key set to `YES`*
36//!
37//! For Cargo-built binaries (no Xcode), users MUST `export
38//! METAL_CAPTURE_ENABLED=1` alongside `MLX_METAL_CAPTURE=...`.
39//! Writing a `.gputrace` to disk requires no special entitlements
40//! (only `GpuTraceDocument`; the `DeveloperTools` destination would
41//! need Xcode running).
42//!
43//! [`MetalCapture::from_env`] is defensive: it checks
44//! `mgr.supports_destination(GpuTraceDocument)` first and inspects the
45//! `Result<(), String>` from `start_capture` (no panic).  On any
46//! failure path the function returns `None` after a one-shot stderr
47//! warning so callers can continue without capture.
48
49use std::sync::atomic::{AtomicBool, Ordering};
50
51use metal::{CaptureDescriptor, CaptureManager, CaptureScope, MTLCaptureDestination};
52
53use crate::MlxDevice;
54
55/// Process-global one-shot latch.  Set to `true` the first time a
56/// `MetalCapture` is constructed; subsequent `from_env` calls return
57/// `None` so a single bench run captures exactly the first forward
58/// pass (mirrors llama.cpp's countdown semantics).
59static CAPTURE_CONSUMED: AtomicBool = AtomicBool::new(false);
60
61/// A live programmatic capture session backed by an `MTLCaptureScope`.
62///
63/// Construct via [`MetalCapture::from_env`].  Pair `begin` /
64/// `end` calls around the unit of work to capture (e.g. a single
65/// forward pass).  `Drop` runs `end` defensively in case the caller
66/// forgets, so a panic mid-forward-pass still flushes a partial trace.
67pub struct MetalCapture {
68    scope: CaptureScope,
69    /// Whether the scope is currently open (`begin_scope` called
70    /// without a matching `end_scope`).  Set true by
71    /// [`Self::begin`]; cleared by [`Self::end`] / `Drop`.
72    started: bool,
73    /// The output URL for the trace document.  Stored for stderr
74    /// reporting at `end()`.
75    output_path: String,
76}
77
78impl MetalCapture {
79    /// Initialize a capture from the env vars `MLX_METAL_CAPTURE` (output
80    /// path) and `METAL_CAPTURE_ENABLED` (Apple's framework-level
81    /// permission gate).
82    ///
83    /// Returns `Some(MetalCapture)` only when ALL of:
84    /// 1. `MLX_METAL_CAPTURE` is set to a non-empty path,
85    /// 2. The capture manager supports
86    ///    `MTLCaptureDestination::GpuTraceDocument`,
87    /// 3. `start_capture` succeeds (which requires
88    ///    `METAL_CAPTURE_ENABLED=1` or running under Xcode), and
89    /// 4. The process-global one-shot latch [`CAPTURE_CONSUMED`] has
90    ///    not yet flipped to `true`.
91    ///
92    /// On any failure path returns `None` after a one-shot stderr
93    /// warning describing the cause.  The caller is expected to
94    /// ignore the `None` and proceed without capture — never panic.
95    pub fn from_env(device: &MlxDevice) -> Option<Self> {
96        // 1. Env-var read.  Empty string treated as unset.
97        let path = match std::env::var("MLX_METAL_CAPTURE") {
98            Ok(s) if !s.is_empty() => s,
99            _ => return None,
100        };
101        // 4. One-shot latch (checked early so the warning fires only
102        // for the first consumer of the env, not every forward pass).
103        if CAPTURE_CONSUMED
104            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
105            .is_err()
106        {
107            return None;
108        }
109        // 2. Destination support check.
110        let mgr = CaptureManager::shared();
111        if !mgr.supports_destination(MTLCaptureDestination::GpuTraceDocument) {
112            eprintln!(
113                "[mlx-native] MLX_METAL_CAPTURE={} ignored: \
114                 GpuTraceDocument destination unsupported on this device",
115                path
116            );
117            return None;
118        }
119        // Build descriptor + scope.  We scope to the device's command
120        // queue so any CB enqueued through the existing
121        // CommandEncoder (which uses device.metal_queue() under the
122        // hood) is captured.
123        let scope = mgr.new_capture_scope_with_command_queue(device.metal_queue());
124        let descriptor = CaptureDescriptor::new();
125        descriptor.set_capture_scope(&scope);
126        descriptor.set_destination(MTLCaptureDestination::GpuTraceDocument);
127        descriptor.set_output_url(&path);
128        // 3. Start capture (returns Result<(), String>).  Failure
129        // typically means METAL_CAPTURE_ENABLED is unset.
130        match mgr.start_capture(&descriptor) {
131            Ok(()) => {
132                eprintln!(
133                    "[mlx-native] MTLCaptureManager: starting capture to {}",
134                    path
135                );
136                Some(Self {
137                    scope,
138                    started: false,
139                    output_path: path,
140                })
141            }
142            Err(e) => {
143                eprintln!(
144                    "[mlx-native] MLX_METAL_CAPTURE={} capture start failed: {} \
145                     (set METAL_CAPTURE_ENABLED=1?)",
146                    path, e
147                );
148                None
149            }
150        }
151    }
152
153    /// Open the capture scope.  Idempotent: a second `begin` without
154    /// an intervening `end` is a no-op so callers can wrap nested
155    /// forward passes safely.
156    ///
157    /// Call at the start of the unit of work to capture (e.g.
158    /// `GraphExecutor::begin` — see graph.rs wire-up).
159    pub fn begin(&mut self) {
160        if self.started {
161            return;
162        }
163        self.scope.begin_scope();
164        self.started = true;
165    }
166
167    /// Close the capture scope and stop the underlying
168    /// `MTLCaptureManager`.  After `end` returns the `.gputrace` file
169    /// is finalized and openable in Xcode.
170    ///
171    /// Idempotent.  Calling `end` without a preceding `begin` is a
172    /// no-op (defensive — ensures `Drop` after a panic between
173    /// construction and `begin` doesn't `endScope` an unopened scope).
174    pub fn end(&mut self) {
175        if !self.started {
176            return;
177        }
178        self.scope.end_scope();
179        CaptureManager::shared().stop_capture();
180        self.started = false;
181        eprintln!(
182            "[mlx-native] MTLCaptureManager: stopped (trace at {})",
183            self.output_path
184        );
185    }
186}
187
188impl Drop for MetalCapture {
189    fn drop(&mut self) {
190        // Defensive flush: if the caller forgot `end`, still finalize
191        // the trace.  Idempotent so this is safe even when end() was
192        // already called explicitly.
193        self.end();
194    }
195}
196
197/// Test-only reset of the one-shot latch.  Hidden from rustdoc and
198/// gated behind `#[doc(hidden)]` because production callers must NOT
199/// flip the latch manually — the once-per-process semantics are part
200/// of the contract.  Cargo runs each test binary in a fresh process,
201/// but unit tests inside the same binary need this hook to exercise
202/// the `from_env` path more than once.
203#[doc(hidden)]
204pub fn reset_capture_consumed_for_test() {
205    CAPTURE_CONSUMED.store(false, Ordering::SeqCst);
206}
207
208#[cfg(test)]
209#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn from_env_returns_none_when_unset() {
215        // Ensure the var is unset (does not affect other tests because
216        // the latch is process-global; this test just verifies the
217        // env-empty path).
218        unsafe { std::env::remove_var("MLX_METAL_CAPTURE") };
219        reset_capture_consumed_for_test();
220        let device = MlxDevice::new().expect("MlxDevice::new");
221        assert!(
222            MetalCapture::from_env(&device).is_none(),
223            "MLX_METAL_CAPTURE unset → from_env must return None"
224        );
225    }
226
227    #[test]
228    fn from_env_returns_none_on_empty_string() {
229        unsafe { std::env::set_var("MLX_METAL_CAPTURE", "") };
230        reset_capture_consumed_for_test();
231        let device = MlxDevice::new().expect("device");
232        assert!(
233            MetalCapture::from_env(&device).is_none(),
234            "MLX_METAL_CAPTURE=\"\" → from_env must return None"
235        );
236        unsafe { std::env::remove_var("MLX_METAL_CAPTURE") };
237    }
238}