alloc_track/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use dashmap::DashMap;
4#[allow(unused_imports)]
5use std::collections::HashMap;
6use std::{
7    alloc::{GlobalAlloc, Layout},
8    cell::Cell,
9    collections::BTreeMap,
10    fmt,
11    sync::atomic::{AtomicU32, AtomicUsize, Ordering},
12};
13
14#[cfg(feature = "backtrace")]
15mod backtrace_support;
16#[cfg(feature = "backtrace")]
17use backtrace_support::*;
18#[cfg(feature = "backtrace")]
19pub use backtrace_support::{BacktraceMetric, BacktraceReport, HashedBacktrace};
20
21/// next thread id incrementor
22static THREAD_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
23
24/// On linux you can check your system by running `cat /proc/sys/kernel/threads-max`
25/// It's almost certain that this limit will be hit in some strange corner cases.
26const MAX_THREADS: usize = 1024;
27
28#[derive(Clone, Copy, Debug)]
29struct PointerData {
30    alloc_thread_id: usize,
31    #[cfg(feature = "backtrace")]
32    trace_hash: u64,
33}
34
35lazy_static::lazy_static! {
36    /// pointer -> data
37    static ref PTR_MAP: DashMap<usize, PointerData> = DashMap::new();
38    // backtrace -> current allocation size
39    #[cfg(feature = "backtrace")]
40    static ref TRACE_MAP: DashMap<u64, TraceInfo> = DashMap::new();
41}
42
43/// Representation of globally-accessible TLS
44struct ThreadStore {
45    #[allow(dead_code)]
46    tid: AtomicU32,
47    alloc: AtomicUsize,
48    free: [AtomicUsize; MAX_THREADS],
49}
50
51/// A layout compatible representation of `ThreadStore` that is `Copy`.
52/// Layout of `AtomicX` is guaranteed to be identical to `X`
53/// Used to initialize arrays.
54#[allow(dead_code)]
55#[derive(Clone, Copy)]
56struct ThreadStoreLocal {
57    tid: u32,
58    alloc: usize,
59    free: [usize; MAX_THREADS],
60}
61
62/// Custom psuedo-TLS implementation that allows safe global introspection
63static THREAD_STORE: [ThreadStore; MAX_THREADS] = unsafe {
64    std::mem::transmute(
65        [ThreadStoreLocal {
66            tid: 0,
67            alloc: 0,
68            free: [0usize; MAX_THREADS],
69        }; MAX_THREADS],
70    )
71};
72
73thread_local! {
74    static THREAD_ID: usize = THREAD_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
75    /// Used to avoid recursive alloc/dealloc calls for interior allocation
76    static IN_ALLOC: Cell<bool> = const { Cell::new(false) };
77}
78
79fn enter_alloc<T>(func: impl FnOnce() -> T) -> T {
80    let current_value = IN_ALLOC.with(|x| x.get());
81    IN_ALLOC.with(|x| x.set(true));
82    let output = func();
83    IN_ALLOC.with(|x| x.set(current_value));
84    output
85}
86
87#[derive(Default, Clone, Copy, Debug)]
88pub enum BacktraceMode {
89    #[default]
90    /// Report no backtraces
91    None,
92    #[cfg(feature = "backtrace")]
93    /// Report backtraces with unuseful entries removed (i.e. alloc_track, allocator internals)
94    /// Provides an additional filter function, which removes entries given the module path as an argument, when the returned value is false.
95    Short(fn(&str) -> bool),
96    /// Report the full backtrace
97    #[cfg(feature = "backtrace")]
98    Full,
99}
100
101/// Global memory allocator wrapper that can track per-thread and per-backtrace memory usage.
102pub struct AllocTrack<T: GlobalAlloc> {
103    inner: T,
104    backtrace: BacktraceMode,
105}
106
107impl<T: GlobalAlloc> AllocTrack<T> {
108    pub const fn new(inner: T, backtrace: BacktraceMode) -> Self {
109        Self { inner, backtrace }
110    }
111}
112
113#[cfg(all(target_os = "macos", feature = "fs"))]
114#[inline(always)]
115unsafe fn get_sys_tid() -> u32 {
116    0
117}
118
119#[cfg(all(target_os = "linux", feature = "fs"))]
120#[inline(always)]
121unsafe fn get_sys_tid() -> u32 {
122    libc::syscall(libc::SYS_gettid) as u32
123}
124
125#[cfg(all(windows, feature = "fs"))]
126#[inline(always)]
127unsafe fn get_sys_tid() -> u32 {
128    windows::Win32::System::Threading::GetCurrentThreadId()
129}
130
131unsafe impl<T: GlobalAlloc> GlobalAlloc for AllocTrack<T> {
132    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
133        if IN_ALLOC.with(|x| x.get()) {
134            return self.inner.alloc(layout);
135        }
136        enter_alloc(|| {
137            let size = layout.size();
138            let ptr = self.inner.alloc(layout);
139            let tid = THREAD_ID.with(|x| *x);
140            assert!(
141                tid < MAX_THREADS,
142                "Thread ID {tid} is greater than the maximum number of threads {MAX_THREADS}"
143            );
144            #[cfg(feature = "fs")]
145            if THREAD_STORE[tid].tid.load(Ordering::Relaxed) == 0 {
146                let os_tid = get_sys_tid();
147                THREAD_STORE[tid].tid.store(os_tid, Ordering::Relaxed);
148            }
149            THREAD_STORE[tid].alloc.fetch_add(size, Ordering::Relaxed);
150            #[cfg(feature = "backtrace")]
151            let trace = HashedBacktrace::capture(self.backtrace);
152            PTR_MAP.insert(
153                ptr as usize,
154                PointerData {
155                    alloc_thread_id: tid,
156                    #[cfg(feature = "backtrace")]
157                    trace_hash: trace.hash(),
158                },
159            );
160            #[cfg(feature = "backtrace")]
161            if !matches!(self.backtrace, BacktraceMode::None) {
162                let mut trace_info = TRACE_MAP.entry(trace.hash()).or_insert_with(|| TraceInfo {
163                    backtrace: trace,
164                    allocated: 0,
165                    freed: 0,
166                    mode: self.backtrace,
167                    allocations: 0,
168                });
169                trace_info.allocated += size as u64;
170                trace_info.allocations += 1;
171            }
172            ptr
173        })
174    }
175
176    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
177        if IN_ALLOC.with(|x| x.get()) {
178            self.inner.dealloc(ptr, layout);
179            return;
180        }
181        enter_alloc(|| {
182            let size = layout.size();
183            let (_, target) = PTR_MAP.remove(&(ptr as usize)).expect("double free");
184            #[cfg(feature = "backtrace")]
185            if !matches!(self.backtrace, BacktraceMode::None) {
186                if let Some(mut info) = TRACE_MAP.get_mut(&target.trace_hash) {
187                    info.freed += size as u64;
188                }
189            }
190            self.inner.dealloc(ptr, layout);
191            let tid = THREAD_ID.with(|x| *x);
192            THREAD_STORE[tid].free[target.alloc_thread_id].fetch_add(size, Ordering::SeqCst);
193        });
194    }
195}
196
197/// Size display helper
198pub struct Size(pub u64);
199
200impl fmt::Display for Size {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        if self.0 < 1024 {
203            write!(f, "{} B", self.0)
204        } else if self.0 < 1024 * 1024 {
205            write!(f, "{} KB", self.0 / 1024)
206        } else {
207            write!(f, "{} MB", self.0 / 1024 / 1024)
208        }
209    }
210}
211
212/// Size display helper
213pub struct SizeF64(pub f64);
214
215impl fmt::Display for SizeF64 {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        if self.0 < 1024.0 {
218            write!(f, "{:.01} B", self.0)
219        } else if self.0 < 1024.0 * 1024.0 {
220            write!(f, "{:.01} KB", self.0 / 1024.0)
221        } else {
222            write!(f, "{:.01} MB", self.0 / 1024.0 / 1024.0)
223        }
224    }
225}
226
227#[derive(Debug, Clone, Default)]
228pub struct ThreadMetric {
229    /// Total bytes allocated in this thread
230    pub total_alloc: u64,
231    /// Total bytes freed in this thread
232    pub total_did_free: u64,
233    /// Total bytes allocated in this thread that have been freed
234    pub total_freed: u64,
235    /// Total bytes allocated in this thread that are not freed
236    pub current_used: u64,
237    /// Total bytes allocated in this thread that have been freed by the given thread
238    pub freed_by_others: BTreeMap<String, u64>,
239}
240
241impl fmt::Display for ThreadMetric {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        writeln!(f, "total_alloc: {}", Size(self.total_alloc))?;
244        writeln!(f, "total_did_free: {}", Size(self.total_did_free))?;
245        writeln!(f, "total_freed: {}", Size(self.total_freed))?;
246        writeln!(f, "current_used: {}", Size(self.current_used))?;
247        for (name, size) in &self.freed_by_others {
248            writeln!(f, "freed by {}: {}", name, Size(*size))?;
249        }
250        Ok(())
251    }
252}
253
254/// A comprehensive report of all thread allocation metrics
255pub struct ThreadReport(pub BTreeMap<String, ThreadMetric>);
256
257impl fmt::Display for ThreadReport {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        for (name, metric) in &self.0 {
260            writeln!(f, "{name}:\n{metric}\n")?;
261        }
262        Ok(())
263    }
264}
265
266/// Generate a memory usage report for backtraces, if enabled
267#[cfg(feature = "backtrace")]
268pub fn backtrace_report(
269    filter: impl Fn(&crate::backtrace::Backtrace, &BacktraceMetric) -> bool,
270) -> BacktraceReport {
271    IN_ALLOC.with(|x| x.set(true));
272    let mut out = vec![];
273    for mut entry in TRACE_MAP.iter_mut() {
274        let metric = BacktraceMetric {
275            allocated: entry.allocated,
276            freed: entry.freed,
277            mode: entry.mode,
278            allocations: entry.allocations,
279        };
280        if !filter(entry.backtrace.inner(), &metric) {
281            continue;
282        }
283        entry.backtrace.inner_mut().resolve();
284        out.push((entry.backtrace.clone(), metric));
285    }
286    out.sort_by_key(|x| x.1.allocated.saturating_sub(x.1.freed) as i64);
287    IN_ALLOC.with(|x| x.set(false));
288    let out2 = out.clone();
289    IN_ALLOC.with(|x| x.set(true));
290    drop(out);
291    IN_ALLOC.with(|x| x.set(false));
292    BacktraceReport(out2)
293}
294
295#[cfg(all(target_os = "macos", feature = "fs"))]
296fn os_tid_names() -> HashMap<u32, String> {
297    Default::default()
298}
299
300#[cfg(all(target_os = "linux", feature = "fs"))]
301fn os_tid_names() -> HashMap<u32, String> {
302    let mut os_tid_names: HashMap<u32, String> = HashMap::new();
303    for task in procfs::process::Process::myself().unwrap().tasks().unwrap() {
304        let task = task.unwrap();
305        os_tid_names.insert(
306            task.tid as u32,
307            std::fs::read_to_string(format!("/proc/{}/task/{}/comm", task.pid, task.tid))
308                .unwrap()
309                .trim()
310                .to_string(),
311        );
312    }
313    os_tid_names
314}
315
316#[cfg(all(windows, feature = "fs"))]
317fn os_tid_names() -> HashMap<u32, String> {
318    use std::alloc::alloc;
319    use windows::Win32::Foundation::CloseHandle;
320    use windows::Win32::System::Diagnostics::ToolHelp::{
321        Thread32First, Thread32Next, THREADENTRY32,
322    };
323    let mut os_tid_names: HashMap<u32, String> = HashMap::new();
324    unsafe {
325        let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
326        let snapshot = windows::Win32::System::Diagnostics::ToolHelp::CreateToolhelp32Snapshot(
327            windows::Win32::System::Diagnostics::ToolHelp::TH32CS_SNAPTHREAD,
328            0,
329        );
330        if let Ok(snapshot) = snapshot {
331            let mut thread_entry = alloc(Layout::new::<THREADENTRY32>()) as *mut THREADENTRY32;
332            (*thread_entry).dwSize = std::mem::size_of::<THREADENTRY32>() as u32;
333            if Thread32First(snapshot, thread_entry).as_bool() {
334                loop {
335                    if (*thread_entry).th32OwnerProcessID == process_id {
336                        let thread_handle = windows::Win32::System::Threading::OpenThread(
337                            windows::Win32::System::Threading::THREAD_QUERY_LIMITED_INFORMATION,
338                            false,
339                            (*thread_entry).th32ThreadID,
340                        );
341                        if let Ok(handle) = thread_handle {
342                            let result =
343                                windows::Win32::System::Threading::GetThreadDescription(handle);
344                            if let Ok(str) = result {
345                                os_tid_names.insert(
346                                    (*thread_entry).th32ThreadID,
347                                    str.to_string().unwrap_or("UTF-16 Error".to_string()),
348                                );
349                            } else {
350                                os_tid_names
351                                    .insert((*thread_entry).th32ThreadID, "unknown".to_string());
352                            }
353                            CloseHandle(handle);
354                        }
355                    }
356                    if !Thread32Next(snapshot, thread_entry).as_bool() {
357                        break;
358                    }
359                }
360            }
361            CloseHandle(snapshot);
362        }
363    }
364    os_tid_names
365}
366
367/// Generate a memory usage report
368/// Note that the numbers are not a synchronized snapshot, and have slight timing skew.
369pub fn thread_report() -> ThreadReport {
370    #[cfg(feature = "fs")]
371    let os_tid_names: HashMap<u32, String> = os_tid_names();
372    #[cfg(feature = "fs")]
373    let mut tid_names: HashMap<usize, &String> = HashMap::new();
374    #[cfg(feature = "fs")]
375    let get_tid_name = {
376        for (i, thread) in THREAD_STORE.iter().enumerate() {
377            let tid = thread.tid.load(Ordering::Relaxed);
378            if tid == 0 {
379                continue;
380            }
381            if let Some(name) = os_tid_names.get(&tid) {
382                tid_names.insert(i, name);
383            }
384        }
385        |id: usize| tid_names.get(&id).map(|x| &**x)
386    };
387    #[cfg(not(feature = "fs"))]
388    let get_tid_name = { move |id: usize| Some(id.to_string()) };
389
390    let mut metrics = BTreeMap::new();
391
392    for (i, thread) in THREAD_STORE.iter().enumerate() {
393        let Some(name) = get_tid_name(i) else {
394            continue;
395        };
396        let alloced = thread.alloc.load(Ordering::Relaxed) as u64;
397        let metric: &mut ThreadMetric = metrics.entry(name.into()).or_default();
398        metric.total_alloc += alloced;
399
400        let mut total_free: u64 = 0;
401        for (j, thread2) in THREAD_STORE.iter().enumerate() {
402            let Some(name) = get_tid_name(j) else {
403                continue;
404            };
405            let freed = thread2.free[i].load(Ordering::Relaxed);
406            if freed == 0 {
407                continue;
408            }
409            total_free += freed as u64;
410            *metric.freed_by_others.entry(name.into()).or_default() += freed as u64;
411        }
412        metric.total_did_free += total_free;
413        metric.total_freed += thread
414            .free
415            .iter()
416            .map(|x| x.load(Ordering::Relaxed) as u64)
417            .sum::<u64>();
418        metric.current_used += alloced.saturating_sub(total_free);
419    }
420    ThreadReport(metrics)
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    pub fn test_os_tid_names() {
429        std::thread::Builder::new()
430            .name("thread2".to_string())
431            .spawn(move || {
432                std::thread::sleep(std::time::Duration::from_secs(100));
433            })
434            .unwrap();
435
436        let os_tid_names = os_tid_names();
437        println!("{:?}", os_tid_names);
438    }
439}