ipfrs_tensorlogic/
ffi_profiler.rs

1//! FFI Overhead Profiling
2//!
3//! This module provides utilities for profiling FFI call overhead and identifying
4//! performance bottlenecks in cross-language boundaries.
5
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12/// FFI call statistics
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct FfiCallStats {
15    /// Function name
16    pub name: String,
17    /// Total number of calls
18    pub call_count: u64,
19    /// Total time spent in calls
20    pub total_duration: Duration,
21    /// Minimum call duration
22    pub min_duration: Duration,
23    /// Maximum call duration
24    pub max_duration: Duration,
25    /// Average call duration
26    pub avg_duration: Duration,
27}
28
29impl FfiCallStats {
30    fn new(name: String) -> Self {
31        Self {
32            name,
33            call_count: 0,
34            total_duration: Duration::ZERO,
35            min_duration: Duration::MAX,
36            max_duration: Duration::ZERO,
37            avg_duration: Duration::ZERO,
38        }
39    }
40
41    fn record(&mut self, duration: Duration) {
42        self.call_count += 1;
43        self.total_duration += duration;
44        self.min_duration = self.min_duration.min(duration);
45        self.max_duration = self.max_duration.max(duration);
46        self.avg_duration = self.total_duration / self.call_count as u32;
47    }
48
49    /// Check if call overhead exceeds target
50    pub fn exceeds_target(&self, target_micros: u64) -> bool {
51        self.avg_duration.as_micros() > target_micros as u128
52    }
53
54    /// Get overhead percentage relative to target
55    pub fn overhead_percentage(&self, target_micros: u64) -> f64 {
56        let avg_micros = self.avg_duration.as_micros() as f64;
57        ((avg_micros - target_micros as f64) / target_micros as f64) * 100.0
58    }
59}
60
61/// FFI profiler for measuring call overhead
62pub struct FfiProfiler {
63    stats: Arc<RwLock<HashMap<String, FfiCallStats>>>,
64    enabled: Arc<RwLock<bool>>,
65}
66
67impl FfiProfiler {
68    /// Create a new FFI profiler
69    pub fn new() -> Self {
70        Self {
71            stats: Arc::new(RwLock::new(HashMap::new())),
72            enabled: Arc::new(RwLock::new(true)),
73        }
74    }
75
76    /// Enable profiling
77    pub fn enable(&self) {
78        *self.enabled.write() = true;
79    }
80
81    /// Disable profiling
82    pub fn disable(&self) {
83        *self.enabled.write() = false;
84    }
85
86    /// Check if profiling is enabled
87    pub fn is_enabled(&self) -> bool {
88        *self.enabled.read()
89    }
90
91    /// Start profiling a function call
92    pub fn start(&self, name: &str) -> FfiCallGuard {
93        FfiCallGuard {
94            name: name.to_string(),
95            start: Instant::now(),
96            profiler: self.clone(),
97        }
98    }
99
100    /// Record a call duration
101    fn record(&self, name: String, duration: Duration) {
102        if !self.is_enabled() {
103            return;
104        }
105
106        let mut stats = self.stats.write();
107        stats
108            .entry(name.clone())
109            .or_insert_with(|| FfiCallStats::new(name))
110            .record(duration);
111    }
112
113    /// Get statistics for a specific function
114    pub fn get_stats(&self, name: &str) -> Option<FfiCallStats> {
115        self.stats.read().get(name).cloned()
116    }
117
118    /// Get all statistics
119    pub fn get_all_stats(&self) -> Vec<FfiCallStats> {
120        self.stats.read().values().cloned().collect()
121    }
122
123    /// Reset all statistics
124    pub fn reset(&self) {
125        self.stats.write().clear();
126    }
127
128    /// Get statistics sorted by average duration
129    pub fn get_hotspots(&self) -> Vec<FfiCallStats> {
130        let mut stats = self.get_all_stats();
131        stats.sort_by(|a, b| b.avg_duration.cmp(&a.avg_duration));
132        stats
133    }
134
135    /// Get total overhead
136    pub fn total_overhead(&self) -> Duration {
137        self.stats.read().values().map(|s| s.total_duration).sum()
138    }
139
140    /// Generate profiling report
141    pub fn report(&self) -> ProfilingReport {
142        let stats = self.get_all_stats();
143        let total_calls: u64 = stats.iter().map(|s| s.call_count).sum();
144        let total_duration = self.total_overhead();
145
146        ProfilingReport {
147            total_calls,
148            total_duration,
149            function_stats: stats,
150        }
151    }
152}
153
154impl Default for FfiProfiler {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160impl Clone for FfiProfiler {
161    fn clone(&self) -> Self {
162        Self {
163            stats: Arc::clone(&self.stats),
164            enabled: Arc::clone(&self.enabled),
165        }
166    }
167}
168
169/// RAII guard for profiling FFI calls
170pub struct FfiCallGuard {
171    name: String,
172    start: Instant,
173    profiler: FfiProfiler,
174}
175
176impl Drop for FfiCallGuard {
177    fn drop(&mut self) {
178        let duration = self.start.elapsed();
179        self.profiler.record(self.name.clone(), duration);
180    }
181}
182
183/// Profiling report
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ProfilingReport {
186    /// Total number of FFI calls
187    pub total_calls: u64,
188    /// Total time spent in FFI calls
189    pub total_duration: Duration,
190    /// Per-function statistics
191    pub function_stats: Vec<FfiCallStats>,
192}
193
194impl ProfilingReport {
195    /// Print report to stdout
196    pub fn print(&self) {
197        println!("\n=== FFI Profiling Report ===");
198        println!("Total calls: {}", self.total_calls);
199        println!("Total duration: {:?}", self.total_duration);
200        println!("\nFunction statistics:");
201        println!(
202            "{:<30} {:>10} {:>15} {:>15} {:>15}",
203            "Function", "Calls", "Avg (μs)", "Min (μs)", "Max (μs)"
204        );
205        println!("{}", "-".repeat(85));
206
207        let mut sorted_stats = self.function_stats.clone();
208        sorted_stats.sort_by(|a, b| b.avg_duration.cmp(&a.avg_duration));
209
210        for stat in sorted_stats {
211            println!(
212                "{:<30} {:>10} {:>15.2} {:>15.2} {:>15.2}",
213                stat.name,
214                stat.call_count,
215                stat.avg_duration.as_micros() as f64,
216                stat.min_duration.as_micros() as f64,
217                stat.max_duration.as_micros() as f64,
218            );
219        }
220    }
221
222    /// Identify functions exceeding target overhead
223    pub fn identify_bottlenecks(&self, target_micros: u64) -> Vec<String> {
224        self.function_stats
225            .iter()
226            .filter(|s| s.exceeds_target(target_micros))
227            .map(|s| s.name.clone())
228            .collect()
229    }
230
231    /// Get overhead summary
232    pub fn summary(&self) -> OverheadSummary {
233        let avg_call_duration = if self.total_calls > 0 {
234            self.total_duration / self.total_calls as u32
235        } else {
236            Duration::ZERO
237        };
238
239        let max_duration = self
240            .function_stats
241            .iter()
242            .map(|s| s.max_duration)
243            .max()
244            .unwrap_or(Duration::ZERO);
245
246        OverheadSummary {
247            total_calls: self.total_calls,
248            total_duration: self.total_duration,
249            avg_call_duration,
250            max_call_duration: max_duration,
251            functions_profiled: self.function_stats.len(),
252        }
253    }
254}
255
256/// Overhead summary
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct OverheadSummary {
259    pub total_calls: u64,
260    pub total_duration: Duration,
261    pub avg_call_duration: Duration,
262    pub max_call_duration: Duration,
263    pub functions_profiled: usize,
264}
265
266impl OverheadSummary {
267    /// Check if average overhead meets target
268    pub fn meets_target(&self, target_micros: u64) -> bool {
269        self.avg_call_duration.as_micros() <= target_micros as u128
270    }
271}
272
273/// Global FFI profiler instance
274static GLOBAL_PROFILER: once_cell::sync::Lazy<FfiProfiler> =
275    once_cell::sync::Lazy::new(FfiProfiler::new);
276
277/// Get the global FFI profiler
278pub fn global_profiler() -> &'static FfiProfiler {
279    &GLOBAL_PROFILER
280}
281
282/// Profile an FFI function call
283#[macro_export]
284macro_rules! profile_ffi {
285    ($name:expr, $body:expr) => {{
286        let _guard = $crate::ffi_profiler::global_profiler().start($name);
287        $body
288    }};
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use std::thread;
295
296    #[test]
297    fn test_ffi_profiler_basic() {
298        let profiler = FfiProfiler::new();
299
300        // Profile a function
301        {
302            let _guard = profiler.start("test_function");
303            thread::sleep(Duration::from_millis(10));
304        }
305
306        let stats = profiler.get_stats("test_function");
307        assert!(stats.is_some());
308
309        let stats = stats.unwrap();
310        assert_eq!(stats.call_count, 1);
311        assert!(stats.avg_duration >= Duration::from_millis(10));
312    }
313
314    #[test]
315    fn test_multiple_calls() {
316        let profiler = FfiProfiler::new();
317
318        for _ in 0..5 {
319            let _guard = profiler.start("multi_call");
320            thread::sleep(Duration::from_millis(5));
321        }
322
323        let stats = profiler.get_stats("multi_call").unwrap();
324        assert_eq!(stats.call_count, 5);
325        assert!(stats.avg_duration >= Duration::from_millis(5));
326    }
327
328    #[test]
329    fn test_enable_disable() {
330        let profiler = FfiProfiler::new();
331
332        profiler.disable();
333        {
334            let _guard = profiler.start("disabled");
335            thread::sleep(Duration::from_millis(5));
336        }
337
338        assert!(profiler.get_stats("disabled").is_none());
339
340        profiler.enable();
341        {
342            let _guard = profiler.start("enabled");
343            thread::sleep(Duration::from_millis(5));
344        }
345
346        assert!(profiler.get_stats("enabled").is_some());
347    }
348
349    #[test]
350    fn test_reset() {
351        let profiler = FfiProfiler::new();
352
353        {
354            let _guard = profiler.start("test");
355            thread::sleep(Duration::from_millis(5));
356        }
357
358        assert!(profiler.get_stats("test").is_some());
359
360        profiler.reset();
361        assert!(profiler.get_stats("test").is_none());
362    }
363
364    #[test]
365    fn test_hotspots() {
366        let profiler = FfiProfiler::new();
367
368        {
369            let _guard = profiler.start("fast");
370            thread::sleep(Duration::from_millis(1));
371        }
372
373        {
374            let _guard = profiler.start("slow");
375            thread::sleep(Duration::from_millis(10));
376        }
377
378        let hotspots = profiler.get_hotspots();
379        assert_eq!(hotspots.len(), 2);
380        assert_eq!(hotspots[0].name, "slow");
381        assert_eq!(hotspots[1].name, "fast");
382    }
383
384    #[test]
385    fn test_profiling_report() {
386        let profiler = FfiProfiler::new();
387
388        for i in 0..3 {
389            let _guard = profiler.start(&format!("func_{}", i));
390            thread::sleep(Duration::from_millis(5));
391        }
392
393        let report = profiler.report();
394        assert_eq!(report.total_calls, 3);
395        assert_eq!(report.function_stats.len(), 3);
396
397        let summary = report.summary();
398        assert_eq!(summary.total_calls, 3);
399        assert_eq!(summary.functions_profiled, 3);
400    }
401
402    #[test]
403    fn test_exceeds_target() {
404        let mut stats = FfiCallStats::new("test".to_string());
405        stats.record(Duration::from_micros(500));
406
407        assert!(!stats.exceeds_target(1000));
408        assert!(stats.exceeds_target(100));
409    }
410
411    #[test]
412    fn test_identify_bottlenecks() {
413        let profiler = FfiProfiler::new();
414
415        {
416            let _guard = profiler.start("fast");
417            thread::sleep(Duration::from_micros(100));
418        }
419
420        {
421            let _guard = profiler.start("slow");
422            thread::sleep(Duration::from_millis(2));
423        }
424
425        let report = profiler.report();
426        let bottlenecks = report.identify_bottlenecks(1000); // 1ms target
427
428        assert!(bottlenecks.contains(&"slow".to_string()));
429        assert!(!bottlenecks.contains(&"fast".to_string()));
430    }
431
432    #[test]
433    fn test_global_profiler() {
434        let profiler = global_profiler();
435
436        profiler.reset(); // Clear any previous stats
437
438        {
439            let _guard = profiler.start("global_test");
440            thread::sleep(Duration::from_millis(5));
441        }
442
443        let stats = profiler.get_stats("global_test");
444        assert!(stats.is_some());
445    }
446}