cuda_rust_wasm/profiling/
runtime_profiler.rs

1//! Runtime performance profiling
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6use crate::error::CudaRustError;
7
8/// Runtime operation types
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum OperationType {
11    ModuleLoad,
12    ModuleCompile,
13    KernelLaunch,
14    MemoryTransfer,
15    Synchronization,
16    RuntimeInit,
17    RuntimeShutdown,
18    Custom(u32),
19}
20
21/// Runtime operation event
22#[derive(Debug, Clone)]
23pub struct OperationEvent {
24    pub operation_type: OperationType,
25    pub name: String,
26    pub start_time: Instant,
27    pub duration: Duration,
28    pub metadata: HashMap<String, String>,
29}
30
31/// Runtime profiler for tracking WASM runtime performance
32pub struct RuntimeProfiler {
33    events: Arc<Mutex<Vec<OperationEvent>>>,
34    operation_stats: Arc<Mutex<HashMap<OperationType, OperationStats>>>,
35    enabled: bool,
36    start_time: Instant,
37}
38
39#[derive(Debug, Clone)]
40pub struct OperationStats {
41    pub count: usize,
42    pub total_time: Duration,
43    pub min_time: Duration,
44    pub max_time: Duration,
45    pub average_time: Duration,
46}
47
48impl OperationStats {
49    fn new() -> Self {
50        Self {
51            count: 0,
52            total_time: Duration::ZERO,
53            min_time: Duration::MAX,
54            max_time: Duration::ZERO,
55            average_time: Duration::ZERO,
56        }
57    }
58
59    fn update(&mut self, duration: Duration) {
60        self.count += 1;
61        self.total_time += duration;
62        self.average_time = self.total_time / self.count as u32;
63        
64        if duration < self.min_time {
65            self.min_time = duration;
66        }
67        if duration > self.max_time {
68            self.max_time = duration;
69        }
70    }
71}
72
73impl Default for RuntimeProfiler {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl RuntimeProfiler {
80    pub fn new() -> Self {
81        Self {
82            events: Arc::new(Mutex::new(Vec::new())),
83            operation_stats: Arc::new(Mutex::new(HashMap::new())),
84            enabled: false,
85            start_time: Instant::now(),
86        }
87    }
88
89    pub fn enable(&mut self) {
90        self.enabled = true;
91        self.start_time = Instant::now();
92    }
93
94    pub fn disable(&mut self) {
95        self.enabled = false;
96    }
97
98    pub fn is_enabled(&self) -> bool {
99        self.enabled
100    }
101
102    pub fn start_operation(&self, operation_type: OperationType, name: &str) -> OperationTimer {
103        OperationTimer::new(
104            self.enabled,
105            operation_type,
106            name.to_string(),
107            Instant::now(),
108        )
109    }
110
111    pub fn end_operation(&self, timer: OperationTimer, metadata: HashMap<String, String>) {
112        if !self.enabled || !timer.enabled {
113            return;
114        }
115
116        let duration = timer.start_time.elapsed();
117        
118        let event = OperationEvent {
119            operation_type: timer.operation_type,
120            name: timer.name,
121            start_time: timer.start_time,
122            duration,
123            metadata,
124        };
125
126        // Record event
127        {
128            let mut events = self.events.lock().unwrap();
129            events.push(event);
130        }
131
132        // Update statistics
133        {
134            let mut stats = self.operation_stats.lock().unwrap();
135            stats
136                .entry(timer.operation_type)
137                .or_insert_with(OperationStats::new)
138                .update(duration);
139        }
140    }
141
142    pub fn get_events(&self) -> Vec<OperationEvent> {
143        self.events.lock().unwrap().clone()
144    }
145
146    pub fn get_stats(&self) -> HashMap<OperationType, OperationStats> {
147        self.operation_stats.lock().unwrap().clone()
148    }
149
150    pub fn get_total_runtime(&self) -> Duration {
151        self.start_time.elapsed()
152    }
153
154    pub fn print_summary(&self) {
155        println!("\n========== RUNTIME PROFILING SUMMARY ==========");
156        
157        let stats = self.get_stats();
158        let total_runtime = self.get_total_runtime();
159        
160        println!("\nTotal Runtime: {total_runtime:?}");
161        
162        // Sort operations by total time
163        let mut sorted_ops: Vec<_> = stats.iter().collect();
164        sorted_ops.sort_by(|a, b| b.1.total_time.cmp(&a.1.total_time));
165        
166        println!("\nOperation Statistics:");
167        for (op_type, stat) in sorted_ops {
168            let percentage = (stat.total_time.as_secs_f64() / total_runtime.as_secs_f64()) * 100.0;
169            
170            println!("\n{op_type:?}:");
171            println!("  Count: {}", stat.count);
172            println!("  Total time: {:?} ({:.1}%)", stat.total_time, percentage);
173            println!("  Average: {:?}", stat.average_time);
174            println!("  Min/Max: {:?} / {:?}", stat.min_time, stat.max_time);
175        }
176        
177        // Timeline analysis
178        self.print_timeline_analysis();
179        
180        println!("==============================================\n");
181    }
182
183    fn print_timeline_analysis(&self) {
184        let events = self.get_events();
185        if events.is_empty() {
186            return;
187        }
188        
189        println!("\nTimeline Analysis:");
190        
191        // Find critical path
192        let mut critical_path_time = Duration::ZERO;
193        let mut last_end_time = self.start_time;
194        
195        for event in &events {
196            let event_end = event.start_time + event.duration;
197            if event.start_time >= last_end_time {
198                critical_path_time += event.duration;
199                last_end_time = event_end;
200            }
201        }
202        
203        println!("  Critical path time: {critical_path_time:?}");
204        println!("  Parallelization efficiency: {:.1}%", 
205            (critical_path_time.as_secs_f64() / self.get_total_runtime().as_secs_f64()) * 100.0
206        );
207        
208        // Find longest operations
209        let mut longest_ops = events.clone();
210        longest_ops.sort_by(|a, b| b.duration.cmp(&a.duration));
211        
212        println!("\n  Longest operations:");
213        for (i, event) in longest_ops.iter().take(5).enumerate() {
214            println!("    {}. {} ({:?}): {:?}",
215                i + 1,
216                event.name,
217                event.operation_type,
218                event.duration
219            );
220        }
221    }
222
223    pub fn export_trace(&self, path: &str) -> Result<(), CudaRustError> {
224        use std::fs::File;
225        use std::io::Write;
226
227        let events = self.get_events();
228        let mut file = File::create(path)
229            .map_err(|e| CudaRustError::RuntimeError(format!("Failed to create file: {e}")))?;
230
231        // Write Chrome Tracing Format
232        writeln!(file, "[")
233            .map_err(|e| CudaRustError::RuntimeError(format!("Failed to write header: {e}")))?;
234
235        for (i, event) in events.iter().enumerate() {
236            let start_us = event.start_time.duration_since(self.start_time).as_micros();
237            let duration_us = event.duration.as_micros();
238            
239            let trace_event = format!(
240                r#"{{
241    "name": "{}",
242    "cat": "{:?}",
243    "ph": "X",
244    "ts": {},
245    "dur": {},
246    "pid": 1,
247    "tid": 1,
248    "args": {{}}
249}}"#,
250                event.name,
251                event.operation_type,
252                start_us,
253                duration_us
254            );
255            
256            if i < events.len() - 1 {
257                writeln!(file, "{trace_event},")
258                    .map_err(|e| CudaRustError::RuntimeError(format!("Failed to write event: {e}")))?;
259            } else {
260                writeln!(file, "{trace_event}")
261                    .map_err(|e| CudaRustError::RuntimeError(format!("Failed to write event: {e}")))?;
262            }
263        }
264
265        writeln!(file, "]")
266            .map_err(|e| CudaRustError::RuntimeError(format!("Failed to write footer: {e}")))?;
267
268        Ok(())
269    }
270
271    pub fn analyze_bottlenecks(&self) -> BottleneckAnalysis {
272        let stats = self.get_stats();
273        let total_runtime = self.get_total_runtime();
274        
275        // Find operations that take most time
276        let mut time_by_operation: Vec<_> = stats.iter()
277            .map(|(op, stat)| (*op, stat.total_time))
278            .collect();
279        time_by_operation.sort_by(|a, b| b.1.cmp(&a.1));
280        
281        let primary_bottleneck = time_by_operation.first()
282            .map(|(op, _)| *op)
283            .unwrap_or(OperationType::Custom(0));
284        
285        // Calculate time distribution
286        let mut time_distribution = HashMap::new();
287        for (op, stat) in &stats {
288            let percentage = (stat.total_time.as_secs_f64() / total_runtime.as_secs_f64()) * 100.0;
289            time_distribution.insert(*op, percentage);
290        }
291        
292        // Find operations with high variance
293        let mut high_variance_ops = Vec::new();
294        for (op, stat) in &stats {
295            if stat.count > 1 {
296                let range = stat.max_time.as_secs_f64() - stat.min_time.as_secs_f64();
297                let variance_ratio = range / stat.average_time.as_secs_f64();
298                if variance_ratio > 2.0 {
299                    high_variance_ops.push((*op, variance_ratio));
300                }
301            }
302        }
303        
304        BottleneckAnalysis {
305            primary_bottleneck,
306            time_distribution,
307            high_variance_operations: high_variance_ops,
308            total_runtime,
309        }
310    }
311
312    pub fn clear(&self) {
313        self.events.lock().unwrap().clear();
314        self.operation_stats.lock().unwrap().clear();
315    }
316}
317
318/// Timer for runtime operations
319pub struct OperationTimer {
320    enabled: bool,
321    operation_type: OperationType,
322    name: String,
323    start_time: Instant,
324}
325
326impl OperationTimer {
327    fn new(enabled: bool, operation_type: OperationType, name: String, start_time: Instant) -> Self {
328        Self {
329            enabled,
330            operation_type,
331            name,
332            start_time,
333        }
334    }
335}
336
337/// Bottleneck analysis results
338#[derive(Debug, Clone)]
339pub struct BottleneckAnalysis {
340    pub primary_bottleneck: OperationType,
341    pub time_distribution: HashMap<OperationType, f64>,
342    pub high_variance_operations: Vec<(OperationType, f64)>,
343    pub total_runtime: Duration,
344}
345
346impl BottleneckAnalysis {
347    pub fn print_analysis(&self) {
348        println!("\n=== Bottleneck Analysis ===");
349        println!("Total runtime: {:?}", self.total_runtime);
350        println!("Primary bottleneck: {:?}", self.primary_bottleneck);
351        
352        println!("\nTime distribution:");
353        let mut sorted_dist: Vec<_> = self.time_distribution.iter().collect();
354        sorted_dist.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
355        
356        for (op, percentage) in sorted_dist {
357            println!("  {op:?}: {percentage:.1}%");
358        }
359        
360        if !self.high_variance_operations.is_empty() {
361            println!("\nHigh variance operations:");
362            for (op, ratio) in &self.high_variance_operations {
363                println!("  {op:?}: {ratio:.1}x variance");
364            }
365        }
366    }
367}
368
369/// Performance optimization suggestions
370pub struct OptimizationSuggestions {
371    suggestions: Vec<Suggestion>,
372}
373
374#[derive(Debug, Clone)]
375pub struct Suggestion {
376    pub severity: SuggestionSeverity,
377    pub category: SuggestionCategory,
378    pub message: String,
379    pub expected_improvement: Option<f64>,
380}
381
382#[derive(Debug, Clone, Copy)]
383pub enum SuggestionSeverity {
384    Low,
385    Medium,
386    High,
387}
388
389#[derive(Debug, Clone, Copy)]
390pub enum SuggestionCategory {
391    MemoryOptimization,
392    KernelOptimization,
393    RuntimeOptimization,
394    Parallelization,
395}
396
397impl OptimizationSuggestions {
398    pub fn analyze(profiler: &RuntimeProfiler) -> Self {
399        let mut suggestions = Vec::new();
400        let analysis = profiler.analyze_bottlenecks();
401        
402        // Check for module loading bottleneck
403        if let Some(percentage) = analysis.time_distribution.get(&OperationType::ModuleLoad) {
404            if *percentage > 20.0 {
405                suggestions.push(Suggestion {
406                    severity: SuggestionSeverity::High,
407                    category: SuggestionCategory::RuntimeOptimization,
408                    message: "Module loading takes >20% of runtime. Consider caching compiled modules.".to_string(),
409                    expected_improvement: Some(percentage * 0.8),
410                });
411            }
412        }
413        
414        // Check for compilation bottleneck
415        if let Some(percentage) = analysis.time_distribution.get(&OperationType::ModuleCompile) {
416            if *percentage > 30.0 {
417                suggestions.push(Suggestion {
418                    severity: SuggestionSeverity::High,
419                    category: SuggestionCategory::RuntimeOptimization,
420                    message: "Compilation takes >30% of runtime. Use pre-compiled WASM modules.".to_string(),
421                    expected_improvement: Some(percentage * 0.9),
422                });
423            }
424        }
425        
426        // Check for memory transfer bottleneck
427        if let Some(percentage) = analysis.time_distribution.get(&OperationType::MemoryTransfer) {
428            if *percentage > 40.0 {
429                suggestions.push(Suggestion {
430                    severity: SuggestionSeverity::High,
431                    category: SuggestionCategory::MemoryOptimization,
432                    message: "Memory transfers dominate runtime. Consider unified memory or reducing transfers.".to_string(),
433                    expected_improvement: Some(percentage * 0.5),
434                });
435            }
436        }
437        
438        Self { suggestions }
439    }
440    
441    pub fn print_suggestions(&self) {
442        if self.suggestions.is_empty() {
443            println!("\nNo optimization suggestions found.");
444            return;
445        }
446        
447        println!("\n=== Optimization Suggestions ===");
448        
449        for (i, suggestion) in self.suggestions.iter().enumerate() {
450            println!("\n{}. {:?} - {:?}", i + 1, suggestion.severity, suggestion.category);
451            println!("   {}", suggestion.message);
452            if let Some(improvement) = suggestion.expected_improvement {
453                println!("   Expected improvement: {improvement:.1}%");
454            }
455        }
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_runtime_profiler() {
465        let mut profiler = RuntimeProfiler::new();
466        profiler.enable();
467
468        // Simulate operations
469        let timer1 = profiler.start_operation(OperationType::ModuleLoad, "test_module");
470        std::thread::sleep(Duration::from_millis(10));
471        profiler.end_operation(timer1, HashMap::new());
472
473        let timer2 = profiler.start_operation(OperationType::KernelLaunch, "test_kernel");
474        std::thread::sleep(Duration::from_millis(5));
475        profiler.end_operation(timer2, HashMap::new());
476
477        let stats = profiler.get_stats();
478        assert_eq!(stats.len(), 2);
479        assert_eq!(stats[&OperationType::ModuleLoad].count, 1);
480        assert_eq!(stats[&OperationType::KernelLaunch].count, 1);
481    }
482
483    #[test]
484    fn test_bottleneck_analysis() {
485        let mut profiler = RuntimeProfiler::new();
486        profiler.enable();
487
488        // Create a bottleneck scenario
489        for _ in 0..10 {
490            let timer = profiler.start_operation(OperationType::MemoryTransfer, "transfer");
491            std::thread::sleep(Duration::from_millis(10));
492            profiler.end_operation(timer, HashMap::new());
493        }
494
495        let timer = profiler.start_operation(OperationType::KernelLaunch, "kernel");
496        std::thread::sleep(Duration::from_millis(5));
497        profiler.end_operation(timer, HashMap::new());
498
499        let analysis = profiler.analyze_bottlenecks();
500        assert_eq!(analysis.primary_bottleneck, OperationType::MemoryTransfer);
501    }
502}