use std::collections::HashSet;
use wasmtime::Module;
use crate::error::WasmError;
use crate::manifest::{capability_to_imports, PluginType};
pub fn validate_exports(module: &Module, plugin_type: PluginType) -> Result<(), WasmError> {
let exports: HashSet<&str> = module.exports().map(|e| e.name()).collect();
if !exports.contains("memory") {
return Err(WasmError::MissingExport("memory".into()));
}
for required in plugin_type.required_exports() {
if !exports.contains(required) {
return Err(WasmError::MissingExport((*required).into()));
}
}
for export in module.exports() {
match export.name() {
"init" => validate_init_signature(&export)?,
"on_request" => validate_on_request_signature(&export)?,
"on_response" => validate_on_response_signature(&export)?,
"dispatch" => validate_dispatch_signature(&export)?,
_ => {} }
}
Ok(())
}
pub fn validate_imports(
module: &Module,
declared_capabilities: &[String],
) -> Result<(), WasmError> {
let mut allowed: HashSet<&str> = HashSet::new();
allowed.insert("host_set_output");
for capability in declared_capabilities {
for import_name in capability_to_imports(capability) {
allowed.insert(import_name);
}
}
for import in module.imports() {
if import.module() == "barbacane" && !allowed.contains(import.name()) {
return Err(WasmError::UndeclaredImport(import.name().into()));
}
}
Ok(())
}
fn validate_init_signature(export: &wasmtime::ExportType) -> Result<(), WasmError> {
let func = match export.ty() {
wasmtime::ExternType::Func(f) => f,
_ => {
return Err(WasmError::InvalidExportSignature {
name: "init".into(),
expected: "function".into(),
actual: format!("{:?}", export.ty()),
})
}
};
let params: Vec<_> = func.params().collect();
if params.len() != 2
|| !matches!(params[0], wasmtime::ValType::I32)
|| !matches!(params[1], wasmtime::ValType::I32)
{
return Err(WasmError::InvalidExportSignature {
name: "init".into(),
expected: "(i32, i32) -> i32".into(),
actual: format!("{:?}", func),
});
}
let results: Vec<_> = func.results().collect();
if results.len() != 1 || !matches!(results[0], wasmtime::ValType::I32) {
return Err(WasmError::InvalidExportSignature {
name: "init".into(),
expected: "(i32, i32) -> i32".into(),
actual: format!("{:?}", func),
});
}
Ok(())
}
fn validate_on_request_signature(export: &wasmtime::ExportType) -> Result<(), WasmError> {
validate_standard_handler_signature(export, "on_request")
}
fn validate_on_response_signature(export: &wasmtime::ExportType) -> Result<(), WasmError> {
validate_standard_handler_signature(export, "on_response")
}
fn validate_dispatch_signature(export: &wasmtime::ExportType) -> Result<(), WasmError> {
validate_standard_handler_signature(export, "dispatch")
}
fn validate_standard_handler_signature(
export: &wasmtime::ExportType,
name: &str,
) -> Result<(), WasmError> {
let func = match export.ty() {
wasmtime::ExternType::Func(f) => f,
_ => {
return Err(WasmError::InvalidExportSignature {
name: name.into(),
expected: "function".into(),
actual: format!("{:?}", export.ty()),
})
}
};
let params: Vec<_> = func.params().collect();
if params.len() != 2
|| !matches!(params[0], wasmtime::ValType::I32)
|| !matches!(params[1], wasmtime::ValType::I32)
{
return Err(WasmError::InvalidExportSignature {
name: name.into(),
expected: "(i32, i32) -> i32".into(),
actual: format!("{:?}", func),
});
}
let results: Vec<_> = func.results().collect();
if results.len() != 1 || !matches!(results[0], wasmtime::ValType::I32) {
return Err(WasmError::InvalidExportSignature {
name: name.into(),
expected: "(i32, i32) -> i32".into(),
actual: format!("{:?}", func),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn capability_to_imports_log() {
let imports = capability_to_imports("log");
assert!(imports.contains(&"host_log"));
}
#[test]
fn capability_to_imports_context() {
let imports = capability_to_imports("context_get");
assert!(imports.contains(&"host_context_get"));
assert!(imports.contains(&"host_context_read_result"));
let imports = capability_to_imports("context_set");
assert!(imports.contains(&"host_context_set"));
}
#[test]
fn capability_to_imports_telemetry() {
let imports = capability_to_imports("telemetry");
assert!(imports.contains(&"host_metric_counter_inc"));
assert!(imports.contains(&"host_span_start"));
}
#[test]
fn capability_to_imports_uuid() {
let imports = capability_to_imports("generate_uuid");
assert!(imports.contains(&"host_uuid_generate"));
assert!(imports.contains(&"host_uuid_read_result"));
}
#[test]
fn unknown_capability_returns_empty() {
let imports = capability_to_imports("unknown");
assert!(imports.is_empty());
}
}