use core::{
alloc::{GlobalAlloc, Layout},
ops::{Deref, DerefMut},
};
use std::{
cell::UnsafeCell,
ptr::null_mut,
sync::atomic::{AtomicBool, Ordering},
};
struct Spinlock<T> {
lock: AtomicBool,
data: UnsafeCell<T>,
}
struct SpinlockGuard<'a, T: 'a> {
lock: &'a Spinlock<T>,
}
unsafe impl<T: Send> Sync for Spinlock<T> {}
impl<T> Spinlock<T> {
pub const fn new(t: T) -> Spinlock<T> {
Spinlock { lock: AtomicBool::new(false), data: UnsafeCell::new(t) }
}
#[inline]
pub fn lock(&self) -> SpinlockGuard<'_, T> {
loop {
if self
.lock
.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return SpinlockGuard { lock: self };
}
while self.lock.load(Ordering::Relaxed) {
std::hint::spin_loop();
}
}
}
#[inline]
unsafe fn unlock(&self) {
self.lock.store(false, Ordering::Release);
}
}
impl<T> Deref for SpinlockGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
impl<T> DerefMut for SpinlockGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
impl<T> Drop for SpinlockGuard<'_, T> {
fn drop(&mut self) {
unsafe { self.lock.unlock() }
}
}
struct TrackingAllocatorData {
current: isize,
peak: isize,
limit: isize,
failure_handler: Option<Box<dyn Fn() + Send>>,
}
impl TrackingAllocatorData {
fn start_tracking(
mut guard: SpinlockGuard<Self>,
limit: isize,
failure_handler: Option<Box<dyn Fn() + Send>>,
) {
guard.current = 0;
guard.peak = 0;
guard.limit = limit;
let old_handler = guard.failure_handler.take();
guard.failure_handler = failure_handler;
drop(guard);
drop(old_handler);
}
fn end_tracking(mut guard: SpinlockGuard<Self>) -> isize {
let peak = guard.peak;
guard.limit = 0;
let old_handler = guard.failure_handler.take();
drop(guard);
drop(old_handler);
peak
}
#[inline]
fn track_and_check_limits(
mut guard: SpinlockGuard<Self>,
alloc: isize,
) -> Option<SpinlockGuard<Self>> {
guard.current += alloc;
if guard.current > guard.peak {
guard.peak = guard.current;
}
if guard.limit == 0 || guard.peak <= guard.limit {
None
} else {
Some(guard)
}
}
}
static ALLOCATOR_DATA: Spinlock<TrackingAllocatorData> =
Spinlock::new(TrackingAllocatorData { current: 0, peak: 0, limit: 0, failure_handler: None });
pub struct TrackingAllocator<A: GlobalAlloc>(pub A);
impl<A: GlobalAlloc> TrackingAllocator<A> {
pub unsafe fn start_tracking(
&self,
limit: Option<isize>,
failure_handler: Option<Box<dyn Fn() + Send>>,
) {
TrackingAllocatorData::start_tracking(
ALLOCATOR_DATA.lock(),
limit.unwrap_or(0),
failure_handler,
);
}
pub fn end_tracking(&self) -> isize {
TrackingAllocatorData::end_tracking(ALLOCATOR_DATA.lock())
}
}
#[cold]
#[inline(never)]
unsafe fn fail_allocation(guard: SpinlockGuard<TrackingAllocatorData>) -> *mut u8 {
if let Some(failure_handler) = &guard.failure_handler {
failure_handler()
}
null_mut()
}
unsafe impl<A: GlobalAlloc> GlobalAlloc for TrackingAllocator<A> {
#[inline]
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let guard = ALLOCATOR_DATA.lock();
if let Some(guard) =
TrackingAllocatorData::track_and_check_limits(guard, layout.size() as isize)
{
fail_allocation(guard)
} else {
self.0.alloc(layout)
}
}
#[inline]
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
let guard = ALLOCATOR_DATA.lock();
if let Some(guard) =
TrackingAllocatorData::track_and_check_limits(guard, layout.size() as isize)
{
fail_allocation(guard)
} else {
self.0.alloc_zeroed(layout)
}
}
#[inline]
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let guard = ALLOCATOR_DATA.lock();
TrackingAllocatorData::track_and_check_limits(guard, -(layout.size() as isize));
self.0.dealloc(ptr, layout)
}
#[inline]
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
let guard = ALLOCATOR_DATA.lock();
if let Some(guard) = TrackingAllocatorData::track_and_check_limits(
guard,
(new_size as isize) - (layout.size() as isize),
) {
fail_allocation(guard)
} else {
self.0.realloc(ptr, layout, new_size)
}
}
}