cuda_rust_wasm/profiling/
performance_monitor.rs1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum CounterType {
13 KernelExecution,
15 MemoryAllocation,
17 MemoryTransfer,
19 Compilation,
21 TotalPipeline,
23 WebGPUEncoding,
25 Custom(String),
27}
28
29#[derive(Debug, Clone)]
31pub struct Measurement {
32 pub duration: Duration,
34 pub timestamp: Instant,
36 pub metadata: HashMap<String, String>,
38 pub size: Option<usize>,
40}
41
42#[derive(Debug, Clone)]
44pub struct CounterStats {
45 pub count: u64,
47 pub total_time: Duration,
49 pub min_time: Duration,
51 pub max_time: Duration,
53 pub avg_time: Duration,
55 pub p95_time: Duration,
57 pub p99_time: Duration,
59 pub throughput: f64,
61 pub total_bytes: u64,
63 pub data_throughput: f64,
65}
66
67#[derive(Debug)]
69pub struct PerformanceMonitor {
70 counters: Arc<Mutex<HashMap<CounterType, Vec<Measurement>>>>,
72 start_time: Instant,
74 config: MonitorConfig,
76}
77
78#[derive(Debug, Clone)]
80pub struct MonitorConfig {
81 pub max_measurements: usize,
83 pub detailed_timing: bool,
85 pub calculate_throughput: bool,
87 pub sampling_rate: f64,
89}
90
91impl Default for MonitorConfig {
92 fn default() -> Self {
93 Self {
94 max_measurements: 1000,
95 detailed_timing: cfg!(debug_assertions),
96 calculate_throughput: true,
97 sampling_rate: 1.0,
98 }
99 }
100}
101
102pub struct Timer<'a> {
104 monitor: &'a PerformanceMonitor,
105 counter_type: CounterType,
106 start_time: Instant,
107 metadata: HashMap<String, String>,
108 size: Option<usize>,
109}
110
111impl<'a> Timer<'a> {
112 fn new(monitor: &'a PerformanceMonitor, counter_type: CounterType) -> Self {
114 Self {
115 monitor,
116 counter_type,
117 start_time: Instant::now(),
118 metadata: HashMap::new(),
119 size: None,
120 }
121 }
122
123 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
125 self.metadata.insert(key.into(), value.into());
126 self
127 }
128
129 pub fn with_size(mut self, size: usize) -> Self {
131 self.size = Some(size);
132 self
133 }
134}
135
136impl<'a> Drop for Timer<'a> {
137 fn drop(&mut self) {
138 let duration = self.start_time.elapsed();
139 let measurement = Measurement {
140 duration,
141 timestamp: self.start_time,
142 metadata: std::mem::take(&mut self.metadata),
143 size: self.size,
144 };
145
146 self.monitor.record_measurement(self.counter_type.clone(), measurement);
147 }
148}
149
150impl PerformanceMonitor {
151 pub fn new() -> Self {
153 Self::with_config(MonitorConfig::default())
154 }
155
156 pub fn with_config(config: MonitorConfig) -> Self {
158 Self {
159 counters: Arc::new(Mutex::new(HashMap::new())),
160 start_time: Instant::now(),
161 config,
162 }
163 }
164
165 pub fn time(&self, counter_type: CounterType) -> Timer<'_> {
167 Timer::new(self, counter_type)
168 }
169
170 pub fn record(&self, counter_type: CounterType, duration: Duration) {
172 self.record_with_size(counter_type, duration, None);
173 }
174
175 pub fn record_with_size(&self, counter_type: CounterType, duration: Duration, size: Option<usize>) {
177 if self.config.sampling_rate < 1.0 {
179 use std::collections::hash_map::DefaultHasher;
180 use std::hash::{Hash, Hasher};
181
182 let mut hasher = DefaultHasher::new();
183 duration.as_nanos().hash(&mut hasher);
184 let sample = (hasher.finish() % 1000) as f64 / 1000.0;
185
186 if sample > self.config.sampling_rate {
187 return;
188 }
189 }
190
191 let measurement = Measurement {
192 duration,
193 timestamp: Instant::now(),
194 metadata: HashMap::new(),
195 size,
196 };
197
198 self.record_measurement(counter_type, measurement);
199 }
200
201 fn record_measurement(&self, counter_type: CounterType, measurement: Measurement) {
203 let mut counters = self.counters.lock().unwrap();
204 let measurements = counters.entry(counter_type).or_default();
205
206 measurements.push(measurement);
207
208 if measurements.len() > self.config.max_measurements {
210 measurements.drain(0..measurements.len() - self.config.max_measurements);
211 }
212 }
213
214 pub fn stats(&self, counter_type: &CounterType) -> Option<CounterStats> {
216 let counters = self.counters.lock().unwrap();
217 let measurements = counters.get(counter_type)?;
218
219 if measurements.is_empty() {
220 return None;
221 }
222
223 let mut durations: Vec<Duration> = measurements.iter().map(|m| m.duration).collect();
224 durations.sort();
225
226 let count = measurements.len() as u64;
227 let total_time: Duration = durations.iter().sum();
228 let min_time = durations[0];
229 let max_time = durations[durations.len() - 1];
230 let avg_time = total_time / count as u32;
231
232 let p95_index = (durations.len() as f64 * 0.95) as usize;
233 let p99_index = (durations.len() as f64 * 0.99) as usize;
234 let p95_time = durations.get(p95_index.saturating_sub(1)).copied().unwrap_or(max_time);
235 let p99_time = durations.get(p99_index.saturating_sub(1)).copied().unwrap_or(max_time);
236
237 let throughput = if total_time.as_secs_f64() > 0.0 {
238 count as f64 / total_time.as_secs_f64()
239 } else {
240 0.0
241 };
242
243 let total_bytes: u64 = measurements.iter()
244 .filter_map(|m| m.size)
245 .map(|s| s as u64)
246 .sum();
247
248 let data_throughput = if total_time.as_secs_f64() > 0.0 {
249 total_bytes as f64 / total_time.as_secs_f64()
250 } else {
251 0.0
252 };
253
254 Some(CounterStats {
255 count,
256 total_time,
257 min_time,
258 max_time,
259 avg_time,
260 p95_time,
261 p99_time,
262 throughput,
263 total_bytes,
264 data_throughput,
265 })
266 }
267
268 pub fn all_stats(&self) -> HashMap<CounterType, CounterStats> {
270 let counters = self.counters.lock().unwrap();
271 let mut stats = HashMap::new();
272
273 for (counter_type, measurements) in counters.iter() {
274 if measurements.is_empty() {
275 continue;
276 }
277
278 let mut durations: Vec<Duration> = measurements.iter().map(|m| m.duration).collect();
279 durations.sort();
280
281 let count = measurements.len() as u64;
282 let total_time: Duration = durations.iter().sum();
283 let min_time = durations[0];
284 let max_time = durations[durations.len() - 1];
285 let avg_time = total_time / count as u32;
286
287 let p95_idx = ((durations.len() as f64 * 0.95) as usize).min(durations.len() - 1);
288 let p99_idx = ((durations.len() as f64 * 0.99) as usize).min(durations.len() - 1);
289
290 let throughput = if total_time.as_secs_f64() > 0.0 {
291 count as f64 / total_time.as_secs_f64()
292 } else {
293 0.0
294 };
295
296 let total_bytes: u64 = measurements.iter().filter_map(|m| m.size).map(|s| s as u64).sum();
297 let data_throughput = if total_time.as_secs_f64() > 0.0 {
298 total_bytes as f64 / total_time.as_secs_f64()
299 } else {
300 0.0
301 };
302
303 stats.insert(counter_type.clone(), CounterStats {
304 count,
305 total_time,
306 avg_time,
307 min_time,
308 max_time,
309 p95_time: durations[p95_idx],
310 p99_time: durations[p99_idx],
311 throughput,
312 total_bytes,
313 data_throughput,
314 });
315 }
316
317 stats
318 }
319
320 pub fn clear(&self) {
322 self.counters.lock().unwrap().clear();
323 }
324
325 pub fn total_runtime(&self) -> Duration {
327 self.start_time.elapsed()
328 }
329
330 pub fn report(&self) -> PerformanceReport {
332 let all_stats = self.all_stats();
333 let total_runtime = self.total_runtime();
334
335 PerformanceReport {
336 stats: all_stats,
337 total_runtime,
338 monitor_config: self.config.clone(),
339 }
340 }
341
342 pub fn memory_usage(&self) -> usize {
344 let counters = self.counters.lock().unwrap();
345 counters.values()
346 .map(|measurements| measurements.len() * std::mem::size_of::<Measurement>())
347 .sum::<usize>()
348 + counters.len() * std::mem::size_of::<Vec<Measurement>>()
349 }
350}
351
352impl Default for PerformanceMonitor {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358#[derive(Debug, Clone)]
360pub struct PerformanceReport {
361 pub stats: HashMap<CounterType, CounterStats>,
363 pub total_runtime: Duration,
365 pub monitor_config: MonitorConfig,
367}
368
369impl PerformanceReport {
370 pub fn to_string(&self) -> String {
372 let mut report = String::new();
373
374 report.push_str("=== Performance Report ===\n");
375 report.push_str(&format!("Total Runtime: {:.2}s\n", self.total_runtime.as_secs_f64()));
376 report.push_str(&format!("Monitor Config: {:?}\n\n", self.monitor_config));
377
378 for (counter_type, stats) in &self.stats {
379 report.push_str(&format!("{counter_type:?}:\n"));
380 report.push_str(&format!(" Count: {}\n", stats.count));
381 report.push_str(&format!(" Total Time: {:.2}ms\n", stats.total_time.as_millis()));
382 report.push_str(&format!(" Avg Time: {:.2}ms\n", stats.avg_time.as_millis()));
383 report.push_str(&format!(" Min Time: {:.2}ms\n", stats.min_time.as_millis()));
384 report.push_str(&format!(" Max Time: {:.2}ms\n", stats.max_time.as_millis()));
385 report.push_str(&format!(" P95 Time: {:.2}ms\n", stats.p95_time.as_millis()));
386 report.push_str(&format!(" P99 Time: {:.2}ms\n", stats.p99_time.as_millis()));
387 report.push_str(&format!(" Throughput: {:.2} ops/s\n", stats.throughput));
388
389 if stats.total_bytes > 0 {
390 report.push_str(&format!(" Data Processed: {:.2} MB\n", stats.total_bytes as f64 / 1_000_000.0));
391 report.push_str(&format!(" Data Throughput: {:.2} MB/s\n", stats.data_throughput / 1_000_000.0));
392 }
393
394 report.push('\n');
395 }
396
397 report
398 }
399
400 pub fn to_json(&self) -> Result<String, String> {
402 Ok(self.to_string())
404 }
405}
406
407static GLOBAL_MONITOR: std::sync::OnceLock<PerformanceMonitor> = std::sync::OnceLock::new();
409
410pub fn global_monitor() -> &'static PerformanceMonitor {
412 GLOBAL_MONITOR.get_or_init(PerformanceMonitor::new)
413}
414
415pub fn time_operation(counter_type: CounterType) -> Timer<'static> {
417 global_monitor().time(counter_type)
418}
419
420pub fn record_measurement(counter_type: CounterType, duration: Duration) {
422 global_monitor().record(counter_type, duration);
423}
424
425pub fn global_report() -> PerformanceReport {
427 global_monitor().report()
428}
429
430#[macro_export]
432macro_rules! time_block {
433 ($counter_type:expr, $block:block) => {{
434 let _timer = $crate::profiling::performance_monitor::time_operation($counter_type);
435 $block
436 }};
437
438 ($counter_type:expr, $size:expr, $block:block) => {{
439 let _timer = $crate::profiling::performance_monitor::time_operation($counter_type).with_size($size);
440 $block
441 }};
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use std::thread;
448
449 #[test]
450 fn test_performance_monitor() {
451 let monitor = PerformanceMonitor::new();
452
453 {
455 let _timer = monitor.time(CounterType::KernelExecution);
456 thread::sleep(Duration::from_millis(10));
457 }
458
459 let stats = monitor.stats(&CounterType::KernelExecution).unwrap();
460 assert_eq!(stats.count, 1);
461 assert!(stats.avg_time >= Duration::from_millis(9));
462 }
463
464 #[test]
465 fn test_timer_with_metadata() {
466 let monitor = PerformanceMonitor::new();
467
468 {
469 let _timer = monitor.time(CounterType::MemoryAllocation)
470 .with_metadata("size", "1024")
471 .with_size(1024);
472 thread::sleep(Duration::from_millis(5));
473 }
474
475 let stats = monitor.stats(&CounterType::MemoryAllocation).unwrap();
476 assert_eq!(stats.count, 1);
477 assert_eq!(stats.total_bytes, 1024);
478 }
479
480 #[test]
481 fn test_global_monitor() {
482 let monitor = PerformanceMonitor::new();
485 {
486 let _timer = monitor.time(CounterType::Compilation);
487 thread::sleep(Duration::from_millis(1));
488 }
489
490 let report = monitor.report();
491 assert!(report.stats.contains_key(&CounterType::Compilation));
492 }
493
494 #[test]
495 fn test_time_block_macro() {
496 let monitor = PerformanceMonitor::new();
498 {
499 let _timer = monitor.time(CounterType::Custom("test".to_string()));
500 thread::sleep(Duration::from_millis(1));
501 }
502
503 let report = monitor.report();
504 assert!(report.stats.contains_key(&CounterType::Custom("test".to_string())));
505 }
506}