ghostflow_core/
profiler.rs

1//! Profiling tools for performance analysis
2//!
3//! This module provides utilities for profiling tensor operations and model execution.
4
5use std::collections::HashMap;
6use std::time::{Duration, Instant};
7use std::sync::{Arc, Mutex};
8
9/// Operation profiler for tracking performance
10pub struct Profiler {
11    records: Arc<Mutex<Vec<ProfileRecord>>>,
12    enabled: bool,
13}
14
15impl Profiler {
16    /// Create a new profiler
17    pub fn new() -> Self {
18        Self {
19            records: Arc::new(Mutex::new(Vec::new())),
20            enabled: true,
21        }
22    }
23
24    /// Enable profiling
25    pub fn enable(&mut self) {
26        self.enabled = true;
27    }
28
29    /// Disable profiling
30    pub fn disable(&mut self) {
31        self.enabled = false;
32    }
33
34    /// Start profiling an operation
35    pub fn start(&self, name: &str) -> ProfileScope {
36        if !self.enabled {
37            return ProfileScope::disabled();
38        }
39
40        ProfileScope {
41            name: name.to_string(),
42            start: Instant::now(),
43            records: Some(Arc::clone(&self.records)),
44        }
45    }
46
47    /// Get all profile records
48    pub fn records(&self) -> Vec<ProfileRecord> {
49        self.records.lock().unwrap().clone()
50    }
51
52    /// Get summary statistics
53    pub fn summary(&self) -> ProfileSummary {
54        let records = self.records.lock().unwrap();
55        let mut op_stats: HashMap<String, OpStats> = HashMap::new();
56
57        for record in records.iter() {
58            let stats = op_stats.entry(record.name.clone()).or_insert_with(OpStats::default);
59            stats.count += 1;
60            stats.total_time += record.duration;
61            stats.min_time = stats.min_time.min(record.duration);
62            stats.max_time = stats.max_time.max(record.duration);
63        }
64
65        // Calculate averages
66        for stats in op_stats.values_mut() {
67            stats.avg_time = stats.total_time / stats.count as u32;
68        }
69
70        ProfileSummary {
71            total_operations: records.len(),
72            op_stats,
73        }
74    }
75
76    /// Clear all records
77    pub fn clear(&self) {
78        self.records.lock().unwrap().clear();
79    }
80
81    /// Print summary to stdout
82    pub fn print_summary(&self) {
83        let summary = self.summary();
84        println!("\n=== Profiler Summary ===");
85        println!("Total operations: {}", summary.total_operations);
86        println!("\nOperation Statistics:");
87        println!("{:<30} {:>10} {:>15} {:>15} {:>15}", 
88                 "Operation", "Count", "Total (ms)", "Avg (ms)", "Max (ms)");
89        println!("{}", "-".repeat(85));
90
91        let mut ops: Vec<_> = summary.op_stats.iter().collect();
92        ops.sort_by(|a, b| b.1.total_time.cmp(&a.1.total_time));
93
94        for (name, stats) in ops {
95            println!("{:<30} {:>10} {:>15.3} {:>15.3} {:>15.3}",
96                     name,
97                     stats.count,
98                     stats.total_time.as_secs_f64() * 1000.0,
99                     stats.avg_time.as_secs_f64() * 1000.0,
100                     stats.max_time.as_secs_f64() * 1000.0);
101        }
102        println!();
103    }
104}
105
106impl Default for Profiler {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112/// Profile scope for RAII-style profiling
113pub struct ProfileScope {
114    name: String,
115    start: Instant,
116    records: Option<Arc<Mutex<Vec<ProfileRecord>>>>,
117}
118
119impl ProfileScope {
120    fn disabled() -> Self {
121        Self {
122            name: String::new(),
123            start: Instant::now(),
124            records: None,
125        }
126    }
127}
128
129impl Drop for ProfileScope {
130    fn drop(&mut self) {
131        if let Some(records) = &self.records {
132            let duration = self.start.elapsed();
133            records.lock().unwrap().push(ProfileRecord {
134                name: self.name.clone(),
135                duration,
136            });
137        }
138    }
139}
140
141/// Single profile record
142#[derive(Debug, Clone)]
143pub struct ProfileRecord {
144    pub name: String,
145    pub duration: Duration,
146}
147
148/// Statistics for a single operation type
149#[derive(Debug, Clone)]
150pub struct OpStats {
151    pub count: usize,
152    pub total_time: Duration,
153    pub avg_time: Duration,
154    pub min_time: Duration,
155    pub max_time: Duration,
156}
157
158impl Default for OpStats {
159    fn default() -> Self {
160        Self {
161            count: 0,
162            total_time: Duration::ZERO,
163            avg_time: Duration::ZERO,
164            min_time: Duration::MAX,
165            max_time: Duration::ZERO,
166        }
167    }
168}
169
170/// Profile summary
171#[derive(Debug, Clone)]
172pub struct ProfileSummary {
173    pub total_operations: usize,
174    pub op_stats: HashMap<String, OpStats>,
175}
176
177/// Benchmark utility for measuring performance
178pub struct Benchmark {
179    name: String,
180    warmup_iterations: usize,
181    iterations: usize,
182}
183
184impl Benchmark {
185    /// Create a new benchmark
186    pub fn new(name: &str) -> Self {
187        Self {
188            name: name.to_string(),
189            warmup_iterations: 3,
190            iterations: 10,
191        }
192    }
193
194    /// Set warmup iterations
195    pub fn warmup(mut self, iterations: usize) -> Self {
196        self.warmup_iterations = iterations;
197        self
198    }
199
200    /// Set benchmark iterations
201    pub fn iterations(mut self, iterations: usize) -> Self {
202        self.iterations = iterations;
203        self
204    }
205
206    /// Run the benchmark
207    pub fn run<F>(&self, mut f: F) -> BenchmarkResult
208    where
209        F: FnMut(),
210    {
211        // Warmup
212        for _ in 0..self.warmup_iterations {
213            f();
214        }
215
216        // Benchmark
217        let mut times = Vec::with_capacity(self.iterations);
218        for _ in 0..self.iterations {
219            let start = Instant::now();
220            f();
221            times.push(start.elapsed());
222        }
223
224        BenchmarkResult::new(&self.name, times)
225    }
226}
227
228/// Benchmark result
229#[derive(Debug, Clone)]
230pub struct BenchmarkResult {
231    pub name: String,
232    pub iterations: usize,
233    pub total_time: Duration,
234    pub mean_time: Duration,
235    pub median_time: Duration,
236    pub min_time: Duration,
237    pub max_time: Duration,
238    pub std_dev: f64,
239}
240
241impl BenchmarkResult {
242    fn new(name: &str, mut times: Vec<Duration>) -> Self {
243        times.sort();
244        
245        let iterations = times.len();
246        let total_time: Duration = times.iter().sum();
247        let mean_time = total_time / iterations as u32;
248        let median_time = times[iterations / 2];
249        let min_time = times[0];
250        let max_time = times[iterations - 1];
251
252        // Calculate standard deviation
253        let mean_secs = mean_time.as_secs_f64();
254        let variance: f64 = times.iter()
255            .map(|t| {
256                let diff = t.as_secs_f64() - mean_secs;
257                diff * diff
258            })
259            .sum::<f64>() / iterations as f64;
260        let std_dev = variance.sqrt();
261
262        Self {
263            name: name.to_string(),
264            iterations,
265            total_time,
266            mean_time,
267            median_time,
268            min_time,
269            max_time,
270            std_dev,
271        }
272    }
273
274    /// Print the benchmark result
275    pub fn print(&self) {
276        println!("\n=== Benchmark: {} ===", self.name);
277        println!("Iterations: {}", self.iterations);
278        println!("Total time: {:.3} ms", self.total_time.as_secs_f64() * 1000.0);
279        println!("Mean time:   {:.3} ms", self.mean_time.as_secs_f64() * 1000.0);
280        println!("Median time: {:.3} ms", self.median_time.as_secs_f64() * 1000.0);
281        println!("Min time:    {:.3} ms", self.min_time.as_secs_f64() * 1000.0);
282        println!("Max time:    {:.3} ms", self.max_time.as_secs_f64() * 1000.0);
283        println!("Std dev:     {:.3} ms", self.std_dev * 1000.0);
284        println!();
285    }
286}
287
288/// Global profiler instance
289static mut GLOBAL_PROFILER: Option<Profiler> = None;
290
291/// Get the global profiler
292pub fn global_profiler() -> &'static Profiler {
293    unsafe {
294        GLOBAL_PROFILER.get_or_insert_with(Profiler::new)
295    }
296}
297
298/// Profile a code block
299#[macro_export]
300macro_rules! profile {
301    ($name:expr, $code:block) => {{
302        let _scope = $crate::profiler::global_profiler().start($name);
303        $code
304    }};
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use std::thread;
311
312    #[test]
313    fn test_profiler() {
314        let profiler = Profiler::new();
315        
316        {
317            let _scope = profiler.start("test_op");
318            thread::sleep(Duration::from_millis(10));
319        }
320        
321        let records = profiler.records();
322        assert_eq!(records.len(), 1);
323        assert_eq!(records[0].name, "test_op");
324        assert!(records[0].duration >= Duration::from_millis(10));
325    }
326
327    #[test]
328    fn test_profiler_summary() {
329        let profiler = Profiler::new();
330        
331        for _ in 0..5 {
332            let _scope = profiler.start("op1");
333            thread::sleep(Duration::from_millis(1));
334        }
335        
336        for _ in 0..3 {
337            let _scope = profiler.start("op2");
338            thread::sleep(Duration::from_millis(1));
339        }
340        
341        let summary = profiler.summary();
342        assert_eq!(summary.total_operations, 8);
343        assert_eq!(summary.op_stats.len(), 2);
344        assert_eq!(summary.op_stats["op1"].count, 5);
345        assert_eq!(summary.op_stats["op2"].count, 3);
346    }
347
348    #[test]
349    fn test_benchmark() {
350        let result = Benchmark::new("test")
351            .warmup(2)
352            .iterations(5)
353            .run(|| {
354                thread::sleep(Duration::from_millis(1));
355            });
356        
357        assert_eq!(result.iterations, 5);
358        assert!(result.mean_time >= Duration::from_millis(1));
359    }
360
361    #[test]
362    fn test_disabled_profiler() {
363        let mut profiler = Profiler::new();
364        profiler.disable();
365        
366        {
367            let _scope = profiler.start("test");
368            thread::sleep(Duration::from_millis(10));
369        }
370        
371        let records = profiler.records();
372        assert_eq!(records.len(), 0);
373    }
374}