alloc_track/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use dashmap::DashMap;
4#[allow(unused_imports)]
5use std::collections::HashMap;
6use std::{
7    alloc::{GlobalAlloc, Layout},
8    cell::Cell,
9    collections::BTreeMap,
10    fmt,
11    sync::atomic::{AtomicU32, AtomicUsize, Ordering},
12};
13
14#[cfg(feature = "backtrace")]
15mod backtrace_support;
16#[cfg(feature = "backtrace")]
17use backtrace_support::*;
18#[cfg(feature = "backtrace")]
19pub use backtrace_support::{BacktraceMetric, BacktraceReport, HashedBacktrace};
20
21/// next thread id incrementor
22static THREAD_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
23
24const MAX_THREADS: usize = 256;
25
26#[derive(Clone, Copy, Debug)]
27struct PointerData {
28    alloc_thread_id: usize,
29    #[cfg(feature = "backtrace")]
30    trace_hash: u64,
31}
32
33lazy_static::lazy_static! {
34    /// pointer -> data
35    static ref PTR_MAP: DashMap<usize, PointerData> = DashMap::new();
36    // backtrace -> current allocation size
37    #[cfg(feature = "backtrace")]
38    static ref TRACE_MAP: DashMap<u64, TraceInfo> = DashMap::new();
39}
40
41/// Representation of globally-accessible TLS
42struct ThreadStore {
43    #[allow(dead_code)]
44    tid: AtomicU32,
45    alloc: AtomicUsize,
46    free: [AtomicUsize; MAX_THREADS],
47}
48
49/// A layout compatible representation of `ThreadStore` that is `Copy`.
50/// Layout of `AtomicX` is guaranteed to be identical to `X`
51/// Used to initialize arrays.
52#[allow(dead_code)]
53#[derive(Clone, Copy)]
54struct ThreadStoreLocal {
55    tid: u32,
56    alloc: usize,
57    free: [usize; MAX_THREADS],
58}
59
60/// Custom psuedo-TLS implementation that allows safe global introspection
61static THREAD_STORE: [ThreadStore; MAX_THREADS] = unsafe {
62    std::mem::transmute(
63        [ThreadStoreLocal {
64            tid: 0,
65            alloc: 0,
66            free: [0usize; MAX_THREADS],
67        }; MAX_THREADS],
68    )
69};
70
71thread_local! {
72    static THREAD_ID: usize = THREAD_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
73    /// Used to avoid recursive alloc/dealloc calls for interior allocation
74    static IN_ALLOC: Cell<bool> = Cell::new(false);
75}
76
77fn enter_alloc<T>(func: impl FnOnce() -> T) -> T {
78    let current_value = IN_ALLOC.with(|x| x.get());
79    IN_ALLOC.with(|x| x.set(true));
80    let output = func();
81    IN_ALLOC.with(|x| x.set(current_value));
82    output
83}
84
85#[derive(Default, Clone, Copy, Debug, PartialEq)]
86pub enum BacktraceMode {
87    #[default]
88    /// Report no backtraces
89    None,
90    #[cfg(feature = "backtrace")]
91    /// Report backtraces with unuseful entries removed (i.e. alloc_track, allocator internals)
92    Short,
93    /// Report the full backtrace
94    #[cfg(feature = "backtrace")]
95    Full,
96}
97
98/// Global memory allocator wrapper that can track per-thread and per-backtrace memory usage.
99pub struct AllocTrack<T: GlobalAlloc> {
100    inner: T,
101    backtrace: BacktraceMode,
102}
103
104impl<T: GlobalAlloc> AllocTrack<T> {
105    pub const fn new(inner: T, backtrace: BacktraceMode) -> Self {
106        Self { inner, backtrace }
107    }
108}
109#[cfg(all(unix, feature = "fs"))]
110#[inline(always)]
111unsafe fn get_sys_tid() -> u32 {
112    libc::syscall(libc::SYS_gettid) as u32
113}
114
115#[cfg(all(windows, feature = "fs"))]
116#[inline(always)]
117unsafe fn get_sys_tid() -> u32 {
118    windows::Win32::System::Threading::GetCurrentThreadId()
119}
120
121unsafe impl<T: GlobalAlloc> GlobalAlloc for AllocTrack<T> {
122    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
123        if IN_ALLOC.with(|x| x.get()) {
124            return self.inner.alloc(layout);
125        }
126        enter_alloc(|| {
127            let size = layout.size();
128            let ptr = self.inner.alloc(layout);
129            let tid = THREAD_ID.with(|x| *x);
130            #[cfg(feature = "fs")]
131            if THREAD_STORE[tid].tid.load(Ordering::Relaxed) == 0 {
132                let os_tid = get_sys_tid();
133                THREAD_STORE[tid].tid.store(os_tid, Ordering::Relaxed);
134            }
135            THREAD_STORE[tid].alloc.fetch_add(size, Ordering::Relaxed);
136            #[cfg(feature = "backtrace")]
137            let trace = HashedBacktrace::capture(self.backtrace);
138            PTR_MAP.insert(
139                ptr as usize,
140                PointerData {
141                    alloc_thread_id: tid,
142                    #[cfg(feature = "backtrace")]
143                    trace_hash: trace.hash(),
144                },
145            );
146            #[cfg(feature = "backtrace")]
147            if !matches!(self.backtrace, BacktraceMode::None) {
148                let mut trace_info = TRACE_MAP.entry(trace.hash()).or_insert_with(|| TraceInfo {
149                    backtrace: trace,
150                    allocated: 0,
151                    freed: 0,
152                    mode: self.backtrace,
153                    allocations: 0,
154                });
155                trace_info.allocated += size as u64;
156                trace_info.allocations += 1;
157            }
158            ptr
159        })
160    }
161
162    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
163        if IN_ALLOC.with(|x| x.get()) {
164            self.inner.dealloc(ptr, layout);
165            return;
166        }
167        enter_alloc(|| {
168            let size = layout.size();
169            let (_, target) = PTR_MAP.remove(&(ptr as usize)).expect("double free");
170            #[cfg(feature = "backtrace")]
171            if !matches!(self.backtrace, BacktraceMode::None) {
172                if let Some(mut info) = TRACE_MAP.get_mut(&target.trace_hash) {
173                    info.freed += size as u64;
174                }
175            }
176            self.inner.dealloc(ptr, layout);
177            let tid = THREAD_ID.with(|x| *x);
178            THREAD_STORE[tid].free[target.alloc_thread_id].fetch_add(size, Ordering::SeqCst);
179        });
180    }
181}
182
183/// Size display helper
184pub struct Size(pub u64);
185
186impl fmt::Display for Size {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        if self.0 < 1024 {
189            write!(f, "{} B", self.0)
190        } else if self.0 < 1024 * 1024 {
191            write!(f, "{} KB", self.0 / 1024)
192        } else {
193            write!(f, "{} MB", self.0 / 1024 / 1024)
194        }
195    }
196}
197
198/// Size display helper
199pub struct SizeF64(pub f64);
200
201impl fmt::Display for SizeF64 {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        if self.0 < 1024.0 {
204            write!(f, "{:.01} B", self.0)
205        } else if self.0 < 1024.0 * 1024.0 {
206            write!(f, "{:.01} KB", self.0 / 1024.0)
207        } else {
208            write!(f, "{:.01} MB", self.0 / 1024.0 / 1024.0)
209        }
210    }
211}
212
213#[derive(Debug, Clone, Default)]
214pub struct ThreadMetric {
215    /// Total bytes allocated in this thread
216    pub total_alloc: u64,
217    /// Total bytes freed in this thread
218    pub total_did_free: u64,
219    /// Total bytes allocated in this thread that have been freed
220    pub total_freed: u64,
221    /// Total bytes allocated in this thread that are not freed
222    pub current_used: u64,
223    /// Total bytes allocated in this thread that have been freed by the given thread
224    pub freed_by_others: BTreeMap<String, u64>,
225}
226
227impl fmt::Display for ThreadMetric {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        writeln!(f, "total_alloc: {}", Size(self.total_alloc))?;
230        writeln!(f, "total_did_free: {}", Size(self.total_did_free))?;
231        writeln!(f, "total_freed: {}", Size(self.total_freed))?;
232        writeln!(f, "current_used: {}", Size(self.current_used))?;
233        for (name, size) in &self.freed_by_others {
234            writeln!(f, "freed by {}: {}", name, Size(*size))?;
235        }
236        Ok(())
237    }
238}
239
240/// A comprehensive report of all thread allocation metrics
241pub struct ThreadReport(pub BTreeMap<String, ThreadMetric>);
242
243impl fmt::Display for ThreadReport {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        for (name, metric) in &self.0 {
246            writeln!(f, "{name}:\n{metric}\n")?;
247        }
248        Ok(())
249    }
250}
251
252/// Generate a memory usage report for backtraces, if enabled
253#[cfg(feature = "backtrace")]
254pub fn backtrace_report(
255    filter: impl Fn(&crate::backtrace::Backtrace, &BacktraceMetric) -> bool,
256) -> BacktraceReport {
257    IN_ALLOC.with(|x| x.set(true));
258    let mut out = vec![];
259    for mut entry in TRACE_MAP.iter_mut() {
260        let metric = BacktraceMetric {
261            allocated: entry.allocated,
262            freed: entry.freed,
263            mode: entry.mode,
264            allocations: entry.allocations,
265        };
266        if !filter(entry.backtrace.inner(), &metric) {
267            continue;
268        }
269        entry.backtrace.inner_mut().resolve();
270        out.push((entry.backtrace.clone(), metric));
271    }
272    out.sort_by_key(|x| x.1.allocated.saturating_sub(x.1.freed) as i64);
273    IN_ALLOC.with(|x| x.set(false));
274    let out2 = out.clone();
275    IN_ALLOC.with(|x| x.set(true));
276    drop(out);
277    IN_ALLOC.with(|x| x.set(false));
278    BacktraceReport(out2)
279}
280
281#[cfg(all(unix, feature = "fs"))]
282fn os_tid_names() -> HashMap<u32, String> {
283    let mut os_tid_names: HashMap<u32, String> = HashMap::new();
284    for task in procfs::process::Process::myself().unwrap().tasks().unwrap() {
285        let task = task.unwrap();
286        os_tid_names.insert(
287            task.tid as u32,
288            std::fs::read_to_string(format!("/proc/{}/task/{}/comm", task.pid, task.tid))
289                .unwrap()
290                .trim()
291                .to_string(),
292        );
293    }
294    os_tid_names
295}
296
297#[cfg(all(windows, feature = "fs"))]
298fn os_tid_names() -> HashMap<u32, String> {
299    use std::alloc::alloc;
300    use windows::Win32::Foundation::CloseHandle;
301    use windows::Win32::System::Diagnostics::ToolHelp::{
302        Thread32First, Thread32Next, THREADENTRY32,
303    };
304    let mut os_tid_names: HashMap<u32, String> = HashMap::new();
305    unsafe {
306        let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
307        let snapshot = windows::Win32::System::Diagnostics::ToolHelp::CreateToolhelp32Snapshot(
308            windows::Win32::System::Diagnostics::ToolHelp::TH32CS_SNAPTHREAD,
309            0,
310        );
311        if let Ok(snapshot) = snapshot {
312            let mut thread_entry = alloc(Layout::new::<THREADENTRY32>()) as *mut THREADENTRY32;
313            (*thread_entry).dwSize = std::mem::size_of::<THREADENTRY32>() as u32;
314            if Thread32First(snapshot, thread_entry).as_bool() {
315                loop {
316                    if (*thread_entry).th32OwnerProcessID == process_id {
317                        let thread_handle = windows::Win32::System::Threading::OpenThread(
318                            windows::Win32::System::Threading::THREAD_QUERY_LIMITED_INFORMATION,
319                            false,
320                            (*thread_entry).th32ThreadID,
321                        );
322                        if let Ok(handle) = thread_handle {
323                            let result =
324                                windows::Win32::System::Threading::GetThreadDescription(handle);
325                            if let Ok(str) = result {
326                                os_tid_names.insert(
327                                    (*thread_entry).th32ThreadID,
328                                    str.to_string().unwrap_or("UTF-16 Error".to_string()),
329                                );
330                            } else {
331                                os_tid_names
332                                    .insert((*thread_entry).th32ThreadID, "unknown".to_string());
333                            }
334                            CloseHandle(handle);
335                        }
336                    }
337                    if !Thread32Next(snapshot, thread_entry).as_bool() {
338                        break;
339                    }
340                }
341            }
342            CloseHandle(snapshot);
343        }
344    }
345    os_tid_names
346}
347
348/// Generate a memory usage report
349/// Note that the numbers are not a synchronized snapshot, and have slight timing skew.
350pub fn thread_report() -> ThreadReport {
351    #[cfg(feature = "fs")]
352    let os_tid_names: HashMap<u32, String> = os_tid_names();
353    #[cfg(feature = "fs")]
354    let mut tid_names: HashMap<usize, &String> = HashMap::new();
355    #[cfg(feature = "fs")]
356    let get_tid_name = {
357        for (i, thread) in THREAD_STORE.iter().enumerate() {
358            let tid = thread.tid.load(Ordering::Relaxed);
359            if tid == 0 {
360                continue;
361            }
362            if let Some(name) = os_tid_names.get(&tid) {
363                tid_names.insert(i, name);
364            }
365        }
366        |id: usize| tid_names.get(&id).map(|x| &**x)
367    };
368    #[cfg(not(feature = "fs"))]
369    let get_tid_name = { move |id: usize| Some(id.to_string()) };
370
371    let mut metrics = BTreeMap::new();
372
373    for (i, thread) in THREAD_STORE.iter().enumerate() {
374        let Some(name) = get_tid_name(i) else {
375            continue;
376        };
377        let alloced = thread.alloc.load(Ordering::Relaxed) as u64;
378        let metric: &mut ThreadMetric = metrics.entry(name.into()).or_default();
379        metric.total_alloc += alloced;
380
381        let mut total_free: u64 = 0;
382        for (j, thread2) in THREAD_STORE.iter().enumerate() {
383            let Some(name) = get_tid_name(j) else {
384                continue;
385            };
386            let freed = thread2.free[i].load(Ordering::Relaxed);
387            if freed == 0 {
388                continue;
389            }
390            total_free += freed as u64;
391            *metric.freed_by_others.entry(name.into()).or_default() += freed as u64;
392        }
393        metric.total_did_free += total_free;
394        metric.total_freed += thread
395            .free
396            .iter()
397            .map(|x| x.load(Ordering::Relaxed) as u64)
398            .sum::<u64>();
399        metric.current_used += alloced.saturating_sub(total_free);
400    }
401    ThreadReport(metrics)
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    pub fn test_os_tid_names() {
410        std::thread::Builder::new()
411            .name("thread2".to_string())
412            .spawn(move || {
413                std::thread::sleep(std::time::Duration::from_secs(100));
414            })
415            .unwrap();
416
417        let os_tid_names = os_tid_names();
418        println!("{:?}", os_tid_names);
419    }
420}