use std::alloc::{GlobalAlloc, Layout};
use std::cell::Cell;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Once;
mod syscall;
const DEFAULT_SLAB_GB: usize = 8;
const SLACK: usize = 4;
#[derive(Debug)]
pub struct ZkAllocator;
static SLAB_SIZE: AtomicUsize = AtomicUsize::new(0);
static GENERATION: AtomicUsize = AtomicUsize::new(0);
static ARENA_ACTIVE: AtomicBool = AtomicBool::new(false);
static REGION_BASE: AtomicUsize = AtomicUsize::new(0);
static REGION_SIZE: AtomicUsize = AtomicUsize::new(0);
static REGION_INIT: Once = Once::new();
static THREAD_IDX: AtomicUsize = AtomicUsize::new(0);
static MAX_THREADS: AtomicUsize = AtomicUsize::new(0);
static OVERFLOW_COUNT: AtomicUsize = AtomicUsize::new(0);
static OVERFLOW_BYTES: AtomicUsize = AtomicUsize::new(0);
const DEFAULT_MIN_ARENA_BYTES: usize = 4096;
static MIN_ARENA_BYTES: AtomicUsize = AtomicUsize::new(DEFAULT_MIN_ARENA_BYTES);
thread_local! {
static ARENA_PTR: Cell<usize> = const { Cell::new(0) };
static ARENA_END: Cell<usize> = const { Cell::new(0) };
static ARENA_BASE: Cell<usize> = const { Cell::new(0) };
static ARENA_GEN: Cell<usize> = const { Cell::new(0) };
static ARENA_NO_SLAB: Cell<bool> = const { Cell::new(false) };
}
fn ensure_region() -> usize {
REGION_INIT.call_once(|| {
let slab_gb = std::env::var("ZK_ALLOC_SLAB_GB")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_SLAB_GB);
let slab_size = slab_gb << 30;
SLAB_SIZE.store(slab_size, Ordering::Release);
if let Ok(s) = std::env::var("ZK_ALLOC_MIN_BYTES") {
if let Ok(n) = s.parse::<usize>() {
MIN_ARENA_BYTES.store(n, Ordering::Release);
}
}
let cpus = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(8);
let max_threads = cpus + SLACK;
let region_size = slab_size * max_threads;
let ptr = unsafe { syscall::mmap_anonymous(region_size) };
if ptr.is_null() {
std::process::abort();
}
unsafe { syscall::madvise(ptr, region_size, syscall::MADV_NOHUGEPAGE) };
MAX_THREADS.store(max_threads, Ordering::Release);
REGION_SIZE.store(region_size, Ordering::Release);
REGION_BASE.store(ptr as usize, Ordering::Release);
});
REGION_BASE.load(Ordering::Acquire)
}
pub fn begin_phase() {
ensure_region();
let prev_active = ARENA_ACTIVE.swap(true, Ordering::Release);
assert!(
!prev_active,
"begin_phase() called while another phase is already active — phases must not nest"
);
GENERATION.fetch_add(1, Ordering::Release);
}
pub fn end_phase() {
ARENA_ACTIVE.store(false, Ordering::Release);
#[cfg(feature = "rayon-flush")]
flush_rayon();
}
#[cfg(feature = "rayon-flush")]
fn flush_rayon() {
const FLUSH_JOBS: usize = 256;
for _ in 0..FLUSH_JOBS {
rayon::join(|| {}, || {});
}
}
pub struct PhaseGuard {
_private: (),
}
impl PhaseGuard {
pub fn new() -> Self {
begin_phase();
Self { _private: () }
}
}
impl Default for PhaseGuard {
fn default() -> Self {
Self::new()
}
}
impl Drop for PhaseGuard {
fn drop(&mut self) {
end_phase();
}
}
pub fn phase<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = PhaseGuard::new();
f()
}
pub fn overflow_stats() -> (usize, usize) {
(
OVERFLOW_COUNT.load(Ordering::Relaxed),
OVERFLOW_BYTES.load(Ordering::Relaxed),
)
}
pub fn reset_overflow_stats() {
OVERFLOW_COUNT.store(0, Ordering::Relaxed);
OVERFLOW_BYTES.store(0, Ordering::Relaxed);
}
pub fn slab_size() -> usize {
SLAB_SIZE.load(Ordering::Relaxed)
}
pub fn min_arena_bytes() -> usize {
MIN_ARENA_BYTES.load(Ordering::Relaxed)
}
#[cold]
#[inline(never)]
unsafe fn arena_alloc_cold(size: usize, align: usize) -> *mut u8 {
let generation = GENERATION.load(Ordering::Relaxed);
if !ARENA_NO_SLAB.get() && ARENA_GEN.get() != generation {
let mut base = ARENA_BASE.get();
if base == 0 {
let region = ensure_region();
let max = MAX_THREADS.load(Ordering::Relaxed);
let idx = THREAD_IDX.fetch_add(1, Ordering::Relaxed);
if idx >= max {
ARENA_NO_SLAB.set(true);
return unsafe {
std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align))
};
}
let slab_size = SLAB_SIZE.load(Ordering::Relaxed);
base = region + idx * slab_size;
ARENA_BASE.set(base);
ARENA_END.set(base + slab_size);
}
ARENA_PTR.set(base);
ARENA_GEN.set(generation);
let aligned = (base + align - 1) & !(align - 1);
let new_ptr = aligned + size;
if new_ptr <= ARENA_END.get() {
ARENA_PTR.set(new_ptr);
return aligned as *mut u8;
}
}
OVERFLOW_COUNT.fetch_add(1, Ordering::Relaxed);
OVERFLOW_BYTES.fetch_add(size, Ordering::Relaxed);
unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) }
}
unsafe impl GlobalAlloc for ZkAllocator {
#[inline(always)]
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
if ARENA_ACTIVE.load(Ordering::Relaxed) {
let min_bytes = MIN_ARENA_BYTES.load(Ordering::Relaxed);
if min_bytes != 0 && layout.size() < min_bytes {
return unsafe { std::alloc::System.alloc(layout) };
}
let generation = GENERATION.load(Ordering::Relaxed);
if ARENA_GEN.get() == generation {
let ptr = ARENA_PTR.get();
let aligned = (ptr + layout.align() - 1) & !(layout.align() - 1);
let new_ptr = aligned + layout.size();
if new_ptr <= ARENA_END.get() {
ARENA_PTR.set(new_ptr);
return aligned as *mut u8;
}
}
return unsafe { arena_alloc_cold(layout.size(), layout.align()) };
}
unsafe { std::alloc::System.alloc(layout) }
}
#[inline(always)]
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let addr = ptr as usize;
let base = REGION_BASE.load(Ordering::Relaxed);
let region_size = REGION_SIZE.load(Ordering::Relaxed);
if base != 0 && addr >= base && addr < base + region_size {
return;
}
unsafe { std::alloc::System.dealloc(ptr, layout) };
}
#[inline(always)]
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
if new_size <= layout.size() {
return ptr;
}
let addr = ptr as usize;
let base = REGION_BASE.load(Ordering::Relaxed);
let region_size = REGION_SIZE.load(Ordering::Relaxed);
let in_arena = base != 0 && addr >= base && addr < base + region_size;
if !in_arena {
return unsafe { std::alloc::System.realloc(ptr, layout, new_size) };
}
let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) };
let new_ptr = unsafe { self.alloc(new_layout) };
if !new_ptr.is_null() {
unsafe { std::ptr::copy(ptr, new_ptr, layout.size()) };
unsafe { self.dealloc(ptr, layout) };
}
new_ptr
}
}