use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
const MAX_THREADS: usize = 64;
#[derive(Debug)]
pub struct MemorySlot {
pub thread_id: AtomicU64,
pub arena_bytes: AtomicU64,
pub peak_arena_bytes: AtomicU64,
}
impl MemorySlot {
const fn new() -> Self {
Self {
thread_id: AtomicU64::new(0),
arena_bytes: AtomicU64::new(0),
peak_arena_bytes: AtomicU64::new(0),
}
}
}
pub struct MemoryStatsRegistry {
slots: Box<[MemorySlot]>,
pub overflow_count: AtomicU64,
}
impl MemoryStatsRegistry {
fn new(capacity: usize) -> Self {
let slots: Vec<MemorySlot> = (0..capacity).map(|_| MemorySlot::new()).collect();
Self {
slots: slots.into_boxed_slice(),
overflow_count: AtomicU64::new(0),
}
}
pub fn register(&self) -> Option<usize> {
let thread_id = current_thread_id();
for (idx, slot) in self.slots.iter().enumerate() {
if slot
.thread_id
.compare_exchange(0, thread_id, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return Some(idx);
}
}
self.overflow_count.fetch_add(1, Ordering::Relaxed);
None
}
#[inline]
pub fn update_arena(&self, slot_idx: usize, arena_bytes: usize) {
if let Some(slot) = self.slots.get(slot_idx) {
let bytes = arena_bytes as u64;
slot.arena_bytes.store(bytes, Ordering::Relaxed);
let mut peak = slot.peak_arena_bytes.load(Ordering::Relaxed);
while bytes > peak {
match slot.peak_arena_bytes.compare_exchange_weak(
peak,
bytes,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(current) => peak = current,
}
}
}
}
pub fn aggregate_stats(&self) -> AggregateMemoryStats {
let mut total_arena_bytes: u64 = 0;
let mut total_peak_arena_bytes: u64 = 0;
let mut active_threads: usize = 0;
for slot in self.slots.iter() {
let thread_id = slot.thread_id.load(Ordering::Acquire);
if thread_id > 0 {
active_threads += 1;
total_arena_bytes += slot.arena_bytes.load(Ordering::Relaxed);
total_peak_arena_bytes += slot.peak_arena_bytes.load(Ordering::Relaxed);
}
}
AggregateMemoryStats {
active_threads,
total_arena_bytes,
total_peak_arena_bytes,
overflow_count: self.overflow_count.load(Ordering::Relaxed),
}
}
pub fn capacity(&self) -> usize {
self.slots.len()
}
}
#[derive(Debug, Clone, Copy)]
pub struct AggregateMemoryStats {
pub active_threads: usize,
pub total_arena_bytes: u64,
pub total_peak_arena_bytes: u64,
pub overflow_count: u64,
}
static NEXT_THREAD_ID: AtomicU64 = AtomicU64::new(1);
thread_local! {
static THIS_THREAD_ID: u64 = NEXT_THREAD_ID.fetch_add(1, Ordering::Relaxed);
}
fn current_thread_id() -> u64 {
THIS_THREAD_ID.with(|&id| id)
}
static MEMORY_REGISTRY: OnceLock<MemoryStatsRegistry> = OnceLock::new();
pub fn memory_registry() -> &'static MemoryStatsRegistry {
MEMORY_REGISTRY.get_or_init(|| MemoryStatsRegistry::new(MAX_THREADS))
}
thread_local! {
static SLOT_INDEX: std::cell::Cell<Option<usize>> = const { std::cell::Cell::new(None) };
}
pub fn get_or_register_slot() -> Option<usize> {
SLOT_INDEX.with(|cell| {
if let Some(idx) = cell.get() {
Some(idx)
} else {
let idx = memory_registry().register();
cell.set(idx);
idx
}
})
}
#[inline]
pub fn update_arena_stats(arena_bytes: usize) {
if let Some(idx) = SLOT_INDEX.with(|cell| cell.get()) {
memory_registry().update_arena(idx, arena_bytes);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_basic() {
let registry = MemoryStatsRegistry::new(4);
let slot = registry.register();
assert!(slot.is_some());
let idx = slot.unwrap();
registry.update_arena(idx, 1024);
let stats = registry.aggregate_stats();
assert_eq!(stats.active_threads, 1);
assert_eq!(stats.total_arena_bytes, 1024);
}
#[test]
fn test_registry_overflow() {
let registry = MemoryStatsRegistry::new(2);
assert!(registry.register().is_some());
assert!(registry.register().is_some());
assert_eq!(registry.overflow_count.load(Ordering::Relaxed), 0);
}
#[test]
fn test_thread_local_slot() {
let slot1 = get_or_register_slot();
let slot2 = get_or_register_slot();
assert_eq!(slot1, slot2);
}
#[test]
fn test_update_helpers() {
let slot = get_or_register_slot();
if slot.is_some() {
update_arena_stats(2048);
let stats = memory_registry().aggregate_stats();
assert!(stats.total_arena_bytes >= 2048); }
}
#[test]
fn test_concurrent_registration() {
use std::thread;
let handles: Vec<_> = (0..4)
.map(|i| {
thread::spawn(move || {
let slot = get_or_register_slot();
if slot.is_some() {
update_arena_stats(1000 * (i + 1));
}
slot.is_some()
})
})
.collect();
let mut registered_count = 0;
for h in handles {
if h.join().unwrap() {
registered_count += 1;
}
}
let stats = memory_registry().aggregate_stats();
assert!(stats.active_threads >= registered_count);
}
#[test]
fn test_thread_ids_are_unique() {
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use std::thread;
let ids = Arc::new(Mutex::new(HashSet::new()));
let handles: Vec<_> = (0..8)
.map(|_| {
let ids = Arc::clone(&ids);
thread::spawn(move || {
let id = current_thread_id();
ids.lock().unwrap().insert(id);
id
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let unique_count = ids.lock().unwrap().len();
assert_eq!(unique_count, 8, "Thread IDs should be unique");
}
}