1#![doc = include_str!("../README.md")]
2
3use dashmap::DashMap;
4#[allow(unused_imports)]
5use std::collections::HashMap;
6use std::{
7 alloc::{GlobalAlloc, Layout},
8 cell::Cell,
9 collections::BTreeMap,
10 fmt,
11 sync::atomic::{AtomicU32, AtomicUsize, Ordering},
12};
13
14#[cfg(feature = "backtrace")]
15mod backtrace_support;
16#[cfg(feature = "backtrace")]
17use backtrace_support::*;
18#[cfg(feature = "backtrace")]
19pub use backtrace_support::{BacktraceMetric, BacktraceReport, HashedBacktrace};
20
21static THREAD_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
23
24const MAX_THREADS: usize = 1024;
27
28#[derive(Clone, Copy, Debug)]
29struct PointerData {
30 alloc_thread_id: usize,
31 #[cfg(feature = "backtrace")]
32 trace_hash: u64,
33}
34
35lazy_static::lazy_static! {
36 static ref PTR_MAP: DashMap<usize, PointerData> = DashMap::new();
38 #[cfg(feature = "backtrace")]
40 static ref TRACE_MAP: DashMap<u64, TraceInfo> = DashMap::new();
41}
42
43struct ThreadStore {
45 #[allow(dead_code)]
46 tid: AtomicU32,
47 alloc: AtomicUsize,
48 free: [AtomicUsize; MAX_THREADS],
49}
50
51#[allow(dead_code)]
55#[derive(Clone, Copy)]
56struct ThreadStoreLocal {
57 tid: u32,
58 alloc: usize,
59 free: [usize; MAX_THREADS],
60}
61
62static THREAD_STORE: [ThreadStore; MAX_THREADS] = unsafe {
64 std::mem::transmute(
65 [ThreadStoreLocal {
66 tid: 0,
67 alloc: 0,
68 free: [0usize; MAX_THREADS],
69 }; MAX_THREADS],
70 )
71};
72
73thread_local! {
74 static THREAD_ID: usize = THREAD_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
75 static IN_ALLOC: Cell<bool> = const { Cell::new(false) };
77}
78
79fn enter_alloc<T>(func: impl FnOnce() -> T) -> T {
80 let current_value = IN_ALLOC.with(|x| x.get());
81 IN_ALLOC.with(|x| x.set(true));
82 let output = func();
83 IN_ALLOC.with(|x| x.set(current_value));
84 output
85}
86
87#[derive(Default, Clone, Copy, Debug)]
88pub enum BacktraceMode {
89 #[default]
90 None,
92 #[cfg(feature = "backtrace")]
93 Short(fn(&str) -> bool),
96 #[cfg(feature = "backtrace")]
98 Full,
99}
100
101pub struct AllocTrack<T: GlobalAlloc> {
103 inner: T,
104 backtrace: BacktraceMode,
105}
106
107impl<T: GlobalAlloc> AllocTrack<T> {
108 pub const fn new(inner: T, backtrace: BacktraceMode) -> Self {
109 Self { inner, backtrace }
110 }
111}
112
113#[cfg(all(target_os = "macos", feature = "fs"))]
114#[inline(always)]
115unsafe fn get_sys_tid() -> u32 {
116 0
117}
118
119#[cfg(all(target_os = "linux", feature = "fs"))]
120#[inline(always)]
121unsafe fn get_sys_tid() -> u32 {
122 libc::syscall(libc::SYS_gettid) as u32
123}
124
125#[cfg(all(windows, feature = "fs"))]
126#[inline(always)]
127unsafe fn get_sys_tid() -> u32 {
128 windows::Win32::System::Threading::GetCurrentThreadId()
129}
130
131unsafe impl<T: GlobalAlloc> GlobalAlloc for AllocTrack<T> {
132 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
133 if IN_ALLOC.with(|x| x.get()) {
134 return self.inner.alloc(layout);
135 }
136 enter_alloc(|| {
137 let size = layout.size();
138 let ptr = self.inner.alloc(layout);
139 let tid = THREAD_ID.with(|x| *x);
140 assert!(
141 tid < MAX_THREADS,
142 "Thread ID {tid} is greater than the maximum number of threads {MAX_THREADS}"
143 );
144 #[cfg(feature = "fs")]
145 if THREAD_STORE[tid].tid.load(Ordering::Relaxed) == 0 {
146 let os_tid = get_sys_tid();
147 THREAD_STORE[tid].tid.store(os_tid, Ordering::Relaxed);
148 }
149 THREAD_STORE[tid].alloc.fetch_add(size, Ordering::Relaxed);
150 #[cfg(feature = "backtrace")]
151 let trace = HashedBacktrace::capture(self.backtrace);
152 PTR_MAP.insert(
153 ptr as usize,
154 PointerData {
155 alloc_thread_id: tid,
156 #[cfg(feature = "backtrace")]
157 trace_hash: trace.hash(),
158 },
159 );
160 #[cfg(feature = "backtrace")]
161 if !matches!(self.backtrace, BacktraceMode::None) {
162 let mut trace_info = TRACE_MAP.entry(trace.hash()).or_insert_with(|| TraceInfo {
163 backtrace: trace,
164 allocated: 0,
165 freed: 0,
166 mode: self.backtrace,
167 allocations: 0,
168 });
169 trace_info.allocated += size as u64;
170 trace_info.allocations += 1;
171 }
172 ptr
173 })
174 }
175
176 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
177 if IN_ALLOC.with(|x| x.get()) {
178 self.inner.dealloc(ptr, layout);
179 return;
180 }
181 enter_alloc(|| {
182 let size = layout.size();
183 let (_, target) = PTR_MAP.remove(&(ptr as usize)).expect("double free");
184 #[cfg(feature = "backtrace")]
185 if !matches!(self.backtrace, BacktraceMode::None) {
186 if let Some(mut info) = TRACE_MAP.get_mut(&target.trace_hash) {
187 info.freed += size as u64;
188 }
189 }
190 self.inner.dealloc(ptr, layout);
191 let tid = THREAD_ID.with(|x| *x);
192 THREAD_STORE[tid].free[target.alloc_thread_id].fetch_add(size, Ordering::SeqCst);
193 });
194 }
195}
196
197pub struct Size(pub u64);
199
200impl fmt::Display for Size {
201 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202 if self.0 < 1024 {
203 write!(f, "{} B", self.0)
204 } else if self.0 < 1024 * 1024 {
205 write!(f, "{} KB", self.0 / 1024)
206 } else {
207 write!(f, "{} MB", self.0 / 1024 / 1024)
208 }
209 }
210}
211
212pub struct SizeF64(pub f64);
214
215impl fmt::Display for SizeF64 {
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 if self.0 < 1024.0 {
218 write!(f, "{:.01} B", self.0)
219 } else if self.0 < 1024.0 * 1024.0 {
220 write!(f, "{:.01} KB", self.0 / 1024.0)
221 } else {
222 write!(f, "{:.01} MB", self.0 / 1024.0 / 1024.0)
223 }
224 }
225}
226
227#[derive(Debug, Clone, Default)]
228pub struct ThreadMetric {
229 pub total_alloc: u64,
231 pub total_did_free: u64,
233 pub total_freed: u64,
235 pub current_used: u64,
237 pub freed_by_others: BTreeMap<String, u64>,
239}
240
241impl fmt::Display for ThreadMetric {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 writeln!(f, "total_alloc: {}", Size(self.total_alloc))?;
244 writeln!(f, "total_did_free: {}", Size(self.total_did_free))?;
245 writeln!(f, "total_freed: {}", Size(self.total_freed))?;
246 writeln!(f, "current_used: {}", Size(self.current_used))?;
247 for (name, size) in &self.freed_by_others {
248 writeln!(f, "freed by {}: {}", name, Size(*size))?;
249 }
250 Ok(())
251 }
252}
253
254pub struct ThreadReport(pub BTreeMap<String, ThreadMetric>);
256
257impl fmt::Display for ThreadReport {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 for (name, metric) in &self.0 {
260 writeln!(f, "{name}:\n{metric}\n")?;
261 }
262 Ok(())
263 }
264}
265
266#[cfg(feature = "backtrace")]
268pub fn backtrace_report(
269 filter: impl Fn(&crate::backtrace::Backtrace, &BacktraceMetric) -> bool,
270) -> BacktraceReport {
271 IN_ALLOC.with(|x| x.set(true));
272 let mut out = vec![];
273 for mut entry in TRACE_MAP.iter_mut() {
274 let metric = BacktraceMetric {
275 allocated: entry.allocated,
276 freed: entry.freed,
277 mode: entry.mode,
278 allocations: entry.allocations,
279 };
280 if !filter(entry.backtrace.inner(), &metric) {
281 continue;
282 }
283 entry.backtrace.inner_mut().resolve();
284 out.push((entry.backtrace.clone(), metric));
285 }
286 out.sort_by_key(|x| x.1.allocated.saturating_sub(x.1.freed) as i64);
287 IN_ALLOC.with(|x| x.set(false));
288 let out2 = out.clone();
289 IN_ALLOC.with(|x| x.set(true));
290 drop(out);
291 IN_ALLOC.with(|x| x.set(false));
292 BacktraceReport(out2)
293}
294
295#[cfg(all(target_os = "macos", feature = "fs"))]
296fn os_tid_names() -> HashMap<u32, String> {
297 Default::default()
298}
299
300#[cfg(all(target_os = "linux", feature = "fs"))]
301fn os_tid_names() -> HashMap<u32, String> {
302 let mut os_tid_names: HashMap<u32, String> = HashMap::new();
303 for task in procfs::process::Process::myself().unwrap().tasks().unwrap() {
304 let task = task.unwrap();
305 os_tid_names.insert(
306 task.tid as u32,
307 std::fs::read_to_string(format!("/proc/{}/task/{}/comm", task.pid, task.tid))
308 .unwrap()
309 .trim()
310 .to_string(),
311 );
312 }
313 os_tid_names
314}
315
316#[cfg(all(windows, feature = "fs"))]
317fn os_tid_names() -> HashMap<u32, String> {
318 use std::alloc::alloc;
319 use windows::Win32::Foundation::CloseHandle;
320 use windows::Win32::System::Diagnostics::ToolHelp::{
321 Thread32First, Thread32Next, THREADENTRY32,
322 };
323 let mut os_tid_names: HashMap<u32, String> = HashMap::new();
324 unsafe {
325 let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
326 let snapshot = windows::Win32::System::Diagnostics::ToolHelp::CreateToolhelp32Snapshot(
327 windows::Win32::System::Diagnostics::ToolHelp::TH32CS_SNAPTHREAD,
328 0,
329 );
330 if let Ok(snapshot) = snapshot {
331 let mut thread_entry = alloc(Layout::new::<THREADENTRY32>()) as *mut THREADENTRY32;
332 (*thread_entry).dwSize = std::mem::size_of::<THREADENTRY32>() as u32;
333 if Thread32First(snapshot, thread_entry).as_bool() {
334 loop {
335 if (*thread_entry).th32OwnerProcessID == process_id {
336 let thread_handle = windows::Win32::System::Threading::OpenThread(
337 windows::Win32::System::Threading::THREAD_QUERY_LIMITED_INFORMATION,
338 false,
339 (*thread_entry).th32ThreadID,
340 );
341 if let Ok(handle) = thread_handle {
342 let result =
343 windows::Win32::System::Threading::GetThreadDescription(handle);
344 if let Ok(str) = result {
345 os_tid_names.insert(
346 (*thread_entry).th32ThreadID,
347 str.to_string().unwrap_or("UTF-16 Error".to_string()),
348 );
349 } else {
350 os_tid_names
351 .insert((*thread_entry).th32ThreadID, "unknown".to_string());
352 }
353 CloseHandle(handle);
354 }
355 }
356 if !Thread32Next(snapshot, thread_entry).as_bool() {
357 break;
358 }
359 }
360 }
361 CloseHandle(snapshot);
362 }
363 }
364 os_tid_names
365}
366
367pub fn thread_report() -> ThreadReport {
370 #[cfg(feature = "fs")]
371 let os_tid_names: HashMap<u32, String> = os_tid_names();
372 #[cfg(feature = "fs")]
373 let mut tid_names: HashMap<usize, &String> = HashMap::new();
374 #[cfg(feature = "fs")]
375 let get_tid_name = {
376 for (i, thread) in THREAD_STORE.iter().enumerate() {
377 let tid = thread.tid.load(Ordering::Relaxed);
378 if tid == 0 {
379 continue;
380 }
381 if let Some(name) = os_tid_names.get(&tid) {
382 tid_names.insert(i, name);
383 }
384 }
385 |id: usize| tid_names.get(&id).map(|x| &**x)
386 };
387 #[cfg(not(feature = "fs"))]
388 let get_tid_name = { move |id: usize| Some(id.to_string()) };
389
390 let mut metrics = BTreeMap::new();
391
392 for (i, thread) in THREAD_STORE.iter().enumerate() {
393 let Some(name) = get_tid_name(i) else {
394 continue;
395 };
396 let alloced = thread.alloc.load(Ordering::Relaxed) as u64;
397 let metric: &mut ThreadMetric = metrics.entry(name.into()).or_default();
398 metric.total_alloc += alloced;
399
400 let mut total_free: u64 = 0;
401 for (j, thread2) in THREAD_STORE.iter().enumerate() {
402 let Some(name) = get_tid_name(j) else {
403 continue;
404 };
405 let freed = thread2.free[i].load(Ordering::Relaxed);
406 if freed == 0 {
407 continue;
408 }
409 total_free += freed as u64;
410 *metric.freed_by_others.entry(name.into()).or_default() += freed as u64;
411 }
412 metric.total_did_free += total_free;
413 metric.total_freed += thread
414 .free
415 .iter()
416 .map(|x| x.load(Ordering::Relaxed) as u64)
417 .sum::<u64>();
418 metric.current_used += alloced.saturating_sub(total_free);
419 }
420 ThreadReport(metrics)
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 pub fn test_os_tid_names() {
429 std::thread::Builder::new()
430 .name("thread2".to_string())
431 .spawn(move || {
432 std::thread::sleep(std::time::Duration::from_secs(100));
433 })
434 .unwrap();
435
436 let os_tid_names = os_tid_names();
437 println!("{:?}", os_tid_names);
438 }
439}