use parking_lot::RwLock;
use std::collections::hash_map::RandomState;
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::hash::{BuildHasher, Hash};
use std::io;
use std::sync::LazyLock;
use std::time::{Duration, Instant};
use thread_local::ThreadLocal;
pub static PROFILER: LazyLock<ThreadLocal<RwLock<Profiler>>> = LazyLock::new(ThreadLocal::new);
pub static RANDOM_STATE: LazyLock<RandomState> = LazyLock::new(RandomState::new);
#[doc(hidden)]
#[macro_export]
macro_rules! profile_impl {
($name:expr) => {
let (_profiling_scope_guard, _) = $crate::profiling::PROFILER
.get_or(Default::default)
.write()
.enter($name);
};
($scope_id:ident, $name:expr) => {
let (_profiling_scope_guard, $scope_id) = $crate::profiling::PROFILER
.get_or(Default::default)
.write()
.enter($name);
};
($name:expr, parent = $parent_id:ident) => {
let (_profiling_scope_guard, _) = $crate::profiling::PROFILER
.get_or(Default::default)
.write()
.enter_with_parent($name, &$parent_id);
};
($scope_id:ident, $name:expr, parent = $parent_id:ident) => {
let (_profiling_scope_guard, $scope_id) = $crate::profiling::PROFILER
.get_or(Default::default)
.write()
.enter_with_parent($name, &$parent_id);
};
}
pub struct Guard {
enter_time: Instant,
}
impl Guard {
fn new() -> Self {
Self {
enter_time: Instant::now(),
}
}
}
impl Drop for Guard {
fn drop(&mut self) {
let duration = self.enter_time.elapsed();
PROFILER
.get()
.expect("Missing thread local profiler")
.write()
.leave(duration)
}
}
#[derive(Clone, Debug)]
struct Scope {
name: &'static str,
num_calls: usize,
duration_sum: Duration,
first_call: Instant,
}
impl Scope {
fn new(name: &'static str) -> Self {
Scope {
name,
num_calls: 0,
duration_sum: Default::default(),
first_call: Instant::now(),
}
}
fn merge(&mut self, other: &Self) {
if other.name == self.name {
self.num_calls += other.num_calls;
self.duration_sum += other.duration_sum;
if other.first_call < self.first_call {
self.first_call = other.first_call;
}
}
}
}
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct ScopeId {
name: &'static str,
parent_hash: u64,
}
impl ScopeId {
fn get_hash(id: Option<&ScopeId>) -> u64 {
RANDOM_STATE.hash_one(id)
}
}
#[derive(Default)]
pub struct Profiler {
scopes: HashMap<ScopeId, Scope>,
scope_stack: Vec<ScopeId>,
roots: HashSet<ScopeId>,
}
impl Profiler {
pub fn reset(&mut self) {
self.scopes.clear();
self.scope_stack.clear();
self.roots.clear();
}
pub fn enter(&mut self, name: &'static str) -> (Guard, ScopeId) {
let id = self.new_id(name, self.scope_stack.last());
self.enter_with_id(name, id)
}
pub fn enter_with_parent(&mut self, name: &'static str, parent: &ScopeId) -> (Guard, ScopeId) {
let id = self.new_id(name, Some(parent));
self.enter_with_id(name, id)
}
fn enter_with_id(&mut self, name: &'static str, id: ScopeId) -> (Guard, ScopeId) {
self.scopes.entry(id).or_insert_with(|| Scope::new(name));
if self.scope_stack.is_empty() {
self.roots.insert(id);
}
self.scope_stack.push(id);
(Guard::new(), id)
}
fn leave(&mut self, duration: Duration) {
if let Some(id) = self.scope_stack.pop() {
if let Some(scope) = self.scopes.get_mut(&id) {
scope.num_calls += 1;
scope.duration_sum += duration;
}
}
}
fn new_id(&self, name: &'static str, parent: Option<&ScopeId>) -> ScopeId {
ScopeId {
name,
parent_hash: ScopeId::get_hash(parent),
}
}
}
fn write_recursively<W: io::Write>(
out: &mut W,
sorted_scopes: &[(ScopeId, Scope)],
current: &(ScopeId, Scope),
parent_duration: Option<Duration>,
depth: usize,
is_parallel: bool,
) -> io::Result<()> {
let (id, scope) = current;
for _ in 0..depth {
write!(out, " ")?;
}
let duration_sum_secs = scope.duration_sum.as_secs_f64();
let parent_duration_secs = parent_duration.map_or(duration_sum_secs, |t| t.as_secs_f64());
let percent = duration_sum_secs / parent_duration_secs * 100.0;
writeln!(
out,
"{}: {}{:3.2}%, {:>4.2}ms avg, {} {} (total: {:.3}s)",
scope.name,
if is_parallel { "≈" } else { "" },
percent,
duration_sum_secs * 1000.0 / (scope.num_calls as f64),
scope.num_calls,
if scope.num_calls > 1 { "calls" } else { "call" },
duration_sum_secs
)?;
let mut children_runtime = Duration::default();
let current_hash = ScopeId::get_hash(Some(id));
for s in sorted_scopes {
let (child_id, child_scope) = s;
if child_id.parent_hash == current_hash {
children_runtime += child_scope.duration_sum;
}
}
let children_parallel = children_runtime > scope.duration_sum;
let own_runtime = if children_parallel {
children_runtime
} else {
scope.duration_sum
};
let current_hash = ScopeId::get_hash(Some(id));
for s in sorted_scopes {
let (child_id, _) = s;
if child_id.parent_hash == current_hash {
write_recursively(
out,
sorted_scopes,
s,
Some(own_runtime),
depth + 1,
children_parallel,
)?;
}
}
Ok(())
}
pub fn write<W: io::Write>(out: &mut W) -> io::Result<()> {
let mut merged_scopes = HashMap::<ScopeId, Scope>::new();
let mut roots = HashSet::<ScopeId>::new();
for profiler in PROFILER.iter() {
let profiler = profiler.read();
roots.extend(profiler.roots.iter());
for (&id, scope) in &profiler.scopes {
merged_scopes
.entry(id)
.and_modify(|s| s.merge(scope))
.or_insert_with(|| scope.clone());
}
}
let sorted_roots = {
let root_hash = ScopeId::get_hash(None);
let mut roots = roots
.into_iter()
.filter(|id| id.parent_hash == root_hash)
.flat_map(|id| merged_scopes.get(&id).cloned().map(|s| (id, s)))
.collect::<Vec<_>>();
roots.sort_unstable_by_key(|(_, s)| s.first_call);
roots
};
let sorted_scopes = {
let mut scopes = merged_scopes.into_iter().collect::<Vec<_>>();
scopes.sort_unstable_by_key(|(_, s)| s.first_call);
scopes
};
for root in &sorted_roots {
write_recursively(out, sorted_scopes.as_slice(), root, None, 0, false)?;
}
Ok(())
}
pub fn write_to_string() -> Result<String, Box<dyn Error>> {
let mut buffer = Vec::new();
write(&mut buffer)?;
Ok(String::from_utf8(buffer)?)
}
pub fn reset() {
for profiler in PROFILER.iter() {
profiler.write().reset();
}
}