use core::sync::atomic::AtomicU64;
use core::sync::atomic::Ordering;
use std::collections::HashMap;
use std::ffi::OsStr;
use std::sync::Mutex;
use std::sync::PoisonError;
use std::thread::ThreadId;
use mlua::Function;
use mlua::Lua;
use mlua::RegistryKey;
use mlua::Table;
use mlua::Value;
use crate::config::LuaConfig;
use polyplug::Runtime;
use polyplug::error::LoaderError;
use polyplug::loader::BundleLoader;
use polyplug::loader::BundleSource;
use polyplug::loader::ManifestData;
use polyplug::logger::LoggerHandle;
use polyplug_abi::AbiError;
use polyplug_abi::AbiErrorCode;
use polyplug_abi::CallArena;
use polyplug_abi::DispatchType;
use polyplug_abi::GuestContractInstance;
use polyplug_abi::GuestContractInterface;
use polyplug_abi::HostApi;
use polyplug_abi::PluginDescriptor;
use polyplug_abi::StringView;
use polyplug_abi::SupportedLanguage;
use polyplug_abi::VmLoaderData;
use polyplug_abi::dispatch::dispatch_mechanisms::DispatchMechanisms;
use polyplug_abi::dispatch::vm_dispatch::VmDispatch;
use polyplug_abi::types::LogLevel;
use polyplug_abi::types::Version;
use polyplug_utils::BundleId;
use polyplug_utils::GuestContractId;
const GUEST_LUA_DIR: &str = env!("POLYPLUG_GUEST_LUA_DIR");
const ABI_LUA_DIR: &str = env!("POLYPLUG_ABI_LUA_DIR");
struct InitBundleGuard<'r> {
runtime: &'r Runtime,
}
impl<'r> InitBundleGuard<'r> {
fn enter(runtime: &'r Runtime, bundle_id: u64) -> Self {
runtime.push_init_bundle_id(bundle_id);
Self { runtime }
}
}
impl Drop for InitBundleGuard<'_> {
fn drop(&mut self) {
self.runtime.pop_init_bundle_id();
}
}
pub struct LuaLoaderData {
pub _vm: Lua,
pub functions: Vec<Function>,
pub arena_alloc: Function,
pub factory: Function,
pub default_impl: RegistryKey,
pub instances: Mutex<HashMap<u64, RegistryKey>>,
pub next_id: AtomicU64,
pub contract_id: GuestContractId,
pub in_dispatch_threads: Mutex<Vec<ThreadId>>,
pub dispatch_lock: Mutex<()>,
pub logger: LoggerHandle,
}
struct LuaVm(Box<LuaLoaderData>);
impl LuaVm {
fn as_ptr(&self) -> *const LuaLoaderData {
&*self.0 as *const LuaLoaderData
}
#[cfg(test)]
fn data(&self) -> &LuaLoaderData {
&self.0
}
}
struct LuaDispatchGuard<'a> {
threads: &'a Mutex<Vec<ThreadId>>,
}
impl Drop for LuaDispatchGuard<'_> {
fn drop(&mut self) {
let this: ThreadId = std::thread::current().id();
let mut guard: std::sync::MutexGuard<'_, Vec<ThreadId>> =
self.threads.lock().unwrap_or_else(PoisonError::into_inner);
if let Some(pos) = guard.iter().position(|&id| id == this) {
guard.swap_remove(pos);
}
}
}
unsafe extern "C" fn lua_create_instance(
loader_data: VmLoaderData,
host: *const HostApi,
_args: *const (),
out_instance: *mut GuestContractInstance,
) {
if out_instance.is_null() {
return;
}
if loader_data.data.is_null() {
unsafe { out_instance.write(GuestContractInstance::null()) };
return;
}
let data: &LuaLoaderData = unsafe { &*(loader_data.data as *const LuaLoaderData) };
let this_thread: ThreadId = std::thread::current().id();
{
let threads: std::sync::MutexGuard<'_, Vec<ThreadId>> = data
.in_dispatch_threads
.lock()
.unwrap_or_else(PoisonError::into_inner);
if threads.contains(&this_thread) {
drop(threads);
unsafe { out_instance.write(GuestContractInstance::null()) };
return;
}
}
let host_i64: i64 = host as usize as i64;
let instance: GuestContractInstance = match data.factory.call::<Value>(host_i64) {
Ok(value) => match data._vm.create_registry_value(value) {
Ok(key) => {
let id: u64 = data.next_id.fetch_add(1, Ordering::Relaxed);
let mut map: std::sync::MutexGuard<'_, HashMap<u64, RegistryKey>> = data
.instances
.lock()
.unwrap_or_else(PoisonError::into_inner);
map.insert(id, key);
GuestContractInstance {
data: id as usize as *mut core::ffi::c_void,
contract_id: data.contract_id,
}
}
Err(e) => {
data.logger.log(LogLevel::Error, "loader.lua", || {
format!("Lua create_instance: registry_value failed: {e}")
});
GuestContractInstance::null()
}
},
Err(e) => {
data.logger.log(LogLevel::Error, "loader.lua", || {
format!("Lua create_instance: factory call failed: {e}")
});
GuestContractInstance::null()
}
};
unsafe { out_instance.write(instance) };
}
unsafe extern "C" fn lua_destroy_instance(
_loader_data: VmLoaderData,
_host: *const HostApi,
instance: GuestContractInstance,
) {
let id: u64 = instance.data as usize as u64;
if id == 0 {
return;
}
if _loader_data.data.is_null() {
return;
}
let data: &LuaLoaderData = unsafe { &*(_loader_data.data as *const LuaLoaderData) };
let mut map: std::sync::MutexGuard<'_, HashMap<u64, RegistryKey>> = data
.instances
.lock()
.unwrap_or_else(PoisonError::into_inner);
map.remove(&id);
}
unsafe extern "C" fn lua_dispatch(
loader_data: VmLoaderData,
instance: GuestContractInstance,
fn_id: u32,
args: *const (),
out: *mut (),
arena: *mut CallArena,
out_err: *mut AbiError,
) {
let result: AbiError =
unsafe { lua_dispatch_impl(loader_data, instance, fn_id, args, out, arena) };
if !out_err.is_null() {
unsafe { out_err.write(result) };
}
}
unsafe fn lua_dispatch_impl(
loader_data: VmLoaderData,
instance: GuestContractInstance,
fn_id: u32,
args: *const (),
out: *mut (),
arena: *mut CallArena,
) -> AbiError {
let data: &LuaLoaderData = unsafe { &*(loader_data.data as *const LuaLoaderData) };
let this_thread: ThreadId = std::thread::current().id();
{
let mut threads: std::sync::MutexGuard<'_, Vec<ThreadId>> = data
.in_dispatch_threads
.lock()
.unwrap_or_else(PoisonError::into_inner);
if threads.contains(&this_thread) {
drop(threads);
return AbiError {
code: AbiErrorCode::ReentrantCall as u32,
message: StringView::null(),
};
}
threads.push(this_thread);
}
let _dispatch_guard: LuaDispatchGuard<'_> = LuaDispatchGuard {
threads: &data.in_dispatch_threads,
};
let lua_fn: &Function = match data.functions.get(fn_id as usize) {
Some(f) => f,
None => {
return AbiError {
code: AbiErrorCode::FunctionNotAvailable as u32,
message: StringView::null(),
};
}
};
let instance_id: u64 = instance.data as usize as u64;
let instance_value: Value = if instance_id == 0 {
match data._vm.registry_value::<Value>(&data.default_impl) {
Ok(v) => v,
Err(_) => {
return AbiError {
code: AbiErrorCode::Generic as u32,
message: StringView::null(),
};
}
}
} else {
let map: std::sync::MutexGuard<'_, HashMap<u64, RegistryKey>> = data
.instances
.lock()
.unwrap_or_else(PoisonError::into_inner);
match map.get(&instance_id) {
Some(key) => match data._vm.registry_value::<Value>(key) {
Ok(v) => v,
Err(_) => {
return AbiError {
code: AbiErrorCode::Generic as u32,
message: StringView::null(),
};
}
},
None => {
return AbiError {
code: AbiErrorCode::FunctionNotAvailable as u32,
message: StringView::null(),
};
}
}
};
let args_i64: i64 = args as usize as i64;
let out_i64: i64 = out as usize as i64;
let call_result: Result<(), mlua::Error> = {
let _dispatch_lock: std::sync::MutexGuard<'_, ()> = data
.dispatch_lock
.lock()
.unwrap_or_else(PoisonError::into_inner);
let arena_i64: i64 = arena as usize as i64;
lua_fn.call::<()>((
instance_value,
args_i64,
out_i64,
arena_i64,
data.arena_alloc.clone(),
))
};
match call_result {
Ok(()) => AbiError::ok(),
Err(e) => {
data.logger.log(LogLevel::Error, "loader.lua", || {
format!("Lua function call failed: {e}")
});
AbiError {
code: AbiErrorCode::Generic as u32,
message: StringView::null(),
}
}
}
}
pub struct LuaLoader {
pub config: LuaConfig,
live: Mutex<HashMap<BundleId, Vec<LuaVm>>>,
scheduled_reclaims: AtomicU64,
}
impl LuaLoader {
pub fn new(config: LuaConfig) -> Self {
Self {
config,
live: Mutex::new(HashMap::new()),
scheduled_reclaims: AtomicU64::new(0),
}
}
fn schedule_reclaim(&self, state: Vec<LuaVm>) {
for vm in state {
self.scheduled_reclaims.fetch_add(1, Ordering::Relaxed);
crossbeam_epoch::pin().defer(move || drop(vm));
}
}
fn prepend_package_field(
lua: &Lua,
bundle: &str,
field: &str,
entries: &str,
) -> Result<(), LoaderError> {
let package: Table = lua
.globals()
.get::<Table>("package")
.map_err(|e: mlua::Error| LoaderError::InitFailed {
bundle: bundle.to_owned(),
error: format!("Lua VM init failed: missing package table: {e}"),
})?;
let current: String = package
.get::<String>(field)
.unwrap_or_else(|_: mlua::Error| String::new());
let combined: String = if current.is_empty() {
entries.to_owned()
} else {
format!("{entries};{current}")
};
package
.set(field, combined)
.map_err(|e: mlua::Error| LoaderError::InitFailed {
bundle: bundle.to_owned(),
error: format!("Lua VM init failed: package.{field} update failed: {e}"),
})
}
fn build_arena_alloc(
lua: &Lua,
bundle: &str,
host_interface: *const HostApi,
) -> Result<Function, LoaderError> {
let host_addr: usize = host_interface as usize;
lua.create_function(
move |_lua_ctx: &Lua, (size, arena_addr): (u32, i64)| -> mlua::Result<i64> {
let arena: *mut CallArena = arena_addr as usize as *mut CallArena;
let ptr: *mut u8 = if arena.is_null() {
let host: *const HostApi = host_addr as *const HostApi;
if host.is_null() {
core::ptr::null_mut()
} else {
unsafe { ((*host).alloc)(host, size as usize, 1) }
}
} else {
unsafe { (*arena).alloc(size as usize, 1) }
};
Ok(ptr as usize as i64)
},
)
.map_err(|e: mlua::Error| LoaderError::InitFailed {
bundle: bundle.to_owned(),
error: format!("Lua VM init failed: arena allocator creation failed: {e}"),
})
}
fn resolve_source(
manifest: &ManifestData,
source: &BundleSource,
) -> Result<(String, String, Option<String>), LoaderError> {
match source {
BundleSource::Path(_) => {
let bundle_path: std::path::PathBuf = if !manifest.file.is_empty() {
manifest.path.join(&manifest.file)
} else {
return Err(LoaderError::ManifestMissingFile {
bundle: manifest.name.clone(),
});
};
if !bundle_path.exists() {
return Err(LoaderError::InitFailed {
bundle: manifest.name.clone(),
error: format!(
"Lua script load failed at {}: file does not exist",
bundle_path.display()
),
});
}
let source_text: String =
std::fs::read_to_string(&bundle_path).map_err(|e: std::io::Error| {
LoaderError::InitFailed {
bundle: manifest.name.clone(),
error: format!(
"Lua script load failed at {}: {}",
bundle_path.display(),
e
),
}
})?;
let chunk_name: String = bundle_path
.file_name()
.map(|n: &OsStr| n.to_string_lossy().into_owned())
.unwrap_or_else(|| bundle_path.display().to_string());
let bundle_dir_str: String = manifest.path.to_string_lossy().into_owned();
Ok((source_text, chunk_name, Some(bundle_dir_str)))
}
BundleSource::Code(code) => Ok((code.clone(), manifest.name.clone(), None)),
BundleSource::Bytes(bytes) => {
let source_text: String =
String::from_utf8(bytes.clone()).map_err(|_: std::string::FromUtf8Error| {
LoaderError::InvalidSourceEncoding {
loader: "lua",
source_kind: source.kind(),
bundle: manifest.name.clone(),
}
})?;
Ok((source_text, manifest.name.clone(), None))
}
}
}
fn load_inner(
&self,
manifest: &ManifestData,
source: &BundleSource,
runtime: &Runtime,
) -> Result<(), LoaderError> {
let (source_text, chunk_name, bundle_dir): (String, String, Option<String>) =
Self::resolve_source(manifest, source)?;
let bundle_id: u64 = manifest.id;
let bundle_dir_str: String = bundle_dir
.clone()
.unwrap_or_else(|| manifest.path.to_string_lossy().into_owned());
let lua: Lua = unsafe { Lua::unsafe_new() };
let guest_dir_fwd: String = GUEST_LUA_DIR.replace('\\', "/");
let abi_dir_fwd: String = ABI_LUA_DIR.replace('\\', "/");
let cpath_ext: &str = if cfg!(windows) { "dll" } else { "so" };
let path_entries: String = match &bundle_dir {
Some(dir) => {
let bundle_dir_fwd: String = dir.replace('\\', "/");
format!(
"{bundle_dir_fwd}/?.lua;{bundle_dir_fwd}/?.init.lua;{guest_dir_fwd}/?.lua;{abi_dir_fwd}/?.lua"
)
}
None => format!("{guest_dir_fwd}/?.lua;{abi_dir_fwd}/?.lua"),
};
Self::prepend_package_field(&lua, &manifest.name, "path", &path_entries)?;
if let Some(dir) = &bundle_dir {
let bundle_dir_fwd: String = dir.replace('\\', "/");
let cpath_entries: String = format!("{bundle_dir_fwd}/?.{cpath_ext}");
Self::prepend_package_field(&lua, &manifest.name, "cpath", &cpath_entries)?;
}
lua.load(&source_text)
.set_name(&chunk_name)
.exec()
.map_err(|e: mlua::Error| LoaderError::InitFailed {
bundle: manifest.name.clone(),
error: format!("Lua script load failed for {}: {}", chunk_name, e),
})?;
let bundle_name: String = chunk_name;
let init_fn: Function =
lua.globals()
.get::<Function>("polyplug_init")
.map_err(|_: mlua::Error| LoaderError::InitFailed {
bundle: bundle_name.clone(),
error: format!(
"Lua plugin missing polyplug_init function: bundle={}",
bundle_name
),
})?;
let host_interface: *const HostApi = runtime.as_context_ptr();
let arena_alloc: Function = Self::build_arena_alloc(&lua, &bundle_name, host_interface)?;
let _init_window: InitBundleGuard<'_> = InitBundleGuard::enter(runtime, bundle_id);
let bundle_path_static: &'static str = Box::leak(bundle_dir_str.clone().into_boxed_str());
let ctx: polyplug_abi::BundleInitContext = polyplug_abi::BundleInitContext {
bundle_path: polyplug_abi::StringView {
ptr: bundle_path_static.as_ptr(),
len: bundle_path_static.len(),
},
bundle_id,
};
let host_interface_i64: i64 = host_interface as usize as i64;
let ctx_ptr: i64 = &ctx as *const polyplug_abi::BundleInitContext as i64;
let (handlers, abi_error): (Table, Table) = init_fn
.call::<(Table, Table)>((host_interface_i64, ctx_ptr))
.map_err(|e: mlua::Error| LoaderError::InitFailed {
bundle: bundle_name.clone(),
error: format!("Lua polyplug_init failed: {}", e),
})?;
let init_code: u32 = abi_error.get::<u32>("code").unwrap_or(0_u32);
if init_code != AbiErrorCode::Ok as u32 {
let init_message: Option<String> = abi_error
.get::<String>("message")
.ok()
.filter(|s: &String| !s.is_empty());
let error: String = match init_message {
Some(msg) => msg,
None => format!("Lua polyplug_init returned error code {}", init_code),
};
return Err(LoaderError::InitFailed {
bundle: bundle_name.clone(),
error,
});
}
let mut bundle_vm_state: Vec<LuaVm> = Vec::new();
let mut registered: u32 = 0_u32;
for pair in handlers.pairs::<String, Table>() {
let (contract_name_str, entry): (String, Table) =
pair.map_err(|e: mlua::Error| LoaderError::InitFailed {
bundle: bundle_name.clone(),
error: format!("Lua handlers iteration error: {}", e),
})?;
let contract_version: u32 = entry.get::<u32>("contract_version").unwrap_or(1_u32);
let plugin_name_str: String = entry
.get::<String>("plugin_name")
.unwrap_or_else(|_: mlua::Error| bundle_name.clone());
let functions_table: Table =
entry.get::<Table>("functions").map_err(|e: mlua::Error| {
LoaderError::InitFailed {
bundle: bundle_name.clone(),
error: format!(
"Lua handlers error: missing functions table for contract '{}': {}",
contract_name_str, e
),
}
})?;
let function_count: u32 = {
let mut count: u32 = 0_u32;
let mut idx: i64 = 0_i64;
loop {
let v: Value = functions_table.get::<Value>(idx).unwrap_or(Value::Nil);
if v == Value::Nil {
break;
}
count += 1;
idx += 1;
}
count
};
let mut lua_functions: Vec<Function> = Vec::with_capacity(function_count as usize);
for slot_idx in 0..function_count {
let lua_fn: Function = functions_table.get::<Function>(slot_idx as i64).map_err(
|e: mlua::Error| LoaderError::InitFailed {
bundle: bundle_name.clone(),
error: format!(
"Lua function slot {} error for contract '{}': {}",
slot_idx, contract_name_str, e
),
},
)?;
lua_functions.push(lua_fn);
}
let factory: Function =
entry.get::<Function>("factory").map_err(|e: mlua::Error| {
LoaderError::InitFailed {
bundle: bundle_name.clone(),
error: format!(
"Lua handlers error: contract '{}' needs a `factory` function: {}",
contract_name_str, e
),
}
})?;
let host_i64: i64 = host_interface as usize as i64;
let default_value: Value =
factory.call::<Value>(host_i64).map_err(|e: mlua::Error| {
LoaderError::InitFailed {
bundle: bundle_name.clone(),
error: format!(
"factory for contract '{}' failed building the default instance: {}",
contract_name_str, e
),
}
})?;
let default_impl: RegistryKey =
lua.create_registry_value(default_value)
.map_err(|e: mlua::Error| LoaderError::InitFailed {
bundle: bundle_name.clone(),
error: format!(
"failed to register default impl for contract '{}': {}",
contract_name_str, e
),
})?;
let cid: GuestContractId = GuestContractId::new(&contract_name_str, contract_version);
let loader_data: LuaVm = LuaVm(Box::new(LuaLoaderData {
_vm: lua.clone(),
functions: lua_functions,
arena_alloc: arena_alloc.clone(),
factory,
default_impl,
instances: Mutex::new(HashMap::new()),
next_id: AtomicU64::new(1),
contract_id: cid,
in_dispatch_threads: Mutex::new(Vec::new()),
dispatch_lock: Mutex::new(()),
logger: runtime.logger(),
}));
let loader_data_ptr: *const LuaLoaderData = loader_data.as_ptr();
bundle_vm_state.push(loader_data);
let plugin_interface: GuestContractInterface = GuestContractInterface {
contract_id: cid,
contract_version: Version {
major: contract_version,
minor: 0,
patch: 0,
},
dispatch_type: DispatchType::VirtualMachine,
create_instance: lua_create_instance,
destroy_instance: lua_destroy_instance,
dispatch: DispatchMechanisms {
vm: VmDispatch {
call: lua_dispatch,
loader_data: VmLoaderData {
data: loader_data_ptr as *mut LuaLoaderData as *mut core::ffi::c_void,
},
},
},
};
let interface_for_reg: GuestContractInterface = plugin_interface;
let static_interface: *const GuestContractInterface =
&interface_for_reg as *const GuestContractInterface;
let contract_display_name: String =
format!("{}@{}", contract_name_str, contract_version);
let descriptor: PluginDescriptor = PluginDescriptor {
name: StringView {
ptr: plugin_name_str.as_ptr(),
len: plugin_name_str.len(),
},
contract_name: StringView {
ptr: contract_display_name.as_ptr(),
len: contract_display_name.len(),
},
version: Version {
major: contract_version,
minor: 0,
patch: 0,
},
};
let mut reg_result: AbiError = AbiError::ok();
unsafe {
((*host_interface).register_guest_contract)(
host_interface,
&descriptor as *const PluginDescriptor,
static_interface,
&mut reg_result,
)
};
if !reg_result.is_ok() {
self.schedule_reclaim(bundle_vm_state);
return Err(LoaderError::InitFailed {
bundle: bundle_name,
error: format!(
"register_guest_contract error for contract '{}': code={:?}",
contract_name_str, reg_result.code
),
});
}
registered += 1;
}
if registered == 0 {
return Err(LoaderError::InitFailed {
bundle: bundle_name,
error: "Lua plugin registered no contracts: polyplug_init returned an empty registrations table".to_owned(),
});
}
let superseded: Option<Vec<LuaVm>> = {
let mut live: std::sync::MutexGuard<'_, HashMap<BundleId, Vec<LuaVm>>> =
self.live.lock().unwrap_or_else(PoisonError::into_inner);
live.insert(BundleId::from_u64(bundle_id), bundle_vm_state)
};
if let Some(old_state) = superseded {
self.schedule_reclaim(old_state);
}
Ok(())
}
#[cfg(test)]
fn live_vm_count(&self, bundle_id: BundleId) -> usize {
let live: std::sync::MutexGuard<'_, HashMap<BundleId, Vec<LuaVm>>> =
self.live.lock().unwrap_or_else(PoisonError::into_inner);
live.get(&bundle_id).map(Vec::len).unwrap_or(0)
}
#[cfg(test)]
fn scheduled_reclaim_count(&self) -> u64 {
self.scheduled_reclaims.load(Ordering::Relaxed)
}
}
impl BundleLoader for LuaLoader {
fn loader_name(&self) -> &'static str {
"lua"
}
fn loader_language(&self) -> SupportedLanguage {
SupportedLanguage::Lua
}
fn supports_hot_reload(&self) -> bool {
true
}
fn load(
&self,
manifest: &ManifestData,
source: &BundleSource,
runtime: &Runtime,
) -> Result<(), LoaderError> {
self.load_inner(manifest, source, runtime)
}
fn reload(&self, manifest: &ManifestData, runtime: &Runtime) -> Result<(), LoaderError> {
self.load_inner(
manifest,
&BundleSource::Path(manifest.path.clone()),
runtime,
)
}
fn unload(&self, bundle_id: BundleId, _runtime: &Runtime) -> Result<(), LoaderError> {
let state: Vec<LuaVm> = {
let mut live: std::sync::MutexGuard<'_, HashMap<BundleId, Vec<LuaVm>>> =
self.live.lock().unwrap_or_else(PoisonError::into_inner);
match live.remove(&bundle_id) {
Some(v) => v,
None => return Ok(()),
}
};
self.schedule_reclaim(state);
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
use core::sync::atomic::AtomicUsize;
use core::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::Barrier;
use super::*;
#[test]
fn lua_loader_name() {
let loader: LuaLoader = LuaLoader::new(LuaConfig::default());
assert_eq!(loader.loader_name(), "lua");
}
fn unload_plugin_script() -> &'static [u8] {
br#"
local function impl_noop(_instance, _a, _o, _arena_ptr, _arena_alloc) end
function polyplug_init(_host, _ctx)
local registrations = {
["test.unload"] = {
contract_version = 1,
plugin_name = "test-unload",
factory = function(_host) return {} end,
functions = { [0] = impl_noop },
},
}
return registrations, { code = 0 }
end
"#
}
fn write_unload_bundle(name: &str) -> (tempfile::TempDir, ManifestData) {
let dir: tempfile::TempDir = tempfile::tempdir().expect("tempdir");
std::fs::write(dir.path().join("bundle.lua"), unload_plugin_script())
.expect("write bundle.lua");
let manifest: ManifestData = ManifestData {
id: polyplug_utils::bundle_id(name),
name: name.to_owned(),
loader: "lua".to_owned(),
file: "bundle.lua".to_owned(),
path: dir.path().to_path_buf(),
version: String::new(),
provides: Vec::new(),
function_count: std::collections::HashMap::new(),
dependencies: Vec::new(),
needs_reinit_on_dep_reload: false,
bundle_dependencies: Vec::new(),
};
(dir, manifest)
}
#[test]
fn unload_removes_live_and_schedules_reclaim() {
let loader: LuaLoader = LuaLoader::new(LuaConfig::default());
let runtime: std::sync::Arc<polyplug::Runtime> = polyplug::runtime::RuntimeBuilder::new()
.loader(LuaLoader::new(LuaConfig::default()))
.build()
.expect("runtime build must succeed");
let (_dir, manifest): (tempfile::TempDir, ManifestData) =
write_unload_bundle("lua_unload_quiescent");
let bundle_id: BundleId = BundleId::from_u64(manifest.id);
loader
.load(
&manifest,
&BundleSource::Path(manifest.path.clone()),
&runtime,
)
.expect("load must succeed");
assert_eq!(
loader.live_vm_count(bundle_id),
1,
"one contract's VM state must be owned after load"
);
loader
.unload(bundle_id, &runtime)
.expect("unload must succeed");
assert_eq!(
loader.live_vm_count(bundle_id),
0,
"unload must remove the bundle's VM state from the live map"
);
assert_eq!(
loader.scheduled_reclaim_count(),
1,
"unload must schedule the VM state for epoch-deferred reclaim"
);
}
#[test]
fn unload_load_loop_is_bounded() {
let loader: LuaLoader = LuaLoader::new(LuaConfig::default());
let runtime: std::sync::Arc<polyplug::Runtime> = polyplug::runtime::RuntimeBuilder::new()
.loader(LuaLoader::new(LuaConfig::default()))
.build()
.expect("runtime build must succeed");
let (_dir, manifest): (tempfile::TempDir, ManifestData) =
write_unload_bundle("lua_unload_loop");
let bundle_id: BundleId = BundleId::from_u64(manifest.id);
for _ in 0..5 {
loader
.load(
&manifest,
&BundleSource::Path(manifest.path.clone()),
&runtime,
)
.expect("load must succeed");
assert_eq!(
loader.live_vm_count(bundle_id),
1,
"live map must hold exactly one entry per load"
);
loader
.unload(bundle_id, &runtime)
.expect("unload must succeed");
runtime
.registry()
.invalidate_bundle(bundle_id)
.expect("invalidate must succeed");
assert_eq!(
loader.live_vm_count(bundle_id),
0,
"unload must reclaim the entry each iteration"
);
}
assert_eq!(
loader.scheduled_reclaim_count(),
5,
"each of the 5 unloads must schedule its VM state for epoch-deferred reclaim"
);
}
#[test]
fn reload_replaces_live_and_reclaims_superseded_vm() {
let loader: LuaLoader = LuaLoader::new(LuaConfig::default());
let runtime: std::sync::Arc<polyplug::Runtime> = polyplug::runtime::RuntimeBuilder::new()
.loader(LuaLoader::new(LuaConfig::default()))
.build()
.expect("runtime build must succeed");
let (_dir, manifest): (tempfile::TempDir, ManifestData) =
write_unload_bundle("lua_reload_reclaim");
let bundle_id: BundleId = BundleId::from_u64(manifest.id);
loader
.load(
&manifest,
&BundleSource::Path(manifest.path.clone()),
&runtime,
)
.expect("first load must succeed");
assert_eq!(
loader.live_vm_count(bundle_id),
1,
"first load installs exactly one live VM entry"
);
assert_eq!(
loader.scheduled_reclaim_count(),
0,
"nothing is superseded by the first load"
);
runtime
.registry()
.invalidate_bundle(bundle_id)
.expect("invalidate must succeed");
loader
.load(
&manifest,
&BundleSource::Path(manifest.path.clone()),
&runtime,
)
.expect("second load (reload) must succeed");
assert_eq!(
loader.live_vm_count(bundle_id),
1,
"reload replaces the live VM set — the live map must not grow"
);
assert_eq!(
loader.scheduled_reclaim_count(),
1,
"reload must schedule the superseded VM set for epoch-deferred reclaim"
);
}
#[test]
fn unload_schedules_reclaim_even_when_in_flight() {
let loader: LuaLoader = LuaLoader::new(LuaConfig::default());
let runtime: std::sync::Arc<polyplug::Runtime> = polyplug::runtime::RuntimeBuilder::new()
.loader(LuaLoader::new(LuaConfig::default()))
.build()
.expect("runtime build must succeed");
let (_dir, manifest): (tempfile::TempDir, ManifestData) =
write_unload_bundle("lua_unload_deferred");
let bundle_id: BundleId = BundleId::from_u64(manifest.id);
loader
.load(
&manifest,
&BundleSource::Path(manifest.path.clone()),
&runtime,
)
.expect("load must succeed");
{
let live: std::sync::MutexGuard<'_, HashMap<BundleId, Vec<LuaVm>>> =
loader.live.lock().unwrap_or_else(PoisonError::into_inner);
let state: &Vec<LuaVm> = live.get(&bundle_id).expect("bundle must be live");
let mut threads: std::sync::MutexGuard<'_, Vec<ThreadId>> = state[0]
.data()
.in_dispatch_threads
.lock()
.unwrap_or_else(PoisonError::into_inner);
threads.push(std::thread::current().id());
}
loader
.unload(bundle_id, &runtime)
.expect("unload must succeed even when marked in-flight");
assert_eq!(
loader.live_vm_count(bundle_id),
0,
"unload must remove the bundle from the live map"
);
assert_eq!(
loader.scheduled_reclaim_count(),
1,
"unload must schedule epoch-deferred reclaim even when marked in-flight"
);
}
fn make_loader_data(
vm: Lua,
functions: Vec<Function>,
) -> (VmLoaderData, &'static LuaLoaderData) {
let factory: Function = vm
.load("return function(_host) return {} end")
.eval::<Function>()
.expect("create trivial factory");
let default_impl: RegistryKey = vm
.create_registry_value(Value::Table(
vm.create_table().expect("create default impl table"),
))
.expect("register default impl");
let arena_alloc: Function =
LuaLoader::build_arena_alloc(&vm, "make_loader_data", core::ptr::null())
.expect("build arena allocator");
let boxed: Box<LuaLoaderData> = Box::new(LuaLoaderData {
_vm: vm,
functions,
arena_alloc,
factory,
default_impl,
instances: Mutex::new(HashMap::new()),
next_id: AtomicU64::new(1),
contract_id: GuestContractId::new("test.make_loader_data", 1),
in_dispatch_threads: Mutex::new(Vec::new()),
dispatch_lock: Mutex::new(()),
logger: LoggerHandle::default_stderr(),
});
let ptr: *mut LuaLoaderData = Box::into_raw(boxed);
let data_ref: &'static LuaLoaderData = unsafe { &*ptr };
let vm_loader_data: VmLoaderData = VmLoaderData {
data: ptr as *mut core::ffi::c_void,
};
(vm_loader_data, data_ref)
}
#[test]
fn lua_dispatch_normal_call_succeeds() {
let lua: Lua = unsafe { Lua::unsafe_new() };
let noop: Function = lua
.create_function(|_, (_instance, _a, _o): (Value, i64, i64)| Ok(()))
.expect("create_function should succeed");
let (vm_loader_data, data_ref): (VmLoaderData, &'static LuaLoaderData) =
make_loader_data(lua, vec![noop]);
let mut out_buf: i32 = 0;
let err: AbiError = unsafe {
lua_dispatch_impl(
vm_loader_data,
GuestContractInstance::null(),
0,
core::ptr::null(),
&mut out_buf as *mut i32 as *mut (),
core::ptr::null_mut(),
)
};
assert!(err.is_ok(), "normal dispatch should return Ok");
assert!(
data_ref
.in_dispatch_threads
.lock()
.expect("tracking mutex must not be poisoned")
.is_empty(),
"thread tracking must be empty after a normal dispatch"
);
}
#[test]
fn lua_dispatch_reentrant_call_is_rejected_and_vm_recovers() {
let lua: Lua = unsafe { Lua::unsafe_new() };
let loader_data_cell: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
let cell_for_fn: Arc<AtomicUsize> = Arc::clone(&loader_data_cell);
let reentrant_fn: Function = lua
.create_function(
move |lua_ctx: &Lua, (_instance, _a, _o): (Value, i64, i64)| {
let ptr_usize: usize = cell_for_fn.load(Ordering::Acquire);
let vm_loader_data: VmLoaderData = VmLoaderData {
data: ptr_usize as *mut core::ffi::c_void,
};
let nested: AbiError = unsafe {
lua_dispatch_impl(
vm_loader_data,
GuestContractInstance::null(),
0,
core::ptr::null(),
core::ptr::null_mut(),
core::ptr::null_mut(),
)
};
lua_ctx.globals().set("_nested_code", nested.code as i64)?;
Ok(())
},
)
.expect("create_function should succeed");
let (vm_loader_data, data_ref): (VmLoaderData, &'static LuaLoaderData) =
make_loader_data(lua, vec![reentrant_fn]);
loader_data_cell.store(vm_loader_data.data as usize, Ordering::Release);
let outer: AbiError = unsafe {
lua_dispatch_impl(
vm_loader_data,
GuestContractInstance::null(),
0,
core::ptr::null(),
core::ptr::null_mut(),
core::ptr::null_mut(),
)
};
assert!(outer.is_ok(), "outer dispatch should complete Ok");
let nested_code: i64 = data_ref
._vm
.globals()
.get::<i64>("_nested_code")
.expect("nested code global must be set by the guest fn");
assert_eq!(
nested_code,
AbiErrorCode::ReentrantCall as i64,
"nested same-VM dispatch must return ReentrantCall"
);
assert!(
data_ref
.in_dispatch_threads
.lock()
.expect("tracking mutex must not be poisoned")
.is_empty(),
"thread tracking must be empty after the outer dispatch returns"
);
let recovered: AbiError = unsafe {
lua_dispatch_impl(
vm_loader_data,
GuestContractInstance::null(),
0,
core::ptr::null(),
core::ptr::null_mut(),
core::ptr::null_mut(),
)
};
assert!(
recovered.is_ok(),
"VM must remain usable after a rejected reentrant call"
);
}
#[test]
fn lua_dispatch_cross_thread_concurrent_call_succeeds() {
let lua: Lua = unsafe { Lua::unsafe_new() };
let entered: Arc<Barrier> = Arc::new(Barrier::new(2));
let release: Arc<Barrier> = Arc::new(Barrier::new(2));
let entered_for_fn: Arc<Barrier> = Arc::clone(&entered);
let release_for_fn: Arc<Barrier> = Arc::clone(&release);
let blocking_fn: Function = lua
.create_function(
move |_lua_ctx: &Lua, (_instance, _a, _o): (Value, i64, i64)| {
entered_for_fn.wait();
release_for_fn.wait();
Ok(())
},
)
.expect("create_function should succeed");
let noop_fn: Function = lua
.create_function(|_lua_ctx: &Lua, (_instance, _a, _o): (Value, i64, i64)| Ok(()))
.expect("create_function should succeed");
let (vm_loader_data, data_ref): (VmLoaderData, &'static LuaLoaderData) =
make_loader_data(lua, vec![blocking_fn, noop_fn]);
let data_addr: usize = vm_loader_data.data as usize;
let handle: std::thread::JoinHandle<AbiError> = std::thread::spawn(move || {
let vm_loader_data_a: VmLoaderData = VmLoaderData {
data: data_addr as *mut core::ffi::c_void,
};
unsafe {
lua_dispatch_impl(
vm_loader_data_a,
GuestContractInstance::null(),
0,
core::ptr::null(),
core::ptr::null_mut(),
core::ptr::null_mut(),
)
}
});
entered.wait();
let main_handle: std::thread::JoinHandle<AbiError> = std::thread::spawn(move || {
let vm_loader_data_b: VmLoaderData = VmLoaderData {
data: data_addr as *mut core::ffi::c_void,
};
unsafe {
lua_dispatch_impl(
vm_loader_data_b,
GuestContractInstance::null(),
1,
core::ptr::null(),
core::ptr::null_mut(),
core::ptr::null_mut(),
)
}
});
release.wait();
let a_result: AbiError = handle.join().expect("thread A must not panic");
let b_result: AbiError = main_handle
.join()
.expect("concurrent thread must not panic");
assert!(a_result.is_ok(), "the initial dispatch must succeed");
assert!(
b_result.is_ok(),
"a concurrent cross-thread dispatch must succeed, not return ReentrantCall (got code {})",
b_result.code
);
assert!(
data_ref
.in_dispatch_threads
.lock()
.expect("tracking mutex must not be poisoned")
.is_empty(),
"thread tracking must be empty after both dispatches return"
);
}
#[test]
fn lua_dispatch_concurrent_arena_returns_stay_isolated() {
let lua: Lua = unsafe { Lua::unsafe_new() };
let make_alloc_fn = |lua: &Lua| -> Function {
lua.create_function(
|_lua_ctx: &Lua,
(_instance, _a, out_ptr, arena_ptr, arena_alloc): (
Value,
i64,
i64,
i64,
Function,
)|
-> mlua::Result<()> {
let addr: i64 = arena_alloc.call::<i64>((64_u32, arena_ptr))?;
let out: *mut i64 = out_ptr as usize as *mut i64;
unsafe { *out = addr };
Ok(())
},
)
.expect("create_function should succeed")
};
let fn0: Function = make_alloc_fn(&lua);
let fn1: Function = make_alloc_fn(&lua);
let (vm_loader_data, _data_ref): (VmLoaderData, &'static LuaLoaderData) =
make_loader_data(lua, vec![fn0, fn1]);
let data_addr: usize = vm_loader_data.data as usize;
let buf_a: &'static mut [u8] = Box::leak(vec![0_u8; 4096].into_boxed_slice());
let buf_b: &'static mut [u8] = Box::leak(vec![0_u8; 4096].into_boxed_slice());
let a_lo: usize = buf_a.as_ptr() as usize;
let a_hi: usize = a_lo + buf_a.len();
let b_lo: usize = buf_b.as_ptr() as usize;
let b_hi: usize = b_lo + buf_b.len();
let arena_a: &'static mut CallArena =
Box::leak(Box::new(CallArena::new(buf_a, core::ptr::null())));
let arena_b: &'static mut CallArena =
Box::leak(Box::new(CallArena::new(buf_b, core::ptr::null())));
let arena_a_addr: usize = arena_a as *mut CallArena as usize;
let arena_b_addr: usize = arena_b as *mut CallArena as usize;
const ITERS: usize = 2_000;
let start: Arc<Barrier> = Arc::new(Barrier::new(2));
let start_a: Arc<Barrier> = Arc::clone(&start);
let start_b: Arc<Barrier> = Arc::clone(&start);
let handle_a: std::thread::JoinHandle<Result<(), String>> = std::thread::spawn(move || {
start_a.wait();
for i in 0..ITERS {
let arena: &mut CallArena = unsafe { &mut *(arena_a_addr as *mut CallArena) };
arena.reset();
let mut out: i64 = 0;
let vm: VmLoaderData = VmLoaderData {
data: data_addr as *mut core::ffi::c_void,
};
let err: AbiError = unsafe {
lua_dispatch_impl(
vm,
GuestContractInstance::null(),
0,
core::ptr::null(),
&mut out as *mut i64 as *mut (),
arena_a_addr as *mut CallArena,
)
};
if !err.is_ok() {
return Err(format!("A iter {i}: dispatch failed code={}", err.code));
}
let p: usize = out as usize;
if !(p >= a_lo && p < a_hi) {
return Err(format!(
"A iter {i}: allocation {p:#x} escaped arena A buffer [{a_lo:#x}, {a_hi:#x}) — threaded arena was misattributed"
));
}
}
Ok(())
});
let handle_b: std::thread::JoinHandle<Result<(), String>> = std::thread::spawn(move || {
start_b.wait();
for i in 0..ITERS {
let arena: &mut CallArena = unsafe { &mut *(arena_b_addr as *mut CallArena) };
arena.reset();
let mut out: i64 = 0;
let vm: VmLoaderData = VmLoaderData {
data: data_addr as *mut core::ffi::c_void,
};
let err: AbiError = unsafe {
lua_dispatch_impl(
vm,
GuestContractInstance::null(),
1,
core::ptr::null(),
&mut out as *mut i64 as *mut (),
arena_b_addr as *mut CallArena,
)
};
if !err.is_ok() {
return Err(format!("B iter {i}: dispatch failed code={}", err.code));
}
let p: usize = out as usize;
if !(p >= b_lo && p < b_hi) {
return Err(format!(
"B iter {i}: allocation {p:#x} escaped arena B buffer [{b_lo:#x}, {b_hi:#x}) — threaded arena was misattributed"
));
}
}
Ok(())
});
let a_outcome: Result<(), String> = handle_a.join().expect("thread A must not panic");
let b_outcome: Result<(), String> = handle_b.join().expect("thread B must not panic");
if let Err(e) = a_outcome {
panic!("{e}");
}
if let Err(e) = b_outcome {
panic!("{e}");
}
}
#[test]
fn prepend_package_field_does_not_execute_injected_code() {
let lua: Lua = unsafe { Lua::unsafe_new() };
let malicious: &str = "/tmp/evil\";_G._INJECTED=true;package.path=\"x/?.lua";
let entries: String = format!("{malicious}/?.lua");
LuaLoader::prepend_package_field(&lua, "test-bundle", "path", &entries)
.expect("prepend_package_field should succeed for any path bytes");
let injected: Value = lua
.globals()
.get::<Value>("_INJECTED")
.expect("globals lookup should not fail");
assert_eq!(
injected,
Value::Nil,
"injected Lua code executed — package.path was interpreted as source"
);
let package: Table = lua
.globals()
.get::<Table>("package")
.expect("package table must exist");
let path: String = package
.get::<String>("path")
.expect("package.path must be a string");
assert!(
path.contains(malicious),
"package.path should contain the raw entry verbatim: {path}"
);
}
}