use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CounterType {
KernelExecution,
MemoryAllocation,
MemoryTransfer,
Compilation,
TotalPipeline,
WebGPUEncoding,
Custom(String),
}
#[derive(Debug, Clone)]
pub struct Measurement {
pub duration: Duration,
pub timestamp: Instant,
pub metadata: HashMap<String, String>,
pub size: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct CounterStats {
pub count: u64,
pub total_time: Duration,
pub min_time: Duration,
pub max_time: Duration,
pub avg_time: Duration,
pub p95_time: Duration,
pub p99_time: Duration,
pub throughput: f64,
pub total_bytes: u64,
pub data_throughput: f64,
}
#[derive(Debug)]
pub struct PerformanceMonitor {
counters: Arc<Mutex<HashMap<CounterType, Vec<Measurement>>>>,
start_time: Instant,
config: MonitorConfig,
}
#[derive(Debug, Clone)]
pub struct MonitorConfig {
pub max_measurements: usize,
pub detailed_timing: bool,
pub calculate_throughput: bool,
pub sampling_rate: f64,
}
impl Default for MonitorConfig {
fn default() -> Self {
Self {
max_measurements: 1000,
detailed_timing: cfg!(debug_assertions),
calculate_throughput: true,
sampling_rate: 1.0,
}
}
}
pub struct Timer<'a> {
monitor: &'a PerformanceMonitor,
counter_type: CounterType,
start_time: Instant,
metadata: HashMap<String, String>,
size: Option<usize>,
}
impl<'a> Timer<'a> {
fn new(monitor: &'a PerformanceMonitor, counter_type: CounterType) -> Self {
Self {
monitor,
counter_type,
start_time: Instant::now(),
metadata: HashMap::new(),
size: None,
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn with_size(mut self, size: usize) -> Self {
self.size = Some(size);
self
}
}
impl<'a> Drop for Timer<'a> {
fn drop(&mut self) {
let duration = self.start_time.elapsed();
let measurement = Measurement {
duration,
timestamp: self.start_time,
metadata: std::mem::take(&mut self.metadata),
size: self.size,
};
self.monitor.record_measurement(self.counter_type.clone(), measurement);
}
}
impl PerformanceMonitor {
pub fn new() -> Self {
Self::with_config(MonitorConfig::default())
}
pub fn with_config(config: MonitorConfig) -> Self {
Self {
counters: Arc::new(Mutex::new(HashMap::new())),
start_time: Instant::now(),
config,
}
}
pub fn time(&self, counter_type: CounterType) -> Timer<'_> {
Timer::new(self, counter_type)
}
pub fn record(&self, counter_type: CounterType, duration: Duration) {
self.record_with_size(counter_type, duration, None);
}
pub fn record_with_size(&self, counter_type: CounterType, duration: Duration, size: Option<usize>) {
if self.config.sampling_rate < 1.0 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
duration.as_nanos().hash(&mut hasher);
let sample = (hasher.finish() % 1000) as f64 / 1000.0;
if sample > self.config.sampling_rate {
return;
}
}
let measurement = Measurement {
duration,
timestamp: Instant::now(),
metadata: HashMap::new(),
size,
};
self.record_measurement(counter_type, measurement);
}
fn record_measurement(&self, counter_type: CounterType, measurement: Measurement) {
let mut counters = self.counters.lock().unwrap();
let measurements = counters.entry(counter_type).or_default();
measurements.push(measurement);
if measurements.len() > self.config.max_measurements {
measurements.drain(0..measurements.len() - self.config.max_measurements);
}
}
pub fn stats(&self, counter_type: &CounterType) -> Option<CounterStats> {
let counters = self.counters.lock().unwrap();
let measurements = counters.get(counter_type)?;
if measurements.is_empty() {
return None;
}
let mut durations: Vec<Duration> = measurements.iter().map(|m| m.duration).collect();
durations.sort();
let count = measurements.len() as u64;
let total_time: Duration = durations.iter().sum();
let min_time = durations[0];
let max_time = durations[durations.len() - 1];
let avg_time = total_time / count as u32;
let p95_index = (durations.len() as f64 * 0.95) as usize;
let p99_index = (durations.len() as f64 * 0.99) as usize;
let p95_time = durations.get(p95_index.saturating_sub(1)).copied().unwrap_or(max_time);
let p99_time = durations.get(p99_index.saturating_sub(1)).copied().unwrap_or(max_time);
let throughput = if total_time.as_secs_f64() > 0.0 {
count as f64 / total_time.as_secs_f64()
} else {
0.0
};
let total_bytes: u64 = measurements.iter()
.filter_map(|m| m.size)
.map(|s| s as u64)
.sum();
let data_throughput = if total_time.as_secs_f64() > 0.0 {
total_bytes as f64 / total_time.as_secs_f64()
} else {
0.0
};
Some(CounterStats {
count,
total_time,
min_time,
max_time,
avg_time,
p95_time,
p99_time,
throughput,
total_bytes,
data_throughput,
})
}
pub fn all_stats(&self) -> HashMap<CounterType, CounterStats> {
let counters = self.counters.lock().unwrap();
let mut stats = HashMap::new();
for (counter_type, measurements) in counters.iter() {
if measurements.is_empty() {
continue;
}
let mut durations: Vec<Duration> = measurements.iter().map(|m| m.duration).collect();
durations.sort();
let count = measurements.len() as u64;
let total_time: Duration = durations.iter().sum();
let min_time = durations[0];
let max_time = durations[durations.len() - 1];
let avg_time = total_time / count as u32;
let p95_idx = ((durations.len() as f64 * 0.95) as usize).min(durations.len() - 1);
let p99_idx = ((durations.len() as f64 * 0.99) as usize).min(durations.len() - 1);
let throughput = if total_time.as_secs_f64() > 0.0 {
count as f64 / total_time.as_secs_f64()
} else {
0.0
};
let total_bytes: u64 = measurements.iter().filter_map(|m| m.size).map(|s| s as u64).sum();
let data_throughput = if total_time.as_secs_f64() > 0.0 {
total_bytes as f64 / total_time.as_secs_f64()
} else {
0.0
};
stats.insert(counter_type.clone(), CounterStats {
count,
total_time,
avg_time,
min_time,
max_time,
p95_time: durations[p95_idx],
p99_time: durations[p99_idx],
throughput,
total_bytes,
data_throughput,
});
}
stats
}
pub fn clear(&self) {
self.counters.lock().unwrap().clear();
}
pub fn total_runtime(&self) -> Duration {
self.start_time.elapsed()
}
pub fn report(&self) -> PerformanceReport {
let all_stats = self.all_stats();
let total_runtime = self.total_runtime();
PerformanceReport {
stats: all_stats,
total_runtime,
monitor_config: self.config.clone(),
}
}
pub fn memory_usage(&self) -> usize {
let counters = self.counters.lock().unwrap();
counters.values()
.map(|measurements| measurements.len() * std::mem::size_of::<Measurement>())
.sum::<usize>()
+ counters.len() * std::mem::size_of::<Vec<Measurement>>()
}
}
impl Default for PerformanceMonitor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PerformanceReport {
pub stats: HashMap<CounterType, CounterStats>,
pub total_runtime: Duration,
pub monitor_config: MonitorConfig,
}
impl PerformanceReport {
pub fn to_string(&self) -> String {
let mut report = String::new();
report.push_str("=== Performance Report ===\n");
report.push_str(&format!("Total Runtime: {:.2}s\n", self.total_runtime.as_secs_f64()));
report.push_str(&format!("Monitor Config: {:?}\n\n", self.monitor_config));
for (counter_type, stats) in &self.stats {
report.push_str(&format!("{counter_type:?}:\n"));
report.push_str(&format!(" Count: {}\n", stats.count));
report.push_str(&format!(" Total Time: {:.2}ms\n", stats.total_time.as_millis()));
report.push_str(&format!(" Avg Time: {:.2}ms\n", stats.avg_time.as_millis()));
report.push_str(&format!(" Min Time: {:.2}ms\n", stats.min_time.as_millis()));
report.push_str(&format!(" Max Time: {:.2}ms\n", stats.max_time.as_millis()));
report.push_str(&format!(" P95 Time: {:.2}ms\n", stats.p95_time.as_millis()));
report.push_str(&format!(" P99 Time: {:.2}ms\n", stats.p99_time.as_millis()));
report.push_str(&format!(" Throughput: {:.2} ops/s\n", stats.throughput));
if stats.total_bytes > 0 {
report.push_str(&format!(" Data Processed: {:.2} MB\n", stats.total_bytes as f64 / 1_000_000.0));
report.push_str(&format!(" Data Throughput: {:.2} MB/s\n", stats.data_throughput / 1_000_000.0));
}
report.push('\n');
}
report
}
pub fn to_json(&self) -> Result<String, String> {
Ok(self.to_string())
}
}
static GLOBAL_MONITOR: std::sync::OnceLock<PerformanceMonitor> = std::sync::OnceLock::new();
pub fn global_monitor() -> &'static PerformanceMonitor {
GLOBAL_MONITOR.get_or_init(PerformanceMonitor::new)
}
pub fn time_operation(counter_type: CounterType) -> Timer<'static> {
global_monitor().time(counter_type)
}
pub fn record_measurement(counter_type: CounterType, duration: Duration) {
global_monitor().record(counter_type, duration);
}
pub fn global_report() -> PerformanceReport {
global_monitor().report()
}
#[macro_export]
macro_rules! time_block {
($counter_type:expr, $block:block) => {{
let _timer = $crate::profiling::performance_monitor::time_operation($counter_type);
$block
}};
($counter_type:expr, $size:expr, $block:block) => {{
let _timer = $crate::profiling::performance_monitor::time_operation($counter_type).with_size($size);
$block
}};
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_performance_monitor() {
let monitor = PerformanceMonitor::new();
{
let _timer = monitor.time(CounterType::KernelExecution);
thread::sleep(Duration::from_millis(10));
}
let stats = monitor.stats(&CounterType::KernelExecution).unwrap();
assert_eq!(stats.count, 1);
assert!(stats.avg_time >= Duration::from_millis(9));
}
#[test]
fn test_timer_with_metadata() {
let monitor = PerformanceMonitor::new();
{
let _timer = monitor.time(CounterType::MemoryAllocation)
.with_metadata("size", "1024")
.with_size(1024);
thread::sleep(Duration::from_millis(5));
}
let stats = monitor.stats(&CounterType::MemoryAllocation).unwrap();
assert_eq!(stats.count, 1);
assert_eq!(stats.total_bytes, 1024);
}
#[test]
fn test_global_monitor() {
let monitor = PerformanceMonitor::new();
{
let _timer = monitor.time(CounterType::Compilation);
thread::sleep(Duration::from_millis(1));
}
let report = monitor.report();
assert!(report.stats.contains_key(&CounterType::Compilation));
}
#[test]
fn test_time_block_macro() {
let monitor = PerformanceMonitor::new();
{
let _timer = monitor.time(CounterType::Custom("test".to_string()));
thread::sleep(Duration::from_millis(1));
}
let report = monitor.report();
assert!(report.stats.contains_key(&CounterType::Custom("test".to_string())));
}
}