use std::collections::HashMap;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicI8, AtomicU64, Ordering};
#[derive(Clone, Debug, Default)]
pub struct ProfileEntry {
pub count: u64,
pub total_ns: u64,
pub min_ns: u64,
pub max_ns: u64,
}
#[derive(Clone, Debug)]
pub struct DispatchEntry {
pub cb_label: String,
pub op_kind: &'static str,
pub dispatch_index: u32,
pub gpu_ns: u64,
pub start_gpu_ns: u64,
pub end_gpu_ns: u64,
}
fn table() -> &'static Mutex<HashMap<String, ProfileEntry>> {
static T: OnceLock<Mutex<HashMap<String, ProfileEntry>>> = OnceLock::new();
T.get_or_init(|| Mutex::new(HashMap::new()))
}
fn dispatch_table() -> &'static Mutex<Vec<DispatchEntry>> {
static T: OnceLock<Mutex<Vec<DispatchEntry>>> = OnceLock::new();
T.get_or_init(|| Mutex::new(Vec::new()))
}
pub fn record(label: &str, gpu_ns: u64) {
if let Ok(mut t) = table().lock() {
let e = t.entry(label.to_string()).or_default();
if e.count == 0 || gpu_ns < e.min_ns {
e.min_ns = gpu_ns;
}
if gpu_ns > e.max_ns {
e.max_ns = gpu_ns;
}
e.count = e.count.saturating_add(1);
e.total_ns = e.total_ns.saturating_add(gpu_ns);
}
}
pub fn record_dispatch(entry: DispatchEntry) {
if let Ok(mut t) = dispatch_table().lock() {
t.push(entry);
}
}
pub fn reset() {
if let Ok(mut t) = table().lock() {
t.clear();
}
if let Ok(mut t) = dispatch_table().lock() {
t.clear();
}
CLOCK_CPU_NS.store(0, Ordering::Relaxed);
CLOCK_GPU_TICKS.store(0, Ordering::Relaxed);
}
pub fn dump() -> Vec<(String, ProfileEntry)> {
let mut v: Vec<(String, ProfileEntry)> = if let Ok(t) = table().lock() {
t.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
} else {
Vec::new()
};
v.sort_by(|a, b| b.1.total_ns.cmp(&a.1.total_ns));
v
}
pub fn dump_dispatches() -> Vec<(String, Vec<DispatchEntry>)> {
let entries = if let Ok(t) = dispatch_table().lock() {
t.clone()
} else {
return Vec::new();
};
let mut order: Vec<String> = Vec::new();
let mut groups: HashMap<String, Vec<DispatchEntry>> = HashMap::new();
for e in entries {
let key = e.cb_label.clone();
if !groups.contains_key(&key) {
order.push(key.clone());
}
groups.entry(key).or_default().push(e);
}
order
.into_iter()
.map(|k| {
let v = groups.remove(&k).unwrap_or_default();
(k, v)
})
.collect()
}
pub fn is_enabled() -> bool {
static CACHED: AtomicI8 = AtomicI8::new(-1);
let v = CACHED.load(Ordering::Relaxed);
if v >= 0 {
return v == 1;
}
let on = std::env::var("MLX_PROFILE_CB").is_ok();
CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
on
}
pub fn is_dispatch_enabled() -> bool {
static CACHED: AtomicI8 = AtomicI8::new(-1);
let v = CACHED.load(Ordering::Relaxed);
if v >= 0 {
return v == 1;
}
let on = std::env::var("MLX_PROFILE_DISPATCH").is_ok();
CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
on
}
static CLOCK_CPU_NS: AtomicU64 = AtomicU64::new(0);
static CLOCK_GPU_TICKS: AtomicU64 = AtomicU64::new(0);
pub fn record_clock_pair(cpu_ns: u64, gpu_ticks: u64) {
CLOCK_CPU_NS.store(cpu_ns, Ordering::Relaxed);
CLOCK_GPU_TICKS.store(gpu_ticks, Ordering::Relaxed);
}
pub fn convert_gpu_ticks_to_ns(gpu_ticks: u64) -> u64 {
let cpu = CLOCK_CPU_NS.load(Ordering::Relaxed);
let gpu = CLOCK_GPU_TICKS.load(Ordering::Relaxed);
if cpu == 0 || gpu == 0 {
return gpu_ticks;
}
let scale = cpu as f64 / gpu as f64;
(gpu_ticks as f64 * scale) as u64
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn record_dump_reset_cycle() {
reset();
record("A", 100);
record("A", 200);
record("B", 50);
let d = dump();
assert_eq!(d.len(), 2);
assert_eq!(d[0].0, "A");
assert_eq!(d[0].1.count, 2);
assert_eq!(d[0].1.total_ns, 300);
assert_eq!(d[0].1.min_ns, 100);
assert_eq!(d[0].1.max_ns, 200);
assert_eq!(d[1].0, "B");
assert_eq!(d[1].1.count, 1);
reset();
assert!(dump().is_empty());
}
#[test]
fn dispatch_record_dump_reset_cycle() {
reset();
record_dispatch(DispatchEntry {
cb_label: "layer.attn[0]".into(),
op_kind: "RmsNorm",
dispatch_index: 0,
gpu_ns: 100,
start_gpu_ns: 1_000,
end_gpu_ns: 1_100,
});
record_dispatch(DispatchEntry {
cb_label: "layer.attn[0]".into(),
op_kind: "Sdpa",
dispatch_index: 1,
gpu_ns: 500,
start_gpu_ns: 1_100,
end_gpu_ns: 1_600,
});
record_dispatch(DispatchEntry {
cb_label: "layer.ffn[0]".into(),
op_kind: "Other",
dispatch_index: 0,
gpu_ns: 250,
start_gpu_ns: 2_000,
end_gpu_ns: 2_250,
});
let dumps = dump_dispatches();
assert_eq!(dumps.len(), 2);
assert_eq!(dumps[0].0, "layer.attn[0]");
assert_eq!(dumps[0].1.len(), 2);
assert_eq!(dumps[0].1[0].dispatch_index, 0);
assert_eq!(dumps[0].1[0].op_kind, "RmsNorm");
assert_eq!(dumps[0].1[1].dispatch_index, 1);
assert_eq!(dumps[0].1[1].op_kind, "Sdpa");
assert_eq!(dumps[1].0, "layer.ffn[0]");
assert_eq!(dumps[1].1.len(), 1);
reset();
assert!(dump_dispatches().is_empty());
}
#[test]
fn dispatch_dump_empty_when_no_entries() {
reset();
assert!(dump_dispatches().is_empty());
}
#[test]
fn convert_gpu_ticks_default_one_to_one() {
reset();
assert_eq!(convert_gpu_ticks_to_ns(12_345), 12_345);
}
#[test]
fn convert_gpu_ticks_with_recorded_pair() {
reset();
record_clock_pair(2_000, 1_000);
assert_eq!(convert_gpu_ticks_to_ns(500), 1_000);
assert_eq!(convert_gpu_ticks_to_ns(0), 0);
}
#[test]
fn convert_gpu_ticks_zero_pair_is_one_to_one() {
reset();
record_clock_pair(0, 1_000);
assert_eq!(convert_gpu_ticks_to_ns(7), 7);
record_clock_pair(2_000, 0);
assert_eq!(convert_gpu_ticks_to_ns(7), 7);
}
}