Skip to main content

ronn_core/
profiling.rs

1//! Performance profiling infrastructure
2//!
3//! Provides detailed profiling of inference operations to identify bottlenecks.
4//! Minimal overhead when disabled, detailed insights when enabled.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10/// Profiling configuration
11#[derive(Debug, Clone)]
12pub struct ProfileConfig {
13    /// Enable profiling
14    pub enabled: bool,
15    /// Profile individual operators
16    pub profile_ops: bool,
17    /// Profile memory allocations
18    pub profile_memory: bool,
19    /// Profile data transfers
20    pub profile_transfers: bool,
21    /// Minimum duration to record (filter noise)
22    pub min_duration_us: u64,
23}
24
25impl Default for ProfileConfig {
26    fn default() -> Self {
27        Self {
28            enabled: false,
29            profile_ops: true,
30            profile_memory: true,
31            profile_transfers: true,
32            min_duration_us: 10, // 10 microseconds
33        }
34    }
35}
36
37impl ProfileConfig {
38    /// Create a development profiling config (everything enabled)
39    pub fn development() -> Self {
40        Self {
41            enabled: true,
42            profile_ops: true,
43            profile_memory: true,
44            profile_transfers: true,
45            min_duration_us: 1,
46        }
47    }
48
49    /// Create a production profiling config (minimal overhead)
50    pub fn production() -> Self {
51        Self {
52            enabled: true,
53            profile_ops: true,
54            profile_memory: false,
55            profile_transfers: false,
56            min_duration_us: 100, // Only record slow ops
57        }
58    }
59}
60
61/// A single profiling event
62#[derive(Debug, Clone)]
63pub struct ProfileEvent {
64    /// Event name
65    pub name: String,
66    /// Event category (op, memory, transfer, etc.)
67    pub category: String,
68    /// Duration in microseconds
69    pub duration_us: u64,
70    /// Start timestamp
71    pub timestamp: Instant,
72    /// Additional metadata
73    pub metadata: HashMap<String, String>,
74}
75
76impl ProfileEvent {
77    /// Create a new profile event
78    pub fn new(name: String, category: String, duration: Duration) -> Self {
79        Self {
80            name,
81            category,
82            duration_us: duration.as_micros() as u64,
83            timestamp: Instant::now(),
84            metadata: HashMap::new(),
85        }
86    }
87
88    /// Add metadata to event
89    pub fn with_metadata(mut self, key: String, value: String) -> Self {
90        self.metadata.insert(key, value);
91        self
92    }
93}
94
95/// Profiler for recording performance events
96pub struct Profiler {
97    config: ProfileConfig,
98    events: Arc<Mutex<Vec<ProfileEvent>>>,
99    session_start: Instant,
100}
101
102impl Profiler {
103    /// Create a new profiler
104    pub fn new(config: ProfileConfig) -> Self {
105        Self {
106            config,
107            events: Arc::new(Mutex::new(Vec::new())),
108            session_start: Instant::now(),
109        }
110    }
111
112    /// Create a profiler with default config
113    pub fn default() -> Self {
114        Self::new(ProfileConfig::default())
115    }
116
117    /// Start profiling a named operation
118    ///
119    /// # Returns
120    ///
121    /// A ProfileScope that automatically records duration on drop
122    pub fn scope(&self, name: impl Into<String>, category: impl Into<String>) -> ProfileScope {
123        ProfileScope::new(self.clone(), name.into(), category.into())
124    }
125
126    /// Record an event
127    pub fn record(&self, event: ProfileEvent) {
128        if !self.config.enabled {
129            return;
130        }
131
132        if event.duration_us < self.config.min_duration_us {
133            return;
134        }
135
136        let mut events = self.events.lock().unwrap();
137        events.push(event);
138    }
139
140    /// Get all recorded events
141    pub fn events(&self) -> Vec<ProfileEvent> {
142        self.events.lock().unwrap().clone()
143    }
144
145    /// Clear all events
146    pub fn clear(&self) {
147        self.events.lock().unwrap().clear();
148    }
149
150    /// Generate profiling report
151    pub fn report(&self) -> ProfileReport {
152        let events = self.events();
153        ProfileReport::from_events(events, self.session_start.elapsed())
154    }
155
156    /// Check if profiling is enabled
157    pub fn is_enabled(&self) -> bool {
158        self.config.enabled
159    }
160}
161
162impl Clone for Profiler {
163    fn clone(&self) -> Self {
164        Self {
165            config: self.config.clone(),
166            events: Arc::clone(&self.events),
167            session_start: self.session_start,
168        }
169    }
170}
171
172/// RAII scope for automatic profiling
173///
174/// Records duration automatically when dropped
175pub struct ProfileScope {
176    profiler: Profiler,
177    name: String,
178    category: String,
179    start: Instant,
180    metadata: HashMap<String, String>,
181}
182
183impl ProfileScope {
184    fn new(profiler: Profiler, name: String, category: String) -> Self {
185        Self {
186            profiler,
187            name,
188            category,
189            start: Instant::now(),
190            metadata: HashMap::new(),
191        }
192    }
193
194    /// Add metadata to this scope
195    pub fn with_meta(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
196        self.metadata.insert(key.into(), value.into());
197        self
198    }
199}
200
201impl Drop for ProfileScope {
202    fn drop(&mut self) {
203        let duration = self.start.elapsed();
204        let mut event = ProfileEvent::new(self.name.clone(), self.category.clone(), duration);
205        event.metadata = self.metadata.clone();
206        self.profiler.record(event);
207    }
208}
209
210/// Profiling report with aggregated statistics
211#[derive(Debug, Clone)]
212pub struct ProfileReport {
213    /// Total session duration
214    pub total_duration: Duration,
215    /// Events grouped by category
216    pub by_category: HashMap<String, CategoryStats>,
217    /// Events grouped by name
218    pub by_name: HashMap<String, OperationStats>,
219    /// All events
220    pub events: Vec<ProfileEvent>,
221}
222
223/// Statistics for a category
224#[derive(Debug, Clone, serde::Serialize)]
225pub struct CategoryStats {
226    /// Number of events
227    pub count: usize,
228    /// Total time spent
229    pub total_us: u64,
230    /// Average time per event
231    pub avg_us: u64,
232    /// Minimum time
233    pub min_us: u64,
234    /// Maximum time
235    pub max_us: u64,
236    /// Percentage of total time
237    pub percentage: f64,
238}
239
240/// Statistics for a specific operation
241#[derive(Debug, Clone, serde::Serialize)]
242pub struct OperationStats {
243    /// Number of calls
244    pub count: usize,
245    /// Total time
246    pub total_us: u64,
247    /// Average time
248    pub avg_us: u64,
249    /// Minimum time
250    pub min_us: u64,
251    /// Maximum time
252    pub max_us: u64,
253    /// Standard deviation
254    pub std_dev_us: f64,
255}
256
257impl ProfileReport {
258    /// Create report from events
259    pub fn from_events(events: Vec<ProfileEvent>, total_duration: Duration) -> Self {
260        let total_us = total_duration.as_micros() as u64;
261
262        // Group by category
263        let mut by_category: HashMap<String, Vec<u64>> = HashMap::new();
264        for event in &events {
265            by_category
266                .entry(event.category.clone())
267                .or_insert_with(Vec::new)
268                .push(event.duration_us);
269        }
270
271        let category_stats: HashMap<String, CategoryStats> = by_category
272            .into_iter()
273            .map(|(cat, durations)| {
274                let count = durations.len();
275                let total: u64 = durations.iter().sum();
276                let min = *durations.iter().min().unwrap_or(&0);
277                let max = *durations.iter().max().unwrap_or(&0);
278                let avg = if count > 0 { total / count as u64 } else { 0 };
279                let percentage = if total_us > 0 {
280                    (total as f64 / total_us as f64) * 100.0
281                } else {
282                    0.0
283                };
284
285                (
286                    cat,
287                    CategoryStats {
288                        count,
289                        total_us: total,
290                        avg_us: avg,
291                        min_us: min,
292                        max_us: max,
293                        percentage,
294                    },
295                )
296            })
297            .collect();
298
299        // Group by name
300        let mut by_name: HashMap<String, Vec<u64>> = HashMap::new();
301        for event in &events {
302            by_name
303                .entry(event.name.clone())
304                .or_insert_with(Vec::new)
305                .push(event.duration_us);
306        }
307
308        let name_stats: HashMap<String, OperationStats> = by_name
309            .into_iter()
310            .map(|(name, durations)| {
311                let count = durations.len();
312                let total: u64 = durations.iter().sum();
313                let min = *durations.iter().min().unwrap_or(&0);
314                let max = *durations.iter().max().unwrap_or(&0);
315                let avg = if count > 0 { total / count as u64 } else { 0 };
316
317                // Calculate standard deviation
318                let variance: f64 = if count > 1 {
319                    durations
320                        .iter()
321                        .map(|&d| {
322                            let diff = d as f64 - avg as f64;
323                            diff * diff
324                        })
325                        .sum::<f64>()
326                        / (count - 1) as f64
327                } else {
328                    0.0
329                };
330                let std_dev = variance.sqrt();
331
332                (
333                    name,
334                    OperationStats {
335                        count,
336                        total_us: total,
337                        avg_us: avg,
338                        min_us: min,
339                        max_us: max,
340                        std_dev_us: std_dev,
341                    },
342                )
343            })
344            .collect();
345
346        Self {
347            total_duration,
348            by_category: category_stats,
349            by_name: name_stats,
350            events,
351        }
352    }
353
354    /// Print a human-readable report
355    pub fn print(&self) {
356        println!("\n=== Profiling Report ===");
357        println!(
358            "Total Duration: {:.2}ms\n",
359            self.total_duration.as_secs_f64() * 1000.0
360        );
361
362        println!("By Category:");
363        let mut categories: Vec<_> = self.by_category.iter().collect();
364        categories.sort_by(|a, b| b.1.total_us.cmp(&a.1.total_us));
365        for (cat, stats) in categories {
366            println!(
367                "  {}: {:.2}ms ({:.1}%) - {} calls, avg {:.2}µs",
368                cat,
369                stats.total_us as f64 / 1000.0,
370                stats.percentage,
371                stats.count,
372                stats.avg_us
373            );
374        }
375
376        println!("\nTop 10 Operations:");
377        let mut operations: Vec<_> = self.by_name.iter().collect();
378        operations.sort_by(|a, b| b.1.total_us.cmp(&a.1.total_us));
379        for (name, stats) in operations.iter().take(10) {
380            println!(
381                "  {}: {:.2}ms - {} calls, avg {:.2}µs ± {:.2}µs",
382                name,
383                stats.total_us as f64 / 1000.0,
384                stats.count,
385                stats.avg_us,
386                stats.std_dev_us
387            );
388        }
389
390        println!("\n=== End Report ===\n");
391    }
392
393    /// Export report as JSON
394    pub fn to_json(&self) -> serde_json::Value {
395        serde_json::json!({
396            "total_duration_ms": self.total_duration.as_secs_f64() * 1000.0,
397            "by_category": self.by_category,
398            "by_name": self.by_name,
399            "event_count": self.events.len(),
400        })
401    }
402}
403
404/// Global profiler instance
405static GLOBAL_PROFILER: std::sync::OnceLock<Profiler> = std::sync::OnceLock::new();
406
407/// Initialize global profiler
408pub fn init_profiler(config: ProfileConfig) {
409    let _ = GLOBAL_PROFILER.set(Profiler::new(config));
410}
411
412/// Get global profiler
413pub fn global_profiler() -> &'static Profiler {
414    GLOBAL_PROFILER.get_or_init(Profiler::default)
415}
416
417/// Profile a scope with the global profiler
418#[macro_export]
419macro_rules! profile {
420    ($name:expr, $category:expr) => {
421        let _scope = $crate::profiling::global_profiler().scope($name, $category);
422    };
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use std::thread;
429
430    #[test]
431    fn test_profiler_creation() {
432        let profiler = Profiler::new(ProfileConfig::default());
433        assert!(!profiler.is_enabled());
434    }
435
436    #[test]
437    fn test_profile_scope() {
438        let profiler = Profiler::new(ProfileConfig::development());
439        {
440            let _scope = profiler.scope("test_op", "test");
441            thread::sleep(Duration::from_millis(10));
442        }
443
444        let events = profiler.events();
445        assert_eq!(events.len(), 1);
446        assert_eq!(events[0].name, "test_op");
447        assert!(events[0].duration_us >= 10_000); // At least 10ms
448    }
449
450    #[test]
451    fn test_profiler_report() {
452        let profiler = Profiler::new(ProfileConfig::development());
453
454        // Record some events
455        for i in 0..5 {
456            let _scope = profiler.scope(format!("op_{}", i), "ops");
457            thread::sleep(Duration::from_millis(1));
458        }
459
460        let report = profiler.report();
461        assert_eq!(report.events.len(), 5);
462        assert!(report.by_category.contains_key("ops"));
463        assert_eq!(report.by_category["ops"].count, 5);
464    }
465
466    #[test]
467    fn test_min_duration_filter() {
468        let config = ProfileConfig {
469            enabled: true,
470            min_duration_us: 1000, // 1ms minimum
471            ..Default::default()
472        };
473        let profiler = Profiler::new(config);
474
475        // Fast operation (should be filtered)
476        {
477            let _scope = profiler.scope("fast", "test");
478            // No sleep - very fast
479        }
480
481        // Slow operation (should be recorded)
482        {
483            let _scope = profiler.scope("slow", "test");
484            thread::sleep(Duration::from_millis(2));
485        }
486
487        let events = profiler.events();
488        assert_eq!(events.len(), 1);
489        assert_eq!(events[0].name, "slow");
490    }
491
492    #[test]
493    fn test_report_statistics() {
494        let profiler = Profiler::new(ProfileConfig::development());
495
496        for _ in 0..10 {
497            let _scope = profiler.scope("test_op", "test");
498            thread::sleep(Duration::from_millis(1));
499        }
500
501        let report = profiler.report();
502        let stats = &report.by_name["test_op"];
503
504        assert_eq!(stats.count, 10);
505        assert!(stats.avg_us >= 1000); // At least 1ms average
506        assert!(stats.min_us <= stats.max_us);
507    }
508}