ipfrs_tensorlogic/
memory_profiler.rs

1//! Memory profiling utilities for tracking allocations and memory usage.
2//!
3//! This module provides tools for:
4//! - Tracking heap allocations
5//! - Monitoring shared memory usage
6//! - Detecting potential memory leaks
7//! - Measuring peak memory consumption
8//!
9//! # Examples
10//!
11//! ```
12//! use ipfrs_tensorlogic::MemoryProfiler;
13//!
14//! let profiler = MemoryProfiler::new();
15//!
16//! {
17//!     let _guard = profiler.start_tracking("my_operation");
18//!     // Your operation here
19//!     let data = vec![0u8; 1024 * 1024]; // 1 MB allocation
20//!     drop(data);
21//! }
22//!
23//! let stats = profiler.get_stats("my_operation").unwrap();
24//! println!("Peak memory: {} bytes", stats.peak_bytes);
25//! ```
26
27use parking_lot::RwLock;
28use std::collections::HashMap;
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31
32/// Memory usage statistics for a tracked operation
33#[derive(Debug, Clone)]
34pub struct MemoryStats {
35    /// Number of times this operation was tracked
36    pub track_count: usize,
37    /// Total bytes allocated (cumulative across all tracks)
38    pub total_bytes: usize,
39    /// Peak bytes used during any single track
40    pub peak_bytes: usize,
41    /// Average bytes per track
42    pub avg_bytes: usize,
43    /// Total duration tracked
44    pub total_duration: Duration,
45    /// Average duration per track
46    pub avg_duration: Duration,
47}
48
49impl MemoryStats {
50    fn new() -> Self {
51        Self {
52            track_count: 0,
53            total_bytes: 0,
54            peak_bytes: 0,
55            avg_bytes: 0,
56            total_duration: Duration::ZERO,
57            avg_duration: Duration::ZERO,
58        }
59    }
60
61    fn update(&mut self, bytes: usize, duration: Duration) {
62        self.track_count += 1;
63        self.total_bytes += bytes;
64        self.peak_bytes = self.peak_bytes.max(bytes);
65        self.total_duration += duration;
66
67        if self.track_count > 0 {
68            self.avg_bytes = self.total_bytes / self.track_count;
69            self.avg_duration = self.total_duration / self.track_count as u32;
70        }
71    }
72}
73
74/// A guard that tracks memory usage for the duration of its lifetime
75pub struct MemoryTrackingGuard {
76    profiler: Arc<MemoryProfiler>,
77    operation: String,
78    start_time: Instant,
79    initial_memory: usize,
80}
81
82impl Drop for MemoryTrackingGuard {
83    fn drop(&mut self) {
84        let duration = self.start_time.elapsed();
85        let current_memory = self.profiler.get_current_memory_usage();
86        let bytes_used = current_memory.saturating_sub(self.initial_memory);
87
88        let mut stats = self.profiler.stats.write();
89        let entry = stats
90            .entry(self.operation.clone())
91            .or_insert_with(MemoryStats::new);
92        entry.update(bytes_used, duration);
93    }
94}
95
96/// Memory profiler for tracking allocations and usage
97pub struct MemoryProfiler {
98    stats: Arc<RwLock<HashMap<String, MemoryStats>>>,
99}
100
101impl MemoryProfiler {
102    /// Create a new memory profiler
103    pub fn new() -> Arc<Self> {
104        Arc::new(Self {
105            stats: Arc::new(RwLock::new(HashMap::new())),
106        })
107    }
108
109    /// Start tracking memory usage for an operation
110    ///
111    /// Returns a guard that will record statistics when dropped.
112    pub fn start_tracking(self: &Arc<Self>, operation: &str) -> MemoryTrackingGuard {
113        MemoryTrackingGuard {
114            profiler: Arc::clone(self),
115            operation: operation.to_string(),
116            start_time: Instant::now(),
117            initial_memory: self.get_current_memory_usage(),
118        }
119    }
120
121    /// Get statistics for a specific operation
122    pub fn get_stats(&self, operation: &str) -> Option<MemoryStats> {
123        self.stats.read().get(operation).cloned()
124    }
125
126    /// Get all tracked statistics
127    pub fn get_all_stats(&self) -> HashMap<String, MemoryStats> {
128        self.stats.read().clone()
129    }
130
131    /// Clear all statistics
132    pub fn clear(&self) {
133        self.stats.write().clear();
134    }
135
136    /// Get current memory usage in bytes
137    ///
138    /// This is a platform-specific approximation based on available system information.
139    #[cfg(target_os = "linux")]
140    fn get_current_memory_usage(&self) -> usize {
141        // On Linux, read from /proc/self/statm
142        if let Ok(contents) = std::fs::read_to_string("/proc/self/statm") {
143            if let Some(first) = contents.split_whitespace().next() {
144                if let Ok(pages) = first.parse::<usize>() {
145                    // Each page is typically 4096 bytes
146                    return pages * 4096;
147                }
148            }
149        }
150        0
151    }
152
153    #[cfg(not(target_os = "linux"))]
154    fn get_current_memory_usage(&self) -> usize {
155        // For non-Linux systems, we can't easily get RSS without platform-specific code
156        // Return 0 as a placeholder
157        0
158    }
159
160    /// Generate a memory profiling report
161    pub fn generate_report(&self) -> MemoryProfilingReport {
162        let stats = self.get_all_stats();
163        let total_operations = stats.len();
164        let total_tracked = stats.values().map(|s| s.track_count).sum();
165        let total_bytes: usize = stats.values().map(|s| s.total_bytes).sum();
166        let max_peak = stats.values().map(|s| s.peak_bytes).max().unwrap_or(0);
167
168        let mut operations: Vec<_> = stats.into_iter().collect();
169        operations.sort_by(|a, b| b.1.peak_bytes.cmp(&a.1.peak_bytes));
170
171        MemoryProfilingReport {
172            total_operations,
173            total_tracked,
174            total_bytes,
175            max_peak_bytes: max_peak,
176            operations,
177        }
178    }
179}
180
181impl Default for MemoryProfiler {
182    fn default() -> Self {
183        Self {
184            stats: Arc::new(RwLock::new(HashMap::new())),
185        }
186    }
187}
188
189/// A comprehensive memory profiling report
190#[derive(Debug)]
191pub struct MemoryProfilingReport {
192    /// Total number of distinct operations tracked
193    pub total_operations: usize,
194    /// Total number of tracking instances
195    pub total_tracked: usize,
196    /// Total bytes allocated across all operations
197    pub total_bytes: usize,
198    /// Maximum peak memory usage across all operations
199    pub max_peak_bytes: usize,
200    /// Operations sorted by peak memory usage (descending)
201    pub operations: Vec<(String, MemoryStats)>,
202}
203
204impl MemoryProfilingReport {
205    /// Print a formatted report to stdout
206    pub fn print(&self) {
207        println!("=== Memory Profiling Report ===");
208        println!("Total operations: {}", self.total_operations);
209        println!("Total tracks: {}", self.total_tracked);
210        println!(
211            "Total bytes: {} ({:.2} MB)",
212            self.total_bytes,
213            self.total_bytes as f64 / 1024.0 / 1024.0
214        );
215        println!(
216            "Max peak: {} ({:.2} MB)",
217            self.max_peak_bytes,
218            self.max_peak_bytes as f64 / 1024.0 / 1024.0
219        );
220        println!("\nTop memory-consuming operations:");
221        println!(
222            "{:<40} {:>12} {:>12} {:>12} {:>10}",
223            "Operation", "Tracks", "Peak", "Avg", "Avg Time"
224        );
225        println!(
226            "{:-<40} {:-<12} {:-<12} {:-<12} {:-<10}",
227            "", "", "", "", ""
228        );
229
230        for (i, (name, stats)) in self.operations.iter().enumerate().take(10) {
231            println!(
232                "{:<40} {:>12} {:>12} {:>12} {:>10?}",
233                if name.len() > 40 {
234                    format!("{}...", &name[..37])
235                } else {
236                    name.clone()
237                },
238                stats.track_count,
239                format_bytes(stats.peak_bytes),
240                format_bytes(stats.avg_bytes),
241                stats.avg_duration
242            );
243            if i >= 9 {
244                break;
245            }
246        }
247    }
248}
249
250/// Format bytes in human-readable form
251fn format_bytes(bytes: usize) -> String {
252    if bytes < 1024 {
253        format!("{} B", bytes)
254    } else if bytes < 1024 * 1024 {
255        format!("{:.1} KB", bytes as f64 / 1024.0)
256    } else if bytes < 1024 * 1024 * 1024 {
257        format!("{:.1} MB", bytes as f64 / 1024.0 / 1024.0)
258    } else {
259        format!("{:.1} GB", bytes as f64 / 1024.0 / 1024.0 / 1024.0)
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_memory_profiler_basic() {
269        let profiler = MemoryProfiler::new();
270
271        {
272            let _guard = profiler.start_tracking("test_operation");
273            // Simulate some work
274            std::thread::sleep(Duration::from_millis(10));
275        }
276
277        let stats = profiler.get_stats("test_operation");
278        assert!(stats.is_some());
279
280        let stats = stats.unwrap();
281        assert_eq!(stats.track_count, 1);
282        assert!(stats.total_duration >= Duration::from_millis(10));
283    }
284
285    #[test]
286    fn test_memory_profiler_multiple_tracks() {
287        let profiler = MemoryProfiler::new();
288
289        for _ in 0..5 {
290            let _guard = profiler.start_tracking("repeated_op");
291            std::thread::sleep(Duration::from_millis(5));
292        }
293
294        let stats = profiler.get_stats("repeated_op").unwrap();
295        assert_eq!(stats.track_count, 5);
296        assert!(stats.avg_duration >= Duration::from_millis(5));
297    }
298
299    #[test]
300    fn test_memory_profiler_multiple_operations() {
301        let profiler = MemoryProfiler::new();
302
303        {
304            let _guard1 = profiler.start_tracking("op1");
305            std::thread::sleep(Duration::from_millis(5));
306        }
307
308        {
309            let _guard2 = profiler.start_tracking("op2");
310            std::thread::sleep(Duration::from_millis(10));
311        }
312
313        let all_stats = profiler.get_all_stats();
314        assert_eq!(all_stats.len(), 2);
315        assert!(all_stats.contains_key("op1"));
316        assert!(all_stats.contains_key("op2"));
317    }
318
319    #[test]
320    fn test_memory_profiler_clear() {
321        let profiler = MemoryProfiler::new();
322
323        {
324            let _guard = profiler.start_tracking("test");
325        }
326
327        assert_eq!(profiler.get_all_stats().len(), 1);
328
329        profiler.clear();
330        assert_eq!(profiler.get_all_stats().len(), 0);
331    }
332
333    #[test]
334    fn test_memory_profiler_report() {
335        let profiler = MemoryProfiler::new();
336
337        {
338            let _guard = profiler.start_tracking("op1");
339        }
340
341        {
342            let _guard = profiler.start_tracking("op2");
343        }
344
345        let report = profiler.generate_report();
346        assert_eq!(report.total_operations, 2);
347        assert_eq!(report.total_tracked, 2);
348    }
349
350    #[test]
351    fn test_format_bytes() {
352        assert_eq!(format_bytes(512), "512 B");
353        assert_eq!(format_bytes(1024), "1.0 KB");
354        assert_eq!(format_bytes(1024 * 1024), "1.0 MB");
355        assert_eq!(format_bytes(1024 * 1024 * 1024), "1.0 GB");
356    }
357}