use std::collections::HashMap;
use std::sync::Mutex;
use std::sync::OnceLock;
#[derive(Clone, Debug, Default)]
pub struct ProfileEntry {
pub count: u64,
pub total_ns: u64,
pub min_ns: u64,
pub max_ns: u64,
}
fn table() -> &'static Mutex<HashMap<String, ProfileEntry>> {
static T: OnceLock<Mutex<HashMap<String, ProfileEntry>>> = OnceLock::new();
T.get_or_init(|| Mutex::new(HashMap::new()))
}
pub fn record(label: &str, gpu_ns: u64) {
if let Ok(mut t) = table().lock() {
let e = t.entry(label.to_string()).or_default();
if e.count == 0 || gpu_ns < e.min_ns {
e.min_ns = gpu_ns;
}
if gpu_ns > e.max_ns {
e.max_ns = gpu_ns;
}
e.count = e.count.saturating_add(1);
e.total_ns = e.total_ns.saturating_add(gpu_ns);
}
}
pub fn reset() {
if let Ok(mut t) = table().lock() {
t.clear();
}
}
pub fn dump() -> Vec<(String, ProfileEntry)> {
let mut v: Vec<(String, ProfileEntry)> = if let Ok(t) = table().lock() {
t.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
} else {
Vec::new()
};
v.sort_by(|a, b| b.1.total_ns.cmp(&a.1.total_ns));
v
}
pub fn is_enabled() -> bool {
use std::sync::atomic::{AtomicI8, Ordering};
static CACHED: AtomicI8 = AtomicI8::new(-1);
let v = CACHED.load(Ordering::Relaxed);
if v >= 0 {
return v == 1;
}
let on = std::env::var("MLX_PROFILE_CB").is_ok();
CACHED.store(if on { 1 } else { 0 }, Ordering::Relaxed);
on
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn record_dump_reset_cycle() {
reset();
record("A", 100);
record("A", 200);
record("B", 50);
let d = dump();
assert_eq!(d.len(), 2);
assert_eq!(d[0].0, "A");
assert_eq!(d[0].1.count, 2);
assert_eq!(d[0].1.total_ns, 300);
assert_eq!(d[0].1.min_ns, 100);
assert_eq!(d[0].1.max_ns, 200);
assert_eq!(d[1].0, "B");
assert_eq!(d[1].1.count, 1);
reset();
assert!(dump().is_empty());
}
}