#![allow(clippy::expect_used)]
#[cfg(unix)]
use libloading::os::unix::Library as UnixLibrary;
#[cfg(unix)]
use libloading::os::unix::RTLD_GLOBAL;
#[cfg(unix)]
use libloading::os::unix::RTLD_LAZY;
use polyplug::runtime_store::RuntimeStore;
use polyplug_abi::GuestContractInstance;
use polyplug_abi::ffi::polyplug_host_alloc;
use polyplug_abi::ffi::polyplug_host_free;
use polyplug_abi::tracking::TrackingAllocator;
use polyplug_abi::{
AbiError, AbiErrorCode, BundleInitContext, GuestContractHandle, GuestContractInterface,
HostApi, PluginDescriptor, StringView,
};
use polyplug_utils::{BundleId, GuestContractId};
const ERROR_PLUGIN_SO: &str = env!("ERROR_PLUGIN_SO");
#[repr(C)]
struct MessageArgs {
host: *const HostApi,
}
#[repr(C)]
struct ChainArgs {
host: *const HostApi,
target_contract_id: u64,
target_fn_id: u32,
}
std::thread_local! {
static ERROR_REGISTRY: core::cell::RefCell<RuntimeStore> =
core::cell::RefCell::new(RuntimeStore::new());
}
unsafe extern "C" fn chain_find_guest_contract(
_this: *const HostApi,
contract_id: u64,
_min_version: u32,
) -> GuestContractHandle {
ERROR_REGISTRY.with(|cell| {
let registry: core::cell::Ref<'_, RuntimeStore> = cell.borrow();
match registry.find(GuestContractId::from_u64(contract_id), 0) {
Ok(handle) => handle,
Err(_) => GuestContractHandle::null(),
}
})
}
unsafe extern "C" fn chain_find_all_guest_contracts(
_this: *const HostApi,
_contract_id: u64,
_min_version: u32,
) -> polyplug_abi::Array<GuestContractHandle> {
polyplug_abi::Array::empty()
}
unsafe extern "C" fn chain_resolve_guest_contract(
_this: *const HostApi,
handle: GuestContractHandle,
) -> *const GuestContractInterface {
ERROR_REGISTRY.with(|cell| {
let registry: core::cell::Ref<'_, RuntimeStore> = cell.borrow();
registry
.resolve_guest_contract(handle)
.unwrap_or(core::ptr::null())
})
}
unsafe extern "C" fn stub_get_host_contract(
_this: *const HostApi,
_contract_id: u64,
_min_version: u32,
) -> polyplug_abi::HostContractInstance {
polyplug_abi::HostContractInstance::null()
}
unsafe extern "C" fn stub_resolve_host_contract_interface(
_this: *const HostApi,
_contract_id: u64,
_min_version: u32,
) -> *const polyplug_abi::HostContractInterface {
core::ptr::null()
}
unsafe extern "C" fn stub_alloc(_this: *const HostApi, size: usize, align: usize) -> *mut u8 {
polyplug_host_alloc(size, align)
}
unsafe extern "C" fn stub_free(_this: *const HostApi, ptr: *mut u8, size: usize, align: usize) {
unsafe { polyplug_host_free(ptr, size, align) }
}
unsafe extern "C" fn noop_find_guest_contract(
_this: *const HostApi,
_contract_id: u64,
_min_version: u32,
) -> GuestContractHandle {
GuestContractHandle::null()
}
unsafe extern "C" fn noop_find_all_guest_contracts(
_this: *const HostApi,
_contract_id: u64,
_min_version: u32,
) -> polyplug_abi::Array<GuestContractHandle> {
polyplug_abi::Array::empty()
}
unsafe extern "C" fn noop_resolve_guest_contract(
_this: *const HostApi,
_handle: GuestContractHandle,
) -> *const GuestContractInterface {
core::ptr::null()
}
unsafe extern "C" fn noop_get_host_contract(
_this: *const HostApi,
_contract_id: u64,
_min_version: u32,
) -> polyplug_abi::HostContractInstance {
polyplug_abi::HostContractInstance::null()
}
unsafe extern "C" fn noop_list_bundles(
_this: *const HostApi,
) -> polyplug_abi::Array<polyplug_utils::BundleId> {
polyplug_abi::Array::empty()
}
unsafe extern "C" fn noop_get_dependencies(
_this: *const HostApi,
) -> polyplug_abi::Array<polyplug_abi::DependencyInfo> {
polyplug_abi::Array::empty()
}
unsafe extern "C" fn noop_resolve_host_contract_interface(
_this: *const HostApi,
_contract_id: u64,
_min_version: u32,
) -> *const polyplug_abi::HostContractInterface {
core::ptr::null()
}
unsafe extern "C" fn noop_load_bundle(
_this: *const HostApi,
_path: *const u8,
_path_len: usize,
out_err: *mut AbiError,
) {
if !out_err.is_null() {
unsafe { out_err.write(AbiError::ok()) };
}
}
unsafe extern "C" fn noop_reload_bundle(
_this: *const HostApi,
_path: *const u8,
_path_len: usize,
out_err: *mut AbiError,
) {
if !out_err.is_null() {
unsafe { out_err.write(AbiError::ok()) };
}
}
unsafe extern "C" fn noop_register_host_contract(
_this: *const HostApi,
_interface: *const polyplug_abi::HostContractInterface,
out_err: *mut AbiError,
) {
if !out_err.is_null() {
unsafe { out_err.write(AbiError::ok()) };
}
}
unsafe extern "C" fn noop_register_loader(
_this: *const HostApi,
_loader_ptr: *mut core::ffi::c_void,
out_err: *mut AbiError,
) {
if !out_err.is_null() {
unsafe { out_err.write(AbiError::ok()) };
}
}
unsafe extern "C" fn noop_get_last_error(
_this: *const HostApi,
_buf: *mut u8,
_buf_len: usize,
) -> usize {
0
}
unsafe extern "C" fn noop_get_error_len(_this: *const HostApi) -> usize {
0
}
unsafe extern "C" fn noop_unload_bundle(
_this: *const HostApi,
_bundle_id: BundleId,
out_err: *mut AbiError,
) {
if !out_err.is_null() {
unsafe { out_err.write(AbiError::ok()) };
}
}
unsafe extern "C" fn registry_register_callback(
_this: *const HostApi,
descriptor: *const PluginDescriptor,
interface: *const GuestContractInterface,
out_err: *mut AbiError,
) {
if descriptor.is_null() || interface.is_null() {
if !out_err.is_null() {
unsafe {
out_err.write(AbiError {
code: AbiErrorCode::InvalidPointer as u32,
message: StringView::null(),
})
};
}
return;
}
let desc: &PluginDescriptor = unsafe { &*descriptor };
let iface: &GuestContractInterface = unsafe { &*interface };
let contract_name: &str = unsafe {
let bytes: &[u8] =
core::slice::from_raw_parts(desc.contract_name.ptr, desc.contract_name.len);
core::str::from_utf8_unchecked(bytes) };
let result: Result<GuestContractHandle, _> = ERROR_REGISTRY.with(|reg_cell| {
let registry: core::cell::Ref<'_, RuntimeStore> = reg_cell.borrow();
unsafe {
registry.register_guest_contract(
*desc,
interface,
contract_name.to_owned(),
BundleId::from_u64(iface.contract_id.id()),
)
}
});
let err_val: AbiError = match result {
Ok(_) => AbiError {
code: AbiErrorCode::Ok as u32,
message: StringView::null(),
},
Err(_) => AbiError {
code: AbiErrorCode::Generic as u32,
message: StringView::null(),
},
};
if !out_err.is_null() {
unsafe { out_err.write(err_val) };
}
}
fn make_host_interface() -> HostApi {
HostApi {
runtime: core::ptr::null_mut(),
register_guest_contract: registry_register_callback,
alloc: stub_alloc,
free: stub_free,
find_guest_contract: noop_find_guest_contract,
find_all_guest_contracts: noop_find_all_guest_contracts,
resolve_guest_contract: noop_resolve_guest_contract,
get_host_contract: noop_get_host_contract,
resolve_host_contract_interface: noop_resolve_host_contract_interface,
list_bundles: noop_list_bundles,
get_dependencies: noop_get_dependencies,
load_bundle: noop_load_bundle,
reload_bundle: noop_reload_bundle,
register_host_contract: noop_register_host_contract,
register_loader: noop_register_loader,
get_last_error: noop_get_last_error,
get_error_len: noop_get_error_len,
unload_bundle: noop_unload_bundle,
log: stub_host_log,
create_guest_instance: stub_create_guest_instance,
destroy_guest_instance: stub_destroy_guest_instance,
revision_counter: stub_revision_counter,
reserved: core::ptr::null(),
}
}
fn load_error_plugin() -> libloading::Library {
#[cfg(unix)]
{
let raw: UnixLibrary = unsafe {
UnixLibrary::open(Some(ERROR_PLUGIN_SO), RTLD_LAZY | RTLD_GLOBAL)
.expect("failed to load error_plugin .so")
};
libloading::Library::from(raw)
}
#[cfg(not(unix))]
{
unsafe {
libloading::Library::new(ERROR_PLUGIN_SO).expect("failed to load error_plugin .so")
}
}
}
fn init_error_plugin(library: &libloading::Library) -> *const GuestContractInterface {
ERROR_REGISTRY.with(|cell| {
*cell.borrow_mut() = RuntimeStore::new();
});
let init_fn: libloading::Symbol<
'_,
unsafe extern "C" fn(*const HostApi, *const BundleInitContext) -> AbiError,
> = unsafe {
library
.get(b"polyplug_init\0")
.expect("polyplug_init symbol not found")
};
let host_interface: HostApi = make_host_interface();
let ctx: BundleInitContext = BundleInitContext {
bundle_path: StringView::null(),
bundle_id: 0,
};
let init_result: AbiError = unsafe {
init_fn(
&host_interface as *const HostApi,
&ctx as *const BundleInitContext,
)
};
assert_eq!(
init_result.code,
AbiErrorCode::Ok as u32,
"polyplug_init must succeed"
);
let contract_id: GuestContractId = GuestContractId::new("error.test", 1);
let handle: GuestContractHandle = ERROR_REGISTRY.with(|cell| {
cell.borrow()
.find(contract_id, 0)
.expect("error.test must be registered")
});
ERROR_REGISTRY.with(|cell| {
cell.borrow()
.resolve_guest_contract(handle)
.expect("interface must be resolvable")
})
}
#[test]
fn stress_error_code_and_message_received_correctly() {
let library: libloading::Library = load_error_plugin();
let interface_ptr: *const GuestContractInterface = init_error_plugin(&library);
let interface: &GuestContractInterface = unsafe { &*interface_ptr };
let fn_ptr: *const () = unsafe { *interface.dispatch.native.functions.add(0) };
let dispatch_fn: unsafe extern "C" fn(
GuestContractInstance,
*const (),
*mut (),
*mut AbiError,
) = unsafe { core::mem::transmute(fn_ptr) };
let host_interface: HostApi = make_host_interface();
let message_args: MessageArgs = MessageArgs {
host: &host_interface as *const HostApi,
};
let mut out: AbiError = AbiError {
code: AbiErrorCode::Ok as u32,
message: StringView::null(),
};
let mut call_result: AbiError = AbiError::ok();
unsafe {
dispatch_fn(
GuestContractInstance::null(),
&message_args as *const MessageArgs as *const (),
&mut out as *mut AbiError as *mut (),
&mut call_result,
)
};
assert_eq!(
call_result.code,
AbiErrorCode::Ok as u32,
"dispatch wrapper must return Ok"
);
assert_eq!(
out.code,
AbiErrorCode::Generic as u32,
"error code must be Generic"
);
assert_eq!(out.message.len, 22_usize, "message length must be 22");
let msg_bytes: &[u8] = unsafe { core::slice::from_raw_parts(out.message.ptr, out.message.len) };
assert_eq!(msg_bytes, b"test error from plugin", "message must match");
unsafe {
polyplug_host_free(out.message.ptr as *mut u8, out.message.len, 1);
}
let tracker: TrackingAllocator = TrackingAllocator::new();
tracker.assert_no_leaks();
core::mem::forget(library);
}
#[test]
fn stress_panic_returns_abi_error_panic_process_continues() {
let library: libloading::Library = load_error_plugin();
let interface_ptr: *const GuestContractInterface = init_error_plugin(&library);
let interface: &GuestContractInterface = unsafe { &*interface_ptr };
let fn_ptr: *const () = unsafe { *interface.dispatch.native.functions.add(1) };
let dispatch_fn: unsafe extern "C" fn(
GuestContractInstance,
*const (),
*mut (),
*mut AbiError,
) = unsafe { core::mem::transmute(fn_ptr) };
let mut result: AbiError = AbiError::ok();
unsafe {
dispatch_fn(
GuestContractInstance::null(),
core::ptr::null(),
core::ptr::null_mut(),
&mut result,
)
};
assert_eq!(
result.code,
AbiErrorCode::Panic as u32,
"error_panic must return Panic (code={})",
AbiErrorCode::Panic
);
let msg_bytes: &[u8] =
unsafe { core::slice::from_raw_parts(result.message.ptr, result.message.len) };
assert_eq!(
msg_bytes, b"plugin panicked",
"panic message must be 'plugin panicked'"
);
let tracker: TrackingAllocator = TrackingAllocator::new();
tracker.assert_no_leaks();
core::mem::forget(library);
}
#[test]
fn stress_error_chain_b_errors_a_propagates() {
let library: libloading::Library = load_error_plugin();
let interface_ptr: *const GuestContractInterface = init_error_plugin(&library);
let interface: &GuestContractInterface = unsafe { &*interface_ptr };
let chain_host_interface: HostApi = HostApi {
runtime: core::ptr::null_mut(),
register_guest_contract: registry_register_callback,
alloc: stub_alloc,
free: stub_free,
find_guest_contract: chain_find_guest_contract,
find_all_guest_contracts: chain_find_all_guest_contracts,
resolve_guest_contract: chain_resolve_guest_contract,
get_host_contract: stub_get_host_contract,
resolve_host_contract_interface: stub_resolve_host_contract_interface,
list_bundles: noop_list_bundles,
get_dependencies: noop_get_dependencies,
load_bundle: noop_load_bundle,
reload_bundle: noop_reload_bundle,
register_host_contract: noop_register_host_contract,
register_loader: noop_register_loader,
get_last_error: noop_get_last_error,
get_error_len: noop_get_error_len,
unload_bundle: noop_unload_bundle,
log: stub_host_log,
create_guest_instance: stub_create_guest_instance,
destroy_guest_instance: stub_destroy_guest_instance,
revision_counter: stub_revision_counter,
reserved: core::ptr::null(),
};
let error_contract_id: GuestContractId = GuestContractId::new("error.test", 1);
let chain_args: ChainArgs = ChainArgs {
host: &chain_host_interface as *const HostApi,
target_contract_id: error_contract_id.id(),
target_fn_id: 1_u32, };
let mut out: AbiError = AbiError {
code: AbiErrorCode::Ok as u32,
message: StringView::null(),
};
let fn_ptr: *const () = unsafe { *interface.dispatch.native.functions.add(2) };
let dispatch_fn: unsafe extern "C" fn(
GuestContractInstance,
*const (),
*mut (),
*mut AbiError,
) = unsafe { core::mem::transmute(fn_ptr) };
let mut call_result: AbiError = AbiError::ok();
unsafe {
dispatch_fn(
GuestContractInstance::null(),
&chain_args as *const ChainArgs as *const (),
&mut out as *mut AbiError as *mut (),
&mut call_result,
)
};
assert_eq!(
call_result.code,
AbiErrorCode::Ok as u32,
"error_chain_propagate wrapper must return Ok"
);
assert_eq!(
out.code,
AbiErrorCode::Panic as u32,
"propagated error must be Panic (={})",
AbiErrorCode::Panic
);
let tracker: TrackingAllocator = TrackingAllocator::new();
tracker.assert_no_leaks();
core::mem::forget(library);
}
#[test]
fn stress_error_message_lifetime_valid_during_read() {
let library: libloading::Library = load_error_plugin();
let interface_ptr: *const GuestContractInterface = init_error_plugin(&library);
let interface: &GuestContractInterface = unsafe { &*interface_ptr };
let fn_ptr: *const () = unsafe { *interface.dispatch.native.functions.add(0) };
let dispatch_fn: unsafe extern "C" fn(
GuestContractInstance,
*const (),
*mut (),
*mut AbiError,
) = unsafe { core::mem::transmute(fn_ptr) };
let host_interface: HostApi = make_host_interface();
let message_args: MessageArgs = MessageArgs {
host: &host_interface as *const HostApi,
};
let mut out: AbiError = AbiError {
code: AbiErrorCode::Ok as u32,
message: StringView::null(),
};
let mut call_result: AbiError = AbiError::ok();
unsafe {
dispatch_fn(
GuestContractInstance::null(),
&message_args as *const MessageArgs as *const (),
&mut out as *mut AbiError as *mut (),
&mut call_result,
)
};
assert_eq!(
call_result.code,
AbiErrorCode::Ok as u32,
"dispatch wrapper must return Ok"
);
assert_eq!(
out.code,
AbiErrorCode::Generic as u32,
"error code must be Generic"
);
assert_eq!(out.message.len, 22_usize, "message length must be 22");
for _i in 0_u32..1000_u32 {
let bytes: &[u8] = unsafe { core::slice::from_raw_parts(out.message.ptr, out.message.len) };
assert_eq!(
bytes, b"test error from plugin",
"message must remain stable across 1000 reads"
);
}
unsafe {
polyplug_host_free(out.message.ptr as *mut u8, out.message.len, 1);
}
let tracker: TrackingAllocator = TrackingAllocator::new();
tracker.assert_no_leaks();
core::mem::forget(library);
}
unsafe extern "C" fn stub_host_log(
_this: *const polyplug_abi::HostApi,
_level: u32,
_scope: polyplug_abi::StringView,
_message: polyplug_abi::StringView,
) {
}
unsafe extern "C" fn stub_create_guest_instance(
_this: *const polyplug_abi::HostApi,
_interface: *const polyplug_abi::GuestContractInterface,
_args: *const core::ffi::c_void,
out_instance: *mut polyplug_abi::GuestContractInstance,
) {
if !out_instance.is_null() {
unsafe { out_instance.write(polyplug_abi::GuestContractInstance::null()) };
}
}
unsafe extern "C" fn stub_destroy_guest_instance(
_this: *const polyplug_abi::HostApi,
_interface: *const polyplug_abi::GuestContractInterface,
_instance: polyplug_abi::GuestContractInstance,
) {
}
unsafe extern "C" fn stub_revision_counter(_this: *const polyplug_abi::HostApi) -> *const u64 {
core::ptr::null()
}