#![allow(clippy::expect_used)]
use polyplug_codegen::{GenerateConfig, Lang, Side};
use polyplug_utils::host_contract_id;
use polyplugc::generate;
use std::path::PathBuf;
fn create_test_api_with_host_contract(tmp_dir: &PathBuf) -> PathBuf {
std::fs::create_dir_all(tmp_dir).expect("create tmp_dir");
let api_toml_path: PathBuf = tmp_dir.join("test_interface_api.toml");
let content: &str = r#"# Test API with host contract for interface factory tests
[[plugin_contract]]
name = "example.worker"
version = "1.0.0"
[[plugin_contract.functions]]
name = "do_work"
params = [{ name = "input", type = "StringView" }]
return = "StringView"
[[host_contract]]
name = "host.logger"
version = "1.0.0"
[[host_contract.functions]]
name = "log"
params = [{ name = "message", type = "StringView" }]
returns = "void"
[[host_contract.functions]]
name = "log_level"
params = [{ name = "message", type = "StringView" }, { name = "level", type = "u32" }]
returns = "void"
[[host_contract.functions]]
name = "get_level"
returns = "u32"
"#;
std::fs::write(&api_toml_path, content).expect("failed to write test api.toml");
api_toml_path
}
fn generate_host_interface_factories(tmp_dir: &PathBuf) -> String {
let api_toml: PathBuf = create_test_api_with_host_contract(tmp_dir);
let config: GenerateConfig = GenerateConfig {
api_toml: api_toml.clone(),
lang: Lang::Rust,
side: Side::Host,
out_dir: tmp_dir.clone(),
};
let output = generate(config).expect("polyplugc::generate failed");
for file in &output.files {
let file_path: PathBuf = tmp_dir.join(&file.path);
if let Some(parent) = file_path.parent() {
std::fs::create_dir_all(parent).expect("failed to create parent dir");
}
std::fs::write(&file_path, &file.content).expect("failed to write generated file");
}
let interface_factories_path: PathBuf = tmp_dir.join("host").join("interface_factories.rs");
std::fs::read_to_string(&interface_factories_path).expect("read interface_factories.rs")
}
fn expected_host_contract_id(name: &str, major: u32) -> u64 {
host_contract_id(name, major)
}
#[test]
fn test_interface_factory_generates_native_and_vm_factories() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_native_vm");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
assert!(
interfaces.contains("pub fn create_host_logger_interface"),
"NATIVE factory `create_host_logger_interface` must be generated:\n{interfaces}"
);
assert!(
interfaces.contains("pub fn create_host_logger_interface_vm"),
"VM factory `create_host_logger_interface_vm` must be generated:\n{interfaces}"
);
assert!(
interfaces.contains("-> &'static HostContractInterface"),
"Factories must return &'static HostContractInterface:\n{interfaces}"
);
assert!(
interfaces.contains("implementation: Box<dyn HostLogger>"),
"NATIVE factory must take Box<dyn HostLogger>:\n{interfaces}"
);
assert!(
interfaces.contains("bridge_data: *mut c_void"),
"VM factory must take bridge_data:\n{interfaces}"
);
assert!(
interfaces.contains("dispatch_fn: unsafe extern \"C\" fn"),
"VM factory must take dispatch_fn:\n{interfaces}"
);
}
#[test]
fn test_interface_factory_header_has_correct_contract_id() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_contract_id");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
let expected_id: u64 = expected_host_contract_id("host.logger", 1);
let expected_id_hex: String = format!("HostContractId::from(0x{expected_id:016X}_u64)");
assert!(
interfaces.contains(&expected_id_hex),
"NATIVE interface must have correct contract_id `{expected_id_hex}`:\n{interfaces}"
);
let contract_id_count: usize = interfaces.matches(&expected_id_hex).count();
assert_eq!(
contract_id_count, 2,
"contract_id must appear in both NATIVE and VM factories (expected 2, got {contract_id_count}):\n{interfaces}"
);
}
#[test]
fn test_interface_factory_header_has_correct_function_count() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_fn_count");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
let expected_count: usize = 3;
let expected_count_str: String = format!("function_count: {expected_count}");
assert!(
interfaces.contains(&expected_count_str),
"NATIVE interface must have function_count={expected_count}:\n{interfaces}"
);
let expected_array_size: String = format!(
"[unsafe extern \"C\" fn(*const c_void, *const (), *mut (), *mut AbiError); {expected_count}]"
);
assert!(
interfaces.contains(&expected_array_size),
"FUNCTIONS array must have size {expected_count}:\n{interfaces}"
);
}
#[test]
fn test_interface_factory_thunks_have_panic_safety() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_panic_safety");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
assert!(
interfaces.contains("std::panic::catch_unwind"),
"Thunks must use std::panic::catch_unwind for panic safety:\n{interfaces}"
);
assert!(
interfaces.contains("core::panic::AssertUnwindSafe"),
"Thunks must use AssertUnwindSafe wrapper:\n{interfaces}"
);
assert!(
interfaces.contains("AbiErrorCode::Panic"),
"Panic handler must return AbiErrorCode::Panic:\n{interfaces}"
);
let catch_unwind_count: usize = interfaces.matches("catch_unwind").count();
assert_eq!(
catch_unwind_count, 3,
"Expected 3 catch_unwind calls (one per thunk), got {catch_unwind_count}:\n{interfaces}"
);
assert!(
interfaces.contains("host_logger_log_thunk"),
"Thunk `host_logger_log_thunk` must exist:\n{interfaces}"
);
assert!(
interfaces.contains("host_logger_log_level_thunk"),
"Thunk `host_logger_log_level_thunk` must exist:\n{interfaces}"
);
assert!(
interfaces.contains("host_logger_get_level_thunk"),
"Thunk `host_logger_get_level_thunk` must exist:\n{interfaces}"
);
}
#[test]
fn test_interface_factory_native_dispatch_type() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_native_dispatch");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
assert!(
interfaces.contains("dispatch_type: DispatchType::Native"),
"NATIVE factory must set dispatch_type to Native:\n{interfaces}"
);
}
#[test]
fn test_interface_factory_vm_dispatch_type() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_vm_dispatch");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
assert!(
interfaces.contains("dispatch_type: DispatchType::VirtualMachine"),
"VM factory must set dispatch_type to VirtualMachine:\n{interfaces}"
);
}
#[test]
fn test_interface_factory_native_leaks_interface() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_native_leak");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
assert!(
interfaces.contains("Box::leak(Box::new(interface))"),
"NATIVE factory must use Box::leak to return 'static interface:\n{interfaces}"
);
assert!(
interfaces.contains("Box::into_raw(implementation)"),
"NATIVE factory must leak implementation via Box::into_raw:\n{interfaces}"
);
}
#[test]
fn test_interface_factory_vm_leaks_interface() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_vm_leak");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
let leak_count: usize = interfaces.matches("Box::leak(Box::new(interface))").count();
assert_eq!(
leak_count, 2,
"Both factories must use Box::leak (expected 2, got {leak_count}):\n{interfaces}"
);
}
#[test]
fn test_interface_factory_header_has_correct_version() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_version");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
assert!(
interfaces.contains("contract_version: Version { major: 1, minor: 0, patch: 0 }"),
"interface must have contract_version with major=1, minor=0, patch=0:\n{interfaces}"
);
}
#[test]
fn test_interface_factory_thunks_have_safety_comments() {
let tmp_dir: PathBuf =
PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("interface_factories_safety_comments");
let interfaces: String = generate_host_interface_factories(&tmp_dir);
assert!(
interfaces.contains("// SAFETY:"),
"Thunks must have SAFETY comments:\n{interfaces}"
);
assert!(
interfaces.contains("impl_ptr is a valid"),
"SAFETY comment must explain impl_ptr validity:\n{interfaces}"
);
assert!(
interfaces.contains("args is a valid"),
"SAFETY comment must explain args validity:\n{interfaces}"
);
}