mlx_native/kernel_profile.rs
1//! Per-command-buffer GPU timing accumulator for kernel-level profiling.
2//!
3//! Hf2q's `HF2Q_DECODE_PROFILE=1` instrumentation tracks CPU-side wall
4//! clock per layer phase, but does not attribute time to specific GPU
5//! kernel dispatches. The MoE dwq46 0.93× decode parity gap residual
6//! (per ADR-012 §Optimize / Task #15) cannot be localized further
7//! without per-cb (or per-dispatch) GPU timing.
8//!
9//! This module exposes a thread-safe accumulator keyed by string label.
10//! Each labeled `commit_and_wait` records the cb's GPU wall-clock
11//! (`MTLCommandBuffer.GPUEndTime - GPUStartTime`). At decode end,
12//! `dump()` produces a sorted breakdown showing which labeled cb
13//! contributed the most GPU time per token.
14//!
15//! Per-DISPATCH timing (using `MTLCounterSampleBuffer.sampleCounters`)
16//! is a separate Metal API surface deferred to a future ADR; this
17//! module establishes the per-CB ground truth first.
18
19use std::collections::HashMap;
20use std::sync::Mutex;
21use std::sync::OnceLock;
22
23/// Per-label accumulator entry.
24#[derive(Clone, Debug, Default)]
25pub struct ProfileEntry {
26 /// Number of times this label was recorded.
27 pub count: u64,
28 /// Total GPU wall-clock time in nanoseconds.
29 pub total_ns: u64,
30 /// Minimum observed GPU time in nanoseconds.
31 pub min_ns: u64,
32 /// Maximum observed GPU time in nanoseconds.
33 pub max_ns: u64,
34}
35
36fn table() -> &'static Mutex<HashMap<String, ProfileEntry>> {
37 static T: OnceLock<Mutex<HashMap<String, ProfileEntry>>> = OnceLock::new();
38 T.get_or_init(|| Mutex::new(HashMap::new()))
39}
40
41/// Record a labeled GPU duration.
42///
43/// Called by `CommandEncoder::commit_and_wait_labeled` after reading
44/// `MTLCommandBuffer.GPUEndTime - GPUStartTime`. Lock contention is
45/// negligible — the encoder serializes calls anyway.
46pub fn record(label: &str, gpu_ns: u64) {
47 if let Ok(mut t) = table().lock() {
48 let e = t.entry(label.to_string()).or_default();
49 if e.count == 0 || gpu_ns < e.min_ns {
50 e.min_ns = gpu_ns;
51 }
52 if gpu_ns > e.max_ns {
53 e.max_ns = gpu_ns;
54 }
55 e.count = e.count.saturating_add(1);
56 e.total_ns = e.total_ns.saturating_add(gpu_ns);
57 }
58}
59
60/// Reset the profile table. Typically called at start of decode.
61pub fn reset() {
62 if let Ok(mut t) = table().lock() {
63 t.clear();
64 }
65}
66
67/// Dump the profile table sorted by descending total_ns.
68///
69/// Returns `Vec<(label, entry)>` sorted by total time.
70pub fn dump() -> Vec<(String, ProfileEntry)> {
71 let mut v: Vec<(String, ProfileEntry)> = if let Ok(t) = table().lock() {
72 t.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
73 } else {
74 Vec::new()
75 };
76 v.sort_by(|a, b| b.1.total_ns.cmp(&a.1.total_ns));
77 v
78}
79
80/// Whether per-CB profiling is enabled via `MLX_PROFILE_CB=1`.
81///
82/// Cached in an atomic so the hot path is a single load.
83pub fn is_enabled() -> bool {
84 use std::sync::atomic::{AtomicI8, Ordering};
85 static CACHED: AtomicI8 = AtomicI8::new(-1);
86 let v = CACHED.load(Ordering::Relaxed);
87 if v >= 0 {
88 return v == 1;
89 }
90 let on = std::env::var("MLX_PROFILE_CB").is_ok();
91 CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
92 on
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn record_dump_reset_cycle() {
101 reset();
102 record("A", 100);
103 record("A", 200);
104 record("B", 50);
105 let d = dump();
106 // Sorted by total_ns descending.
107 assert_eq!(d.len(), 2);
108 assert_eq!(d[0].0, "A");
109 assert_eq!(d[0].1.count, 2);
110 assert_eq!(d[0].1.total_ns, 300);
111 assert_eq!(d[0].1.min_ns, 100);
112 assert_eq!(d[0].1.max_ns, 200);
113 assert_eq!(d[1].0, "B");
114 assert_eq!(d[1].1.count, 1);
115 reset();
116 assert!(dump().is_empty());
117 }
118}