use crate::metadata::{Bindgen, ModuleMetadata};
use anyhow::{anyhow, bail, Context, Result};
use indexmap::{map::Entry, IndexMap, IndexSet};
use wasmparser::{
types::Types, Encoding, ExternalKind, FuncType, Parser, Payload, TypeRef, ValType,
ValidPayload, Validator,
};
use wit_parser::{
abi::{AbiVariant, WasmSignature, WasmType},
Function, InterfaceId, Resolve, WorldId, WorldItem,
};
fn is_canonical_function(name: &str) -> bool {
name.starts_with("cabi_") || name.starts_with("canonical_abi_")
}
fn wasm_sig_to_func_type(signature: WasmSignature) -> FuncType {
fn from_wasm_type(ty: &WasmType) -> ValType {
match ty {
WasmType::I32 => ValType::I32,
WasmType::I64 => ValType::I64,
WasmType::F32 => ValType::F32,
WasmType::F64 => ValType::F64,
}
}
FuncType::new(
signature.params.iter().map(from_wasm_type),
signature.results.iter().map(from_wasm_type),
)
}
pub const MAIN_MODULE_IMPORT_NAME: &str = "__main_module__";
pub const BARE_FUNC_MODULE_NAME: &str = "$root";
pub struct ValidatedModule<'a> {
pub required_imports: IndexMap<&'a str, IndexSet<&'a str>>,
pub adapters_required: IndexMap<&'a str, IndexMap<&'a str, FuncType>>,
pub has_memory: bool,
pub realloc: Option<&'a str>,
pub metadata: &'a ModuleMetadata,
}
pub fn validate_module<'a>(
bytes: &'a [u8],
metadata: &'a Bindgen,
adapters: &IndexSet<&str>,
) -> Result<ValidatedModule<'a>> {
let mut validator = Validator::new();
let mut types = None;
let mut import_funcs = IndexMap::new();
let mut export_funcs = IndexMap::new();
let mut ret = ValidatedModule {
required_imports: Default::default(),
adapters_required: Default::default(),
has_memory: false,
realloc: None,
metadata: &metadata.metadata,
};
for payload in Parser::new(0).parse_all(bytes) {
let payload = payload?;
if let ValidPayload::End(tys) = validator.payload(&payload)? {
types = Some(tys);
break;
}
match payload {
Payload::Version { encoding, .. } if encoding != Encoding::Module => {
bail!("data is not a WebAssembly module");
}
Payload::ImportSection(s) => {
for import in s {
let import = import?;
match import.ty {
TypeRef::Func(ty) => {
let map = match import_funcs.entry(import.module) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => e.insert(IndexMap::new()),
};
assert!(map.insert(import.name, ty).is_none());
}
_ => bail!("module is only allowed to import functions"),
}
}
}
Payload::ExportSection(s) => {
for export in s {
let export = export?;
match export.kind {
ExternalKind::Func => {
if is_canonical_function(export.name) {
if export.name == "cabi_realloc"
|| export.name == "canonical_abi_realloc"
{
ret.realloc = Some(export.name);
}
continue;
}
assert!(export_funcs.insert(export.name, export.index).is_none())
}
ExternalKind::Memory => {
if export.name == "memory" {
ret.has_memory = true;
}
}
_ => continue,
}
}
}
_ => continue,
}
}
let types = types.unwrap();
let world = &metadata.resolve.worlds[metadata.world];
for (name, funcs) in &import_funcs {
if *name == "$root" {
validate_imports_top_level(&metadata.resolve, metadata.world, funcs, &types)?;
let funcs = funcs.keys().cloned().collect();
let prev = ret.required_imports.insert(BARE_FUNC_MODULE_NAME, funcs);
assert!(prev.is_none());
continue;
}
match world.imports.get(*name) {
Some(WorldItem::Interface(interface)) => {
let funcs =
validate_imported_interface(&metadata.resolve, *interface, name, funcs, &types)
.with_context(|| format!("failed to validate import interface `{name}`"))?;
let prev = ret.required_imports.insert(name, funcs);
assert!(prev.is_none());
}
None if adapters.contains(name) => {
let map = ret.adapters_required.entry(name).or_default();
for (func, ty) in funcs {
let ty = types.func_type_at(*ty).unwrap();
map.insert(func, ty.clone());
}
}
None | Some(WorldItem::Function(_) | WorldItem::Type(_)) => {
bail!("module requires an import interface named `{}`", name)
}
}
}
for (name, item) in world.exports.iter() {
validate_exported_item(&metadata.resolve, item, name, &export_funcs, &types)?;
}
Ok(ret)
}
pub struct ValidatedAdapter<'a> {
pub required_imports: IndexMap<&'a str, IndexSet<&'a str>>,
pub needs_memory: Option<(String, String)>,
pub needs_core_exports: IndexSet<String>,
pub import_realloc: Option<String>,
pub export_realloc: Option<String>,
pub metadata: &'a ModuleMetadata,
}
pub fn validate_adapter_module<'a>(
bytes: &[u8],
resolve: &'a Resolve,
world: WorldId,
metadata: &'a ModuleMetadata,
required: &IndexMap<String, FuncType>,
) -> Result<ValidatedAdapter<'a>> {
let mut validator = Validator::new();
let mut import_funcs = IndexMap::new();
let mut export_funcs = IndexMap::new();
let mut types = None;
let mut funcs = Vec::new();
let mut ret = ValidatedAdapter {
required_imports: Default::default(),
needs_memory: None,
needs_core_exports: Default::default(),
import_realloc: None,
export_realloc: None,
metadata,
};
for payload in Parser::new(0).parse_all(bytes) {
let payload = payload?;
match validator.payload(&payload)? {
ValidPayload::End(tys) => {
types = Some(tys);
break;
}
ValidPayload::Func(validator, body) => {
funcs.push((validator, body));
}
_ => {}
}
match payload {
Payload::Version { encoding, .. } if encoding != Encoding::Module => {
bail!("data is not a WebAssembly module");
}
Payload::ImportSection(s) => {
for import in s {
let import = import?;
match import.ty {
TypeRef::Func(ty) => {
let map = match import_funcs.entry(import.module) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => e.insert(IndexMap::new()),
};
assert!(map.insert(import.name, ty).is_none());
}
TypeRef::Memory(_) => {
ret.needs_memory =
Some((import.module.to_string(), import.name.to_string()));
}
_ => {
bail!("adapter module is only allowed to import functions and memories")
}
}
}
}
Payload::ExportSection(s) => {
for export in s {
let export = export?;
match export.kind {
ExternalKind::Func => {
export_funcs.insert(export.name, export.index);
if export.name == "cabi_export_realloc" {
ret.export_realloc = Some(export.name.to_string());
}
if export.name == "cabi_import_realloc" {
ret.import_realloc = Some(export.name.to_string());
}
}
_ => continue,
}
}
}
_ => continue,
}
}
let mut resources = Default::default();
for (validator, body) in funcs {
let mut validator = validator.into_validator(resources);
validator.validate(&body)?;
resources = validator.into_allocations();
}
let types = types.unwrap();
for (name, funcs) in import_funcs {
if name == MAIN_MODULE_IMPORT_NAME {
ret.needs_core_exports
.extend(funcs.iter().map(|(name, _ty)| name.to_string()));
continue;
}
if name == BARE_FUNC_MODULE_NAME {
validate_imports_top_level(&resolve, world, &funcs, &types)?;
let funcs = resolve.worlds[world]
.imports
.iter()
.filter_map(|(name, item)| match item {
WorldItem::Function(_) if funcs.contains_key(name.as_str()) => {
Some(name.as_str())
}
_ => None,
})
.collect();
ret.required_imports.insert(BARE_FUNC_MODULE_NAME, funcs);
continue;
}
match resolve.worlds[world].imports.get_full(name) {
Some((_, name, WorldItem::Interface(interface))) => {
validate_imported_interface(resolve, *interface, name, &funcs, &types)
.with_context(|| format!("failed to validate import interface `{name}`"))?;
let funcs = resolve.interfaces[*interface]
.functions
.keys()
.map(|s| s.as_str())
.filter(|s| funcs.contains_key(s))
.collect();
let prev = ret.required_imports.insert(name, funcs);
assert!(prev.is_none());
}
None | Some((_, _, WorldItem::Function(_) | WorldItem::Type(_))) => {
bail!(
"adapter module requires an import interface named `{}`",
name
)
}
}
}
for (name, ty) in required {
let idx = match export_funcs.get(name.as_str()) {
Some(idx) => *idx,
None => bail!("adapter module did not export `{name}`"),
};
let actual = types.function_at(idx).unwrap();
if ty == actual {
continue;
}
bail!(
"adapter module export `{name}` does not match the expected signature:\n\
expected: {:?} -> {:?}\n\
actual: {:?} -> {:?}\n\
",
ty.params(),
ty.results(),
actual.params(),
actual.results(),
);
}
Ok(ret)
}
fn validate_imports_top_level<'a>(
resolve: &Resolve,
world: WorldId,
funcs: &IndexMap<&'a str, u32>,
types: &Types,
) -> Result<()> {
for (name, ty) in funcs {
let func = match resolve.worlds[world].imports.get(*name) {
Some(WorldItem::Function(func)) => func,
Some(_) => bail!("expected world top-level import `{name}` to be a function"),
None => bail!("no top-level imported function `{name}` specified"),
};
let ty = types.func_type_at(*ty).unwrap();
validate_func(resolve, ty, func, AbiVariant::GuestImport)?;
}
Ok(())
}
fn validate_imported_interface<'a>(
resolve: &'a Resolve,
interface: InterfaceId,
name: &str,
imports: &IndexMap<&str, u32>,
types: &Types,
) -> Result<IndexSet<&'a str>> {
let mut funcs = IndexSet::new();
for (func_name, ty) in imports {
let f = resolve.interfaces[interface]
.functions
.get(*func_name)
.ok_or_else(|| {
anyhow!(
"import interface `{}` is missing function `{}` that is required by the module",
name,
func_name,
)
})?;
let ty = types.func_type_at(*ty).unwrap();
validate_func(resolve, ty, f, AbiVariant::GuestImport)?;
funcs.insert(f.name.as_str());
}
Ok(funcs)
}
fn validate_func(
resolve: &Resolve,
ty: &wasmparser::FuncType,
func: &Function,
abi: AbiVariant,
) -> Result<()> {
let expected = wasm_sig_to_func_type(resolve.wasm_signature(abi, func));
if ty != &expected {
bail!(
"type mismatch for function `{}`: expected `{:?} -> {:?}` but found `{:?} -> {:?}`",
func.name,
expected.params(),
expected.results(),
ty.params(),
ty.results()
);
}
Ok(())
}
fn validate_exported_item(
resolve: &Resolve,
item: &WorldItem,
export_name: &str,
exports: &IndexMap<&str, u32>,
types: &Types,
) -> Result<()> {
let validate = |func: &Function, name: Option<&str>| {
let expected_export_name = func.core_export_name(name);
match exports.get(expected_export_name.as_ref()) {
Some(func_index) => {
let ty = types.function_at(*func_index).unwrap();
validate_func(resolve, ty, func, AbiVariant::GuestExport)
}
None => bail!(
"module does not export required function `{}`",
expected_export_name
),
}
};
match item {
WorldItem::Function(func) => validate(func, None)?,
WorldItem::Interface(interface) => {
for (_, f) in &resolve.interfaces[*interface].functions {
validate(f, Some(export_name)).with_context(|| {
format!("failed to validate exported interface `{export_name}`")
})?;
}
}
WorldItem::Type(_) => {}
}
Ok(())
}