#![no_std]
use libc::{c_int, c_void, siginfo_t, size_t};
use mersenne_twister::MersenneTwister;
use rand::{Rng, SeedableRng};
use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags};
use nix::sys::signal::{sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal};
use core::cell::UnsafeCell;
use core::mem::MaybeUninit;
use core::sync::atomic::{AtomicBool, AtomicU8, Ordering};
struct RmallocInner {
pub rng: MersenneTwister,
pub probe_fail: bool,
}
struct RmallocState {
initializing: AtomicU8,
mallocating: AtomicBool,
state: UnsafeCell<MaybeUninit<RmallocInner>>,
}
unsafe impl Sync for RmallocState {}
impl RmallocState {
const fn new() -> Self {
RmallocState {
initializing: AtomicU8::new(0),
mallocating: AtomicBool::new(false),
state: UnsafeCell::new(MaybeUninit::uninit()),
}
}
fn init(&self) {
let init_state = self.initializing.compare_and_swap(0, 1, Ordering::SeqCst);
if init_state == 2 {
return;
} else if init_state == 0 {
let rng: MersenneTwister = SeedableRng::from_seed(0xaaaA_aaAA_aaaA_aaAa);
let inner = RmallocInner {
rng,
probe_fail: false,
};
let mut segv_handler_mask = SigSet::empty();
segv_handler_mask.add(Signal::SIGBUS);
segv_handler_mask.add(Signal::SIGSEGV);
let sa = SigAction::new(
SigHandler::SigAction(handle_segv),
SaFlags::SA_RESTART | SaFlags::SA_SIGINFO,
segv_handler_mask,
);
unsafe {
let _sigsegv = sigaction(Signal::SIGSEGV, &sa).expect("can set sigsegv handler");
let _sigbus = sigaction(Signal::SIGBUS, &sa).expect("can set sigbus handler");
core::ptr::write(self.state.get().as_mut().unwrap().as_mut_ptr(), inner);
}
self.initializing.store(2, Ordering::SeqCst);
} else {
loop {
if self.initializing.load(Ordering::SeqCst) == 2 {
break;
}
}
}
}
fn mallocate(&self, page_count: usize) -> *mut c_void {
let _guard = self.begin_mallocate();
loop {
let guess = unsafe {
self.state
.get()
.as_mut()
.unwrap()
.as_mut_ptr()
.as_mut()
.unwrap()
.rng
.next_u64() as usize
};
let page_address = (guess.wrapping_mul(PAGE_SIZE)) & 0x0000_ffff_ffff_ffff;
if page_address.checked_add(page_count * PAGE_SIZE).is_none() {
continue;
}
let mut probe_failed = false;
for i in 0..page_count {
if !self.probe_page(page_address + (i * PAGE_SIZE)) {
probe_failed = true;
break;
}
}
if probe_failed {
continue;
}
return unsafe {
mmap(
page_address as *mut c_void,
page_count * PAGE_SIZE,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_ANON | MapFlags::MAP_PRIVATE,
-1,
0,
)
.expect("can mmap")
};
}
}
fn probe_page(&self, page_addr: usize) -> bool {
unsafe {
let ptr = (page_addr as *mut AtomicU8).as_ref().unwrap();
self.state
.get()
.as_mut()
.unwrap()
.as_mut_ptr()
.as_mut()
.unwrap()
.probe_fail = false;
ptr.fetch_xor(0, Ordering::SeqCst);
if self
.state
.get()
.as_mut()
.unwrap()
.as_mut_ptr()
.as_mut()
.unwrap()
.probe_fail
{
self.state
.get()
.as_mut()
.unwrap()
.as_mut_ptr()
.as_mut()
.unwrap()
.probe_fail = false;
ptr.load(Ordering::SeqCst);
}
!self
.state
.get()
.as_mut()
.unwrap()
.as_mut_ptr()
.as_mut()
.unwrap()
.probe_fail
}
}
fn begin_mallocate(&self) -> MallocGuard {
loop {
if !self.mallocating.swap(true, Ordering::SeqCst) {
break;
}
}
MallocGuard::of(self)
}
}
struct MallocGuard<'a> {
state: &'a RmallocState,
}
impl<'a> MallocGuard<'a> {
fn of(state: &'a RmallocState) -> Self {
MallocGuard { state }
}
}
impl<'a> Drop for MallocGuard<'a> {
fn drop(&mut self) {
assert!(self.state.mallocating.load(Ordering::SeqCst));
self.state.mallocating.store(false, Ordering::SeqCst);
}
}
pub extern "C" fn handle_segv(
_signum: c_int,
_siginfo_ptr: *mut siginfo_t,
_ucontext_ptr: *mut c_void,
) {
if RMALLOC_STATE.mallocating.load(Ordering::SeqCst) {
unsafe {
RMALLOC_STATE
.state
.get()
.as_mut()
.unwrap()
.as_mut_ptr()
.as_mut()
.unwrap()
.probe_fail = true;
}
} else {
}
}
#[repr(C)]
struct AllocMetadata {
page_count: u32, _padding: [u8; 60], }
static RMALLOC_STATE: RmallocState = RmallocState::new();
const PAGE_SIZE: usize = 4096;
#[no_mangle]
pub extern "C" fn malloc(sz: size_t) -> *mut c_void {
RMALLOC_STATE.init();
let total_alloc_min = sz + core::mem::size_of::<AllocMetadata>();
let page_rounded_size = (total_alloc_min + PAGE_SIZE - 1) & !(PAGE_SIZE - 1);
let page_count = page_rounded_size >> (10 + 2);
assert!(page_count != 0);
let ptr = RMALLOC_STATE.mallocate(page_count);
unsafe {
core::ptr::write(
ptr as *mut AllocMetadata,
AllocMetadata {
page_count: page_count as u32,
_padding: [0; 60],
},
);
}
(ptr as usize + core::mem::size_of::<AllocMetadata>()) as *mut c_void
}
#[no_mangle]
pub extern "C" fn calloc(nmemb: size_t, size: size_t) -> *mut c_void {
let total_size = nmemb.wrapping_mul(size);
let ptr = malloc(total_size);
for i in 0..total_size {
unsafe {
*((ptr as *mut u8).offset(i as isize)) = 0;
}
}
ptr
}
#[no_mangle]
pub extern "C" fn realloc(ptr: *mut c_void, size: size_t) -> *mut c_void {
let new_region = malloc(size);
if new_region == core::ptr::null_mut() {
return new_region;
}
if ptr == core::ptr::null_mut() {
return new_region;
}
let alloc_metadata =
(ptr as usize - core::mem::size_of::<AllocMetadata>()) as *mut AllocMetadata;
let old_page_count = unsafe { alloc_metadata.as_ref().unwrap().page_count as usize };
let old_sz = old_page_count * PAGE_SIZE - core::mem::size_of::<AllocMetadata>();
let copy_sz = core::cmp::min(old_sz, size);
for i in 0..copy_sz {
unsafe {
*(new_region as *mut u8).offset(i as isize) = *(ptr as *mut u8).offset(i as isize);
}
}
new_region
}
#[no_mangle]
pub extern "C" fn free(ptr: *mut c_void) {
if ptr == core::ptr::null_mut() {
return;
}
let alloc_metadata =
(ptr as usize - core::mem::size_of::<AllocMetadata>()) as *mut AllocMetadata;
unsafe {
let page_count = alloc_metadata.as_ref().unwrap().page_count as usize;
munmap(
alloc_metadata as *mut c_void,
page_count.wrapping_mul(PAGE_SIZE),
)
.unwrap();
}
}