ipfrs_tensorlogic/
memory_profiler.rs1use parking_lot::RwLock;
28use std::collections::HashMap;
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31
32#[derive(Debug, Clone)]
34pub struct MemoryStats {
35 pub track_count: usize,
37 pub total_bytes: usize,
39 pub peak_bytes: usize,
41 pub avg_bytes: usize,
43 pub total_duration: Duration,
45 pub avg_duration: Duration,
47}
48
49impl MemoryStats {
50 fn new() -> Self {
51 Self {
52 track_count: 0,
53 total_bytes: 0,
54 peak_bytes: 0,
55 avg_bytes: 0,
56 total_duration: Duration::ZERO,
57 avg_duration: Duration::ZERO,
58 }
59 }
60
61 fn update(&mut self, bytes: usize, duration: Duration) {
62 self.track_count += 1;
63 self.total_bytes += bytes;
64 self.peak_bytes = self.peak_bytes.max(bytes);
65 self.total_duration += duration;
66
67 if self.track_count > 0 {
68 self.avg_bytes = self.total_bytes / self.track_count;
69 self.avg_duration = self.total_duration / self.track_count as u32;
70 }
71 }
72}
73
74pub struct MemoryTrackingGuard {
76 profiler: Arc<MemoryProfiler>,
77 operation: String,
78 start_time: Instant,
79 initial_memory: usize,
80}
81
82impl Drop for MemoryTrackingGuard {
83 fn drop(&mut self) {
84 let duration = self.start_time.elapsed();
85 let current_memory = self.profiler.get_current_memory_usage();
86 let bytes_used = current_memory.saturating_sub(self.initial_memory);
87
88 let mut stats = self.profiler.stats.write();
89 let entry = stats
90 .entry(self.operation.clone())
91 .or_insert_with(MemoryStats::new);
92 entry.update(bytes_used, duration);
93 }
94}
95
96pub struct MemoryProfiler {
98 stats: Arc<RwLock<HashMap<String, MemoryStats>>>,
99}
100
101impl MemoryProfiler {
102 pub fn new() -> Arc<Self> {
104 Arc::new(Self {
105 stats: Arc::new(RwLock::new(HashMap::new())),
106 })
107 }
108
109 pub fn start_tracking(self: &Arc<Self>, operation: &str) -> MemoryTrackingGuard {
113 MemoryTrackingGuard {
114 profiler: Arc::clone(self),
115 operation: operation.to_string(),
116 start_time: Instant::now(),
117 initial_memory: self.get_current_memory_usage(),
118 }
119 }
120
121 pub fn get_stats(&self, operation: &str) -> Option<MemoryStats> {
123 self.stats.read().get(operation).cloned()
124 }
125
126 pub fn get_all_stats(&self) -> HashMap<String, MemoryStats> {
128 self.stats.read().clone()
129 }
130
131 pub fn clear(&self) {
133 self.stats.write().clear();
134 }
135
136 #[cfg(target_os = "linux")]
140 fn get_current_memory_usage(&self) -> usize {
141 if let Ok(contents) = std::fs::read_to_string("/proc/self/statm") {
143 if let Some(first) = contents.split_whitespace().next() {
144 if let Ok(pages) = first.parse::<usize>() {
145 return pages * 4096;
147 }
148 }
149 }
150 0
151 }
152
153 #[cfg(not(target_os = "linux"))]
154 fn get_current_memory_usage(&self) -> usize {
155 0
158 }
159
160 pub fn generate_report(&self) -> MemoryProfilingReport {
162 let stats = self.get_all_stats();
163 let total_operations = stats.len();
164 let total_tracked = stats.values().map(|s| s.track_count).sum();
165 let total_bytes: usize = stats.values().map(|s| s.total_bytes).sum();
166 let max_peak = stats.values().map(|s| s.peak_bytes).max().unwrap_or(0);
167
168 let mut operations: Vec<_> = stats.into_iter().collect();
169 operations.sort_by(|a, b| b.1.peak_bytes.cmp(&a.1.peak_bytes));
170
171 MemoryProfilingReport {
172 total_operations,
173 total_tracked,
174 total_bytes,
175 max_peak_bytes: max_peak,
176 operations,
177 }
178 }
179}
180
181impl Default for MemoryProfiler {
182 fn default() -> Self {
183 Self {
184 stats: Arc::new(RwLock::new(HashMap::new())),
185 }
186 }
187}
188
189#[derive(Debug)]
191pub struct MemoryProfilingReport {
192 pub total_operations: usize,
194 pub total_tracked: usize,
196 pub total_bytes: usize,
198 pub max_peak_bytes: usize,
200 pub operations: Vec<(String, MemoryStats)>,
202}
203
204impl MemoryProfilingReport {
205 pub fn print(&self) {
207 println!("=== Memory Profiling Report ===");
208 println!("Total operations: {}", self.total_operations);
209 println!("Total tracks: {}", self.total_tracked);
210 println!(
211 "Total bytes: {} ({:.2} MB)",
212 self.total_bytes,
213 self.total_bytes as f64 / 1024.0 / 1024.0
214 );
215 println!(
216 "Max peak: {} ({:.2} MB)",
217 self.max_peak_bytes,
218 self.max_peak_bytes as f64 / 1024.0 / 1024.0
219 );
220 println!("\nTop memory-consuming operations:");
221 println!(
222 "{:<40} {:>12} {:>12} {:>12} {:>10}",
223 "Operation", "Tracks", "Peak", "Avg", "Avg Time"
224 );
225 println!(
226 "{:-<40} {:-<12} {:-<12} {:-<12} {:-<10}",
227 "", "", "", "", ""
228 );
229
230 for (i, (name, stats)) in self.operations.iter().enumerate().take(10) {
231 println!(
232 "{:<40} {:>12} {:>12} {:>12} {:>10?}",
233 if name.len() > 40 {
234 format!("{}...", &name[..37])
235 } else {
236 name.clone()
237 },
238 stats.track_count,
239 format_bytes(stats.peak_bytes),
240 format_bytes(stats.avg_bytes),
241 stats.avg_duration
242 );
243 if i >= 9 {
244 break;
245 }
246 }
247 }
248}
249
250fn format_bytes(bytes: usize) -> String {
252 if bytes < 1024 {
253 format!("{} B", bytes)
254 } else if bytes < 1024 * 1024 {
255 format!("{:.1} KB", bytes as f64 / 1024.0)
256 } else if bytes < 1024 * 1024 * 1024 {
257 format!("{:.1} MB", bytes as f64 / 1024.0 / 1024.0)
258 } else {
259 format!("{:.1} GB", bytes as f64 / 1024.0 / 1024.0 / 1024.0)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_memory_profiler_basic() {
269 let profiler = MemoryProfiler::new();
270
271 {
272 let _guard = profiler.start_tracking("test_operation");
273 std::thread::sleep(Duration::from_millis(10));
275 }
276
277 let stats = profiler.get_stats("test_operation");
278 assert!(stats.is_some());
279
280 let stats = stats.unwrap();
281 assert_eq!(stats.track_count, 1);
282 assert!(stats.total_duration >= Duration::from_millis(10));
283 }
284
285 #[test]
286 fn test_memory_profiler_multiple_tracks() {
287 let profiler = MemoryProfiler::new();
288
289 for _ in 0..5 {
290 let _guard = profiler.start_tracking("repeated_op");
291 std::thread::sleep(Duration::from_millis(5));
292 }
293
294 let stats = profiler.get_stats("repeated_op").unwrap();
295 assert_eq!(stats.track_count, 5);
296 assert!(stats.avg_duration >= Duration::from_millis(5));
297 }
298
299 #[test]
300 fn test_memory_profiler_multiple_operations() {
301 let profiler = MemoryProfiler::new();
302
303 {
304 let _guard1 = profiler.start_tracking("op1");
305 std::thread::sleep(Duration::from_millis(5));
306 }
307
308 {
309 let _guard2 = profiler.start_tracking("op2");
310 std::thread::sleep(Duration::from_millis(10));
311 }
312
313 let all_stats = profiler.get_all_stats();
314 assert_eq!(all_stats.len(), 2);
315 assert!(all_stats.contains_key("op1"));
316 assert!(all_stats.contains_key("op2"));
317 }
318
319 #[test]
320 fn test_memory_profiler_clear() {
321 let profiler = MemoryProfiler::new();
322
323 {
324 let _guard = profiler.start_tracking("test");
325 }
326
327 assert_eq!(profiler.get_all_stats().len(), 1);
328
329 profiler.clear();
330 assert_eq!(profiler.get_all_stats().len(), 0);
331 }
332
333 #[test]
334 fn test_memory_profiler_report() {
335 let profiler = MemoryProfiler::new();
336
337 {
338 let _guard = profiler.start_tracking("op1");
339 }
340
341 {
342 let _guard = profiler.start_tracking("op2");
343 }
344
345 let report = profiler.generate_report();
346 assert_eq!(report.total_operations, 2);
347 assert_eq!(report.total_tracked, 2);
348 }
349
350 #[test]
351 fn test_format_bytes() {
352 assert_eq!(format_bytes(512), "512 B");
353 assert_eq!(format_bytes(1024), "1.0 KB");
354 assert_eq!(format_bytes(1024 * 1024), "1.0 MB");
355 assert_eq!(format_bytes(1024 * 1024 * 1024), "1.0 GB");
356 }
357}