#![deny(missing_docs)]
use std::cell::RefCell;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem::{take, ManuallyDrop};
use std::ops::Range;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender};
use std::sync::{Arc, Mutex, OnceLock, RwLock};
use std::thread::{JoinHandle, ThreadId};
use std::time::{Duration, Instant};
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use thiserror::Error;
mod readme {
#![doc = include_str!("../README.md")]
}
pub struct Handle {
ptr: NonNull<u8>,
len: usize,
}
unsafe impl Send for Handle {}
unsafe impl Sync for Handle {}
#[allow(clippy::len_without_is_empty)]
impl Handle {
fn new(ptr: NonNull<u8>, len: usize) -> Self {
Self { ptr, len }
}
fn dangling() -> Self {
Self {
ptr: NonNull::dangling(),
len: 0,
}
}
fn is_dangling(&self) -> bool {
self.ptr == NonNull::dangling()
}
fn len(&self) -> usize {
self.len
}
fn as_non_null(&self) -> NonNull<u8> {
self.ptr
}
fn clear(&mut self) -> std::io::Result<()> {
unsafe { self.madvise(MADV_CLEAR_STRATEGY) }
}
fn fast_clear(&mut self) -> std::io::Result<()> {
unsafe { self.madvise(libc::MADV_DONTNEED) }
}
pub fn prefetch<T>(&self, range: Range<usize>) -> Result<(), PrefetchError> {
let elem_size = std::mem::size_of::<T>();
let byte_offset = range.start.checked_mul(elem_size);
let byte_len = (range.end - range.start).checked_mul(elem_size);
let (byte_offset, byte_len) = match (byte_offset, byte_len) {
(Some(o), Some(l)) => (o, l),
_ => {
return Err(PrefetchError::OutOfBounds {
byte_offset: range.start.saturating_mul(elem_size),
byte_len: range.len().saturating_mul(elem_size),
allocation_len: self.len,
});
}
};
if byte_len == 0 || self.is_dangling() {
return Ok(());
}
if byte_offset.saturating_add(byte_len) > self.len {
return Err(PrefetchError::OutOfBounds {
byte_offset,
byte_len,
allocation_len: self.len,
});
}
unsafe {
let ptr = self.as_non_null().as_ptr().add(byte_offset);
libc::madvise(ptr.cast(), byte_len, libc::MADV_WILLNEED);
}
Ok(())
}
unsafe fn madvise(&self, advice: libc::c_int) -> std::io::Result<()> {
let ptr = self.as_non_null().as_ptr().cast();
let ret = unsafe { libc::madvise(ptr, self.len, advice) };
if ret != 0 {
let err = std::io::Error::last_os_error();
return Err(err);
}
Ok(())
}
}
const INITIAL_SIZE: usize = 32 << 20;
pub const VALID_SIZE_CLASS: Range<usize> = 21..37;
#[cfg(target_os = "linux")]
const MADV_CLEAR_STRATEGY: libc::c_int = libc::MADV_FREE;
#[cfg(not(target_os = "linux"))]
const MADV_CLEAR_STRATEGY: libc::c_int = libc::MADV_DONTNEED;
#[cfg(target_os = "linux")]
static MADV_HUGEPAGE_WARNED: AtomicBool = AtomicBool::new(false);
type PhantomUnsyncUnsend<T> = PhantomData<*mut T>;
#[derive(Error, Debug)]
pub enum AllocError {
#[error("I/O error")]
Io(#[from] std::io::Error),
#[error("Out of memory")]
OutOfMemory,
#[error("Invalid size class")]
InvalidSizeClass(usize),
#[error("Disabled by configuration")]
Disabled,
#[error("Memory unsuitable for requested alignment")]
UnalignedMemory,
}
#[derive(Error, Debug)]
pub enum PrefetchError {
#[error("prefetch byte range [{byte_offset}..{end}) exceeds allocation length {allocation_len}", end = byte_offset + byte_len)]
OutOfBounds {
byte_offset: usize,
byte_len: usize,
allocation_len: usize,
},
}
impl AllocError {
#[must_use]
pub fn is_disabled(&self) -> bool {
matches!(self, AllocError::Disabled)
}
}
#[derive(Clone, Copy)]
struct SizeClass(usize);
impl SizeClass {
const fn new_unchecked(value: usize) -> Self {
Self(value)
}
const fn index(self) -> usize {
self.0 - VALID_SIZE_CLASS.start
}
const fn byte_size(self) -> usize {
1 << self.0
}
const fn from_index(index: usize) -> Self {
Self(index + VALID_SIZE_CLASS.start)
}
fn from_byte_size(byte_size: usize) -> Result<Self, AllocError> {
let class = byte_size.next_power_of_two().trailing_zeros() as usize;
class.try_into()
}
const fn from_byte_size_unchecked(byte_size: usize) -> Self {
Self::new_unchecked(byte_size.next_power_of_two().trailing_zeros() as usize)
}
}
impl TryFrom<usize> for SizeClass {
type Error = AllocError;
fn try_from(value: usize) -> Result<Self, Self::Error> {
if VALID_SIZE_CLASS.contains(&value) {
Ok(SizeClass(value))
} else {
Err(AllocError::InvalidSizeClass(value))
}
}
}
#[derive(Default, Debug)]
struct AllocStats {
allocations: AtomicU64,
slow_path: AtomicU64,
refill: AtomicU64,
deallocations: AtomicU64,
clear_eager: AtomicU64,
clear_slow: AtomicU64,
}
static INJECTOR: OnceLock<GlobalStealer> = OnceLock::new();
static LGALLOC_ENABLED: AtomicBool = AtomicBool::new(false);
static LGALLOC_EAGER_RETURN: AtomicBool = AtomicBool::new(false);
static LGALLOC_GROWTH_DAMPENER: AtomicUsize = AtomicUsize::new(0);
static LOCAL_BUFFER_BYTES: AtomicUsize = AtomicUsize::new(32 << 20);
struct GlobalStealer {
size_classes: Vec<SizeClassState>,
background_sender: Mutex<Option<(JoinHandle<()>, Sender<BackgroundWorkerConfig>)>>,
}
#[derive(Default)]
struct SizeClassState {
areas: RwLock<Vec<ManuallyDrop<(usize, usize)>>>,
injector: Injector<Handle>,
clean_injector: Injector<Handle>,
lock: Mutex<()>,
stealers: RwLock<HashMap<ThreadId, PerThreadState<Handle>>>,
alloc_stats: AllocStats,
total_bytes: AtomicUsize,
area_count: AtomicUsize,
}
impl GlobalStealer {
fn get_static() -> &'static Self {
INJECTOR.get_or_init(Self::new)
}
fn get_size_class(&self, size_class: SizeClass) -> &SizeClassState {
&self.size_classes[size_class.index()]
}
fn new() -> Self {
let mut size_classes = Vec::with_capacity(VALID_SIZE_CLASS.len());
for _ in VALID_SIZE_CLASS {
size_classes.push(SizeClassState::default());
}
Self {
size_classes,
background_sender: Mutex::default(),
}
}
}
impl Drop for GlobalStealer {
fn drop(&mut self) {
for size_class_state in &mut self.size_classes {
let mut areas = size_class_state.areas.write().expect("lock poisoned");
for area in areas.drain(..) {
let (addr, len) = ManuallyDrop::into_inner(area);
unsafe {
libc::munmap(addr as *mut libc::c_void, len);
}
}
}
take(&mut self.size_classes);
}
}
struct PerThreadState<T> {
stealer: Stealer<T>,
alloc_stats: Arc<AllocStats>,
}
struct ThreadLocalStealer {
size_classes: Vec<LocalSizeClass>,
_phantom: PhantomUnsyncUnsend<Self>,
}
impl ThreadLocalStealer {
fn new() -> Self {
let thread_id = std::thread::current().id();
let size_classes = VALID_SIZE_CLASS
.map(|size_class| LocalSizeClass::new(SizeClass::new_unchecked(size_class), thread_id))
.collect();
Self {
size_classes,
_phantom: PhantomData,
}
}
fn allocate(&self, size_class: SizeClass) -> Result<Handle, AllocError> {
if !LGALLOC_ENABLED.load(Ordering::Relaxed) {
return Err(AllocError::Disabled);
}
self.size_classes[size_class.index()].get_with_refill()
}
fn deallocate(&self, mem: Handle) {
let size_class = SizeClass::from_byte_size_unchecked(mem.len());
self.size_classes[size_class.index()].push(mem);
}
}
thread_local! {
static WORKER: RefCell<ThreadLocalStealer> = RefCell::new(ThreadLocalStealer::new());
}
struct LocalSizeClass {
worker: Worker<Handle>,
size_class: SizeClass,
size_class_state: &'static SizeClassState,
thread_id: ThreadId,
stats: Arc<AllocStats>,
_phantom: PhantomUnsyncUnsend<Self>,
}
impl LocalSizeClass {
fn new(size_class: SizeClass, thread_id: ThreadId) -> Self {
let worker = Worker::new_lifo();
let stealer = GlobalStealer::get_static();
let size_class_state = stealer.get_size_class(size_class);
let stats = Arc::new(AllocStats::default());
let mut lock = size_class_state.stealers.write().expect("lock poisoned");
lock.insert(
thread_id,
PerThreadState {
stealer: worker.stealer(),
alloc_stats: Arc::clone(&stats),
},
);
Self {
worker,
size_class,
size_class_state,
thread_id,
stats,
_phantom: PhantomData,
}
}
#[inline]
fn get(&self) -> Result<Handle, AllocError> {
self.worker
.pop()
.or_else(|| {
std::iter::repeat_with(|| {
let limit = 1.max(
LOCAL_BUFFER_BYTES.load(Ordering::Relaxed)
/ self.size_class.byte_size()
/ 2,
);
self.size_class_state
.injector
.steal_batch_with_limit_and_pop(&self.worker, limit)
.or_else(|| {
self.size_class_state
.clean_injector
.steal_batch_with_limit_and_pop(&self.worker, limit)
})
.or_else(|| {
self.size_class_state
.stealers
.read()
.expect("lock poisoned")
.values()
.map(|state| state.stealer.steal())
.collect()
})
})
.find(|s| !s.is_retry())
.and_then(Steal::success)
})
.ok_or(AllocError::OutOfMemory)
}
fn get_with_refill(&self) -> Result<Handle, AllocError> {
self.stats.allocations.fetch_add(1, Ordering::Relaxed);
match self.get() {
Err(AllocError::OutOfMemory) => {
self.stats.slow_path.fetch_add(1, Ordering::Relaxed);
let _lock = self.size_class_state.lock.lock().expect("lock poisoned");
if let Ok(mem) = self.get() {
return Ok(mem);
}
self.try_refill_and_get()
}
r => r,
}
}
fn push(&self, mut mem: Handle) {
debug_assert_eq!(mem.len(), self.size_class.byte_size());
self.stats.deallocations.fetch_add(1, Ordering::Relaxed);
if self.worker.len()
>= LOCAL_BUFFER_BYTES.load(Ordering::Relaxed) / self.size_class.byte_size()
{
if LGALLOC_EAGER_RETURN.load(Ordering::Relaxed) {
self.stats.clear_eager.fetch_add(1, Ordering::Relaxed);
mem.fast_clear().expect("clearing successful");
}
self.size_class_state.injector.push(mem);
} else {
self.worker.push(mem);
}
}
fn try_refill_and_get(&self) -> Result<Handle, AllocError> {
self.stats.refill.fetch_add(1, Ordering::Relaxed);
let mut stash = self.size_class_state.areas.write().expect("lock poisoned");
let initial_capacity = std::cmp::max(1, INITIAL_SIZE / self.size_class.byte_size());
let last_capacity =
stash.iter().last().map_or(0, |mmap| mmap.1) / self.size_class.byte_size();
let growth_dampener = LGALLOC_GROWTH_DAMPENER.load(Ordering::Relaxed);
let next_capacity = last_capacity
+ std::cmp::max(
initial_capacity,
last_capacity / (growth_dampener.saturating_add(1)),
);
let next_byte_len = next_capacity * self.size_class.byte_size();
let (mmap_ptr, slice) = mmap_anonymous(next_byte_len)?;
self.size_class_state
.total_bytes
.fetch_add(next_byte_len, Ordering::Relaxed);
self.size_class_state
.area_count
.fetch_add(1, Ordering::Relaxed);
let mut chunks = slice
.chunks_exact_mut(self.size_class.byte_size())
.map(|chunk| NonNull::new(chunk.as_mut_ptr()).expect("non-null"));
let ptr = chunks.next().expect("At least one chunk allocated.");
let mem = Handle::new(ptr, self.size_class.byte_size());
for ptr in chunks {
self.size_class_state
.clean_injector
.push(Handle::new(ptr, self.size_class.byte_size()));
}
stash.push(ManuallyDrop::new((mmap_ptr, next_byte_len)));
Ok(mem)
}
}
fn mmap_anonymous(len: usize) -> Result<(usize, &'static mut [u8]), AllocError> {
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
len,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
-1,
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(std::io::Error::last_os_error().into());
}
#[cfg(target_os = "linux")]
{
let ret = unsafe { libc::madvise(ptr, len, libc::MADV_HUGEPAGE) };
if ret == -1 && !MADV_HUGEPAGE_WARNED.swap(true, Ordering::Relaxed) {
eprintln!(
"lgalloc: MADV_HUGEPAGE failed: {}. Transparent huge pages may be disabled.",
std::io::Error::last_os_error()
);
}
}
let slice = unsafe { std::slice::from_raw_parts_mut(ptr.cast::<u8>(), len) };
Ok((ptr as usize, slice))
}
impl Drop for LocalSizeClass {
fn drop(&mut self) {
if let Ok(mut lock) = self.size_class_state.stealers.write() {
lock.remove(&self.thread_id);
}
while let Some(mem) = self.worker.pop() {
self.size_class_state.injector.push(mem);
}
let ordering = Ordering::Relaxed;
self.size_class_state
.alloc_stats
.allocations
.fetch_add(self.stats.allocations.load(ordering), ordering);
let global_stats = &self.size_class_state.alloc_stats;
global_stats
.refill
.fetch_add(self.stats.refill.load(ordering), ordering);
global_stats
.slow_path
.fetch_add(self.stats.slow_path.load(ordering), ordering);
global_stats
.deallocations
.fetch_add(self.stats.deallocations.load(ordering), ordering);
global_stats
.clear_slow
.fetch_add(self.stats.clear_slow.load(ordering), ordering);
global_stats
.clear_eager
.fetch_add(self.stats.clear_eager.load(ordering), ordering);
}
}
fn thread_context<R, F: FnOnce(&ThreadLocalStealer) -> R>(f: F) -> R {
WORKER.with(|cell| f(&cell.borrow()))
}
pub fn allocate<T>(capacity: usize) -> Result<(NonNull<T>, usize, Handle), AllocError> {
if std::mem::size_of::<T>() == 0 {
return Ok((NonNull::dangling(), usize::MAX, Handle::dangling()));
} else if capacity == 0 {
return Ok((NonNull::dangling(), 0, Handle::dangling()));
}
let byte_len = std::cmp::max(page_size::get(), std::mem::size_of::<T>() * capacity);
let size_class = SizeClass::from_byte_size(byte_len)?;
let handle = thread_context(|s| s.allocate(size_class))?;
debug_assert_eq!(handle.len(), size_class.byte_size());
let ptr: NonNull<T> = handle.as_non_null().cast();
if ptr.as_ptr().align_offset(std::mem::align_of::<T>()) != 0 {
thread_context(move |s| s.deallocate(handle));
return Err(AllocError::UnalignedMemory);
}
let actual_capacity = handle.len() / std::mem::size_of::<T>();
Ok((ptr, actual_capacity, handle))
}
pub fn deallocate(handle: Handle) {
if handle.is_dangling() {
return;
}
thread_context(|s| s.deallocate(handle));
}
struct BackgroundWorker {
config: BackgroundWorkerConfig,
receiver: Receiver<BackgroundWorkerConfig>,
global_stealer: &'static GlobalStealer,
worker: Worker<Handle>,
}
impl BackgroundWorker {
fn new(receiver: Receiver<BackgroundWorkerConfig>) -> Self {
let config = BackgroundWorkerConfig {
interval: Duration::MAX,
..Default::default()
};
let global_stealer = GlobalStealer::get_static();
let worker = Worker::new_fifo();
Self {
config,
receiver,
global_stealer,
worker,
}
}
fn run(&mut self) {
let mut next_cleanup: Option<Instant> = None;
loop {
let timeout = next_cleanup.map_or(Duration::MAX, |next_cleanup| {
next_cleanup.saturating_duration_since(Instant::now())
});
match self.receiver.recv_timeout(timeout) {
Ok(config) => {
self.config = config;
next_cleanup = None;
}
Err(RecvTimeoutError::Disconnected) => break,
Err(RecvTimeoutError::Timeout) => {
self.maintenance();
}
}
next_cleanup = next_cleanup
.unwrap_or_else(Instant::now)
.checked_add(self.config.interval);
}
}
fn maintenance(&self) {
for (index, size_class_state) in self.global_stealer.size_classes.iter().enumerate() {
let size_class = SizeClass::from_index(index);
let count = self.clear(size_class, size_class_state, &self.worker);
size_class_state
.alloc_stats
.clear_slow
.fetch_add(count.try_into().expect("must fit"), Ordering::Relaxed);
}
}
fn clear(
&self,
size_class: SizeClass,
state: &SizeClassState,
worker: &Worker<Handle>,
) -> usize {
let byte_size = size_class.byte_size();
let mut limit = (self.config.clear_bytes + byte_size - 1) / byte_size;
let mut count = 0;
let mut steal = Steal::Retry;
while limit > 0 && !steal.is_empty() {
steal = std::iter::repeat_with(|| state.injector.steal_batch_with_limit(worker, limit))
.find(|s| !s.is_retry())
.unwrap_or(Steal::Empty);
while let Some(mut mem) = worker.pop() {
match mem.clear() {
Ok(()) => count += 1,
Err(e) => panic!("Syscall failed: {e:?}"),
}
state.clean_injector.push(mem);
limit -= 1;
}
}
count
}
}
pub fn lgalloc_set_config(config: &LgAlloc) {
let stealer = GlobalStealer::get_static();
if let Some(enabled) = &config.enabled {
LGALLOC_ENABLED.store(*enabled, Ordering::Relaxed);
}
if let Some(eager_return) = &config.eager_return {
LGALLOC_EAGER_RETURN.store(*eager_return, Ordering::Relaxed);
}
if let Some(growth_dampener) = &config.growth_dampener {
LGALLOC_GROWTH_DAMPENER.store(*growth_dampener, Ordering::Relaxed);
}
if let Some(local_buffer_bytes) = &config.local_buffer_bytes {
LOCAL_BUFFER_BYTES.store(*local_buffer_bytes, Ordering::Relaxed);
}
if let Some(config) = config.background_config.clone() {
let mut lock = stealer.background_sender.lock().expect("lock poisoned");
let config = if let Some((_, sender)) = &*lock {
match sender.send(config) {
Ok(()) => None,
Err(err) => Some(err.0),
}
} else {
Some(config)
};
if let Some(config) = config {
let (sender, receiver) = std::sync::mpsc::channel();
let mut worker = BackgroundWorker::new(receiver);
let join_handle = std::thread::Builder::new()
.name("lgalloc-0".to_string())
.spawn(move || worker.run())
.expect("thread started successfully");
sender.send(config).expect("Receiver exists");
*lock = Some((join_handle, sender));
}
}
}
#[derive(Default, Debug, Clone, Eq, PartialEq)]
pub struct BackgroundWorkerConfig {
pub interval: Duration,
pub clear_bytes: usize,
}
#[derive(Default, Clone, Eq, PartialEq)]
pub struct LgAlloc {
pub enabled: Option<bool>,
pub background_config: Option<BackgroundWorkerConfig>,
pub eager_return: Option<bool>,
pub growth_dampener: Option<usize>,
pub local_buffer_bytes: Option<usize>,
}
impl LgAlloc {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn enable(&mut self) -> &mut Self {
self.enabled = Some(true);
self
}
pub fn disable(&mut self) -> &mut Self {
self.enabled = Some(false);
self
}
pub fn with_background_config(&mut self, config: BackgroundWorkerConfig) -> &mut Self {
self.background_config = Some(config);
self
}
pub fn eager_return(&mut self, eager_return: bool) -> &mut Self {
self.eager_return = Some(eager_return);
self
}
pub fn growth_dampener(&mut self, growth_dampener: usize) -> &mut Self {
self.growth_dampener = Some(growth_dampener);
self
}
pub fn local_buffer_bytes(&mut self, local_buffer_bytes: usize) -> &mut Self {
self.local_buffer_bytes = Some(local_buffer_bytes);
self
}
}
pub fn lgalloc_stats() -> LgAllocStats {
let global = GlobalStealer::get_static();
let mut size_class_stats = Vec::with_capacity(VALID_SIZE_CLASS.len());
for (index, state) in global.size_classes.iter().enumerate() {
let size_class = SizeClass::from_index(index);
let size_class_bytes = size_class.byte_size();
size_class_stats.push((size_class_bytes, SizeClassStats::from(state)));
}
LgAllocStats {
size_class: size_class_stats,
}
}
#[derive(Debug)]
pub struct LgAllocStats {
pub size_class: Vec<(usize, SizeClassStats)>,
}
#[derive(Debug)]
pub struct SizeClassStats {
pub areas: usize,
pub area_total_bytes: usize,
pub free_regions: usize,
pub clean_regions: usize,
pub global_regions: usize,
pub thread_regions: usize,
pub allocations: u64,
pub slow_path: u64,
pub refill: u64,
pub deallocations: u64,
pub clear_eager_total: u64,
pub clear_slow_total: u64,
}
impl From<&SizeClassState> for SizeClassStats {
fn from(size_class_state: &SizeClassState) -> Self {
let areas = size_class_state.area_count.load(Ordering::Relaxed);
let area_total_bytes = size_class_state.total_bytes.load(Ordering::Relaxed);
let global_regions = size_class_state.injector.len();
let clean_regions = size_class_state.clean_injector.len();
let stealers = size_class_state.stealers.read().expect("lock poisoned");
let mut thread_regions = 0;
let mut allocations = 0;
let mut deallocations = 0;
let mut refill = 0;
let mut slow_path = 0;
let mut clear_eager_total = 0;
let mut clear_slow_total = 0;
for thread_state in stealers.values() {
thread_regions += thread_state.stealer.len();
let thread_stats = &*thread_state.alloc_stats;
allocations += thread_stats.allocations.load(Ordering::Relaxed);
deallocations += thread_stats.deallocations.load(Ordering::Relaxed);
refill += thread_stats.refill.load(Ordering::Relaxed);
slow_path += thread_stats.slow_path.load(Ordering::Relaxed);
clear_eager_total += thread_stats.clear_eager.load(Ordering::Relaxed);
clear_slow_total += thread_stats.clear_slow.load(Ordering::Relaxed);
}
let free_regions = thread_regions + global_regions + clean_regions;
let global_stats = &size_class_state.alloc_stats;
allocations += global_stats.allocations.load(Ordering::Relaxed);
deallocations += global_stats.deallocations.load(Ordering::Relaxed);
refill += global_stats.refill.load(Ordering::Relaxed);
slow_path += global_stats.slow_path.load(Ordering::Relaxed);
clear_eager_total += global_stats.clear_eager.load(Ordering::Relaxed);
clear_slow_total += global_stats.clear_slow.load(Ordering::Relaxed);
Self {
areas,
area_total_bytes,
free_regions,
global_regions,
clean_regions,
thread_regions,
allocations,
deallocations,
refill,
slow_path,
clear_eager_total,
clear_slow_total,
}
}
}