use std::alloc::{GlobalAlloc, Layout};
use std::any::type_name;
use std::cell::{Cell, OnceCell};
use std::fmt;
#[cfg(feature = "panic_on_next_alloc")]
use std::sync::atomic::AtomicBool;
use std::sync::atomic::{self, AtomicU64};
use std::sync::{Arc, LazyLock, Mutex};
use crate::ERR_POISONED_LOCK;
#[derive(Debug)]
pub(crate) struct PerThreadCounters {
bytes: AtomicU64,
count: AtomicU64,
}
impl PerThreadCounters {
#[inline]
const fn new() -> Self {
Self {
bytes: AtomicU64::new(0),
count: AtomicU64::new(0),
}
}
#[inline]
pub(crate) fn register_allocation(&self, bytes: u64) {
self.bytes.fetch_add(bytes, atomic::Ordering::Relaxed);
self.count.fetch_add(1, atomic::Ordering::Relaxed);
}
#[inline]
pub(crate) fn bytes(&self) -> u64 {
self.bytes.load(atomic::Ordering::Relaxed)
}
#[inline]
pub(crate) fn count(&self) -> u64 {
self.count.load(atomic::Ordering::Relaxed)
}
}
static REGISTRY: LazyLock<Mutex<Vec<Arc<PerThreadCounters>>>> =
LazyLock::new(|| Mutex::new(Vec::new()));
thread_local! {
static TLS_COUNTER_PTR: OnceCell<*const PerThreadCounters> = const { OnceCell::new() };
static TLS_INIT_GUARD: Cell<bool> = const { Cell::new(false) };
}
#[inline]
pub(crate) fn get_or_init_thread_counters() -> &'static PerThreadCounters {
TLS_COUNTER_PTR.with(|cell| {
if let Some(ptr) = cell.get() {
return unsafe { &**ptr };
}
TLS_INIT_GUARD.set(true);
let arc = Arc::new(PerThreadCounters::new());
let ptr = Arc::as_ptr(&arc);
REGISTRY.lock().expect(ERR_POISONED_LOCK).push(arc);
_ = cell.set(ptr);
TLS_INIT_GUARD.set(false);
unsafe { &*ptr }
})
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct AllocationTotals {
pub bytes: u64,
pub count: u64,
}
impl AllocationTotals {
#[inline]
pub(crate) const fn zero() -> Self {
Self { bytes: 0, count: 0 }
}
}
#[inline]
pub(crate) fn allocation_totals() -> AllocationTotals {
let reg = REGISTRY.lock().expect(ERR_POISONED_LOCK);
let mut totals = AllocationTotals::zero();
for c in reg.iter() {
totals.bytes = totals.bytes.wrapping_add(c.bytes());
totals.count = totals.count.wrapping_add(c.count());
}
totals
}
#[cfg(feature = "panic_on_next_alloc")]
static PANIC_ON_NEXT_ALLOCATION: AtomicBool = AtomicBool::new(false);
#[cfg(feature = "panic_on_next_alloc")]
pub fn panic_on_next_alloc(enabled: bool) {
PANIC_ON_NEXT_ALLOCATION.store(enabled, atomic::Ordering::Relaxed);
}
#[cfg(feature = "panic_on_next_alloc")]
fn check_and_panic_if_enabled() {
#[expect(
clippy::manual_assert,
reason = "We need to atomically swap the flag, not just check it"
)]
if PANIC_ON_NEXT_ALLOCATION.swap(false, atomic::Ordering::Relaxed) {
panic!("Memory allocation attempted while panic-on-next-allocation was enabled");
}
}
#[cfg(not(feature = "panic_on_next_alloc"))]
#[inline]
fn check_and_panic_if_enabled() {}
fn track_allocation(size: usize) {
let size_u64: u64 = size.try_into().expect("usize always fits into u64");
TLS_INIT_GUARD.with(|guard| {
if guard.get() {
return; }
let counters = get_or_init_thread_counters();
counters.register_allocation(size_u64);
});
}
#[cfg(test)]
pub(crate) fn register_fake_allocation(bytes: u64, count: u64) {
let counters = get_or_init_thread_counters();
if bytes != 0 {
counters.bytes.fetch_add(bytes, atomic::Ordering::Relaxed);
}
if count != 0 {
counters.count.fetch_add(count, atomic::Ordering::Relaxed);
}
}
pub struct Allocator<A: GlobalAlloc> {
inner: A,
}
#[cfg_attr(coverage_nightly, coverage(off))] impl<A: GlobalAlloc> fmt::Debug for Allocator<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(type_name::<Self>())
.field("inner", &"<allocator>")
.finish()
}
}
impl Allocator<std::alloc::System> {
#[must_use]
#[inline]
#[cfg_attr(coverage_nightly, coverage(off))]
pub const fn system() -> Self {
Self {
inner: std::alloc::System,
}
}
}
impl<A: GlobalAlloc> Allocator<A> {
#[must_use]
#[inline]
#[cfg_attr(coverage_nightly, coverage(off))]
pub const fn new(allocator: A) -> Self {
Self { inner: allocator }
}
}
unsafe impl<A: GlobalAlloc> GlobalAlloc for Allocator<A> {
#[inline]
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
check_and_panic_if_enabled();
track_allocation(layout.size());
unsafe { self.inner.alloc(layout) }
}
#[inline]
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
unsafe { self.inner.dealloc(ptr, layout) }
}
#[inline]
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
check_and_panic_if_enabled();
track_allocation(layout.size());
unsafe { self.inner.alloc_zeroed(layout) }
}
#[inline]
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
check_and_panic_if_enabled();
track_allocation(new_size);
unsafe { self.inner.realloc(ptr, layout, new_size) }
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use std::{iter, thread};
use super::*;
use std::panic::RefUnwindSafe;
use std::panic::UnwindSafe;
static_assertions::assert_impl_all!(Allocator<std::alloc::System>: Send, Sync);
static_assertions::assert_impl_all!(PerThreadCounters: Send, Sync);
static_assertions::assert_impl_all!(
Allocator<std::alloc::System>: UnwindSafe, RefUnwindSafe
);
#[test]
#[cfg(feature = "panic_on_next_alloc")]
fn panic_on_next_alloc_can_be_enabled_and_disabled() {
assert!(!PANIC_ON_NEXT_ALLOCATION.load(atomic::Ordering::Relaxed));
panic_on_next_alloc(true);
assert!(PANIC_ON_NEXT_ALLOCATION.load(atomic::Ordering::Relaxed));
panic_on_next_alloc(false);
assert!(!PANIC_ON_NEXT_ALLOCATION.load(atomic::Ordering::Relaxed));
}
#[test]
fn concurrent_threads_register_and_totals_reflect_all() {
const THREADS: usize = 4;
const BYTES_PER_THREAD: u64 = 100;
const COUNT_PER_THREAD: u64 = 10;
let baseline = allocation_totals();
let handles: Vec<_> = iter::repeat_with(|| {
thread::spawn(move || {
register_fake_allocation(BYTES_PER_THREAD, COUNT_PER_THREAD);
})
})
.take(THREADS)
.collect();
for handle in handles {
handle.join().unwrap();
}
let final_totals = allocation_totals();
let bytes_delta = final_totals.bytes.wrapping_sub(baseline.bytes);
let count_delta = final_totals.count.wrapping_sub(baseline.count);
assert!(bytes_delta >= THREADS as u64 * BYTES_PER_THREAD);
assert!(count_delta >= THREADS as u64 * COUNT_PER_THREAD);
}
#[test]
fn concurrent_register_and_read_totals() {
const WRITER_THREADS: usize = 4;
const ALLOCS_PER_WRITER: u64 = 10;
const BYTES_PER_ALLOC: u64 = 50;
let baseline = allocation_totals();
let reader = thread::spawn(move || {
for _ in 0..20 {
let _totals = allocation_totals();
}
});
let writers: Vec<_> = iter::repeat_with(|| {
thread::spawn(move || {
for _ in 0..ALLOCS_PER_WRITER {
register_fake_allocation(BYTES_PER_ALLOC, 1);
}
})
})
.take(WRITER_THREADS)
.collect();
for handle in writers {
handle.join().unwrap();
}
reader.join().unwrap();
let final_totals = allocation_totals();
let bytes_delta = final_totals.bytes.wrapping_sub(baseline.bytes);
assert!(bytes_delta >= WRITER_THREADS as u64 * ALLOCS_PER_WRITER * BYTES_PER_ALLOC);
}
}