use crate::trace_event::TraceEvent;
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use std::collections::hash_map::DefaultHasher;
use std::fs::{self, File};
use std::hash::{Hash, Hasher};
use std::io::Write;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
static PROFILER: Lazy<Mutex<Profiler>> = Lazy::new(|| Mutex::new(Profiler::new()));
pub struct Profiler {
current_session: Option<String>,
file: Option<File>,
profile_count: usize,
max_frames: Option<usize>,
frame_count: usize,
max_duration_ms: Option<u64>,
start_time: Option<Instant>,
}
impl Profiler {
pub fn new() -> Self {
Self {
current_session: None,
file: None,
profile_count: 0,
max_frames: None,
frame_count: 0,
max_duration_ms: None,
start_time: None,
}
}
pub fn frame_count(&self) -> usize {
self.frame_count
}
pub fn get() -> &'static Mutex<Profiler> {
&PROFILER
}
pub fn begin_session_limited(
&mut self,
name: &str,
filepath: &str,
max_frames: Option<usize>,
max_duration_ms: Option<u64>,
) {
self.begin_session(name, filepath);
self.max_frames = max_frames;
self.max_duration_ms = max_duration_ms;
self.frame_count = 0;
self.start_time = Some(Instant::now());
}
pub fn begin_session(&mut self, name: &str, filepath: &str) {
if self.current_session.is_some() {
self.end_session().unwrap();
}
if let Some(parent) = std::path::Path::new(filepath).parent() {
fs::create_dir_all(parent).unwrap();
}
let file = File::create(filepath).unwrap();
self.file = Some(file);
self.current_session = Some(name.to_string());
self.profile_count = 0;
self.write_header();
}
pub fn end_session(&mut self) -> Result<(), String> {
if self.current_session.is_none() {
return Ok(());
}
self.write_footer();
self.file = None;
self.current_session = None;
self.profile_count = 0;
Ok(())
}
pub fn next_frame(&mut self) {
self.frame_count += 1;
if let Some(max) = self.max_frames {
if self.frame_count >= max {
self.end_session().unwrap();
return;
}
}
if let Some(max_ms) = self.max_duration_ms {
if let Some(start) = self.start_time {
let elapsed = start.elapsed().as_millis() as u64;
if elapsed >= max_ms {
self.end_session().unwrap();
}
}
}
}
fn write_profile(&mut self, result: &TraceEvent) {
if self.current_session.is_none() {
return;
}
if let Some(file) = &mut self.file {
if self.profile_count > 0 {
file.write_all(b",").unwrap();
}
self.profile_count += 1;
let thread_name = result
.args
.as_ref()
.and_then(|a| a.get("thread_name"))
.and_then(|v| v.as_str())
.unwrap_or("unnamed")
.to_string();
let mut result = result.clone();
result.args = Some(serde_json::json!({
"frame": self.frame_count,
"thread_name": thread_name,
}));
let json = serde_json::to_string(&result).unwrap();
file.write_all(json.as_bytes()).unwrap();
file.flush().unwrap();
}
}
fn write_header(&mut self) {
if let Some(file) = &mut self.file {
file.write_all(b"{\"otherData\": {},\"traceEvents\":[")
.unwrap();
file.flush().unwrap();
}
}
fn write_footer(&mut self) {
if let Some(file) = &mut self.file {
file.write_all(b"]}").unwrap();
file.flush().unwrap();
}
}
}
pub struct ProfilerTimer {
name: String,
start_point: Instant,
start_timestamp: u64,
frame_index: usize,
thread_name: String,
is_stopped: bool,
}
impl ProfilerTimer {
pub fn new(name: &str) -> Self {
let frame_index = Profiler::get().lock().frame_count();
let thread_name = std::thread::current()
.name()
.unwrap_or("unnamed")
.to_string();
Self {
name: name.to_string(),
start_point: Instant::now(),
start_timestamp: timestamp_micros(),
frame_index,
thread_name,
is_stopped: false,
}
}
pub fn stop(&mut self) {
if self.is_stopped {
return;
}
self.is_stopped = true;
let end_point = Instant::now();
let mut trace = TraceEvent::complete(
self.name.clone(),
self.start_timestamp,
(end_point - self.start_point).as_micros() as u64,
thread_id_u64(),
);
trace.args = Some(serde_json::json!({
"frame": self.frame_index,
"thread_name": self.thread_name,
}));
Profiler::get().lock().write_profile(&trace);
}
}
impl Drop for ProfilerTimer {
fn drop(&mut self) {
if !self.is_stopped {
self.stop();
}
}
}
fn timestamp_micros() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u64
}
fn thread_id_u64() -> u64 {
let thread_id = std::thread::current().id();
let mut hasher = DefaultHasher::new();
thread_id.hash(&mut hasher);
hasher.finish()
}