#![allow(clippy::expect_used)]
#![allow(clippy::unwrap_used)]
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use polyplug::Runtime;
use polyplug::error::LoaderError;
use polyplug::loader::{BundleLoader, ManifestData};
use polyplug_abi::{
DispatchMechanisms, DispatchType, GuestContractHandle, GuestContractInstance,
GuestContractInterface, HostApi, NativeDispatch, PluginDescriptor, StringView, Version,
};
use polyplug_utils::{BundleId, GuestContractId};
struct ProbeResults {
declared_lookup_null: Option<bool>,
undeclared_lookup_null: Option<bool>,
declared_find_all_count: Option<usize>,
undeclared_find_all_count: Option<usize>,
}
struct ProbeLoader {
declared_contract_id: u64,
undeclared_contract_id: u64,
results: Arc<Mutex<ProbeResults>>,
}
impl BundleLoader for ProbeLoader {
fn loader_name(&self) -> &'static str {
"probe-enforce"
}
fn loader_language(&self) -> polyplug_abi::SupportedLanguage {
polyplug_abi::SupportedLanguage::Rust
}
fn supports_hot_reload(&self) -> bool {
false
}
fn load(
&self,
manifest: &ManifestData,
_source: &polyplug::loader::BundleSource,
runtime: &Runtime,
) -> Result<(), LoaderError> {
let host_abi: *const HostApi = runtime.host_abi();
let bundle_id: BundleId = BundleId::new(&manifest.name);
runtime.push_init_bundle_id(bundle_id.id());
let declared_handle: GuestContractHandle = unsafe {
((*host_abi).find_guest_contract)(host_abi, self.declared_contract_id, 0_u32)
};
let undeclared_handle: GuestContractHandle = unsafe {
((*host_abi).find_guest_contract)(host_abi, self.undeclared_contract_id, 0_u32)
};
let declared_all: polyplug_abi::Array<GuestContractHandle> = unsafe {
((*host_abi).find_all_guest_contracts)(host_abi, self.declared_contract_id, 0_u32)
};
let undeclared_all: polyplug_abi::Array<GuestContractHandle> = unsafe {
((*host_abi).find_all_guest_contracts)(host_abi, self.undeclared_contract_id, 0_u32)
};
let declared_all_len: usize = declared_all.len;
let undeclared_all_len: usize = undeclared_all.len;
runtime.pop_init_bundle_id();
let mut guard: std::sync::MutexGuard<'_, ProbeResults> =
self.results.lock().unwrap_or_else(|e| e.into_inner());
guard.declared_lookup_null = Some(declared_handle.is_null());
guard.undeclared_lookup_null = Some(undeclared_handle.is_null());
guard.declared_find_all_count = Some(declared_all_len);
guard.undeclared_find_all_count = Some(undeclared_all_len);
Ok(())
}
fn reload(&self, _manifest: &ManifestData, _runtime: &Runtime) -> Result<(), LoaderError> {
Err(LoaderError::HotReloadUnsupported {
loader_name: self.loader_name().to_owned(),
})
}
}
unsafe extern "C" fn noop_create_instance(
_loader_data: polyplug_abi::dispatch::VmLoaderData,
_host: *const HostApi,
_args: *const (),
out_instance: *mut GuestContractInstance,
) {
if !out_instance.is_null() {
unsafe { out_instance.write(GuestContractInstance::null()) };
}
}
unsafe extern "C" fn noop_destroy_instance(
_loader_data: polyplug_abi::dispatch::VmLoaderData,
_host: *const HostApi,
_instance: GuestContractInstance,
) {
}
fn register_provider(runtime: &Runtime, contract_id: u64, bundle_id: u64) -> GuestContractHandle {
let interface: &'static GuestContractInterface = Box::leak(Box::new(GuestContractInterface {
contract_id: GuestContractId::from_u64(contract_id),
contract_version: Version {
major: 1,
minor: 0,
patch: 0,
},
dispatch_type: DispatchType::Native,
create_instance: noop_create_instance,
destroy_instance: noop_destroy_instance,
dispatch: DispatchMechanisms {
native: NativeDispatch {
function_count: 0,
functions: core::ptr::null(),
},
},
}));
let descriptor: PluginDescriptor = PluginDescriptor {
name: StringView::from_static(b"provider"),
contract_name: StringView::from_static(b"provider.contract"),
version: Version {
major: 1,
minor: 0,
patch: 0,
},
};
unsafe {
runtime.registry().register_guest_contract(
descriptor,
interface,
"provider.contract".to_owned(),
BundleId::from_u64(bundle_id),
)
}
.expect("provider registration should succeed")
}
fn write_bundle(temp: &tempfile::TempDir, bundle_name: &str, declared_contract_id: u64) -> PathBuf {
let bundle_dir: PathBuf = temp.path().join(bundle_name);
std::fs::create_dir_all(&bundle_dir).expect("create bundle dir");
std::fs::write(bundle_dir.join("dummy.so"), b"").expect("write dummy so");
let bundle_id: u64 = polyplug_utils::bundle_id(bundle_name);
let manifest: String = format!(
"id = {bundle_id}\n\
name = \"{bundle_name}\"\n\
loader = \"probe-enforce\"\n\
file = \"dummy.so\"\n\
version = \"1.0\"\n\n\
[[dependency]]\n\
kind = \"contract\"\n\
contract = \"declared.dep@1\"\n\
min_version = \"1.0\"\n\
contract_id = {declared_contract_id}\n"
);
std::fs::write(bundle_dir.join("manifest.toml"), manifest).expect("write manifest");
bundle_dir
}
#[test]
fn declared_dep_resolves_undeclared_denied_during_init_host_unaffected() {
let temp: tempfile::TempDir = tempfile::TempDir::new().expect("temp dir");
let declared_contract_id: u64 = polyplug_utils::guest_contract_id("declared.dep", 1_u32);
let undeclared_contract_id: u64 = polyplug_utils::guest_contract_id("undeclared.other", 1_u32);
let results: Arc<Mutex<ProbeResults>> = Arc::new(Mutex::new(ProbeResults {
declared_lookup_null: None,
undeclared_lookup_null: None,
declared_find_all_count: None,
undeclared_find_all_count: None,
}));
let runtime: Arc<Runtime> = Runtime::builder()
.loader(ProbeLoader {
declared_contract_id,
undeclared_contract_id,
results: Arc::clone(&results),
})
.build()
.expect("runtime build should succeed");
register_provider(&runtime, declared_contract_id, 0xAAAA_u64);
register_provider(&runtime, undeclared_contract_id, 0xBBBB_u64);
let bundle_path: PathBuf = write_bundle(&temp, "probe_bundle", declared_contract_id);
runtime
.load_bundle(bundle_path.as_path())
.expect("load_bundle should succeed");
let guard: std::sync::MutexGuard<'_, ProbeResults> =
results.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(
guard.declared_lookup_null,
Some(false),
"declared dependency must resolve to a non-null handle during init"
);
assert_eq!(
guard.undeclared_lookup_null,
Some(true),
"undeclared contract must be denied (null handle) during init"
);
assert_eq!(
guard.declared_find_all_count,
Some(1),
"declared dependency must be enumerable via find_all during init"
);
assert_eq!(
guard.undeclared_find_all_count,
Some(0),
"undeclared contract must enumerate empty via find_all during init"
);
drop(guard);
let host_declared: Result<GuestContractHandle, _> =
runtime.find_guest_contract(declared_contract_id, 0_u32);
assert!(
host_declared.is_ok(),
"host lookup of declared contract must succeed after init"
);
let host_undeclared: Result<GuestContractHandle, _> =
runtime.find_guest_contract(undeclared_contract_id, 0_u32);
assert!(
host_undeclared.is_ok(),
"host lookup of the (init-undeclared) contract must succeed after init — \
enforcement applies only inside the init window"
);
}