use std::path::PathBuf;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use metal::{CaptureDescriptor, CaptureManager, Device, MTLCaptureDestination};
static DECODE_ARMED: AtomicBool = AtomicBool::new(false);
#[derive(Debug, Clone)]
pub enum CaptureMode {
Prefill {
at_chunk: usize,
at_layer: usize,
n_layers: usize,
},
Decode {
at_layer: usize,
at_token: usize,
},
}
#[derive(Debug, Clone)]
pub struct GpuCaptureConfig {
pub path: PathBuf,
pub mode: CaptureMode,
}
impl GpuCaptureConfig {
pub fn prefill_start(&self, chunk_idx: usize, layer_idx: usize) -> bool {
match &self.mode {
CaptureMode::Prefill { at_chunk, at_layer, .. } => {
chunk_idx == *at_chunk && layer_idx == *at_layer
}
CaptureMode::Decode { .. } => false,
}
}
pub fn prefill_stop(&self, chunk_idx: usize, layer_idx: usize) -> bool {
match &self.mode {
CaptureMode::Prefill { at_chunk, at_layer, n_layers } => {
chunk_idx == *at_chunk
&& layer_idx == at_layer.saturating_add(*n_layers)
}
CaptureMode::Decode { .. } => false,
}
}
pub fn decode_start(&self, layer_idx: usize) -> bool {
match &self.mode {
CaptureMode::Decode { at_layer, .. } => {
DECODE_ARMED.load(Ordering::Relaxed) && layer_idx == *at_layer
}
CaptureMode::Prefill { .. } => false,
}
}
pub fn decode_stop(&self, layer_idx: usize) -> bool {
match &self.mode {
CaptureMode::Decode { at_layer, .. } => {
DECODE_ARMED.load(Ordering::Relaxed)
&& layer_idx == at_layer.saturating_add(1)
}
CaptureMode::Prefill { .. } => false,
}
}
}
pub fn decode_begin_token() -> bool {
static COUNTDOWN: OnceLock<AtomicUsize> = OnceLock::new();
let counter = COUNTDOWN.get_or_init(|| {
let n = config()
.and_then(|cfg| match &cfg.mode {
CaptureMode::Decode { at_token, .. } => Some(*at_token),
_ => None,
})
.unwrap_or(0);
AtomicUsize::new(n)
});
let prev = counter.fetch_sub(1, Ordering::Relaxed);
if prev == 0 {
DECODE_ARMED.store(true, Ordering::Relaxed);
if let Some(cfg) = config() {
if let CaptureMode::Decode { at_layer, at_token } = &cfg.mode {
eprintln!(
"[gpu_capture] token {at_token} reached, \
arming capture at layer {at_layer}",
);
}
}
true
} else {
false
}
}
pub fn config() -> Option<&'static GpuCaptureConfig> {
static CFG: OnceLock<Option<GpuCaptureConfig>> = OnceLock::new();
CFG.get_or_init(|| {
let path = std::env::var("MOEFLUX_GPU_CAPTURE_PATH").ok()?;
if std::env::var("METAL_CAPTURE_ENABLED").ok().as_deref()
!= Some("1")
{
eprintln!(
"[gpu_capture] MOEFLUX_GPU_CAPTURE_PATH is set but \
METAL_CAPTURE_ENABLED is not 1 — Metal will reject \
startCapture. Set METAL_CAPTURE_ENABLED=1 to enable."
);
return None;
}
let mode_str = std::env::var("MOEFLUX_GPU_CAPTURE_MODE")
.unwrap_or_else(|_| "prefill".into());
let mode = match mode_str.as_str() {
"prefill" => {
let at_chunk = std::env::var("MOEFLUX_GPU_CAPTURE_CHUNK")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0usize);
let at_layer = std::env::var("MOEFLUX_GPU_CAPTURE_LAYER")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0usize);
let n_layers = std::env::var("MOEFLUX_GPU_CAPTURE_N_LAYERS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(2usize)
.max(1);
CaptureMode::Prefill { at_chunk, at_layer, n_layers }
}
"decode" => {
let at_layer = std::env::var("MOEFLUX_GPU_CAPTURE_LAYER")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0usize);
let at_token = std::env::var("MOEFLUX_GPU_CAPTURE_TOKEN")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10usize);
CaptureMode::Decode { at_layer, at_token }
}
other => {
eprintln!(
"[gpu_capture] MOEFLUX_GPU_CAPTURE_MODE={other:?} \
unrecognised; valid: prefill | decode."
);
return None;
}
};
Some(GpuCaptureConfig {
path: PathBuf::from(path),
mode,
})
})
.as_ref()
}
pub fn start(device: &Device, cfg: &GpuCaptureConfig) {
let manager = CaptureManager::shared();
if manager.is_capturing() {
eprintln!("[gpu_capture] already capturing; skipping start");
return;
}
if cfg.path.exists() {
let _ = std::fs::remove_dir_all(&cfg.path);
let _ = std::fs::remove_file(&cfg.path);
}
let desc = CaptureDescriptor::new();
desc.set_capture_device(device);
desc.set_destination(MTLCaptureDestination::GpuTraceDocument);
desc.set_output_url(&cfg.path);
match manager.start_capture(&desc) {
Ok(()) => {
let window = match &cfg.mode {
CaptureMode::Prefill { at_chunk, at_layer, n_layers } => {
format!(
"prefill chunk={at_chunk} layers={at_layer}..{}",
at_layer + n_layers,
)
}
CaptureMode::Decode { at_layer, .. } => {
format!("decode layer={at_layer}")
}
};
eprintln!(
"[gpu_capture] started → {} ({window})",
cfg.path.display(),
);
}
Err(e) => eprintln!("[gpu_capture] start_capture failed: {e}"),
}
}
pub fn stop() {
let manager = CaptureManager::shared();
if !manager.is_capturing() {
return;
}
manager.stop_capture();
DECODE_ARMED.store(false, Ordering::Relaxed);
eprintln!("[gpu_capture] stopped");
}