use std::alloc::{GlobalAlloc, Layout, System};
use std::cell::Cell;
use std::ptr::NonNull;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::freelist::FreeBlock;
use crate::heap::GlobalHeap;
use crate::platform;
use crate::size_class::{self, SMALL_LIMIT};
use crate::thread_heap::{MADVISE_THRESHOLD, PerThreadHeap, REFILL_BATCH, max_thread_cache};
struct ThreadHeapSlot {
inner: Cell<Option<NonNull<PerThreadHeap>>>,
}
impl ThreadHeapSlot {
const fn new() -> Self {
Self {
inner: Cell::new(None),
}
}
#[inline]
fn get(&self) -> Option<NonNull<PerThreadHeap>> {
self.inner.get()
}
#[inline]
fn set(&self, val: Option<NonNull<PerThreadHeap>>) {
self.inner.set(val);
}
}
impl Drop for ThreadHeapSlot {
fn drop(&mut self) {
if let Some(mut th_ptr) = self.inner.get() {
unsafe {
th_ptr.as_mut().drain_to_node_heap();
System.dealloc(th_ptr.as_ptr() as *mut u8, Layout::new::<PerThreadHeap>());
}
}
}
}
thread_local! {
static TH_PTR: ThreadHeapSlot = const { ThreadHeapSlot::new() };
}
#[repr(C)]
struct LargeHeader {
original_ptr: NonNull<u8>,
alloc_size: usize,
}
pub struct NumaAlloc {
heap: OnceLock<GlobalHeap>,
next_node: AtomicUsize,
}
unsafe impl Send for NumaAlloc {}
unsafe impl Sync for NumaAlloc {}
impl Default for NumaAlloc {
fn default() -> Self {
Self::new()
}
}
impl NumaAlloc {
pub const fn new() -> Self {
Self {
heap: OnceLock::new(),
next_node: AtomicUsize::new(0),
}
}
fn heap(&self) -> &GlobalHeap {
self.heap.get_or_init(|| {
let topo = platform::detect_topology();
GlobalHeap::new(topo.num_nodes).expect("numalloc: failed to mmap heap region")
})
}
fn thread_heap(&self) -> NonNull<PerThreadHeap> {
if let Ok(Some(ptr)) = TH_PTR.try_with(ThreadHeapSlot::get) {
return ptr;
}
let heap = self.heap();
let node = self.next_node.fetch_add(1, Ordering::Relaxed) % heap.num_nodes();
platform::bind_thread_to_node(node);
let layout = Layout::new::<PerThreadHeap>();
let raw = unsafe { System.alloc(layout) } as *mut PerThreadHeap;
let Some(nn) = NonNull::new(raw) else {
std::alloc::handle_alloc_error(layout);
};
unsafe {
nn.as_ptr()
.write(PerThreadHeap::new(node, NonNull::from(heap)));
}
let _ = TH_PTR.try_with(|slot| slot.set(Some(nn)));
nn
}
}
unsafe impl GlobalAlloc for NumaAlloc {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let effective_size = layout.size().max(layout.align());
if effective_size > SMALL_LIMIT {
return unsafe { self.alloc_large(layout) };
}
let class_idx = match size_class::size_class_index(effective_size) {
Some(i) => i,
None => return std::ptr::null_mut(),
};
let th = unsafe { self.thread_heap().as_mut() };
let heap = self.heap();
let node = th.node_id;
let fl = th.freelist_mut(class_idx);
if let Some(block) = fl.pop() {
return block.as_ptr().cast();
}
let node_fl = heap.node_region(node).node_heap.freelist(class_idx);
if let Some(first) = node_fl.pop() {
unsafe { first.as_ref().write_next(None) };
let mut tail = first;
let mut count = 1usize;
while count < REFILL_BATCH {
let Some(b) = node_fl.pop() else { break };
unsafe { b.as_ref().write_next(None) };
unsafe { tail.as_ref().write_next(Some(b)) };
tail = b;
count += 1;
}
fl.push_chain(first, tail, count);
return fl.pop().unwrap().as_ptr().cast();
}
let region = heap.node_region(node);
let bag_size = size_class::bag_size_for_class(class_idx);
let Some(bag) = region.allocate_bag(bag_size) else {
return unsafe { self.alloc_large(layout) };
};
let obj_size = size_class::size_for_class(class_idx);
let count = bag_size / obj_size;
for i in 0..count {
let obj =
unsafe { NonNull::new_unchecked(bag.as_ptr().add(i * obj_size) as *mut FreeBlock) };
fl.push(obj);
}
fl.pop().unwrap().as_ptr().cast()
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let Some(ptr) = NonNull::new(ptr) else { return };
let effective_size = layout.size().max(layout.align());
if effective_size > SMALL_LIMIT {
unsafe { dealloc_large(ptr, Some(self)) };
return;
}
let heap = self.heap();
if !heap.is_owned(ptr) {
unsafe { dealloc_large(ptr, Some(self)) };
return;
}
let class_idx = match size_class::size_class_index(effective_size) {
Some(i) => i,
None => return,
};
let origin_node = match heap.node_for_ptr(ptr) {
Some(n) => n,
None => return,
};
let th = unsafe { self.thread_heap().as_mut() };
let current_node = th.node_id;
let block = ptr.cast::<FreeBlock>();
if origin_node == current_node {
let fl = th.freelist_mut(class_idx);
fl.push(block);
let cache_limit = max_thread_cache(class_idx);
if fl.count() > cache_limit
&& let Some((head, tail, _)) = fl.drain(cache_limit / 2)
{
heap.node_region(current_node)
.node_heap
.freelist(class_idx)
.push_chain(head, tail);
}
} else {
heap.node_region(origin_node)
.node_heap
.freelist(class_idx)
.push(block);
}
}
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
let effective_size = layout.size().max(layout.align());
if effective_size > SMALL_LIMIT {
return unsafe { self.alloc(layout) };
}
let ptr = unsafe { self.alloc(layout) };
if !ptr.is_null() {
unsafe { std::ptr::write_bytes(ptr, 0, layout.size()) };
}
ptr
}
unsafe fn realloc(&self, ptr: *mut u8, old_layout: Layout, new_size: usize) -> *mut u8 {
let old_effective = old_layout.size().max(old_layout.align());
let new_effective = new_size.max(old_layout.align());
if old_effective <= SMALL_LIMIT
&& new_effective <= SMALL_LIMIT
&& let (Some(old_cls), Some(new_cls)) = (
size_class::size_class_index(old_effective),
size_class::size_class_index(new_effective),
)
&& old_cls == new_cls
{
return ptr;
}
let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, old_layout.align()) };
let new_ptr = unsafe { self.alloc(new_layout) };
if new_ptr.is_null() {
return new_ptr;
}
let copy_size = old_layout.size().min(new_size);
unsafe {
std::ptr::copy_nonoverlapping(ptr, new_ptr, copy_size);
self.dealloc(ptr, old_layout);
}
new_ptr
}
}
impl NumaAlloc {
#[inline]
fn large_alloc_size(layout: &Layout) -> usize {
let page_size = platform::page_size();
let header_size = std::mem::size_of::<LargeHeader>();
let align = layout.align().max(std::mem::align_of::<LargeHeader>());
let alloc_size = header_size + (align - 1) + layout.size();
(alloc_size + page_size - 1) & !(page_size - 1)
}
#[inline]
unsafe fn prepare_large_payload(
raw: NonNull<u8>,
alloc_size: usize,
layout: &Layout,
) -> *mut u8 {
let header_size = std::mem::size_of::<LargeHeader>();
let align = layout.align().max(std::mem::align_of::<LargeHeader>());
let payload_addr = (raw.as_ptr() as usize + header_size + align - 1) & !(align - 1);
let header_ptr =
unsafe { NonNull::new_unchecked((payload_addr - header_size) as *mut LargeHeader) };
unsafe {
header_ptr.as_ptr().write(LargeHeader {
original_ptr: raw,
alloc_size,
});
}
payload_addr as *mut u8
}
unsafe fn alloc_large(&self, layout: Layout) -> *mut u8 {
let alloc_size = Self::large_alloc_size(&layout);
let th = unsafe { self.thread_heap().as_mut() };
if let Some((raw, _)) = th.large_cache_take(alloc_size) {
return unsafe { Self::prepare_large_payload(raw, alloc_size, &layout) };
}
let Some(raw) = (unsafe { platform::mmap_anonymous(alloc_size) }) else {
return std::ptr::null_mut();
};
let node = th.node_id;
unsafe {
platform::bind_to_node(raw, alloc_size, node);
}
unsafe { Self::prepare_large_payload(raw, alloc_size, &layout) }
}
#[inline]
fn try_cache_large(&self, original: NonNull<u8>, alloc_size: usize) -> bool {
if let Ok(Some(mut th)) = TH_PTR.try_with(ThreadHeapSlot::get) {
let th = unsafe { th.as_mut() };
if th.large_cache_put(original, alloc_size) {
if alloc_size >= MADVISE_THRESHOLD {
unsafe {
platform::madvise_dontneed(original, alloc_size);
}
}
return true;
}
}
false
}
}
unsafe fn dealloc_large(ptr: NonNull<u8>, allocator: Option<&NumaAlloc>) {
let header_size = std::mem::size_of::<LargeHeader>();
let header_ptr =
unsafe { NonNull::new_unchecked(ptr.as_ptr().sub(header_size) as *mut LargeHeader) };
let header = unsafe { header_ptr.as_ref() };
let original = header.original_ptr;
let size = header.alloc_size;
if let Some(alloc) = allocator
&& alloc.try_cache_large(original, size)
{
return;
}
unsafe {
platform::munmap(original, size);
}
}