use std::mem::MaybeUninit;
use std::ptr::NonNull;
use crate::freelist::ThreadFreelist;
use crate::heap::GlobalHeap;
use crate::platform;
use crate::size_class::{self, NUM_SIZE_CLASSES};
use crate::sys_box::SysBox;
const LARGE_CACHE_SLOTS: usize = 1024;
const MAX_LARGE_CACHE_BYTES: usize = 512 * 1024 * 1024;
pub const MADVISE_THRESHOLD: usize = 512 * 1024;
#[derive(Clone, Copy)]
struct LargeCacheEntry {
original_ptr: NonNull<u8>,
alloc_size: usize,
}
struct LargeCache {
entries: [MaybeUninit<LargeCacheEntry>; LARGE_CACHE_SLOTS],
count: usize,
bytes: usize,
}
impl LargeCache {
fn new_boxed() -> SysBox<Self> {
unsafe { SysBox::new_zeroed() }
}
const CLOSE_SIZE_TOLERANCE: usize = 8 * 1024;
#[inline]
fn take(&mut self, alloc_size: usize) -> Option<(NonNull<u8>, usize)> {
let count = self.count;
let mut best_idx: usize = usize::MAX;
let mut best_waste: usize = usize::MAX;
for i in 0..count {
let entry = unsafe { self.entries[i].assume_init_read() };
if entry.alloc_size == alloc_size {
self.count -= 1;
self.bytes -= entry.alloc_size;
self.entries[i] = self.entries[self.count];
return Some((entry.original_ptr, entry.alloc_size));
}
let waste = entry.alloc_size.wrapping_sub(alloc_size);
if entry.alloc_size >= alloc_size
&& waste <= Self::CLOSE_SIZE_TOLERANCE
&& waste < best_waste
{
best_waste = waste;
best_idx = i;
}
}
if best_idx < count {
let entry = unsafe { self.entries[best_idx].assume_init_read() };
self.count -= 1;
self.bytes -= entry.alloc_size;
self.entries[best_idx] = self.entries[self.count];
return Some((entry.original_ptr, entry.alloc_size));
}
None
}
#[inline]
fn put(&mut self, original_ptr: NonNull<u8>, alloc_size: usize) -> bool {
while self.count > 0
&& (self.count >= LARGE_CACHE_SLOTS || self.bytes + alloc_size > MAX_LARGE_CACHE_BYTES)
{
self.count -= 1;
let evicted = unsafe { self.entries[self.count].assume_init_read() };
self.bytes -= evicted.alloc_size;
unsafe {
platform::munmap(evicted.original_ptr, evicted.alloc_size);
}
}
if self.bytes + alloc_size > MAX_LARGE_CACHE_BYTES {
return false;
}
let idx = self.count;
self.entries[idx] = MaybeUninit::new(LargeCacheEntry {
original_ptr,
alloc_size,
});
self.count += 1;
self.bytes += alloc_size;
true
}
fn flush(&mut self) {
for i in 0..self.count {
let entry = unsafe { self.entries[i].assume_init_read() };
unsafe {
platform::munmap(entry.original_ptr, entry.alloc_size);
}
}
self.count = 0;
self.bytes = 0;
}
}
impl Drop for LargeCache {
fn drop(&mut self) {
self.flush();
}
}
pub const REFILL_BATCH: usize = 64;
const MAX_CACHE_TABLE: [usize; NUM_SIZE_CLASSES] = build_max_cache_table();
const fn build_max_cache_table() -> [usize; NUM_SIZE_CLASSES] {
let mut table = [0usize; NUM_SIZE_CLASSES];
let mut i = 0;
while i < NUM_SIZE_CLASSES {
let obj = size_class::SIZE_CLASSES[i];
let bag = size_class::bag_size_for_class(i);
let per_bag = bag / obj;
let raw = per_bag * 2;
table[i] = if raw < 64 {
64
} else if raw > 2048 {
2048
} else {
raw
};
i += 1;
}
table
}
#[inline]
pub fn max_thread_cache(class_idx: usize) -> usize {
MAX_CACHE_TABLE[class_idx]
}
pub struct PerThreadHeap {
pub node_id: usize,
pub global_heap: NonNull<GlobalHeap>,
freelists: [ThreadFreelist; NUM_SIZE_CLASSES],
large_cache: SysBox<LargeCache>,
}
impl PerThreadHeap {
pub fn new(node_id: usize, global_heap: NonNull<GlobalHeap>) -> Self {
Self {
node_id,
global_heap,
freelists: std::array::from_fn(|_| ThreadFreelist::new()),
large_cache: LargeCache::new_boxed(),
}
}
#[inline]
pub fn freelist_mut(&mut self, class_index: usize) -> &mut ThreadFreelist {
&mut self.freelists[class_index]
}
#[inline]
pub fn large_cache_take(&mut self, alloc_size: usize) -> Option<(NonNull<u8>, usize)> {
self.large_cache.take(alloc_size)
}
#[inline]
pub fn large_cache_put(&mut self, original_ptr: NonNull<u8>, alloc_size: usize) -> bool {
self.large_cache.put(original_ptr, alloc_size)
}
}
impl Drop for PerThreadHeap {
fn drop(&mut self) {
let heap = unsafe { self.global_heap.as_ref() };
let node = self.node_id;
for class_idx in 0..NUM_SIZE_CLASSES {
if let Some((head, tail, _)) = self.freelists[class_idx].drain_all() {
heap.node_region(node)
.node_heap
.freelist(class_idx)
.push_chain(head, tail);
}
}
}
}