use crate::{OverheadStats, ProfileEvent, TorshResult};
use parking_lot::Mutex;
use std::sync::Arc;
use std::time::Instant;
use torsh_core::TorshError;
#[derive(Debug, Clone)]
pub struct Profiler {
pub events: Vec<ProfileEvent>,
pub enabled: bool,
pub stack_traces_enabled: bool,
pub overhead_tracking_enabled: bool,
pub overhead_stats: OverheadStats,
pub start_time: Option<Instant>,
}
impl Profiler {
pub fn new() -> Self {
Self {
events: Vec::new(),
enabled: false,
stack_traces_enabled: false,
overhead_tracking_enabled: false,
overhead_stats: OverheadStats::default(),
start_time: None,
}
}
pub fn start(&mut self) {
self.enabled = true;
self.start_time = Some(Instant::now());
}
pub fn stop(&mut self) {
self.enabled = false;
}
pub fn clear(&mut self) {
self.events.clear();
self.overhead_stats = OverheadStats::default();
}
pub fn add_event(&mut self, mut event: ProfileEvent) {
if !self.enabled {
return;
}
let start_overhead = if self.overhead_tracking_enabled {
Some(Instant::now())
} else {
None
};
if let Some(profiler_start) = self.start_time {
event.start_us = profiler_start.elapsed().as_micros() as u64;
}
self.events.push(event);
if let Some(start) = start_overhead {
let overhead_ns = start.elapsed().as_nanos() as u64;
self.overhead_stats.add_event_time_ns += overhead_ns;
self.overhead_stats.add_event_count += 1;
self.overhead_stats.total_overhead_ns += overhead_ns;
}
}
pub fn get_stats(&self) -> (u64, u64, u64, f64, f64) {
let count = self.events.len() as u64;
let total_duration: u64 = self.events.iter().map(|e| e.duration_us).sum();
let avg_duration = if count > 0 {
total_duration as f64 / count as f64
} else {
0.0
};
let min_duration = self.events.iter().map(|e| e.duration_us).min().unwrap_or(0) as f64;
let max_duration = self.events.iter().map(|e| e.duration_us).max().unwrap_or(0) as f64;
(
count,
total_duration,
total_duration,
min_duration,
max_duration,
)
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn set_stack_traces_enabled(&mut self, enabled: bool) {
self.stack_traces_enabled = enabled;
}
pub fn are_stack_traces_enabled(&self) -> bool {
self.stack_traces_enabled
}
pub fn set_overhead_tracking_enabled(&mut self, enabled: bool) {
self.overhead_tracking_enabled = enabled;
}
pub fn is_overhead_tracking_enabled(&self) -> bool {
self.overhead_tracking_enabled
}
pub fn get_overhead_stats(&self) -> &OverheadStats {
&self.overhead_stats
}
pub fn reset_overhead_stats(&mut self) {
self.overhead_stats = OverheadStats::default();
}
pub fn events(&self) -> &[ProfileEvent] {
&self.events
}
pub fn event_count(&self) -> usize {
self.events.len()
}
pub fn total_duration_us(&self) -> u64 {
self.events.iter().map(|e| e.duration_us).sum()
}
}
impl Default for Profiler {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_PROFILER: once_cell::sync::Lazy<Arc<Mutex<Profiler>>> =
once_cell::sync::Lazy::new(|| Arc::new(Mutex::new(Profiler::new())));
pub fn global_profiler() -> Arc<Mutex<Profiler>> {
GLOBAL_PROFILER.clone()
}
pub fn start_profiling() {
global_profiler().lock().start();
}
pub fn stop_profiling() {
global_profiler().lock().stop();
}
pub fn clear_global_events() {
global_profiler().lock().clear();
}
pub fn add_global_event(name: &str, category: &str, duration_us: u64, thread_id: usize) {
let event = ProfileEvent {
name: name.to_string(),
category: category.to_string(),
start_us: 0, duration_us,
thread_id,
operation_count: None,
flops: None,
bytes_transferred: None,
stack_trace: None,
};
global_profiler().lock().add_event(event);
}
pub fn get_global_stats() -> TorshResult<(u64, u64, u64, f64, f64)> {
Ok(global_profiler().lock().get_stats())
}
pub fn set_global_stack_traces_enabled(enabled: bool) {
global_profiler().lock().set_stack_traces_enabled(enabled);
}
pub fn are_global_stack_traces_enabled() -> bool {
global_profiler().lock().are_stack_traces_enabled()
}
pub fn set_global_overhead_tracking_enabled(enabled: bool) {
global_profiler()
.lock()
.set_overhead_tracking_enabled(enabled);
}
pub fn is_global_overhead_tracking_enabled() -> bool {
global_profiler().lock().is_overhead_tracking_enabled()
}
pub fn get_global_overhead_stats() -> OverheadStats {
global_profiler().lock().get_overhead_stats().clone()
}
pub fn reset_global_overhead_stats() {
global_profiler().lock().reset_overhead_stats();
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_profiler_lifecycle() {
let mut profiler = Profiler::new();
assert!(!profiler.is_enabled());
profiler.start();
assert!(profiler.is_enabled());
profiler.stop();
assert!(!profiler.is_enabled());
}
#[test]
fn test_event_collection() {
let mut profiler = Profiler::new();
profiler.start();
let event = ProfileEvent {
name: "test_event".to_string(),
category: "test".to_string(),
start_us: 0,
duration_us: 1000,
thread_id: 1,
operation_count: None,
flops: None,
bytes_transferred: None,
stack_trace: None,
};
profiler.add_event(event);
assert_eq!(profiler.event_count(), 1);
assert_eq!(profiler.total_duration_us(), 1000);
}
#[test]
#[ignore = "Flaky test - passes individually but may fail in full suite"]
fn test_overhead_tracking() {
let mut profiler = Profiler::new();
profiler.set_overhead_tracking_enabled(true);
profiler.start();
let event = ProfileEvent {
name: "test_overhead".to_string(),
category: "test".to_string(),
start_us: 0,
duration_us: 500,
thread_id: 1,
operation_count: None,
flops: None,
bytes_transferred: None,
stack_trace: None,
};
profiler.add_event(event);
let stats = profiler.get_overhead_stats();
assert_eq!(stats.add_event_count, 1);
assert!(stats.add_event_time_ns > 0);
assert!(stats.total_overhead_ns > 0);
}
#[test]
fn test_global_profiler() {
start_profiling();
add_global_event("global_test", "test", 2000, 1);
let stats = get_global_stats().expect("get global stats should succeed");
assert!(stats.0 > 0); assert!(stats.1 > 0);
clear_global_events();
stop_profiling();
}
#[test]
fn test_stack_trace_settings() {
let mut profiler = Profiler::new();
assert!(!profiler.are_stack_traces_enabled());
profiler.set_stack_traces_enabled(true);
assert!(profiler.are_stack_traces_enabled());
profiler.set_stack_traces_enabled(false);
assert!(!profiler.are_stack_traces_enabled());
}
#[test]
fn test_profiler_statistics() {
let mut profiler = Profiler::new();
profiler.start();
for i in 1..=5 {
let event = ProfileEvent {
name: format!("event_{}", i),
category: "test".to_string(),
start_us: 0,
duration_us: i * 100,
thread_id: 1,
operation_count: None,
flops: None,
bytes_transferred: None,
stack_trace: None,
};
profiler.add_event(event);
}
let (count, total, _, min, max) = profiler.get_stats();
assert_eq!(count, 5);
assert_eq!(total, 1500); assert_eq!(min, 100.0);
assert_eq!(max, 500.0);
}
}