use std::cell::RefCell;
use std::rc::Rc;
use inkwell::memory_manager::McjitMemoryManager;
use libc::{c_int, c_uint, c_void, size_t};
const DEFAULT_CODE_ARENA_BYTES: size_t = 256 * 1024;
#[derive(Debug)]
struct MmapRegion {
base: *mut u8,
size: size_t,
}
#[derive(Debug)]
pub struct ContiguousCodeMemoryManager {
code_regions: Rc<RefCell<Vec<MmapRegion>>>,
code_cursor: Rc<RefCell<size_t>>,
data_regions: Rc<RefCell<Vec<MmapRegion>>>,
finalized: Rc<RefCell<bool>>,
}
impl ContiguousCodeMemoryManager {
pub fn new() -> Self {
Self {
code_regions: Rc::new(RefCell::new(Vec::new())),
code_cursor: Rc::new(RefCell::new(0)),
data_regions: Rc::new(RefCell::new(Vec::new())),
finalized: Rc::new(RefCell::new(false)),
}
}
fn ensure_code_region(&self, need: size_t, alignment: c_uint) -> *mut u8 {
let mut regions = self.code_regions.borrow_mut();
let mut cursor = self.code_cursor.borrow_mut();
if let Some(region) = regions.last() {
let align = std::cmp::max(alignment as size_t, 16);
let aligned = (*cursor + align - 1) & !(align - 1);
if aligned + need <= region.size {
let ptr = unsafe { region.base.add(aligned) };
*cursor = aligned + need;
return ptr;
}
}
let page = 4096;
let want = std::cmp::max(need, DEFAULT_CODE_ARENA_BYTES);
let want = (want + page - 1) & !(page - 1);
let base = unsafe {
libc::mmap(
std::ptr::null_mut(),
want,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_ANONYMOUS | libc::MAP_PRIVATE | libc::MAP_32BIT,
-1 as c_int,
0,
)
};
if base == libc::MAP_FAILED {
return std::ptr::null_mut();
}
let base = base as *mut u8;
let ptr = base;
*cursor = need;
regions.push(MmapRegion { base, size: want });
ptr
}
}
impl Default for ContiguousCodeMemoryManager {
fn default() -> Self {
Self::new()
}
}
impl McjitMemoryManager for ContiguousCodeMemoryManager {
fn allocate_code_section(
&mut self,
size: size_t,
alignment: c_uint,
_section_id: c_uint,
_section_name: &str,
) -> *mut u8 {
let alignment = if alignment == 0 { 16 } else { alignment };
self.ensure_code_region(size, alignment)
}
fn allocate_data_section(
&mut self,
size: size_t,
alignment: c_uint,
_section_id: c_uint,
_section_name: &str,
_is_read_only: bool,
) -> *mut u8 {
let alignment = if alignment == 0 { 8 } else { alignment };
let page = 4096;
let want = std::cmp::max(size, alignment as size_t);
let want = (want + page - 1) & !(page - 1);
let base = unsafe {
libc::mmap(
std::ptr::null_mut(),
want,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_ANONYMOUS | libc::MAP_PRIVATE | libc::MAP_32BIT,
-1 as c_int,
0,
)
};
if base == libc::MAP_FAILED {
return std::ptr::null_mut();
}
let base = base as *mut u8;
self.data_regions
.borrow_mut()
.push(MmapRegion { base, size: want });
base
}
fn finalize_memory(&mut self) -> Result<(), String> {
let regions = self.code_regions.borrow();
for region in regions.iter() {
let ret = unsafe {
libc::mprotect(
region.base as *mut c_void,
region.size,
libc::PROT_READ | libc::PROT_EXEC,
)
};
if ret != 0 {
let errno = std::io::Error::last_os_error();
return Err(format!(
"mprotect RX failed for code region {:p} (size {}): {errno}",
region.base, region.size
));
}
}
*self.finalized.borrow_mut() = true;
Ok(())
}
fn destroy(&mut self) {
for region in self.code_regions.borrow_mut().drain(..) {
unsafe {
libc::munmap(region.base as *mut c_void, region.size);
}
}
for region in self.data_regions.borrow_mut().drain(..) {
unsafe {
libc::munmap(region.base as *mut c_void, region.size);
}
}
}
}