#[cfg(debug_assertions)]
use core::cell::RefCell;
use core::sync::atomic::AtomicUsize;
use core::sync::atomic::Ordering;
#[cfg(debug_assertions)]
use std::collections::HashSet;
use crate::ffi::polyplug_host_alloc;
use crate::ffi::polyplug_host_free;
thread_local! {
static TLS_ALLOC_COUNT: AtomicUsize = const { AtomicUsize::new(0) };
static TLS_FREE_COUNT: AtomicUsize = const { AtomicUsize::new(0) };
}
#[cfg(debug_assertions)]
thread_local! {
static TLS_LIVE_ADDRS: RefCell<HashSet<usize>> = RefCell::new(HashSet::new());
}
unsafe extern "C" fn tracking_alloc(size: usize, align: usize) -> *mut u8 {
TLS_ALLOC_COUNT.with(|c| c.fetch_add(1, Ordering::SeqCst));
let ptr: *mut u8 = polyplug_host_alloc(size, align);
#[cfg(debug_assertions)]
if !ptr.is_null() {
TLS_LIVE_ADDRS.with(|s| {
s.borrow_mut().insert(ptr as usize);
});
}
ptr
}
unsafe extern "C" fn tracking_free(ptr: *mut u8, size: usize, align: usize) {
TLS_FREE_COUNT.with(|c| c.fetch_add(1, Ordering::SeqCst));
#[cfg(debug_assertions)]
{
let addr: usize = ptr as usize;
TLS_LIVE_ADDRS.with(|s| {
if !s.borrow_mut().remove(&addr) {
eprintln!(
"TrackingAllocator: double-free detected at address {:#x}",
addr
);
#[allow(clippy::std_instead_of_core)]
std::process::abort();
}
});
}
unsafe { polyplug_host_free(ptr, size, align) }
}
pub struct TrackingAllocator;
impl TrackingAllocator {
pub fn new() -> TrackingAllocator {
TLS_ALLOC_COUNT.with(|c| c.store(0, Ordering::SeqCst));
TLS_FREE_COUNT.with(|c| c.store(0, Ordering::SeqCst));
#[cfg(debug_assertions)]
TLS_LIVE_ADDRS.with(|s| s.borrow_mut().clear());
TrackingAllocator
}
pub fn alloc_fn(&self) -> unsafe extern "C" fn(usize, usize) -> *mut u8 {
tracking_alloc
}
pub fn free_fn(&self) -> unsafe extern "C" fn(*mut u8, usize, usize) {
tracking_free
}
pub fn alloc_count(&self) -> usize {
TLS_ALLOC_COUNT.with(|c| c.load(Ordering::SeqCst))
}
pub fn free_count(&self) -> usize {
TLS_FREE_COUNT.with(|c| c.load(Ordering::SeqCst))
}
pub fn assert_no_leaks(&self) {
let a: usize = self.alloc_count();
let f: usize = self.free_count();
if a != f {
panic!(
"TrackingAllocator: leak detected: {} allocs, {} frees",
a, f
);
}
}
}
impl Default for TrackingAllocator {
fn default() -> TrackingAllocator {
TrackingAllocator::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tracking_allocator_counts_alloc_and_free() {
let tracker: TrackingAllocator = TrackingAllocator::new();
let alloc: unsafe extern "C" fn(usize, usize) -> *mut u8 = tracker.alloc_fn();
let free: unsafe extern "C" fn(*mut u8, usize, usize) = tracker.free_fn();
let ptr: *mut u8 = unsafe { alloc(64, 1) };
assert!(!ptr.is_null());
assert_eq!(tracker.alloc_count(), 1);
unsafe { free(ptr, 64, 1) };
assert_eq!(tracker.free_count(), 1);
tracker.assert_no_leaks();
}
#[test]
fn assert_no_leaks_panics_on_mismatch() {
let tracker: TrackingAllocator = TrackingAllocator::new();
let alloc: unsafe extern "C" fn(usize, usize) -> *mut u8 = tracker.alloc_fn();
let free: unsafe extern "C" fn(*mut u8, usize, usize) = tracker.free_fn();
let ptr: *mut u8 = unsafe { alloc(64, 1) };
let outcome: std::thread::Result<()> =
std::panic::catch_unwind(core::panic::AssertUnwindSafe(|| tracker.assert_no_leaks()));
assert!(
outcome.is_err(),
"assert_no_leaks must panic while an allocation is outstanding"
);
unsafe { free(ptr, 64, 1) };
tracker.assert_no_leaks();
}
}