use crate::{core::profiler::global_profiler, ProfileEvent};
use backtrace::Backtrace;
use std::time::Instant;
pub struct ScopeGuard {
name: String,
category: String,
start: Instant,
}
impl ScopeGuard {
pub fn new(name: &str) -> Self {
Self::with_category(name, "general")
}
pub fn with_category(name: &str, category: &str) -> Self {
Self {
name: name.to_string(),
category: category.to_string(),
start: Instant::now(),
}
}
pub fn elapsed(&self) -> std::time::Duration {
self.start.elapsed()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn category(&self) -> &str {
&self.category
}
}
impl Drop for ScopeGuard {
fn drop(&mut self) {
let duration = self.start.elapsed();
let thread_id = get_thread_id();
let profiler_arc = global_profiler();
let (stack_trace, stack_trace_overhead_ns) = {
let profiler = profiler_arc.lock();
if profiler.are_stack_traces_enabled() {
if profiler.is_overhead_tracking_enabled() {
capture_stack_trace_with_overhead()
} else {
(capture_stack_trace(), 0)
}
} else {
(None, 0)
}
};
let event = ProfileEvent {
name: self.name.clone(),
category: self.category.clone(),
start_us: 0, duration_us: duration.as_micros() as u64,
thread_id,
operation_count: None,
flops: None,
bytes_transferred: None,
stack_trace,
};
{
let mut profiler = profiler_arc.lock();
if profiler.is_overhead_tracking_enabled() && stack_trace_overhead_ns > 0 {
profiler.overhead_stats.stack_trace_time_ns += stack_trace_overhead_ns;
profiler.overhead_stats.stack_trace_count += 1;
profiler.overhead_stats.total_overhead_ns += stack_trace_overhead_ns;
}
profiler.add_event(event);
}
}
}
fn get_thread_id() -> usize {
let thread_id = std::thread::current().id();
format!("{thread_id:?}")
.chars()
.filter(|c| c.is_ascii_digit())
.collect::<String>()
.parse::<usize>()
.unwrap_or(0)
}
fn capture_stack_trace() -> Option<String> {
#[cfg(debug_assertions)]
{
let bt = Backtrace::new();
Some(format!("{:?}", bt))
}
#[cfg(not(debug_assertions))]
None
}
fn capture_stack_trace_with_overhead() -> (Option<String>, u64) {
let start = Instant::now();
let stack_trace = capture_stack_trace();
let overhead_ns = start.elapsed().as_nanos() as u64;
(stack_trace, overhead_ns)
}
pub fn profile_function<F, R>(name: &str, func: F) -> R
where
F: FnOnce() -> R,
{
let _guard = ScopeGuard::new(name);
func()
}
pub fn profile_function_with_category<F, R>(name: &str, category: &str, func: F) -> R
where
F: FnOnce() -> R,
{
let _guard = ScopeGuard::with_category(name, category);
func()
}
pub struct MetricsScope {
guard: ScopeGuard,
operation_count: Option<u64>,
flops: Option<u64>,
bytes_transferred: Option<u64>,
}
impl MetricsScope {
pub fn new(name: &str) -> Self {
Self {
guard: ScopeGuard::new(name),
operation_count: None,
flops: None,
bytes_transferred: None,
}
}
pub fn with_category(name: &str, category: &str) -> Self {
Self {
guard: ScopeGuard::with_category(name, category),
operation_count: None,
flops: None,
bytes_transferred: None,
}
}
pub fn set_operation_count(&mut self, count: u64) {
self.operation_count = Some(count);
}
pub fn set_flops(&mut self, flops: u64) {
self.flops = Some(flops);
}
pub fn set_bytes_transferred(&mut self, bytes: u64) {
self.bytes_transferred = Some(bytes);
}
pub fn add_operations(&mut self, count: u64) {
self.operation_count = Some(self.operation_count.unwrap_or(0) + count);
}
pub fn add_flops(&mut self, flops: u64) {
self.flops = Some(self.flops.unwrap_or(0) + flops);
}
pub fn add_bytes(&mut self, bytes: u64) {
self.bytes_transferred = Some(self.bytes_transferred.unwrap_or(0) + bytes);
}
pub fn metrics(&self) -> (Option<u64>, Option<u64>, Option<u64>) {
(self.operation_count, self.flops, self.bytes_transferred)
}
}
impl Drop for MetricsScope {
fn drop(&mut self) {
let duration = self.guard.start.elapsed();
let thread_id = get_thread_id();
let profiler_arc = global_profiler();
let (stack_trace, stack_trace_overhead_ns) = {
let profiler = profiler_arc.lock();
if profiler.are_stack_traces_enabled() {
if profiler.is_overhead_tracking_enabled() {
capture_stack_trace_with_overhead()
} else {
(capture_stack_trace(), 0)
}
} else {
(None, 0)
}
};
let event = ProfileEvent {
name: self.guard.name.clone(),
category: self.guard.category.clone(),
start_us: 0, duration_us: duration.as_micros() as u64,
thread_id,
operation_count: self.operation_count,
flops: self.flops,
bytes_transferred: self.bytes_transferred,
stack_trace,
};
{
let mut profiler = profiler_arc.lock();
if profiler.is_overhead_tracking_enabled() && stack_trace_overhead_ns > 0 {
profiler.overhead_stats.stack_trace_time_ns += stack_trace_overhead_ns;
profiler.overhead_stats.stack_trace_count += 1;
profiler.overhead_stats.total_overhead_ns += stack_trace_overhead_ns;
}
profiler.add_event(event);
}
}
}
#[macro_export]
macro_rules! profile_scope {
($name:expr) => {
let _guard = $crate::core::scope::ScopeGuard::new($name);
};
($name:expr, $category:expr) => {
let _guard = $crate::core::scope::ScopeGuard::with_category($name, $category);
};
}
#[macro_export]
macro_rules! profile_function {
($name:expr, $func:expr) => {
$crate::core::scope::profile_function($name, $func)
};
($name:expr, $category:expr, $func:expr) => {
$crate::core::scope::profile_function_with_category($name, $category, $func)
};
}
#[macro_export]
macro_rules! profile_metrics {
($name:expr) => {
$crate::core::scope::MetricsScope::new($name)
};
($name:expr, $category:expr) => {
$crate::core::scope::MetricsScope::with_category($name, $category)
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::profiler::{
clear_global_events, get_global_stats, start_profiling, stop_profiling,
};
use std::thread;
use std::time::Duration;
#[test]
fn test_scope_guard_basic() {
start_profiling();
clear_global_events();
{
let _guard = ScopeGuard::new("test_scope");
thread::sleep(Duration::from_millis(10));
}
let stats = get_global_stats().expect("get global stats should succeed");
assert!(stats.0 > 0);
stop_profiling();
}
#[test]
fn test_scope_guard_with_category() {
start_profiling();
clear_global_events();
{
let _guard = ScopeGuard::with_category("test_scope", "testing");
thread::sleep(Duration::from_millis(5));
}
let stats = get_global_stats().expect("get global stats should succeed");
assert!(stats.0 > 0);
stop_profiling();
}
#[test]
fn test_profile_function() {
start_profiling();
clear_global_events();
let result = profile_function("test_function", || {
thread::sleep(Duration::from_millis(5));
42
});
assert_eq!(result, 42);
let stats = get_global_stats().expect("get global stats should succeed");
assert!(stats.0 > 0);
stop_profiling();
}
#[test]
fn test_metrics_scope() {
start_profiling();
clear_global_events();
{
let mut scope = MetricsScope::new("test_metrics");
scope.set_operation_count(100);
scope.set_flops(500);
scope.set_bytes_transferred(1024);
thread::sleep(Duration::from_millis(5));
let (ops, flops, bytes) = scope.metrics();
assert_eq!(ops, Some(100));
assert_eq!(flops, Some(500));
assert_eq!(bytes, Some(1024));
}
let stats = get_global_stats().expect("get global stats should succeed");
assert!(stats.0 > 0);
stop_profiling();
}
#[test]
fn test_metrics_scope_accumulation() {
let mut scope = MetricsScope::new("test_accumulation");
scope.add_operations(50);
scope.add_operations(75);
scope.add_flops(100);
scope.add_flops(200);
scope.add_bytes(512);
scope.add_bytes(256);
let (ops, flops, bytes) = scope.metrics();
assert_eq!(ops, Some(125));
assert_eq!(flops, Some(300));
assert_eq!(bytes, Some(768));
}
#[test]
fn test_profile_scope_macro() {
start_profiling();
clear_global_events();
{
profile_scope!("macro_test");
thread::sleep(Duration::from_millis(5));
}
{
profile_scope!("macro_test_with_category", "testing");
thread::sleep(Duration::from_millis(5));
}
let stats = get_global_stats().expect("get global stats should succeed");
assert!(stats.0 >= 2);
stop_profiling();
}
#[test]
fn test_thread_id_extraction() {
let id1 = get_thread_id();
let id2 = thread::spawn(|| get_thread_id())
.join()
.expect("join should succeed");
assert_ne!(id1, id2);
assert!(id1 > 0);
assert!(id2 > 0);
}
}