#[cfg(test)]
mod tests;
use std::alloc::{alloc, dealloc, handle_alloc_error, Layout};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use std::sync::OnceLock;
#[derive(Debug, Default)]
struct MemoryStats {
allocated: AtomicUsize,
peak: AtomicUsize,
count: AtomicUsize,
}
static NON_PERSISTENT_STATS: MemoryStats = MemoryStats {
allocated: AtomicUsize::new(0),
peak: AtomicUsize::new(0),
count: AtomicUsize::new(0),
};
fn persistent_allocs() -> &'static Mutex<HashMap<usize, usize>> {
static ALLOCS: OnceLock<Mutex<HashMap<usize, usize>>> = OnceLock::new();
ALLOCS.get_or_init(|| Mutex::new(HashMap::new()))
}
fn non_persistent_allocs() -> &'static Mutex<HashMap<usize, usize>> {
static ALLOCS: OnceLock<Mutex<HashMap<usize, usize>>> = OnceLock::new();
ALLOCS.get_or_init(|| Mutex::new(HashMap::new()))
}
const ALIGNMENT: usize = 8;
fn align_size(size: usize) -> usize {
(size + ALIGNMENT - 1) & !(ALIGNMENT - 1)
}
pub unsafe fn pemalloc(size: usize, persistent: bool) -> *mut u8 {
let aligned_size = align_size(size);
let layout = match Layout::from_size_align(aligned_size, ALIGNMENT) {
Ok(l) => l,
Err(_) => unsafe {
handle_alloc_error(Layout::from_size_align_unchecked(aligned_size, ALIGNMENT))
},
};
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
handle_alloc_error(layout);
}
if persistent {
let mut allocs = persistent_allocs().lock().unwrap();
allocs.insert(ptr as usize, aligned_size);
} else {
let mut allocs = non_persistent_allocs().lock().unwrap();
allocs.insert(ptr as usize, aligned_size);
NON_PERSISTENT_STATS
.allocated
.fetch_add(aligned_size, Ordering::Relaxed);
NON_PERSISTENT_STATS.count.fetch_add(1, Ordering::Relaxed);
let current = NON_PERSISTENT_STATS.allocated.load(Ordering::Relaxed);
let peak = NON_PERSISTENT_STATS.peak.load(Ordering::Relaxed);
if current > peak {
NON_PERSISTENT_STATS.peak.store(current, Ordering::Relaxed);
}
}
ptr
}
pub unsafe fn perealloc(ptr: *mut u8, new_size: usize, persistent: bool) -> *mut u8 {
if ptr.is_null() {
return unsafe { pemalloc(new_size, persistent) };
}
let old_size = if persistent {
let mut allocs = persistent_allocs().lock().unwrap();
allocs.remove(&(ptr as usize)).unwrap_or(0)
} else {
let mut allocs = non_persistent_allocs().lock().unwrap();
allocs.remove(&(ptr as usize)).unwrap_or(0)
};
let aligned_new_size = align_size(new_size);
let aligned_old_size = align_size(old_size);
if aligned_new_size <= aligned_old_size && !persistent {
let mut allocs = non_persistent_allocs().lock().unwrap();
allocs.insert(ptr as usize, new_size);
let size_diff = aligned_old_size as isize - aligned_new_size as isize;
NON_PERSISTENT_STATS
.allocated
.fetch_sub(size_diff as usize, Ordering::Relaxed);
return ptr;
}
let new_ptr = unsafe { pemalloc(new_size, persistent) };
if !new_ptr.is_null() {
let copy_size = if aligned_new_size < aligned_old_size {
aligned_new_size
} else {
aligned_old_size
};
unsafe {
std::ptr::copy_nonoverlapping(ptr, new_ptr, copy_size);
}
}
unsafe { pefree(ptr, persistent) };
new_ptr
}
pub unsafe fn pefree(ptr: *mut u8, persistent: bool) {
if ptr.is_null() {
return;
}
if persistent {
let mut allocs = persistent_allocs().lock().unwrap();
if let Some(size) = allocs.remove(&(ptr as usize)) {
unsafe {
let layout = Layout::from_size_align_unchecked(size, ALIGNMENT);
dealloc(ptr, layout);
}
}
} else {
let mut allocs = non_persistent_allocs().lock().unwrap();
if let Some(size) = allocs.remove(&(ptr as usize)) {
unsafe {
let layout = Layout::from_size_align_unchecked(size, ALIGNMENT);
dealloc(ptr, layout);
}
NON_PERSISTENT_STATS
.allocated
.fetch_sub(size, Ordering::Relaxed);
NON_PERSISTENT_STATS.count.fetch_sub(1, Ordering::Relaxed);
}
}
}
pub fn get_memory_usage() -> usize {
NON_PERSISTENT_STATS.allocated.load(Ordering::Relaxed)
}
pub fn get_peak_memory_usage() -> usize {
NON_PERSISTENT_STATS.peak.load(Ordering::Relaxed)
}
pub fn get_allocation_count() -> usize {
NON_PERSISTENT_STATS.count.load(Ordering::Relaxed)
}