use std::sync::atomic::{AtomicBool, Ordering};
use metal::{CaptureDescriptor, CaptureManager, CaptureScope, MTLCaptureDestination};
use crate::MlxDevice;
static CAPTURE_CONSUMED: AtomicBool = AtomicBool::new(false);
pub struct MetalCapture {
scope: CaptureScope,
started: bool,
output_path: String,
}
impl MetalCapture {
pub fn from_env(device: &MlxDevice) -> Option<Self> {
let path = match std::env::var("MLX_METAL_CAPTURE") {
Ok(s) if !s.is_empty() => s,
_ => return None,
};
if CAPTURE_CONSUMED
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return None;
}
let mgr = CaptureManager::shared();
if !mgr.supports_destination(MTLCaptureDestination::GpuTraceDocument) {
eprintln!(
"[mlx-native] MLX_METAL_CAPTURE={} ignored: \
GpuTraceDocument destination unsupported on this device",
path
);
return None;
}
let scope = mgr.new_capture_scope_with_command_queue(device.metal_queue());
let descriptor = CaptureDescriptor::new();
descriptor.set_capture_scope(&scope);
descriptor.set_destination(MTLCaptureDestination::GpuTraceDocument);
descriptor.set_output_url(&path);
match mgr.start_capture(&descriptor) {
Ok(()) => {
eprintln!(
"[mlx-native] MTLCaptureManager: starting capture to {}",
path
);
Some(Self {
scope,
started: false,
output_path: path,
})
}
Err(e) => {
eprintln!(
"[mlx-native] MLX_METAL_CAPTURE={} capture start failed: {} \
(set METAL_CAPTURE_ENABLED=1?)",
path, e
);
None
}
}
}
pub fn begin(&mut self) {
if self.started {
return;
}
self.scope.begin_scope();
self.started = true;
}
pub fn end(&mut self) {
if !self.started {
return;
}
self.scope.end_scope();
CaptureManager::shared().stop_capture();
self.started = false;
eprintln!(
"[mlx-native] MTLCaptureManager: stopped (trace at {})",
self.output_path
);
}
}
impl Drop for MetalCapture {
fn drop(&mut self) {
self.end();
}
}
#[doc(hidden)]
pub fn reset_capture_consumed_for_test() {
CAPTURE_CONSUMED.store(false, Ordering::SeqCst);
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn from_env_returns_none_when_unset() {
unsafe { std::env::remove_var("MLX_METAL_CAPTURE") };
reset_capture_consumed_for_test();
let device = MlxDevice::new().expect("MlxDevice::new");
assert!(
MetalCapture::from_env(&device).is_none(),
"MLX_METAL_CAPTURE unset → from_env must return None"
);
}
#[test]
fn from_env_returns_none_on_empty_string() {
unsafe { std::env::set_var("MLX_METAL_CAPTURE", "") };
reset_capture_consumed_for_test();
let device = MlxDevice::new().expect("device");
assert!(
MetalCapture::from_env(&device).is_none(),
"MLX_METAL_CAPTURE=\"\" → from_env must return None"
);
unsafe { std::env::remove_var("MLX_METAL_CAPTURE") };
}
}