use std::alloc::{GlobalAlloc, Layout, System};
use std::ptr::NonNull;
use std::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering};
use crate::config::SECURE_WIPE_PATTERN;
pub struct BumpAlloc {
base: NonNull<u8>,
limit: NonNull<u8>,
cursor: AtomicUsize,
is_recycled: AtomicBool,
#[cfg(feature = "fallback")]
fallback_count: AtomicUsize,
#[cfg(feature = "fallback")]
fallback_bytes: AtomicUsize,
}
impl BumpAlloc {
#[inline]
pub unsafe fn new(base: *mut u8, size: usize) -> Self {
debug_assert!(!base.is_null());
debug_assert!(size > 0);
let base_nn = NonNull::new_unchecked(base);
let limit_nn = NonNull::new_unchecked(base.add(size));
Self {
base: base_nn,
limit: limit_nn,
cursor: AtomicUsize::new(base as usize),
is_recycled: AtomicBool::new(false),
#[cfg(feature = "fallback")]
fallback_count: AtomicUsize::new(0),
#[cfg(feature = "fallback")]
fallback_bytes: AtomicUsize::new(0),
}
}
#[inline]
pub fn base_ptr(&self) -> *mut u8 {
self.base.as_ptr()
}
#[inline(always)]
pub fn alloc(&self, size: usize, align: usize) -> *mut u8 {
if size == 0 || align == 0 || !align.is_power_of_two() {
return std::ptr::null_mut();
}
loop {
let current = self.cursor.load(Ordering::Relaxed);
let aligned = match current.checked_add(align - 1) {
Some(v) => v & !(align - 1),
None => return self.handle_exhaustion(size, align),
};
let next = match aligned.checked_add(size) {
Some(v) => v,
None => return self.handle_exhaustion(size, align),
};
if next > self.limit.as_ptr() as usize {
return self.handle_exhaustion(size, align);
}
if self
.cursor
.compare_exchange_weak(current, next, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return aligned as *mut u8;
}
}
}
#[cold]
#[inline(never)]
fn handle_exhaustion(&self, size: usize, align: usize) -> *mut u8 {
#[cfg(debug_assertions)]
{
eprintln!(
"[nalloc] Arena exhausted: requested {} bytes (align {}), remaining {} bytes",
size,
align,
self.remaining()
);
}
#[cfg(feature = "fallback")]
{
let layout = match Layout::from_size_align(size, align) {
Ok(l) => l,
Err(_) => return std::ptr::null_mut(),
};
let ptr = unsafe { System.alloc(layout) };
if !ptr.is_null() {
self.fallback_count.fetch_add(1, Ordering::Relaxed);
self.fallback_bytes.fetch_add(size, Ordering::Relaxed);
#[cfg(debug_assertions)]
eprintln!("[nalloc] Fallback allocation: {} bytes", size);
}
ptr
}
#[cfg(not(feature = "fallback"))]
{
std::ptr::null_mut()
}
}
#[inline]
pub fn is_recycled(&self) -> bool {
self.is_recycled.load(Ordering::Acquire)
}
#[cfg(feature = "fallback")]
#[inline]
pub fn fallback_count(&self) -> usize {
self.fallback_count.load(Ordering::Relaxed)
}
#[cfg(feature = "fallback")]
#[inline]
pub fn fallback_bytes(&self) -> usize {
self.fallback_bytes.load(Ordering::Relaxed)
}
#[inline]
pub unsafe fn reset(&self) {
self.cursor
.store(self.base.as_ptr() as usize, Ordering::SeqCst);
self.is_recycled.store(true, Ordering::Release);
#[cfg(feature = "fallback")]
{
self.fallback_count.store(0, Ordering::Relaxed);
self.fallback_bytes.store(0, Ordering::Relaxed);
}
}
#[inline]
pub unsafe fn secure_reset(&self) {
let base = self.base.as_ptr();
let size = self.limit.as_ptr() as usize - base as usize;
Self::volatile_memset(base, SECURE_WIPE_PATTERN, size);
fence(Ordering::SeqCst);
self.reset();
}
#[inline(never)]
#[allow(unreachable_code)] unsafe fn volatile_memset(ptr: *mut u8, value: u8, len: usize) {
#[cfg(any(target_os = "linux", target_os = "android"))]
if value == 0 {
extern "C" {
fn explicit_bzero(s: *mut libc::c_void, n: libc::size_t);
}
explicit_bzero(ptr as *mut libc::c_void, len);
return;
}
#[cfg(target_vendor = "apple")]
{
extern "C" {
fn memset_s(
s: *mut libc::c_void,
smax: libc::size_t,
c: libc::c_int,
n: libc::size_t,
) -> libc::c_int;
}
let _ = memset_s(ptr as *mut libc::c_void, len, value as libc::c_int, len);
return;
}
#[cfg(target_os = "windows")]
if value == 0 {
extern "system" {
fn RtlSecureZeroMemory(ptr: *mut u8, len: usize);
}
RtlSecureZeroMemory(ptr, len);
return;
}
let ptr_usize = ptr as *mut usize;
let pattern_usize = if value == 0 {
0usize
} else {
let mut p = 0usize;
for i in 0..std::mem::size_of::<usize>() {
p |= (value as usize) << (i * 8);
}
p
};
let full_words = len / std::mem::size_of::<usize>();
let remainder = len % std::mem::size_of::<usize>();
for i in 0..full_words {
std::ptr::write_volatile(ptr_usize.add(i), pattern_usize);
}
let remainder_ptr = ptr.add(full_words * std::mem::size_of::<usize>());
for i in 0..remainder {
std::ptr::write_volatile(remainder_ptr.add(i), value);
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.limit.as_ptr() as usize - self.base.as_ptr() as usize
}
#[inline]
pub fn used(&self) -> usize {
self.cursor.load(Ordering::Relaxed) - self.base.as_ptr() as usize
}
#[inline]
pub fn remaining(&self) -> usize {
self.capacity() - self.used()
}
}
unsafe impl Send for BumpAlloc {}
unsafe impl Sync for BumpAlloc {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nonnull_safety() {
let mut buffer = vec![0u8; 1024];
let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
assert_eq!(alloc.capacity(), 1024);
assert_eq!(alloc.used(), 0);
assert_eq!(alloc.remaining(), 1024);
assert!(!alloc.is_recycled());
}
#[test]
fn test_recycled_flag() {
let mut buffer = vec![0u8; 1024];
let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
assert!(!alloc.is_recycled());
let _ = alloc.alloc(64, 8);
assert!(!alloc.is_recycled());
unsafe { alloc.reset() };
assert!(alloc.is_recycled());
}
#[test]
fn test_secure_reset_zeroes_memory() {
let mut buffer = vec![0xFFu8; 1024];
let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
let ptr = alloc.alloc(512, 8);
assert!(!ptr.is_null());
unsafe {
std::ptr::write_bytes(ptr, 0xAB, 512);
}
unsafe { alloc.secure_reset() };
for i in 0..1024 {
assert_eq!(buffer[i], 0, "Byte {} not zeroed", i);
}
}
#[test]
fn test_alignment() {
let mut buffer = vec![0u8; 4096];
let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
for align_pow in 0..8 {
let align = 1usize << align_pow;
let ptr = alloc.alloc(64, align);
assert!(!ptr.is_null());
assert_eq!((ptr as usize) % align, 0, "Alignment {} failed", align);
}
}
#[test]
#[cfg(feature = "fallback")]
fn test_fallback_allocation() {
let mut buffer = vec![0u8; 256];
let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
let _ = alloc.alloc(256, 1);
let ptr = alloc.alloc(64, 8);
assert!(!ptr.is_null(), "Fallback allocation should succeed");
assert!(alloc.fallback_count() > 0, "Fallback count should increase");
assert!(alloc.fallback_bytes() >= 64, "Fallback bytes should track");
unsafe {
System.dealloc(ptr, Layout::from_size_align(64, 8).unwrap());
}
}
#[test]
#[cfg(not(feature = "fallback"))]
fn test_exhaustion_returns_null() {
let mut buffer = vec![0u8; 256];
let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
let _ = alloc.alloc(256, 1);
let ptr = alloc.alloc(64, 8);
assert!(
ptr.is_null(),
"Should return null when exhausted without fallback"
);
}
}