1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6use crate::error::CudaRustError;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum OperationType {
11 ModuleLoad,
12 ModuleCompile,
13 KernelLaunch,
14 MemoryTransfer,
15 Synchronization,
16 RuntimeInit,
17 RuntimeShutdown,
18 Custom(u32),
19}
20
21#[derive(Debug, Clone)]
23pub struct OperationEvent {
24 pub operation_type: OperationType,
25 pub name: String,
26 pub start_time: Instant,
27 pub duration: Duration,
28 pub metadata: HashMap<String, String>,
29}
30
31pub struct RuntimeProfiler {
33 events: Arc<Mutex<Vec<OperationEvent>>>,
34 operation_stats: Arc<Mutex<HashMap<OperationType, OperationStats>>>,
35 enabled: bool,
36 start_time: Instant,
37}
38
39#[derive(Debug, Clone)]
40pub struct OperationStats {
41 pub count: usize,
42 pub total_time: Duration,
43 pub min_time: Duration,
44 pub max_time: Duration,
45 pub average_time: Duration,
46}
47
48impl OperationStats {
49 fn new() -> Self {
50 Self {
51 count: 0,
52 total_time: Duration::ZERO,
53 min_time: Duration::MAX,
54 max_time: Duration::ZERO,
55 average_time: Duration::ZERO,
56 }
57 }
58
59 fn update(&mut self, duration: Duration) {
60 self.count += 1;
61 self.total_time += duration;
62 self.average_time = self.total_time / self.count as u32;
63
64 if duration < self.min_time {
65 self.min_time = duration;
66 }
67 if duration > self.max_time {
68 self.max_time = duration;
69 }
70 }
71}
72
73impl Default for RuntimeProfiler {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl RuntimeProfiler {
80 pub fn new() -> Self {
81 Self {
82 events: Arc::new(Mutex::new(Vec::new())),
83 operation_stats: Arc::new(Mutex::new(HashMap::new())),
84 enabled: false,
85 start_time: Instant::now(),
86 }
87 }
88
89 pub fn enable(&mut self) {
90 self.enabled = true;
91 self.start_time = Instant::now();
92 }
93
94 pub fn disable(&mut self) {
95 self.enabled = false;
96 }
97
98 pub fn is_enabled(&self) -> bool {
99 self.enabled
100 }
101
102 pub fn start_operation(&self, operation_type: OperationType, name: &str) -> OperationTimer {
103 OperationTimer::new(
104 self.enabled,
105 operation_type,
106 name.to_string(),
107 Instant::now(),
108 )
109 }
110
111 pub fn end_operation(&self, timer: OperationTimer, metadata: HashMap<String, String>) {
112 if !self.enabled || !timer.enabled {
113 return;
114 }
115
116 let duration = timer.start_time.elapsed();
117
118 let event = OperationEvent {
119 operation_type: timer.operation_type,
120 name: timer.name,
121 start_time: timer.start_time,
122 duration,
123 metadata,
124 };
125
126 {
128 let mut events = self.events.lock().unwrap();
129 events.push(event);
130 }
131
132 {
134 let mut stats = self.operation_stats.lock().unwrap();
135 stats
136 .entry(timer.operation_type)
137 .or_insert_with(OperationStats::new)
138 .update(duration);
139 }
140 }
141
142 pub fn get_events(&self) -> Vec<OperationEvent> {
143 self.events.lock().unwrap().clone()
144 }
145
146 pub fn get_stats(&self) -> HashMap<OperationType, OperationStats> {
147 self.operation_stats.lock().unwrap().clone()
148 }
149
150 pub fn get_total_runtime(&self) -> Duration {
151 self.start_time.elapsed()
152 }
153
154 pub fn print_summary(&self) {
155 println!("\n========== RUNTIME PROFILING SUMMARY ==========");
156
157 let stats = self.get_stats();
158 let total_runtime = self.get_total_runtime();
159
160 println!("\nTotal Runtime: {total_runtime:?}");
161
162 let mut sorted_ops: Vec<_> = stats.iter().collect();
164 sorted_ops.sort_by(|a, b| b.1.total_time.cmp(&a.1.total_time));
165
166 println!("\nOperation Statistics:");
167 for (op_type, stat) in sorted_ops {
168 let percentage = (stat.total_time.as_secs_f64() / total_runtime.as_secs_f64()) * 100.0;
169
170 println!("\n{op_type:?}:");
171 println!(" Count: {}", stat.count);
172 println!(" Total time: {:?} ({:.1}%)", stat.total_time, percentage);
173 println!(" Average: {:?}", stat.average_time);
174 println!(" Min/Max: {:?} / {:?}", stat.min_time, stat.max_time);
175 }
176
177 self.print_timeline_analysis();
179
180 println!("==============================================\n");
181 }
182
183 fn print_timeline_analysis(&self) {
184 let events = self.get_events();
185 if events.is_empty() {
186 return;
187 }
188
189 println!("\nTimeline Analysis:");
190
191 let mut critical_path_time = Duration::ZERO;
193 let mut last_end_time = self.start_time;
194
195 for event in &events {
196 let event_end = event.start_time + event.duration;
197 if event.start_time >= last_end_time {
198 critical_path_time += event.duration;
199 last_end_time = event_end;
200 }
201 }
202
203 println!(" Critical path time: {critical_path_time:?}");
204 println!(" Parallelization efficiency: {:.1}%",
205 (critical_path_time.as_secs_f64() / self.get_total_runtime().as_secs_f64()) * 100.0
206 );
207
208 let mut longest_ops = events.clone();
210 longest_ops.sort_by(|a, b| b.duration.cmp(&a.duration));
211
212 println!("\n Longest operations:");
213 for (i, event) in longest_ops.iter().take(5).enumerate() {
214 println!(" {}. {} ({:?}): {:?}",
215 i + 1,
216 event.name,
217 event.operation_type,
218 event.duration
219 );
220 }
221 }
222
223 pub fn export_trace(&self, path: &str) -> Result<(), CudaRustError> {
224 use std::fs::File;
225 use std::io::Write;
226
227 let events = self.get_events();
228 let mut file = File::create(path)
229 .map_err(|e| CudaRustError::RuntimeError(format!("Failed to create file: {e}")))?;
230
231 writeln!(file, "[")
233 .map_err(|e| CudaRustError::RuntimeError(format!("Failed to write header: {e}")))?;
234
235 for (i, event) in events.iter().enumerate() {
236 let start_us = event.start_time.duration_since(self.start_time).as_micros();
237 let duration_us = event.duration.as_micros();
238
239 let trace_event = format!(
240 r#"{{
241 "name": "{}",
242 "cat": "{:?}",
243 "ph": "X",
244 "ts": {},
245 "dur": {},
246 "pid": 1,
247 "tid": 1,
248 "args": {{}}
249}}"#,
250 event.name,
251 event.operation_type,
252 start_us,
253 duration_us
254 );
255
256 if i < events.len() - 1 {
257 writeln!(file, "{trace_event},")
258 .map_err(|e| CudaRustError::RuntimeError(format!("Failed to write event: {e}")))?;
259 } else {
260 writeln!(file, "{trace_event}")
261 .map_err(|e| CudaRustError::RuntimeError(format!("Failed to write event: {e}")))?;
262 }
263 }
264
265 writeln!(file, "]")
266 .map_err(|e| CudaRustError::RuntimeError(format!("Failed to write footer: {e}")))?;
267
268 Ok(())
269 }
270
271 pub fn analyze_bottlenecks(&self) -> BottleneckAnalysis {
272 let stats = self.get_stats();
273 let total_runtime = self.get_total_runtime();
274
275 let mut time_by_operation: Vec<_> = stats.iter()
277 .map(|(op, stat)| (*op, stat.total_time))
278 .collect();
279 time_by_operation.sort_by(|a, b| b.1.cmp(&a.1));
280
281 let primary_bottleneck = time_by_operation.first()
282 .map(|(op, _)| *op)
283 .unwrap_or(OperationType::Custom(0));
284
285 let mut time_distribution = HashMap::new();
287 for (op, stat) in &stats {
288 let percentage = (stat.total_time.as_secs_f64() / total_runtime.as_secs_f64()) * 100.0;
289 time_distribution.insert(*op, percentage);
290 }
291
292 let mut high_variance_ops = Vec::new();
294 for (op, stat) in &stats {
295 if stat.count > 1 {
296 let range = stat.max_time.as_secs_f64() - stat.min_time.as_secs_f64();
297 let variance_ratio = range / stat.average_time.as_secs_f64();
298 if variance_ratio > 2.0 {
299 high_variance_ops.push((*op, variance_ratio));
300 }
301 }
302 }
303
304 BottleneckAnalysis {
305 primary_bottleneck,
306 time_distribution,
307 high_variance_operations: high_variance_ops,
308 total_runtime,
309 }
310 }
311
312 pub fn clear(&self) {
313 self.events.lock().unwrap().clear();
314 self.operation_stats.lock().unwrap().clear();
315 }
316}
317
318pub struct OperationTimer {
320 enabled: bool,
321 operation_type: OperationType,
322 name: String,
323 start_time: Instant,
324}
325
326impl OperationTimer {
327 fn new(enabled: bool, operation_type: OperationType, name: String, start_time: Instant) -> Self {
328 Self {
329 enabled,
330 operation_type,
331 name,
332 start_time,
333 }
334 }
335}
336
337#[derive(Debug, Clone)]
339pub struct BottleneckAnalysis {
340 pub primary_bottleneck: OperationType,
341 pub time_distribution: HashMap<OperationType, f64>,
342 pub high_variance_operations: Vec<(OperationType, f64)>,
343 pub total_runtime: Duration,
344}
345
346impl BottleneckAnalysis {
347 pub fn print_analysis(&self) {
348 println!("\n=== Bottleneck Analysis ===");
349 println!("Total runtime: {:?}", self.total_runtime);
350 println!("Primary bottleneck: {:?}", self.primary_bottleneck);
351
352 println!("\nTime distribution:");
353 let mut sorted_dist: Vec<_> = self.time_distribution.iter().collect();
354 sorted_dist.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
355
356 for (op, percentage) in sorted_dist {
357 println!(" {op:?}: {percentage:.1}%");
358 }
359
360 if !self.high_variance_operations.is_empty() {
361 println!("\nHigh variance operations:");
362 for (op, ratio) in &self.high_variance_operations {
363 println!(" {op:?}: {ratio:.1}x variance");
364 }
365 }
366 }
367}
368
369pub struct OptimizationSuggestions {
371 suggestions: Vec<Suggestion>,
372}
373
374#[derive(Debug, Clone)]
375pub struct Suggestion {
376 pub severity: SuggestionSeverity,
377 pub category: SuggestionCategory,
378 pub message: String,
379 pub expected_improvement: Option<f64>,
380}
381
382#[derive(Debug, Clone, Copy)]
383pub enum SuggestionSeverity {
384 Low,
385 Medium,
386 High,
387}
388
389#[derive(Debug, Clone, Copy)]
390pub enum SuggestionCategory {
391 MemoryOptimization,
392 KernelOptimization,
393 RuntimeOptimization,
394 Parallelization,
395}
396
397impl OptimizationSuggestions {
398 pub fn analyze(profiler: &RuntimeProfiler) -> Self {
399 let mut suggestions = Vec::new();
400 let analysis = profiler.analyze_bottlenecks();
401
402 if let Some(percentage) = analysis.time_distribution.get(&OperationType::ModuleLoad) {
404 if *percentage > 20.0 {
405 suggestions.push(Suggestion {
406 severity: SuggestionSeverity::High,
407 category: SuggestionCategory::RuntimeOptimization,
408 message: "Module loading takes >20% of runtime. Consider caching compiled modules.".to_string(),
409 expected_improvement: Some(percentage * 0.8),
410 });
411 }
412 }
413
414 if let Some(percentage) = analysis.time_distribution.get(&OperationType::ModuleCompile) {
416 if *percentage > 30.0 {
417 suggestions.push(Suggestion {
418 severity: SuggestionSeverity::High,
419 category: SuggestionCategory::RuntimeOptimization,
420 message: "Compilation takes >30% of runtime. Use pre-compiled WASM modules.".to_string(),
421 expected_improvement: Some(percentage * 0.9),
422 });
423 }
424 }
425
426 if let Some(percentage) = analysis.time_distribution.get(&OperationType::MemoryTransfer) {
428 if *percentage > 40.0 {
429 suggestions.push(Suggestion {
430 severity: SuggestionSeverity::High,
431 category: SuggestionCategory::MemoryOptimization,
432 message: "Memory transfers dominate runtime. Consider unified memory or reducing transfers.".to_string(),
433 expected_improvement: Some(percentage * 0.5),
434 });
435 }
436 }
437
438 Self { suggestions }
439 }
440
441 pub fn print_suggestions(&self) {
442 if self.suggestions.is_empty() {
443 println!("\nNo optimization suggestions found.");
444 return;
445 }
446
447 println!("\n=== Optimization Suggestions ===");
448
449 for (i, suggestion) in self.suggestions.iter().enumerate() {
450 println!("\n{}. {:?} - {:?}", i + 1, suggestion.severity, suggestion.category);
451 println!(" {}", suggestion.message);
452 if let Some(improvement) = suggestion.expected_improvement {
453 println!(" Expected improvement: {improvement:.1}%");
454 }
455 }
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_runtime_profiler() {
465 let mut profiler = RuntimeProfiler::new();
466 profiler.enable();
467
468 let timer1 = profiler.start_operation(OperationType::ModuleLoad, "test_module");
470 std::thread::sleep(Duration::from_millis(10));
471 profiler.end_operation(timer1, HashMap::new());
472
473 let timer2 = profiler.start_operation(OperationType::KernelLaunch, "test_kernel");
474 std::thread::sleep(Duration::from_millis(5));
475 profiler.end_operation(timer2, HashMap::new());
476
477 let stats = profiler.get_stats();
478 assert_eq!(stats.len(), 2);
479 assert_eq!(stats[&OperationType::ModuleLoad].count, 1);
480 assert_eq!(stats[&OperationType::KernelLaunch].count, 1);
481 }
482
483 #[test]
484 fn test_bottleneck_analysis() {
485 let mut profiler = RuntimeProfiler::new();
486 profiler.enable();
487
488 for _ in 0..10 {
490 let timer = profiler.start_operation(OperationType::MemoryTransfer, "transfer");
491 std::thread::sleep(Duration::from_millis(10));
492 profiler.end_operation(timer, HashMap::new());
493 }
494
495 let timer = profiler.start_operation(OperationType::KernelLaunch, "kernel");
496 std::thread::sleep(Duration::from_millis(5));
497 profiler.end_operation(timer, HashMap::new());
498
499 let analysis = profiler.analyze_bottlenecks();
500 assert_eq!(analysis.primary_bottleneck, OperationType::MemoryTransfer);
501 }
502}