use core::cell::Cell;
use core::sync::atomic::{AtomicI64, Ordering};
use std::thread_local;
static GLOBAL_ALLOCATED: AtomicI64 = AtomicI64::new(0);
static GLOBAL_PEAK: AtomicI64 = AtomicI64::new(0);
const DEFAULT_THREAD_FLUSH_THRESHOLD: i64 = 1024 * 1024;
static THREAD_FLUSH_THRESHOLD: AtomicI64 = AtomicI64::new(DEFAULT_THREAD_FLUSH_THRESHOLD);
thread_local! {
static THREAD_COUNTERS: ThreadAllocationCounters = const { ThreadAllocationCounters::new() };
static THREAD_FLUSHED_SINCE_CHECK: Cell<bool> = const { Cell::new(false) };
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct GlobalAllocationStats {
pub allocated: u64,
pub peak: u64,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ThreadAllocationStats {
pub allocated: i64,
pub peak: i64,
}
struct ThreadAllocationCounters {
allocated: Cell<i64>,
peak: Cell<i64>,
last_flushed: Cell<i64>,
}
impl ThreadAllocationCounters {
const fn new() -> Self {
Self {
allocated: Cell::new(0),
peak: Cell::new(0),
last_flushed: Cell::new(0),
}
}
fn on_alloc(&self, size: i64) {
if size == 0 {
return;
}
let new_total = self.allocated.get().saturating_add(size);
self.allocated.set(new_total);
self.peak.set(self.peak.get().max(new_total));
self.flush_if_threshold_exceeded();
}
fn on_free(&self, size: i64) {
if size == 0 {
return;
}
let new_total = self.allocated.get().saturating_sub(size);
self.allocated.set(new_total);
self.flush_if_threshold_exceeded();
}
fn snapshot(&self) -> ThreadAllocationStats {
ThreadAllocationStats {
allocated: self.allocated.get(),
peak: self.peak.get(),
}
}
fn flush(&self) -> Option<(i64, i64)> {
let thread_total = self.allocated.get();
let last_flushed_total = self.last_flushed.replace(thread_total);
let delta = thread_total - last_flushed_total;
if delta != 0 {
let global_total = GLOBAL_ALLOCATED.fetch_add(delta, Ordering::Relaxed) + delta;
let peak = update_global_peak(global_total);
THREAD_FLUSHED_SINCE_CHECK.with(|flag| flag.set(true));
Some((global_total, peak))
} else {
None
}
}
fn flush_if_threshold_exceeded(&self) {
let threshold = THREAD_FLUSH_THRESHOLD.load(Ordering::Relaxed);
if threshold <= 0 {
return;
}
let current_total = self.allocated.get();
let last_flushed_total = self.last_flushed.get();
let delta = current_total - last_flushed_total;
if delta >= threshold || delta <= -threshold {
let _ = self.flush();
}
}
fn pending_delta(&self) -> i64 {
self.allocated.get() - self.last_flushed.get()
}
}
pub fn record_alloc(size: usize) {
if size == 0 {
return;
}
let size = size.min(i64::MAX as usize) as i64;
THREAD_COUNTERS.with(|counters| counters.on_alloc(size));
}
pub fn record_free(size: usize) {
if size == 0 {
return;
}
let size = size.min(i64::MAX as usize) as i64;
THREAD_COUNTERS.with(|counters| counters.on_free(size));
}
pub fn flush_thread_counters() {
THREAD_COUNTERS.with(|counters| {
counters.flush();
});
}
pub fn take_thread_flushed_since_check_flag() -> bool {
THREAD_COUNTERS.with(|_| {
THREAD_FLUSHED_SINCE_CHECK.with(|flag| {
let flagged = flag.get();
if flagged {
flag.set(false);
}
flagged
})
})
}
pub fn allocation_stats_snapshot() -> (GlobalAllocationStats, ThreadAllocationStats) {
let (thread_stats, global_total, global_peak) = THREAD_COUNTERS.with(|counters| {
let (global_total, global_peak) = match counters.flush() {
Some((total, peak)) => (total, peak),
None => load_global_counters(),
};
(counters.snapshot(), global_total, global_peak)
});
let global_stats = GlobalAllocationStats {
allocated: global_total.max(0) as u64,
peak: global_peak.max(0) as u64,
};
(global_stats, thread_stats)
}
pub fn global_allocation_stats_snapshot() -> GlobalAllocationStats {
allocation_stats_snapshot().0
}
pub fn current_thread_allocation_stats() -> ThreadAllocationStats {
allocation_stats_snapshot().1
}
pub fn thread_allocation_pending_delta() -> i64 {
THREAD_COUNTERS.with(|counters| counters.pending_delta())
}
fn update_global_peak(candidate: i64) -> i64 {
if candidate <= 0 {
return GLOBAL_PEAK.load(Ordering::Relaxed).max(0);
}
let mut observed = GLOBAL_PEAK.load(Ordering::Relaxed);
while candidate > observed {
match GLOBAL_PEAK.compare_exchange(
observed,
candidate,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return candidate,
Err(next) => observed = next,
}
}
observed.max(candidate)
}
fn load_global_counters() -> (i64, i64) {
let total = GLOBAL_ALLOCATED.load(Ordering::Relaxed);
let peak = update_global_peak(total);
(total, peak)
}
pub fn set_thread_flush_threshold(bytes: Option<u64>) {
let value = match bytes {
Some(bytes) => (bytes.min(i64::MAX as u64)) as i64,
None => DEFAULT_THREAD_FLUSH_THRESHOLD,
};
THREAD_FLUSH_THRESHOLD.store(value, Ordering::Relaxed);
}
pub fn thread_flush_threshold() -> Option<u64> {
let value = THREAD_FLUSH_THRESHOLD.load(Ordering::Relaxed);
(value > 0).then_some(value as u64)
}