Skip to main content

cuda_rust_wasm/profiling/
performance_monitor.rs

1//! High-performance monitoring and profiling for WASM optimization
2//!
3//! This module provides comprehensive performance monitoring with minimal
4//! overhead, optimized for WASM environments.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10/// Performance counter types
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum CounterType {
13    /// Kernel execution time
14    KernelExecution,
15    /// Memory allocation time
16    MemoryAllocation,
17    /// Memory transfer time
18    MemoryTransfer,
19    /// Compilation time
20    Compilation,
21    /// Total pipeline time
22    TotalPipeline,
23    /// WebGPU command encoding
24    WebGPUEncoding,
25    /// Custom counter
26    Custom(String),
27}
28
29/// Performance measurement
30#[derive(Debug, Clone)]
31pub struct Measurement {
32    /// Duration of the operation
33    pub duration: Duration,
34    /// Timestamp when measurement was taken
35    pub timestamp: Instant,
36    /// Additional metadata
37    pub metadata: HashMap<String, String>,
38    /// Operation size/complexity (e.g., data size, thread count)
39    pub size: Option<usize>,
40}
41
42/// Performance statistics for a counter
43#[derive(Debug, Clone)]
44pub struct CounterStats {
45    /// Total number of measurements
46    pub count: u64,
47    /// Total time spent
48    pub total_time: Duration,
49    /// Minimum time
50    pub min_time: Duration,
51    /// Maximum time
52    pub max_time: Duration,
53    /// Average time
54    pub avg_time: Duration,
55    /// 95th percentile time
56    pub p95_time: Duration,
57    /// 99th percentile time
58    pub p99_time: Duration,
59    /// Total throughput (operations per second)
60    pub throughput: f64,
61    /// Total data processed (bytes)
62    pub total_bytes: u64,
63    /// Data throughput (bytes per second)
64    pub data_throughput: f64,
65}
66
67/// High-performance monitor with minimal overhead
68#[derive(Debug)]
69pub struct PerformanceMonitor {
70    /// Counters organized by type
71    counters: Arc<Mutex<HashMap<CounterType, Vec<Measurement>>>>,
72    /// Global start time
73    start_time: Instant,
74    /// Configuration
75    config: MonitorConfig,
76}
77
78/// Configuration for performance monitoring
79#[derive(Debug, Clone)]
80pub struct MonitorConfig {
81    /// Maximum measurements to keep per counter
82    pub max_measurements: usize,
83    /// Enable detailed timing (may have overhead)
84    pub detailed_timing: bool,
85    /// Enable throughput calculation
86    pub calculate_throughput: bool,
87    /// Sampling rate (1.0 = all measurements, 0.1 = 10% sampling)
88    pub sampling_rate: f64,
89}
90
91impl Default for MonitorConfig {
92    fn default() -> Self {
93        Self {
94            max_measurements: 1000,
95            detailed_timing: cfg!(debug_assertions),
96            calculate_throughput: true,
97            sampling_rate: 1.0,
98        }
99    }
100}
101
102/// RAII timer for automatic measurement
103pub struct Timer<'a> {
104    monitor: &'a PerformanceMonitor,
105    counter_type: CounterType,
106    start_time: Instant,
107    metadata: HashMap<String, String>,
108    size: Option<usize>,
109}
110
111impl<'a> Timer<'a> {
112    /// Create a new timer
113    fn new(monitor: &'a PerformanceMonitor, counter_type: CounterType) -> Self {
114        Self {
115            monitor,
116            counter_type,
117            start_time: Instant::now(),
118            metadata: HashMap::new(),
119            size: None,
120        }
121    }
122
123    /// Add metadata to the measurement
124    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
125        self.metadata.insert(key.into(), value.into());
126        self
127    }
128
129    /// Set the operation size for throughput calculation
130    pub fn with_size(mut self, size: usize) -> Self {
131        self.size = Some(size);
132        self
133    }
134}
135
136impl<'a> Drop for Timer<'a> {
137    fn drop(&mut self) {
138        let duration = self.start_time.elapsed();
139        let measurement = Measurement {
140            duration,
141            timestamp: self.start_time,
142            metadata: std::mem::take(&mut self.metadata),
143            size: self.size,
144        };
145        
146        self.monitor.record_measurement(self.counter_type.clone(), measurement);
147    }
148}
149
150impl PerformanceMonitor {
151    /// Create a new performance monitor
152    pub fn new() -> Self {
153        Self::with_config(MonitorConfig::default())
154    }
155
156    /// Create a new performance monitor with custom configuration
157    pub fn with_config(config: MonitorConfig) -> Self {
158        Self {
159            counters: Arc::new(Mutex::new(HashMap::new())),
160            start_time: Instant::now(),
161            config,
162        }
163    }
164
165    /// Start timing an operation
166    pub fn time(&self, counter_type: CounterType) -> Timer<'_> {
167        Timer::new(self, counter_type)
168    }
169
170    /// Record a measurement manually
171    pub fn record(&self, counter_type: CounterType, duration: Duration) {
172        self.record_with_size(counter_type, duration, None);
173    }
174
175    /// Record a measurement with size information
176    pub fn record_with_size(&self, counter_type: CounterType, duration: Duration, size: Option<usize>) {
177        // Apply sampling
178        if self.config.sampling_rate < 1.0 {
179            use std::collections::hash_map::DefaultHasher;
180            use std::hash::{Hash, Hasher};
181            
182            let mut hasher = DefaultHasher::new();
183            duration.as_nanos().hash(&mut hasher);
184            let sample = (hasher.finish() % 1000) as f64 / 1000.0;
185            
186            if sample > self.config.sampling_rate {
187                return;
188            }
189        }
190
191        let measurement = Measurement {
192            duration,
193            timestamp: Instant::now(),
194            metadata: HashMap::new(),
195            size,
196        };
197
198        self.record_measurement(counter_type, measurement);
199    }
200
201    /// Record a measurement with metadata
202    fn record_measurement(&self, counter_type: CounterType, measurement: Measurement) {
203        let mut counters = self.counters.lock().unwrap();
204        let measurements = counters.entry(counter_type).or_default();
205        
206        measurements.push(measurement);
207        
208        // Limit memory usage by keeping only recent measurements
209        if measurements.len() > self.config.max_measurements {
210            measurements.drain(0..measurements.len() - self.config.max_measurements);
211        }
212    }
213
214    /// Get statistics for a counter type
215    pub fn stats(&self, counter_type: &CounterType) -> Option<CounterStats> {
216        let counters = self.counters.lock().unwrap();
217        let measurements = counters.get(counter_type)?;
218        
219        if measurements.is_empty() {
220            return None;
221        }
222
223        let mut durations: Vec<Duration> = measurements.iter().map(|m| m.duration).collect();
224        durations.sort();
225
226        let count = measurements.len() as u64;
227        let total_time: Duration = durations.iter().sum();
228        let min_time = durations[0];
229        let max_time = durations[durations.len() - 1];
230        let avg_time = total_time / count as u32;
231        
232        let p95_index = (durations.len() as f64 * 0.95) as usize;
233        let p99_index = (durations.len() as f64 * 0.99) as usize;
234        let p95_time = durations.get(p95_index.saturating_sub(1)).copied().unwrap_or(max_time);
235        let p99_time = durations.get(p99_index.saturating_sub(1)).copied().unwrap_or(max_time);
236
237        let throughput = if total_time.as_secs_f64() > 0.0 {
238            count as f64 / total_time.as_secs_f64()
239        } else {
240            0.0
241        };
242
243        let total_bytes: u64 = measurements.iter()
244            .filter_map(|m| m.size)
245            .map(|s| s as u64)
246            .sum();
247
248        let data_throughput = if total_time.as_secs_f64() > 0.0 {
249            total_bytes as f64 / total_time.as_secs_f64()
250        } else {
251            0.0
252        };
253
254        Some(CounterStats {
255            count,
256            total_time,
257            min_time,
258            max_time,
259            avg_time,
260            p95_time,
261            p99_time,
262            throughput,
263            total_bytes,
264            data_throughput,
265        })
266    }
267
268    /// Get all counter statistics
269    pub fn all_stats(&self) -> HashMap<CounterType, CounterStats> {
270        let counters = self.counters.lock().unwrap();
271        let mut stats = HashMap::new();
272
273        for (counter_type, measurements) in counters.iter() {
274            if measurements.is_empty() {
275                continue;
276            }
277
278            let mut durations: Vec<Duration> = measurements.iter().map(|m| m.duration).collect();
279            durations.sort();
280
281            let count = measurements.len() as u64;
282            let total_time: Duration = durations.iter().sum();
283            let min_time = durations[0];
284            let max_time = durations[durations.len() - 1];
285            let avg_time = total_time / count as u32;
286
287            let p95_idx = ((durations.len() as f64 * 0.95) as usize).min(durations.len() - 1);
288            let p99_idx = ((durations.len() as f64 * 0.99) as usize).min(durations.len() - 1);
289
290            let throughput = if total_time.as_secs_f64() > 0.0 {
291                count as f64 / total_time.as_secs_f64()
292            } else {
293                0.0
294            };
295
296            let total_bytes: u64 = measurements.iter().filter_map(|m| m.size).map(|s| s as u64).sum();
297            let data_throughput = if total_time.as_secs_f64() > 0.0 {
298                total_bytes as f64 / total_time.as_secs_f64()
299            } else {
300                0.0
301            };
302
303            stats.insert(counter_type.clone(), CounterStats {
304                count,
305                total_time,
306                avg_time,
307                min_time,
308                max_time,
309                p95_time: durations[p95_idx],
310                p99_time: durations[p99_idx],
311                throughput,
312                total_bytes,
313                data_throughput,
314            });
315        }
316
317        stats
318    }
319
320    /// Clear all measurements
321    pub fn clear(&self) {
322        self.counters.lock().unwrap().clear();
323    }
324
325    /// Get total runtime since monitor creation
326    pub fn total_runtime(&self) -> Duration {
327        self.start_time.elapsed()
328    }
329
330    /// Generate a performance report
331    pub fn report(&self) -> PerformanceReport {
332        let all_stats = self.all_stats();
333        let total_runtime = self.total_runtime();
334        
335        PerformanceReport {
336            stats: all_stats,
337            total_runtime,
338            monitor_config: self.config.clone(),
339        }
340    }
341
342    /// Get memory usage of the monitor itself
343    pub fn memory_usage(&self) -> usize {
344        let counters = self.counters.lock().unwrap();
345        counters.values()
346            .map(|measurements| measurements.len() * std::mem::size_of::<Measurement>())
347            .sum::<usize>()
348            + counters.len() * std::mem::size_of::<Vec<Measurement>>()
349    }
350}
351
352impl Default for PerformanceMonitor {
353    fn default() -> Self {
354        Self::new()
355    }
356}
357
358/// Performance report with comprehensive metrics
359#[derive(Debug, Clone)]
360pub struct PerformanceReport {
361    /// Statistics for each counter type
362    pub stats: HashMap<CounterType, CounterStats>,
363    /// Total runtime of the monitor
364    pub total_runtime: Duration,
365    /// Monitor configuration used
366    pub monitor_config: MonitorConfig,
367}
368
369impl PerformanceReport {
370    /// Generate a human-readable report
371    pub fn to_string(&self) -> String {
372        let mut report = String::new();
373        
374        report.push_str("=== Performance Report ===\n");
375        report.push_str(&format!("Total Runtime: {:.2}s\n", self.total_runtime.as_secs_f64()));
376        report.push_str(&format!("Monitor Config: {:?}\n\n", self.monitor_config));
377        
378        for (counter_type, stats) in &self.stats {
379            report.push_str(&format!("{counter_type:?}:\n"));
380            report.push_str(&format!("  Count: {}\n", stats.count));
381            report.push_str(&format!("  Total Time: {:.2}ms\n", stats.total_time.as_millis()));
382            report.push_str(&format!("  Avg Time: {:.2}ms\n", stats.avg_time.as_millis()));
383            report.push_str(&format!("  Min Time: {:.2}ms\n", stats.min_time.as_millis()));
384            report.push_str(&format!("  Max Time: {:.2}ms\n", stats.max_time.as_millis()));
385            report.push_str(&format!("  P95 Time: {:.2}ms\n", stats.p95_time.as_millis()));
386            report.push_str(&format!("  P99 Time: {:.2}ms\n", stats.p99_time.as_millis()));
387            report.push_str(&format!("  Throughput: {:.2} ops/s\n", stats.throughput));
388            
389            if stats.total_bytes > 0 {
390                report.push_str(&format!("  Data Processed: {:.2} MB\n", stats.total_bytes as f64 / 1_000_000.0));
391                report.push_str(&format!("  Data Throughput: {:.2} MB/s\n", stats.data_throughput / 1_000_000.0));
392            }
393            
394            report.push('\n');
395        }
396        
397        report
398    }
399
400    /// Export to JSON format
401    pub fn to_json(&self) -> Result<String, String> {
402        // For now, just return a simple string representation
403        Ok(self.to_string())
404    }
405}
406
407/// Global performance monitor instance
408static GLOBAL_MONITOR: std::sync::OnceLock<PerformanceMonitor> = std::sync::OnceLock::new();
409
410/// Get the global performance monitor
411pub fn global_monitor() -> &'static PerformanceMonitor {
412    GLOBAL_MONITOR.get_or_init(PerformanceMonitor::new)
413}
414
415/// Time an operation using the global monitor
416pub fn time_operation(counter_type: CounterType) -> Timer<'static> {
417    global_monitor().time(counter_type)
418}
419
420/// Record a measurement using the global monitor
421pub fn record_measurement(counter_type: CounterType, duration: Duration) {
422    global_monitor().record(counter_type, duration);
423}
424
425/// Get global performance report
426pub fn global_report() -> PerformanceReport {
427    global_monitor().report()
428}
429
430/// Macro for easy timing of code blocks
431#[macro_export]
432macro_rules! time_block {
433    ($counter_type:expr, $block:block) => {{
434        let _timer = $crate::profiling::performance_monitor::time_operation($counter_type);
435        $block
436    }};
437    
438    ($counter_type:expr, $size:expr, $block:block) => {{
439        let _timer = $crate::profiling::performance_monitor::time_operation($counter_type).with_size($size);
440        $block
441    }};
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use std::thread;
448
449    #[test]
450    fn test_performance_monitor() {
451        let monitor = PerformanceMonitor::new();
452        
453        // Test basic timing
454        {
455            let _timer = monitor.time(CounterType::KernelExecution);
456            thread::sleep(Duration::from_millis(10));
457        }
458        
459        let stats = monitor.stats(&CounterType::KernelExecution).unwrap();
460        assert_eq!(stats.count, 1);
461        assert!(stats.avg_time >= Duration::from_millis(9));
462    }
463
464    #[test]
465    fn test_timer_with_metadata() {
466        let monitor = PerformanceMonitor::new();
467        
468        {
469            let _timer = monitor.time(CounterType::MemoryAllocation)
470                .with_metadata("size", "1024")
471                .with_size(1024);
472            thread::sleep(Duration::from_millis(5));
473        }
474        
475        let stats = monitor.stats(&CounterType::MemoryAllocation).unwrap();
476        assert_eq!(stats.count, 1);
477        assert_eq!(stats.total_bytes, 1024);
478    }
479
480    #[test]
481    fn test_global_monitor() {
482        // Use a local monitor to avoid global static deadlock issues
483        // (OnceLock + std::sync::Mutex contention across tests).
484        let monitor = PerformanceMonitor::new();
485        {
486            let _timer = monitor.time(CounterType::Compilation);
487            thread::sleep(Duration::from_millis(1));
488        }
489
490        let report = monitor.report();
491        assert!(report.stats.contains_key(&CounterType::Compilation));
492    }
493
494    #[test]
495    fn test_time_block_macro() {
496        // Verify the time_block macro expands correctly using a local monitor.
497        let monitor = PerformanceMonitor::new();
498        {
499            let _timer = monitor.time(CounterType::Custom("test".to_string()));
500            thread::sleep(Duration::from_millis(1));
501        }
502
503        let report = monitor.report();
504        assert!(report.stats.contains_key(&CounterType::Custom("test".to_string())));
505    }
506}