1use std::collections::HashMap;
6use std::time::Instant;
7use serde::{Serialize, Deserialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct AllocationRecord {
12 pub name: String,
14 pub size: usize,
16 pub timestamp: u64,
18 pub freed: bool,
20}
21
22#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct MemoryStats {
25 pub current_usage: usize,
27 pub peak_usage: usize,
29 pub total_allocated: usize,
31 pub total_freed: usize,
33 pub allocation_count: usize,
35 pub deallocation_count: usize,
37 pub per_name_usage: HashMap<String, usize>,
39}
40
41#[derive(Debug)]
43pub struct MemoryProfiler {
44 allocations: HashMap<String, Vec<AllocationRecord>>,
46 current_usage: usize,
48 peak_usage: usize,
50 total_allocated: usize,
52 total_freed: usize,
54 allocation_count: usize,
56 deallocation_count: usize,
58 start_time: Instant,
60}
61
62impl Default for MemoryProfiler {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl MemoryProfiler {
69 pub fn new() -> Self {
71 Self {
72 allocations: HashMap::new(),
73 current_usage: 0,
74 peak_usage: 0,
75 total_allocated: 0,
76 total_freed: 0,
77 allocation_count: 0,
78 deallocation_count: 0,
79 start_time: Instant::now(),
80 }
81 }
82
83 pub fn record_alloc(&mut self, name: &str, bytes: usize) {
85 let timestamp = self.start_time.elapsed().as_micros() as u64;
86
87 let record = AllocationRecord {
88 name: name.to_string(),
89 size: bytes,
90 timestamp,
91 freed: false,
92 };
93
94 self.allocations
95 .entry(name.to_string())
96 .or_insert_with(Vec::new)
97 .push(record);
98
99 self.current_usage += bytes;
100 self.total_allocated += bytes;
101 self.allocation_count += 1;
102
103 if self.current_usage > self.peak_usage {
104 self.peak_usage = self.current_usage;
105 }
106 }
107
108 pub fn record_free(&mut self, name: &str, bytes: usize) {
110 if let Some(records) = self.allocations.get_mut(name) {
111 for record in records.iter_mut() {
113 if !record.freed && record.size == bytes {
114 record.freed = true;
115 break;
116 }
117 }
118 }
119
120 self.current_usage = self.current_usage.saturating_sub(bytes);
121 self.total_freed += bytes;
122 self.deallocation_count += 1;
123 }
124
125 pub fn current_usage(&self) -> usize {
127 self.current_usage
128 }
129
130 pub fn peak_usage(&self) -> usize {
132 self.peak_usage
133 }
134
135 pub fn total_allocated(&self) -> usize {
137 self.total_allocated
138 }
139
140 pub fn total_freed(&self) -> usize {
142 self.total_freed
143 }
144
145 pub fn stats(&self) -> MemoryStats {
147 let mut per_name_usage = HashMap::new();
148
149 for (name, records) in &self.allocations {
150 let active_bytes: usize = records
151 .iter()
152 .filter(|r| !r.freed)
153 .map(|r| r.size)
154 .sum();
155 if active_bytes > 0 {
156 per_name_usage.insert(name.clone(), active_bytes);
157 }
158 }
159
160 MemoryStats {
161 current_usage: self.current_usage,
162 peak_usage: self.peak_usage,
163 total_allocated: self.total_allocated,
164 total_freed: self.total_freed,
165 allocation_count: self.allocation_count,
166 deallocation_count: self.deallocation_count,
167 per_name_usage,
168 }
169 }
170
171 pub fn leaks(&self) -> Vec<AllocationRecord> {
173 self.allocations
174 .values()
175 .flatten()
176 .filter(|r| !r.freed)
177 .cloned()
178 .collect()
179 }
180
181 pub fn reset(&mut self) {
183 self.allocations.clear();
184 self.current_usage = 0;
185 self.peak_usage = 0;
186 self.total_allocated = 0;
187 self.total_freed = 0;
188 self.allocation_count = 0;
189 self.deallocation_count = 0;
190 self.start_time = Instant::now();
191 }
192
193 pub fn format_bytes(bytes: usize) -> String {
195 const KB: usize = 1024;
196 const MB: usize = KB * 1024;
197 const GB: usize = MB * 1024;
198
199 if bytes >= GB {
200 format!("{:.2} GB", bytes as f64 / GB as f64)
201 } else if bytes >= MB {
202 format!("{:.2} MB", bytes as f64 / MB as f64)
203 } else if bytes >= KB {
204 format!("{:.2} KB", bytes as f64 / KB as f64)
205 } else {
206 format!("{} B", bytes)
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_allocation_tracking() {
217 let mut profiler = MemoryProfiler::new();
218
219 profiler.record_alloc("tensor_a", 1024);
220 assert_eq!(profiler.current_usage(), 1024);
221
222 profiler.record_alloc("tensor_b", 2048);
223 assert_eq!(profiler.current_usage(), 3072);
224 assert_eq!(profiler.peak_usage(), 3072);
225 }
226
227 #[test]
228 fn test_deallocation() {
229 let mut profiler = MemoryProfiler::new();
230
231 profiler.record_alloc("tensor", 1000);
232 profiler.record_free("tensor", 1000);
233
234 assert_eq!(profiler.current_usage(), 0);
235 assert_eq!(profiler.peak_usage(), 1000);
236 }
237
238 #[test]
239 fn test_leak_detection() {
240 let mut profiler = MemoryProfiler::new();
241
242 profiler.record_alloc("leak1", 100);
243 profiler.record_alloc("leak2", 200);
244 profiler.record_alloc("freed", 300);
245 profiler.record_free("freed", 300);
246
247 let leaks = profiler.leaks();
248 assert_eq!(leaks.len(), 2);
249 }
250
251 #[test]
252 fn test_format_bytes() {
253 assert_eq!(MemoryProfiler::format_bytes(500), "500 B");
254 assert_eq!(MemoryProfiler::format_bytes(1024), "1.00 KB");
255 assert_eq!(MemoryProfiler::format_bytes(1024 * 1024), "1.00 MB");
256 assert_eq!(MemoryProfiler::format_bytes(1024 * 1024 * 1024), "1.00 GB");
257 }
258}