cuda_rust_wasm/profiling/
performance_monitor.rs1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum CounterType {
13 KernelExecution,
15 MemoryAllocation,
17 MemoryTransfer,
19 Compilation,
21 TotalPipeline,
23 WebGPUEncoding,
25 Custom(String),
27}
28
29#[derive(Debug, Clone)]
31pub struct Measurement {
32 pub duration: Duration,
34 pub timestamp: Instant,
36 pub metadata: HashMap<String, String>,
38 pub size: Option<usize>,
40}
41
42#[derive(Debug, Clone)]
44pub struct CounterStats {
45 pub count: u64,
47 pub total_time: Duration,
49 pub min_time: Duration,
51 pub max_time: Duration,
53 pub avg_time: Duration,
55 pub p95_time: Duration,
57 pub p99_time: Duration,
59 pub throughput: f64,
61 pub total_bytes: u64,
63 pub data_throughput: f64,
65}
66
67#[derive(Debug)]
69pub struct PerformanceMonitor {
70 counters: Arc<Mutex<HashMap<CounterType, Vec<Measurement>>>>,
72 start_time: Instant,
74 config: MonitorConfig,
76}
77
78#[derive(Debug, Clone)]
80pub struct MonitorConfig {
81 pub max_measurements: usize,
83 pub detailed_timing: bool,
85 pub calculate_throughput: bool,
87 pub sampling_rate: f64,
89}
90
91impl Default for MonitorConfig {
92 fn default() -> Self {
93 Self {
94 max_measurements: 1000,
95 detailed_timing: cfg!(debug_assertions),
96 calculate_throughput: true,
97 sampling_rate: 1.0,
98 }
99 }
100}
101
102pub struct Timer<'a> {
104 monitor: &'a PerformanceMonitor,
105 counter_type: CounterType,
106 start_time: Instant,
107 metadata: HashMap<String, String>,
108 size: Option<usize>,
109}
110
111impl<'a> Timer<'a> {
112 fn new(monitor: &'a PerformanceMonitor, counter_type: CounterType) -> Self {
114 Self {
115 monitor,
116 counter_type,
117 start_time: Instant::now(),
118 metadata: HashMap::new(),
119 size: None,
120 }
121 }
122
123 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
125 self.metadata.insert(key.into(), value.into());
126 self
127 }
128
129 pub fn with_size(mut self, size: usize) -> Self {
131 self.size = Some(size);
132 self
133 }
134}
135
136impl<'a> Drop for Timer<'a> {
137 fn drop(&mut self) {
138 let duration = self.start_time.elapsed();
139 let measurement = Measurement {
140 duration,
141 timestamp: self.start_time,
142 metadata: std::mem::take(&mut self.metadata),
143 size: self.size,
144 };
145
146 self.monitor.record_measurement(self.counter_type.clone(), measurement);
147 }
148}
149
150impl PerformanceMonitor {
151 pub fn new() -> Self {
153 Self::with_config(MonitorConfig::default())
154 }
155
156 pub fn with_config(config: MonitorConfig) -> Self {
158 Self {
159 counters: Arc::new(Mutex::new(HashMap::new())),
160 start_time: Instant::now(),
161 config,
162 }
163 }
164
165 pub fn time(&self, counter_type: CounterType) -> Timer<'_> {
167 Timer::new(self, counter_type)
168 }
169
170 pub fn record(&self, counter_type: CounterType, duration: Duration) {
172 self.record_with_size(counter_type, duration, None);
173 }
174
175 pub fn record_with_size(&self, counter_type: CounterType, duration: Duration, size: Option<usize>) {
177 if self.config.sampling_rate < 1.0 {
179 use std::collections::hash_map::DefaultHasher;
180 use std::hash::{Hash, Hasher};
181
182 let mut hasher = DefaultHasher::new();
183 duration.as_nanos().hash(&mut hasher);
184 let sample = (hasher.finish() % 1000) as f64 / 1000.0;
185
186 if sample > self.config.sampling_rate {
187 return;
188 }
189 }
190
191 let measurement = Measurement {
192 duration,
193 timestamp: Instant::now(),
194 metadata: HashMap::new(),
195 size,
196 };
197
198 self.record_measurement(counter_type, measurement);
199 }
200
201 fn record_measurement(&self, counter_type: CounterType, measurement: Measurement) {
203 let mut counters = self.counters.lock().unwrap();
204 let measurements = counters.entry(counter_type).or_default();
205
206 measurements.push(measurement);
207
208 if measurements.len() > self.config.max_measurements {
210 measurements.drain(0..measurements.len() - self.config.max_measurements);
211 }
212 }
213
214 pub fn stats(&self, counter_type: &CounterType) -> Option<CounterStats> {
216 let counters = self.counters.lock().unwrap();
217 let measurements = counters.get(counter_type)?;
218
219 if measurements.is_empty() {
220 return None;
221 }
222
223 let mut durations: Vec<Duration> = measurements.iter().map(|m| m.duration).collect();
224 durations.sort();
225
226 let count = measurements.len() as u64;
227 let total_time: Duration = durations.iter().sum();
228 let min_time = durations[0];
229 let max_time = durations[durations.len() - 1];
230 let avg_time = total_time / count as u32;
231
232 let p95_index = (durations.len() as f64 * 0.95) as usize;
233 let p99_index = (durations.len() as f64 * 0.99) as usize;
234 let p95_time = durations.get(p95_index.saturating_sub(1)).copied().unwrap_or(max_time);
235 let p99_time = durations.get(p99_index.saturating_sub(1)).copied().unwrap_or(max_time);
236
237 let throughput = if total_time.as_secs_f64() > 0.0 {
238 count as f64 / total_time.as_secs_f64()
239 } else {
240 0.0
241 };
242
243 let total_bytes: u64 = measurements.iter()
244 .filter_map(|m| m.size)
245 .map(|s| s as u64)
246 .sum();
247
248 let data_throughput = if total_time.as_secs_f64() > 0.0 {
249 total_bytes as f64 / total_time.as_secs_f64()
250 } else {
251 0.0
252 };
253
254 Some(CounterStats {
255 count,
256 total_time,
257 min_time,
258 max_time,
259 avg_time,
260 p95_time,
261 p99_time,
262 throughput,
263 total_bytes,
264 data_throughput,
265 })
266 }
267
268 pub fn all_stats(&self) -> HashMap<CounterType, CounterStats> {
270 let counters = self.counters.lock().unwrap();
271 let mut stats = HashMap::new();
272
273 for counter_type in counters.keys() {
274 if let Some(counter_stats) = self.stats(counter_type) {
275 stats.insert(counter_type.clone(), counter_stats);
276 }
277 }
278
279 stats
280 }
281
282 pub fn clear(&self) {
284 self.counters.lock().unwrap().clear();
285 }
286
287 pub fn total_runtime(&self) -> Duration {
289 self.start_time.elapsed()
290 }
291
292 pub fn report(&self) -> PerformanceReport {
294 let all_stats = self.all_stats();
295 let total_runtime = self.total_runtime();
296
297 PerformanceReport {
298 stats: all_stats,
299 total_runtime,
300 monitor_config: self.config.clone(),
301 }
302 }
303
304 pub fn memory_usage(&self) -> usize {
306 let counters = self.counters.lock().unwrap();
307 counters.values()
308 .map(|measurements| measurements.len() * std::mem::size_of::<Measurement>())
309 .sum::<usize>()
310 + counters.len() * std::mem::size_of::<Vec<Measurement>>()
311 }
312}
313
314impl Default for PerformanceMonitor {
315 fn default() -> Self {
316 Self::new()
317 }
318}
319
320#[derive(Debug, Clone)]
322pub struct PerformanceReport {
323 pub stats: HashMap<CounterType, CounterStats>,
325 pub total_runtime: Duration,
327 pub monitor_config: MonitorConfig,
329}
330
331impl PerformanceReport {
332 pub fn to_string(&self) -> String {
334 let mut report = String::new();
335
336 report.push_str("=== Performance Report ===\n");
337 report.push_str(&format!("Total Runtime: {:.2}s\n", self.total_runtime.as_secs_f64()));
338 report.push_str(&format!("Monitor Config: {:?}\n\n", self.monitor_config));
339
340 for (counter_type, stats) in &self.stats {
341 report.push_str(&format!("{counter_type:?}:\n"));
342 report.push_str(&format!(" Count: {}\n", stats.count));
343 report.push_str(&format!(" Total Time: {:.2}ms\n", stats.total_time.as_millis()));
344 report.push_str(&format!(" Avg Time: {:.2}ms\n", stats.avg_time.as_millis()));
345 report.push_str(&format!(" Min Time: {:.2}ms\n", stats.min_time.as_millis()));
346 report.push_str(&format!(" Max Time: {:.2}ms\n", stats.max_time.as_millis()));
347 report.push_str(&format!(" P95 Time: {:.2}ms\n", stats.p95_time.as_millis()));
348 report.push_str(&format!(" P99 Time: {:.2}ms\n", stats.p99_time.as_millis()));
349 report.push_str(&format!(" Throughput: {:.2} ops/s\n", stats.throughput));
350
351 if stats.total_bytes > 0 {
352 report.push_str(&format!(" Data Processed: {:.2} MB\n", stats.total_bytes as f64 / 1_000_000.0));
353 report.push_str(&format!(" Data Throughput: {:.2} MB/s\n", stats.data_throughput / 1_000_000.0));
354 }
355
356 report.push('\n');
357 }
358
359 report
360 }
361
362 pub fn to_json(&self) -> Result<String, String> {
364 Ok(self.to_string())
366 }
367}
368
369static GLOBAL_MONITOR: std::sync::OnceLock<PerformanceMonitor> = std::sync::OnceLock::new();
371
372pub fn global_monitor() -> &'static PerformanceMonitor {
374 GLOBAL_MONITOR.get_or_init(PerformanceMonitor::new)
375}
376
377pub fn time_operation(counter_type: CounterType) -> Timer<'static> {
379 global_monitor().time(counter_type)
380}
381
382pub fn record_measurement(counter_type: CounterType, duration: Duration) {
384 global_monitor().record(counter_type, duration);
385}
386
387pub fn global_report() -> PerformanceReport {
389 global_monitor().report()
390}
391
392#[macro_export]
394macro_rules! time_block {
395 ($counter_type:expr, $block:block) => {{
396 let _timer = $crate::profiling::performance_monitor::time_operation($counter_type);
397 $block
398 }};
399
400 ($counter_type:expr, $size:expr, $block:block) => {{
401 let _timer = $crate::profiling::performance_monitor::time_operation($counter_type).with_size($size);
402 $block
403 }};
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use std::thread;
410
411 #[test]
412 fn test_performance_monitor() {
413 let monitor = PerformanceMonitor::new();
414
415 {
417 let _timer = monitor.time(CounterType::KernelExecution);
418 thread::sleep(Duration::from_millis(10));
419 }
420
421 let stats = monitor.stats(&CounterType::KernelExecution).unwrap();
422 assert_eq!(stats.count, 1);
423 assert!(stats.avg_time >= Duration::from_millis(9));
424 }
425
426 #[test]
427 fn test_timer_with_metadata() {
428 let monitor = PerformanceMonitor::new();
429
430 {
431 let _timer = monitor.time(CounterType::MemoryAllocation)
432 .with_metadata("size", "1024")
433 .with_size(1024);
434 thread::sleep(Duration::from_millis(5));
435 }
436
437 let stats = monitor.stats(&CounterType::MemoryAllocation).unwrap();
438 assert_eq!(stats.count, 1);
439 assert_eq!(stats.total_bytes, 1024);
440 }
441
442 #[test]
443 fn test_global_monitor() {
444 {
445 let _timer = time_operation(CounterType::Compilation);
446 thread::sleep(Duration::from_millis(1));
447 }
448
449 let report = global_report();
450 assert!(report.stats.contains_key(&CounterType::Compilation));
451 }
452
453 #[test]
454 fn test_time_block_macro() {
455 time_block!(CounterType::Custom("test".to_string()), {
456 thread::sleep(Duration::from_millis(1));
457 });
458
459 let report = global_report();
460 assert!(report.stats.contains_key(&CounterType::Custom("test".to_string())));
461 }
462}