use std::cell::RefCell;
use std::collections::HashMap;
use crate::error::{Error, Result};
use crate::state::Lua;
use crate::sys::*;
struct MemoryControl {
limit: usize,
global: *mut core::ffi::c_void,
base: unsafe extern "C" fn(
*mut core::ffi::c_void,
*mut core::ffi::c_void,
usize,
usize,
) -> *mut core::ffi::c_void,
base_ud: *mut core::ffi::c_void,
}
thread_local! {
static MEMORY_CONTROLS: RefCell<HashMap<*mut core::ffi::c_void, Box<MemoryControl>>> =
RefCell::new(HashMap::new());
static MEMORY_CATEGORIES: RefCell<HashMap<*mut core::ffi::c_void, HashMap<String, u8>>> =
RefCell::new(HashMap::new());
}
unsafe fn global_key(state: *mut lua_State) -> *mut core::ffi::c_void {
unsafe { (*state).global as *mut core::ffi::c_void }
}
unsafe extern "C" fn limited_alloc(
ud: *mut core::ffi::c_void,
ptr: *mut core::ffi::c_void,
osize: usize,
nsize: usize,
) -> *mut core::ffi::c_void {
let ctrl = unsafe { &*(ud as *const MemoryControl) };
if ctrl.limit != 0 && nsize > osize {
let g = ctrl.global as *const luaur_vm::records::global_state::global_State;
let used = unsafe { (*g).totalbytes };
let projected = used.saturating_sub(osize).saturating_add(nsize);
if projected > ctrl.limit {
return core::ptr::null_mut();
}
}
unsafe { (ctrl.base)(ctrl.base_ud, ptr, osize, nsize) }
}
impl Lua {
pub fn set_memory_limit(&self, limit: usize) -> Result<usize> {
let state = self.state();
unsafe {
let key = global_key(state);
let g = (*state).global;
let prev = MEMORY_CONTROLS.with(|m| {
let mut map = m.borrow_mut();
if let Some(ctrl) = map.get_mut(&key) {
let prev = ctrl.limit;
ctrl.limit = limit;
Some(prev)
} else {
None
}
});
if let Some(prev) = prev {
return Ok(prev);
}
let base = (*g).frealloc.expect("VM allocator must be set");
let base_ud = (*g).ud;
let ctrl = Box::new(MemoryControl {
limit,
global: g as *mut core::ffi::c_void,
base,
base_ud,
});
let ctrl_ptr = (&*ctrl) as *const MemoryControl as *mut core::ffi::c_void;
MEMORY_CONTROLS.with(|m| {
m.borrow_mut().insert(key, ctrl);
});
(*g).ud = ctrl_ptr;
(*g).frealloc = Some(limited_alloc);
Ok(0)
}
}
pub fn set_memory_category(&self, name: &str) -> Result<()> {
if name.is_empty() || !name.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_') {
return Err(Error::runtime(format!(
"invalid memory category name: {name:?}"
)));
}
let state = self.state();
let key = unsafe { global_key(state) };
let id = MEMORY_CATEGORIES.with(|m| -> Result<u8> {
let mut map = m.borrow_mut();
let cats = map.entry(key).or_insert_with(|| {
let mut h = HashMap::new();
h.insert("main".to_string(), 0u8);
h
});
if let Some(&id) = cats.get(name) {
return Ok(id);
}
let next = cats.len();
if next >= 255 {
return Err(Error::runtime(
"too many memory categories (limit 255)".to_string(),
));
}
let id = next as u8;
cats.insert(name.to_string(), id);
Ok(id)
})?;
lua_setmemcat(state, id as c_int);
Ok(())
}
pub fn memory_category_bytes(&self, name: &str) -> Option<usize> {
let state = self.state();
let key = unsafe { global_key(state) };
let id =
MEMORY_CATEGORIES.with(|m| m.borrow().get(&key).and_then(|c| c.get(name).copied()))?;
unsafe {
let g = (*state).global;
Some((*g).memcatbytes[id as usize])
}
}
}