use std::collections::HashMap;
use std::sync::RwLock;
use mlua::Function;
use mlua::Lua;
use polyplug::host_bridge::BridgeError;
use polyplug::host_bridge::RuntimeLanguageBridge;
use polyplug_abi::AbiError;
use polyplug_abi::AbiErrorCode;
use polyplug_abi::StringView;
use polyplug_abi::SupportedLanguage;
pub struct LuaHostBridge {
lua: Lua,
contracts: RwLock<HashMap<u64, Function>>,
}
impl LuaHostBridge {
pub fn new() -> LuaHostBridge {
let lua: Lua = unsafe { Lua::unsafe_new() };
LuaHostBridge {
lua,
contracts: RwLock::new(HashMap::new()),
}
}
pub fn with_capacity(capacity: usize) -> LuaHostBridge {
let lua: Lua = unsafe { Lua::unsafe_new() };
LuaHostBridge {
lua,
contracts: RwLock::new(HashMap::with_capacity(capacity)),
}
}
pub fn lua(&self) -> &Lua {
&self.lua
}
}
impl Default for LuaHostBridge {
fn default() -> LuaHostBridge {
LuaHostBridge::new()
}
}
impl RuntimeLanguageBridge for LuaHostBridge {
fn runtime_type(&self) -> SupportedLanguage {
SupportedLanguage::Lua
}
fn register_host_contract(
&mut self,
contract_id: u64,
implementation: Box<dyn core::any::Any>,
) -> Result<(), BridgeError> {
let callable: Function = implementation
.downcast::<Function>()
.map_err(|_| BridgeError::TypeMismatch {
contract_id,
expected: "Function".to_owned(),
got: "unknown type".to_owned(),
})
.map(|boxed| *boxed)?;
let mut contracts: std::sync::RwLockWriteGuard<'_, HashMap<u64, Function>> = self
.contracts
.write()
.map_err(|_| BridgeError::VmRegistrationFailed {
contract_id,
reason: "failed to acquire write lock on contracts map".to_owned(),
})?;
if contracts.contains_key(&contract_id) {
return Err(BridgeError::DuplicateContract { contract_id });
}
contracts.insert(contract_id, callable);
Ok(())
}
unsafe fn call_host_contract(
&self,
contract_id: u64,
fn_id: u32,
args: *const (),
out: *mut (),
) -> AbiError {
let contracts_guard: std::sync::RwLockReadGuard<'_, HashMap<u64, Function>> =
match self.contracts.read() {
Ok(guard) => guard,
Err(_) => {
return AbiError {
code: AbiErrorCode::HostContractCallFailed as u32,
message: StringView::from_static(
b"failed to acquire read lock on contracts map",
),
};
}
};
let callable: &Function = match contracts_guard.get(&contract_id) {
Some(f) => f,
None => {
return AbiError {
code: AbiErrorCode::HostContractNotFound as u32,
message: StringView::from_static(b"host contract not found"),
};
}
};
let fn_id_arg: u32 = fn_id;
let args_ptr: i64 = args as usize as i64;
let out_ptr: i64 = out as usize as i64;
let call_result: Result<(), mlua::Error> =
callable.call::<()>((fn_id_arg, args_ptr, out_ptr));
match call_result {
Ok(()) => AbiError::ok(),
Err(e) => {
let message: String = format!("Lua exception: {}", e);
let message_static: &'static str = Box::leak(message.into_boxed_str());
AbiError {
code: AbiErrorCode::HostContractCallFailed as u32,
message: StringView {
ptr: message_static.as_ptr(),
len: message_static.len(),
},
}
}
}
}
}
unsafe impl Send for LuaHostBridge {}
unsafe impl Sync for LuaHostBridge {}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
use super::*;
#[test]
fn bridge_new_creates_empty_bridge() {
let bridge: LuaHostBridge = LuaHostBridge::new();
let contracts: std::sync::RwLockReadGuard<'_, HashMap<u64, Function>> =
bridge.contracts.read().expect("read lock");
assert!(contracts.is_empty());
}
#[test]
fn bridge_default_creates_empty_bridge() {
let bridge: LuaHostBridge = LuaHostBridge::default();
let contracts: std::sync::RwLockReadGuard<'_, HashMap<u64, Function>> =
bridge.contracts.read().expect("read lock");
assert!(contracts.is_empty());
}
#[test]
fn bridge_with_capacity_creates_empty_bridge() {
let bridge: LuaHostBridge = LuaHostBridge::with_capacity(10);
let contracts: std::sync::RwLockReadGuard<'_, HashMap<u64, Function>> =
bridge.contracts.read().expect("read lock");
assert!(contracts.is_empty());
}
#[test]
fn bridge_runtime_type_returns_lua() {
let bridge: LuaHostBridge = LuaHostBridge::new();
assert_eq!(bridge.runtime_type(), SupportedLanguage::Lua);
}
#[test]
fn bridge_lua_returns_reference() {
let bridge: LuaHostBridge = LuaHostBridge::new();
let lua: &Lua = bridge.lua();
let result: i64 = lua.load("return 42").eval::<i64>().expect("eval");
assert_eq!(result, 42);
}
#[test]
fn bridge_register_host_contract_success() {
let mut bridge: LuaHostBridge = LuaHostBridge::new();
let callable: Function = bridge
.lua()
.load("function(fn_id, args, out) return fn_id end")
.eval::<Function>()
.expect("eval function");
let result: Result<(), BridgeError> =
bridge.register_host_contract(1234, Box::new(callable));
assert!(result.is_ok());
let contracts: std::sync::RwLockReadGuard<'_, HashMap<u64, Function>> =
bridge.contracts.read().expect("read lock");
assert!(contracts.contains_key(&1234));
}
#[test]
fn bridge_register_host_contract_duplicate_fails() {
let mut bridge: LuaHostBridge = LuaHostBridge::new();
let callable: Function = bridge
.lua()
.load("function(fn_id, args, out) return fn_id end")
.eval::<Function>()
.expect("eval function");
let callable2: Function = bridge
.lua()
.load("function(fn_id, args, out) return fn_id * 2 end")
.eval::<Function>()
.expect("eval function 2");
let result1: Result<(), BridgeError> =
bridge.register_host_contract(1234, Box::new(callable));
assert!(result1.is_ok());
let result2: Result<(), BridgeError> =
bridge.register_host_contract(1234, Box::new(callable2));
assert!(result2.is_err());
let err: BridgeError = result2.expect_err("should fail");
assert!(matches!(
err,
BridgeError::DuplicateContract { contract_id: 1234 }
));
}
#[test]
fn bridge_register_host_contract_type_mismatch_fails() {
let mut bridge: LuaHostBridge = LuaHostBridge::new();
let result: Result<(), BridgeError> = bridge.register_host_contract(1234, Box::new(42i32));
assert!(result.is_err());
let err: BridgeError = result.expect_err("should fail");
assert!(matches!(
err,
BridgeError::TypeMismatch {
contract_id: 1234,
..
}
));
}
#[test]
fn bridge_call_host_contract_not_found() {
let bridge: LuaHostBridge = LuaHostBridge::new();
let result: AbiError =
unsafe { bridge.call_host_contract(9999, 0, core::ptr::null(), core::ptr::null_mut()) };
assert_eq!(result.code, AbiErrorCode::HostContractNotFound as u32);
}
#[test]
fn bridge_call_host_contract_success() {
let mut bridge: LuaHostBridge = LuaHostBridge::new();
let callable: Function = bridge
.lua()
.load("function(fn_id, args, out) return fn_id end")
.eval::<Function>()
.expect("eval function");
bridge
.register_host_contract(1234, Box::new(callable))
.expect("register");
let result: AbiError =
unsafe { bridge.call_host_contract(1234, 5, core::ptr::null(), core::ptr::null_mut()) };
assert!(result.is_ok());
}
#[test]
fn bridge_call_host_contract_exception() {
let mut bridge: LuaHostBridge = LuaHostBridge::new();
let callable: Function = bridge
.lua()
.load("function(fn_id, args, out) error('test error') end")
.eval::<Function>()
.expect("eval function");
bridge
.register_host_contract(1234, Box::new(callable))
.expect("register");
let result: AbiError =
unsafe { bridge.call_host_contract(1234, 0, core::ptr::null(), core::ptr::null_mut()) };
assert_eq!(result.code, AbiErrorCode::HostContractCallFailed as u32);
}
}