1#![doc = include_str!("../README.md")]
2
3use dashmap::DashMap;
4#[allow(unused_imports)]
5use std::collections::HashMap;
6use std::{
7 alloc::{GlobalAlloc, Layout},
8 cell::Cell,
9 collections::BTreeMap,
10 fmt,
11 sync::atomic::{AtomicU32, AtomicUsize, Ordering},
12};
13
14#[cfg(feature = "backtrace")]
15mod backtrace_support;
16#[cfg(feature = "backtrace")]
17use backtrace_support::*;
18#[cfg(feature = "backtrace")]
19pub use backtrace_support::{BacktraceMetric, BacktraceReport, HashedBacktrace};
20
21static THREAD_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
23
24const MAX_THREADS: usize = 256;
25
26#[derive(Clone, Copy, Debug)]
27struct PointerData {
28 alloc_thread_id: usize,
29 #[cfg(feature = "backtrace")]
30 trace_hash: u64,
31}
32
33lazy_static::lazy_static! {
34 static ref PTR_MAP: DashMap<usize, PointerData> = DashMap::new();
36 #[cfg(feature = "backtrace")]
38 static ref TRACE_MAP: DashMap<u64, TraceInfo> = DashMap::new();
39}
40
41struct ThreadStore {
43 #[allow(dead_code)]
44 tid: AtomicU32,
45 alloc: AtomicUsize,
46 free: [AtomicUsize; MAX_THREADS],
47}
48
49#[allow(dead_code)]
53#[derive(Clone, Copy)]
54struct ThreadStoreLocal {
55 tid: u32,
56 alloc: usize,
57 free: [usize; MAX_THREADS],
58}
59
60static THREAD_STORE: [ThreadStore; MAX_THREADS] = unsafe {
62 std::mem::transmute(
63 [ThreadStoreLocal {
64 tid: 0,
65 alloc: 0,
66 free: [0usize; MAX_THREADS],
67 }; MAX_THREADS],
68 )
69};
70
71thread_local! {
72 static THREAD_ID: usize = THREAD_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
73 static IN_ALLOC: Cell<bool> = Cell::new(false);
75}
76
77fn enter_alloc<T>(func: impl FnOnce() -> T) -> T {
78 let current_value = IN_ALLOC.with(|x| x.get());
79 IN_ALLOC.with(|x| x.set(true));
80 let output = func();
81 IN_ALLOC.with(|x| x.set(current_value));
82 output
83}
84
85#[derive(Default, Clone, Copy, Debug, PartialEq)]
86pub enum BacktraceMode {
87 #[default]
88 None,
90 #[cfg(feature = "backtrace")]
91 Short,
93 #[cfg(feature = "backtrace")]
95 Full,
96}
97
98pub struct AllocTrack<T: GlobalAlloc> {
100 inner: T,
101 backtrace: BacktraceMode,
102}
103
104impl<T: GlobalAlloc> AllocTrack<T> {
105 pub const fn new(inner: T, backtrace: BacktraceMode) -> Self {
106 Self { inner, backtrace }
107 }
108}
109#[cfg(all(unix, feature = "fs"))]
110#[inline(always)]
111unsafe fn get_sys_tid() -> u32 {
112 libc::syscall(libc::SYS_gettid) as u32
113}
114
115#[cfg(all(windows, feature = "fs"))]
116#[inline(always)]
117unsafe fn get_sys_tid() -> u32 {
118 windows::Win32::System::Threading::GetCurrentThreadId()
119}
120
121unsafe impl<T: GlobalAlloc> GlobalAlloc for AllocTrack<T> {
122 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
123 if IN_ALLOC.with(|x| x.get()) {
124 return self.inner.alloc(layout);
125 }
126 enter_alloc(|| {
127 let size = layout.size();
128 let ptr = self.inner.alloc(layout);
129 let tid = THREAD_ID.with(|x| *x);
130 #[cfg(feature = "fs")]
131 if THREAD_STORE[tid].tid.load(Ordering::Relaxed) == 0 {
132 let os_tid = get_sys_tid();
133 THREAD_STORE[tid].tid.store(os_tid, Ordering::Relaxed);
134 }
135 THREAD_STORE[tid].alloc.fetch_add(size, Ordering::Relaxed);
136 #[cfg(feature = "backtrace")]
137 let trace = HashedBacktrace::capture(self.backtrace);
138 PTR_MAP.insert(
139 ptr as usize,
140 PointerData {
141 alloc_thread_id: tid,
142 #[cfg(feature = "backtrace")]
143 trace_hash: trace.hash(),
144 },
145 );
146 #[cfg(feature = "backtrace")]
147 if !matches!(self.backtrace, BacktraceMode::None) {
148 let mut trace_info = TRACE_MAP.entry(trace.hash()).or_insert_with(|| TraceInfo {
149 backtrace: trace,
150 allocated: 0,
151 freed: 0,
152 mode: self.backtrace,
153 allocations: 0,
154 });
155 trace_info.allocated += size as u64;
156 trace_info.allocations += 1;
157 }
158 ptr
159 })
160 }
161
162 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
163 if IN_ALLOC.with(|x| x.get()) {
164 self.inner.dealloc(ptr, layout);
165 return;
166 }
167 enter_alloc(|| {
168 let size = layout.size();
169 let (_, target) = PTR_MAP.remove(&(ptr as usize)).expect("double free");
170 #[cfg(feature = "backtrace")]
171 if !matches!(self.backtrace, BacktraceMode::None) {
172 if let Some(mut info) = TRACE_MAP.get_mut(&target.trace_hash) {
173 info.freed += size as u64;
174 }
175 }
176 self.inner.dealloc(ptr, layout);
177 let tid = THREAD_ID.with(|x| *x);
178 THREAD_STORE[tid].free[target.alloc_thread_id].fetch_add(size, Ordering::SeqCst);
179 });
180 }
181}
182
183pub struct Size(pub u64);
185
186impl fmt::Display for Size {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 if self.0 < 1024 {
189 write!(f, "{} B", self.0)
190 } else if self.0 < 1024 * 1024 {
191 write!(f, "{} KB", self.0 / 1024)
192 } else {
193 write!(f, "{} MB", self.0 / 1024 / 1024)
194 }
195 }
196}
197
198pub struct SizeF64(pub f64);
200
201impl fmt::Display for SizeF64 {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 if self.0 < 1024.0 {
204 write!(f, "{:.01} B", self.0)
205 } else if self.0 < 1024.0 * 1024.0 {
206 write!(f, "{:.01} KB", self.0 / 1024.0)
207 } else {
208 write!(f, "{:.01} MB", self.0 / 1024.0 / 1024.0)
209 }
210 }
211}
212
213#[derive(Debug, Clone, Default)]
214pub struct ThreadMetric {
215 pub total_alloc: u64,
217 pub total_did_free: u64,
219 pub total_freed: u64,
221 pub current_used: u64,
223 pub freed_by_others: BTreeMap<String, u64>,
225}
226
227impl fmt::Display for ThreadMetric {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 writeln!(f, "total_alloc: {}", Size(self.total_alloc))?;
230 writeln!(f, "total_did_free: {}", Size(self.total_did_free))?;
231 writeln!(f, "total_freed: {}", Size(self.total_freed))?;
232 writeln!(f, "current_used: {}", Size(self.current_used))?;
233 for (name, size) in &self.freed_by_others {
234 writeln!(f, "freed by {}: {}", name, Size(*size))?;
235 }
236 Ok(())
237 }
238}
239
240pub struct ThreadReport(pub BTreeMap<String, ThreadMetric>);
242
243impl fmt::Display for ThreadReport {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 for (name, metric) in &self.0 {
246 writeln!(f, "{name}:\n{metric}\n")?;
247 }
248 Ok(())
249 }
250}
251
252#[cfg(feature = "backtrace")]
254pub fn backtrace_report(
255 filter: impl Fn(&crate::backtrace::Backtrace, &BacktraceMetric) -> bool,
256) -> BacktraceReport {
257 IN_ALLOC.with(|x| x.set(true));
258 let mut out = vec![];
259 for mut entry in TRACE_MAP.iter_mut() {
260 let metric = BacktraceMetric {
261 allocated: entry.allocated,
262 freed: entry.freed,
263 mode: entry.mode,
264 allocations: entry.allocations,
265 };
266 if !filter(entry.backtrace.inner(), &metric) {
267 continue;
268 }
269 entry.backtrace.inner_mut().resolve();
270 out.push((entry.backtrace.clone(), metric));
271 }
272 out.sort_by_key(|x| x.1.allocated.saturating_sub(x.1.freed) as i64);
273 IN_ALLOC.with(|x| x.set(false));
274 let out2 = out.clone();
275 IN_ALLOC.with(|x| x.set(true));
276 drop(out);
277 IN_ALLOC.with(|x| x.set(false));
278 BacktraceReport(out2)
279}
280
281#[cfg(all(unix, feature = "fs"))]
282fn os_tid_names() -> HashMap<u32, String> {
283 let mut os_tid_names: HashMap<u32, String> = HashMap::new();
284 for task in procfs::process::Process::myself().unwrap().tasks().unwrap() {
285 let task = task.unwrap();
286 os_tid_names.insert(
287 task.tid as u32,
288 std::fs::read_to_string(format!("/proc/{}/task/{}/comm", task.pid, task.tid))
289 .unwrap()
290 .trim()
291 .to_string(),
292 );
293 }
294 os_tid_names
295}
296
297#[cfg(all(windows, feature = "fs"))]
298fn os_tid_names() -> HashMap<u32, String> {
299 use std::alloc::alloc;
300 use windows::Win32::Foundation::CloseHandle;
301 use windows::Win32::System::Diagnostics::ToolHelp::{
302 Thread32First, Thread32Next, THREADENTRY32,
303 };
304 let mut os_tid_names: HashMap<u32, String> = HashMap::new();
305 unsafe {
306 let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
307 let snapshot = windows::Win32::System::Diagnostics::ToolHelp::CreateToolhelp32Snapshot(
308 windows::Win32::System::Diagnostics::ToolHelp::TH32CS_SNAPTHREAD,
309 0,
310 );
311 if let Ok(snapshot) = snapshot {
312 let mut thread_entry = alloc(Layout::new::<THREADENTRY32>()) as *mut THREADENTRY32;
313 (*thread_entry).dwSize = std::mem::size_of::<THREADENTRY32>() as u32;
314 if Thread32First(snapshot, thread_entry).as_bool() {
315 loop {
316 if (*thread_entry).th32OwnerProcessID == process_id {
317 let thread_handle = windows::Win32::System::Threading::OpenThread(
318 windows::Win32::System::Threading::THREAD_QUERY_LIMITED_INFORMATION,
319 false,
320 (*thread_entry).th32ThreadID,
321 );
322 if let Ok(handle) = thread_handle {
323 let result =
324 windows::Win32::System::Threading::GetThreadDescription(handle);
325 if let Ok(str) = result {
326 os_tid_names.insert(
327 (*thread_entry).th32ThreadID,
328 str.to_string().unwrap_or("UTF-16 Error".to_string()),
329 );
330 } else {
331 os_tid_names
332 .insert((*thread_entry).th32ThreadID, "unknown".to_string());
333 }
334 CloseHandle(handle);
335 }
336 }
337 if !Thread32Next(snapshot, thread_entry).as_bool() {
338 break;
339 }
340 }
341 }
342 CloseHandle(snapshot);
343 }
344 }
345 os_tid_names
346}
347
348pub fn thread_report() -> ThreadReport {
351 #[cfg(feature = "fs")]
352 let os_tid_names: HashMap<u32, String> = os_tid_names();
353 #[cfg(feature = "fs")]
354 let mut tid_names: HashMap<usize, &String> = HashMap::new();
355 #[cfg(feature = "fs")]
356 let get_tid_name = {
357 for (i, thread) in THREAD_STORE.iter().enumerate() {
358 let tid = thread.tid.load(Ordering::Relaxed);
359 if tid == 0 {
360 continue;
361 }
362 if let Some(name) = os_tid_names.get(&tid) {
363 tid_names.insert(i, name);
364 }
365 }
366 |id: usize| tid_names.get(&id).map(|x| &**x)
367 };
368 #[cfg(not(feature = "fs"))]
369 let get_tid_name = { move |id: usize| Some(id.to_string()) };
370
371 let mut metrics = BTreeMap::new();
372
373 for (i, thread) in THREAD_STORE.iter().enumerate() {
374 let Some(name) = get_tid_name(i) else {
375 continue;
376 };
377 let alloced = thread.alloc.load(Ordering::Relaxed) as u64;
378 let metric: &mut ThreadMetric = metrics.entry(name.into()).or_default();
379 metric.total_alloc += alloced;
380
381 let mut total_free: u64 = 0;
382 for (j, thread2) in THREAD_STORE.iter().enumerate() {
383 let Some(name) = get_tid_name(j) else {
384 continue;
385 };
386 let freed = thread2.free[i].load(Ordering::Relaxed);
387 if freed == 0 {
388 continue;
389 }
390 total_free += freed as u64;
391 *metric.freed_by_others.entry(name.into()).or_default() += freed as u64;
392 }
393 metric.total_did_free += total_free;
394 metric.total_freed += thread
395 .free
396 .iter()
397 .map(|x| x.load(Ordering::Relaxed) as u64)
398 .sum::<u64>();
399 metric.current_used += alloced.saturating_sub(total_free);
400 }
401 ThreadReport(metrics)
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 pub fn test_os_tid_names() {
410 std::thread::Builder::new()
411 .name("thread2".to_string())
412 .spawn(move || {
413 std::thread::sleep(std::time::Duration::from_secs(100));
414 })
415 .unwrap();
416
417 let os_tid_names = os_tid_names();
418 println!("{:?}", os_tid_names);
419 }
420}