Skip to main content

mlx_native/
kernel_profile.rs

1//! Per-command-buffer + per-dispatch GPU timing accumulator for kernel-level
2//! profiling.
3//!
4//! Hf2q's `HF2Q_DECODE_PROFILE=1` instrumentation tracks CPU-side wall
5//! clock per layer phase, but does not attribute time to specific GPU
6//! kernel dispatches.  The MoE dwq46 0.93× decode parity gap residual
7//! (per ADR-012 §Optimize / Task #15) cannot be localized further
8//! without per-cb (or per-dispatch) GPU timing.
9//!
10//! This module exposes two thread-safe accumulators:
11//!
12//! * **Per-CB** (`MLX_PROFILE_CB=1`) — a HashMap keyed by string label.
13//!   Each labeled `commit_and_wait` records the cb's GPU wall-clock
14//!   (`MTLCommandBuffer.GPUEndTime - GPUStartTime`).
15//! * **Per-dispatch** (`MLX_PROFILE_DISPATCH=1`, ADR-015 iter63) — a flat
16//!   `Vec<DispatchEntry>` populated from
17//!   `MTLCounterSampleBuffer.sampleCounters` between
18//!   `set_compute_pipeline_state` and `dispatch_threads` at every
19//!   `encode*` site.  Dump groups entries by their owning `cb_label`,
20//!   preserving insertion order within each group.
21//!
22//! At decode end, [`dump`] / [`dump_dispatches`] produce sorted
23//! breakdowns showing which labeled cb (and which kernel within each cb)
24//! contributed the most GPU time per token.
25//!
26//! ### Cross-validation (ADR-015 iter63 Risk R3)
27//!
28//! Per-dispatch numbers are **upper-bound serialized cost** — the
29//! `withBarrier:YES` requirement on `sampleCountersInBuffer` serializes
30//! the encoder under `MTLDispatchTypeConcurrent`. The per-CB sum will
31//! therefore be ≥ the matching `MLX_PROFILE_CB` total.  Acceptable
32//! drift: ≤ 5%; > 10% indicates a clock-domain or sampling bug.
33//!
34//! ### Apple Silicon caveat (NEW Risk discovered iter63 impl)
35//!
36//! Verified runtime: `AGXG17XFamilyComputeContext` (M-series, macOS 26)
37//! supports counter sampling **only** at
38//! `MTLCounterSamplingPoint::AtStageBoundary`, never
39//! `AtDispatchBoundary`.  The latter is required for sampling between
40//! dispatches inside a persistent compute encoder (which mlx-native
41//! uses to amortize ~800 encoder create/end cycles per forward pass).
42//! On such hardware, `MLX_PROFILE_DISPATCH=1` gracefully degrades to a
43//! no-op + one-shot stderr warning; only the per-CB path
44//! (`MLX_PROFILE_CB=1`) populates.  The kit is forward-compatible for
45//! AMD/Intel discrete and any future Apple silicon that reports
46//! `AtDispatchBoundary` support.
47
48use std::collections::HashMap;
49use std::sync::Mutex;
50use std::sync::OnceLock;
51use std::sync::atomic::{AtomicI8, AtomicU64, Ordering};
52
53/// Per-label accumulator entry.
54#[derive(Clone, Debug, Default)]
55pub struct ProfileEntry {
56    /// Number of times this label was recorded.
57    pub count: u64,
58    /// Total GPU wall-clock time in nanoseconds.
59    pub total_ns: u64,
60    /// Minimum observed GPU time in nanoseconds.
61    pub min_ns: u64,
62    /// Maximum observed GPU time in nanoseconds.
63    pub max_ns: u64,
64}
65
66/// One per-dispatch timing entry within a CB (ADR-015 iter63 Phase A).
67///
68/// Populated via `MTLCounterSampleBuffer.sampleCounters` calls inserted
69/// between pipeline binding and the actual `dispatch_threads` /
70/// `dispatch_thread_groups` call inside every
71/// [`crate::CommandEncoder::encode*`] method.  Resolved into ns from the
72/// raw GPU-tick samples by [`record_dispatch`] using the (cpu, gpu) pair
73/// captured at the most recent [`reset`] / [`dump_dispatches`] boundary.
74#[derive(Clone, Debug)]
75pub struct DispatchEntry {
76    /// The cb_label that owned this dispatch (mirrors the per-CB table key).
77    pub cb_label: String,
78    /// Captured op kind ("RmsNorm", "Sdpa", "ElemMul", "ElemAdd", "Softmax",
79    /// "Other").  See [`crate::CapturedOpKind::name`].
80    pub op_kind: &'static str,
81    /// 0-based ordinal within the CB.
82    pub dispatch_index: u32,
83    /// `end_gpu_ns - start_gpu_ns`, i.e. the wall-clock of this single
84    /// dispatch on the GPU.  Already converted from raw GPU ticks.
85    pub gpu_ns: u64,
86    /// Raw start timestamp (ns since CPU epoch, after tick→ns conversion).
87    pub start_gpu_ns: u64,
88    /// Raw end timestamp (ns since CPU epoch, after tick→ns conversion).
89    pub end_gpu_ns: u64,
90}
91
92fn table() -> &'static Mutex<HashMap<String, ProfileEntry>> {
93    static T: OnceLock<Mutex<HashMap<String, ProfileEntry>>> = OnceLock::new();
94    T.get_or_init(|| Mutex::new(HashMap::new()))
95}
96
97fn dispatch_table() -> &'static Mutex<Vec<DispatchEntry>> {
98    static T: OnceLock<Mutex<Vec<DispatchEntry>>> = OnceLock::new();
99    T.get_or_init(|| Mutex::new(Vec::new()))
100}
101
102/// Record a labeled GPU duration.
103///
104/// Called by `CommandEncoder::commit_and_wait_labeled` after reading
105/// `MTLCommandBuffer.GPUEndTime - GPUStartTime`.  Lock contention is
106/// negligible — the encoder serializes calls anyway.
107pub fn record(label: &str, gpu_ns: u64) {
108    if let Ok(mut t) = table().lock() {
109        let e = t.entry(label.to_string()).or_default();
110        if e.count == 0 || gpu_ns < e.min_ns {
111            e.min_ns = gpu_ns;
112        }
113        if gpu_ns > e.max_ns {
114            e.max_ns = gpu_ns;
115        }
116        e.count = e.count.saturating_add(1);
117        e.total_ns = e.total_ns.saturating_add(gpu_ns);
118    }
119}
120
121/// Append a per-dispatch entry to the global dispatch table.
122///
123/// Called from
124/// [`crate::CommandEncoder::resolve_dispatch_samples`] inside
125/// `commit_and_wait_labeled` when `MLX_PROFILE_DISPATCH=1` is set.  The
126/// caller has already converted raw GPU ticks to ns using the cached
127/// scale factor (see [`record_clock_pair`]).
128pub fn record_dispatch(entry: DispatchEntry) {
129    if let Ok(mut t) = dispatch_table().lock() {
130        t.push(entry);
131    }
132}
133
134/// Reset the profile tables.  Typically called at start of decode.
135///
136/// Clears both the per-CB table and the per-dispatch entries, and resets
137/// the GPU↔CPU clock-pair cache so the next dump reads a fresh baseline.
138pub fn reset() {
139    if let Ok(mut t) = table().lock() {
140        t.clear();
141    }
142    if let Ok(mut t) = dispatch_table().lock() {
143        t.clear();
144    }
145    CLOCK_CPU_NS.store(0, Ordering::Relaxed);
146    CLOCK_GPU_TICKS.store(0, Ordering::Relaxed);
147}
148
149/// Dump the per-CB profile table sorted by descending total_ns.
150///
151/// Returns `Vec<(label, entry)>` sorted by total time.
152pub fn dump() -> Vec<(String, ProfileEntry)> {
153    let mut v: Vec<(String, ProfileEntry)> = if let Ok(t) = table().lock() {
154        t.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
155    } else {
156        Vec::new()
157    };
158    v.sort_by(|a, b| b.1.total_ns.cmp(&a.1.total_ns));
159    v
160}
161
162/// Dump per-dispatch entries grouped by `cb_label`, preserving CB-arrival
163/// order within each group.
164///
165/// Returns `Vec<(cb_label, Vec<DispatchEntry>)>`.  The outer ordering
166/// follows the order in which each `cb_label` first appeared (i.e. the
167/// chronological CB submission order); the inner ordering follows the
168/// per-CB `dispatch_index` (insertion order from
169/// [`record_dispatch`]).
170pub fn dump_dispatches() -> Vec<(String, Vec<DispatchEntry>)> {
171    let entries = if let Ok(t) = dispatch_table().lock() {
172        t.clone()
173    } else {
174        return Vec::new();
175    };
176    // Group preserving first-appearance order of cb_label.
177    let mut order: Vec<String> = Vec::new();
178    let mut groups: HashMap<String, Vec<DispatchEntry>> = HashMap::new();
179    for e in entries {
180        let key = e.cb_label.clone();
181        if !groups.contains_key(&key) {
182            order.push(key.clone());
183        }
184        groups.entry(key).or_default().push(e);
185    }
186    order
187        .into_iter()
188        .map(|k| {
189            let v = groups.remove(&k).unwrap_or_default();
190            (k, v)
191        })
192        .collect()
193}
194
195/// Whether per-CB profiling is enabled via `MLX_PROFILE_CB=1`.
196///
197/// Cached in an atomic so the hot path is a single load.
198pub fn is_enabled() -> bool {
199    static CACHED: AtomicI8 = AtomicI8::new(-1);
200    let v = CACHED.load(Ordering::Relaxed);
201    if v >= 0 {
202        return v == 1;
203    }
204    let on = std::env::var("MLX_PROFILE_CB").is_ok();
205    CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
206    on
207}
208
209/// Whether per-DISPATCH profiling is enabled via `MLX_PROFILE_DISPATCH=1`.
210///
211/// Cached in an atomic so the hot path is a single load — same
212/// gating semantics as [`is_enabled`].  When set, the per-CB profile
213/// is also force-enabled (so cross-validation per Risk R3 is always
214/// possible) — see [`is_enabled_or_dispatch`].
215pub fn is_dispatch_enabled() -> bool {
216    static CACHED: AtomicI8 = AtomicI8::new(-1);
217    let v = CACHED.load(Ordering::Relaxed);
218    if v >= 0 {
219        return v == 1;
220    }
221    let on = std::env::var("MLX_PROFILE_DISPATCH").is_ok();
222    CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
223    on
224}
225
226// --------------------------------------------------------------------
227// GPU↔CPU clock-pair conversion (ADR-015 iter63 §A.6)
228// --------------------------------------------------------------------
229//
230// `MTLCommonCounterSetTimestamp` returns "GPU time when the sample is
231// taken" in **GPU ticks**, not nanoseconds.  Apple's
232// `device.sampleTimestamps(cpu, gpu)` fills both clocks simultaneously,
233// allowing us to derive a tick→ns scale factor.
234//
235// On Apple silicon the GPU timebase is typically 1 tick = 1 ns (verified
236// empirically; `mach_timebase_info` numer/denom = 125/3 nominal but the
237// GPU side reports the same nanosecond domain), but we do NOT hardcode
238// this — the pair is sampled once on first call and reused until
239// `reset()` clears the cache.  `convert_gpu_ticks_to_ns` falls back to
240// a 1:1 ratio if the pair has not been sampled yet (initial encoder
241// activity before the first `dump_dispatches` snapshot).
242//
243// Storing the CPU/GPU snapshot in two AtomicU64 instead of a Mutex keeps
244// the conversion lock-free on the per-dispatch resolve path.
245
246static CLOCK_CPU_NS: AtomicU64 = AtomicU64::new(0);
247static CLOCK_GPU_TICKS: AtomicU64 = AtomicU64::new(0);
248
249/// Record a `(cpu_ns, gpu_ticks)` snapshot from
250/// `MTLDevice.sampleTimestamps`.  Most recent snapshot wins.
251///
252/// Called from [`crate::CommandEncoder::resolve_dispatch_samples`] on
253/// the first resolve after a [`reset`] (or at any other CB boundary if
254/// the encoder chooses to refresh — both legal).
255pub fn record_clock_pair(cpu_ns: u64, gpu_ticks: u64) {
256    CLOCK_CPU_NS.store(cpu_ns, Ordering::Relaxed);
257    CLOCK_GPU_TICKS.store(gpu_ticks, Ordering::Relaxed);
258}
259
260/// Convert a raw GPU tick value to ns using the most recent
261/// `(cpu_ns, gpu_ticks)` pair, falling back to a 1:1 ratio when no
262/// pair has been recorded yet.
263///
264/// The conversion is `ns = ticks * (cpu_ns / gpu_ticks)`.  When the
265/// snapshot has CPU >> 0 and GPU >> 0 the math is exact at u64
266/// precision for any tick range under 2^32 (well past a single CB's
267/// dispatch count).
268pub fn convert_gpu_ticks_to_ns(gpu_ticks: u64) -> u64 {
269    let cpu = CLOCK_CPU_NS.load(Ordering::Relaxed);
270    let gpu = CLOCK_GPU_TICKS.load(Ordering::Relaxed);
271    if cpu == 0 || gpu == 0 {
272        // No snapshot — use 1:1 (best-effort; Apple silicon ships a
273        // nanosecond GPU timebase in practice).
274        return gpu_ticks;
275    }
276    // Avoid overflow: scale by f64 and round back.  At 6,000 dispatches
277    // per CB and ~10 ms per dispatch the ticks fit comfortably in f64
278    // mantissa (~2^53).
279    let scale = cpu as f64 / gpu as f64;
280    (gpu_ticks as f64 * scale) as u64
281}
282
283#[cfg(test)]
284#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn record_dump_reset_cycle() {
290        reset();
291        record("A", 100);
292        record("A", 200);
293        record("B", 50);
294        let d = dump();
295        // Sorted by total_ns descending.
296        assert_eq!(d.len(), 2);
297        assert_eq!(d[0].0, "A");
298        assert_eq!(d[0].1.count, 2);
299        assert_eq!(d[0].1.total_ns, 300);
300        assert_eq!(d[0].1.min_ns, 100);
301        assert_eq!(d[0].1.max_ns, 200);
302        assert_eq!(d[1].0, "B");
303        assert_eq!(d[1].1.count, 1);
304        reset();
305        assert!(dump().is_empty());
306    }
307
308    #[test]
309    fn dispatch_record_dump_reset_cycle() {
310        reset();
311        record_dispatch(DispatchEntry {
312            cb_label: "layer.attn[0]".into(),
313            op_kind: "RmsNorm",
314            dispatch_index: 0,
315            gpu_ns: 100,
316            start_gpu_ns: 1_000,
317            end_gpu_ns: 1_100,
318        });
319        record_dispatch(DispatchEntry {
320            cb_label: "layer.attn[0]".into(),
321            op_kind: "Sdpa",
322            dispatch_index: 1,
323            gpu_ns: 500,
324            start_gpu_ns: 1_100,
325            end_gpu_ns: 1_600,
326        });
327        record_dispatch(DispatchEntry {
328            cb_label: "layer.ffn[0]".into(),
329            op_kind: "Other",
330            dispatch_index: 0,
331            gpu_ns: 250,
332            start_gpu_ns: 2_000,
333            end_gpu_ns: 2_250,
334        });
335        let dumps = dump_dispatches();
336        // Group order matches first-appearance (attn, then ffn).
337        assert_eq!(dumps.len(), 2);
338        assert_eq!(dumps[0].0, "layer.attn[0]");
339        assert_eq!(dumps[0].1.len(), 2);
340        // Within-group order follows insertion order.
341        assert_eq!(dumps[0].1[0].dispatch_index, 0);
342        assert_eq!(dumps[0].1[0].op_kind, "RmsNorm");
343        assert_eq!(dumps[0].1[1].dispatch_index, 1);
344        assert_eq!(dumps[0].1[1].op_kind, "Sdpa");
345        assert_eq!(dumps[1].0, "layer.ffn[0]");
346        assert_eq!(dumps[1].1.len(), 1);
347        reset();
348        assert!(dump_dispatches().is_empty());
349    }
350
351    #[test]
352    fn dispatch_dump_empty_when_no_entries() {
353        reset();
354        assert!(dump_dispatches().is_empty());
355    }
356
357    #[test]
358    fn convert_gpu_ticks_default_one_to_one() {
359        // After reset(), no pair → 1:1 fallback.
360        reset();
361        assert_eq!(convert_gpu_ticks_to_ns(12_345), 12_345);
362    }
363
364    #[test]
365    fn convert_gpu_ticks_with_recorded_pair() {
366        reset();
367        // Suppose 1 GPU tick = 2 ns → cpu_ns / gpu_ticks = 2.0.
368        record_clock_pair(2_000, 1_000);
369        assert_eq!(convert_gpu_ticks_to_ns(500), 1_000);
370        assert_eq!(convert_gpu_ticks_to_ns(0), 0);
371    }
372
373    #[test]
374    fn convert_gpu_ticks_zero_pair_is_one_to_one() {
375        reset();
376        // Exactly-zero pair acts as "unrecorded".
377        record_clock_pair(0, 1_000);
378        assert_eq!(convert_gpu_ticks_to_ns(7), 7);
379        record_clock_pair(2_000, 0);
380        assert_eq!(convert_gpu_ticks_to_ns(7), 7);
381    }
382}