use std::{fmt, mem, str};
use wasmer::wasmparser::{
BinaryReaderError, CompositeInnerType, Export, FuncToValidate, FunctionBody, Import,
MemoryType, Parser, Payload, TableType, ValidPayload, Validator, ValidatorResources,
WasmFeatures,
};
use crate::{VmError, VmResult};
#[derive(Default)]
pub struct OpaqueDebug<T>(pub T);
impl<T> fmt::Debug for OpaqueDebug<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(std::any::type_name::<T>())
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub enum FunctionValidator<'a> {
Pending(OpaqueDebug<Vec<(FuncToValidate<ValidatorResources>, FunctionBody<'a>)>>),
Success,
Error(BinaryReaderError),
}
impl<'a> FunctionValidator<'a> {
fn push(&mut self, item: (FuncToValidate<ValidatorResources>, FunctionBody<'a>)) {
let Self::Pending(OpaqueDebug(ref mut funcs)) = self else {
panic!("attempted to push function into non-pending validator");
};
funcs.push(item);
}
}
#[derive(Debug)]
pub struct ParsedWasm<'a> {
pub version: u16,
pub exports: Vec<Export<'a>>,
pub imports: Vec<Import<'a>>,
pub tables: Vec<TableType>,
pub memories: Vec<MemoryType>,
pub function_count: usize,
pub type_count: u32,
pub type_params: Vec<usize>,
pub max_func_params: usize,
pub max_func_results: usize,
pub total_func_params: usize,
pub func_validator: FunctionValidator<'a>,
pub contract_migrate_version: Option<u64>,
}
impl<'a> ParsedWasm<'a> {
pub fn parse(wasm: &'a [u8]) -> VmResult<Self> {
let features = WasmFeatures::MUTABLE_GLOBAL
| WasmFeatures::SATURATING_FLOAT_TO_INT
| WasmFeatures::SIGN_EXTENSION
| WasmFeatures::MULTI_VALUE
| WasmFeatures::FLOATS
| WasmFeatures::REFERENCE_TYPES;
let mut validator = Validator::new_with_features(features);
let mut this = Self {
version: 0,
exports: vec![],
imports: vec![],
tables: vec![],
memories: vec![],
function_count: 0,
type_count: 0,
type_params: Vec::new(),
max_func_params: 0,
max_func_results: 0,
total_func_params: 0,
func_validator: FunctionValidator::Pending(OpaqueDebug::default()),
contract_migrate_version: None,
};
for p in Parser::new(0).parse_all(wasm) {
let p = p?;
if let ValidPayload::Func(fv, body) = validator.payload(&p)? {
this.func_validator.push((fv, body));
this.function_count += 1;
}
match p {
Payload::TypeSection(t) => {
this.type_count = 0;
this.type_params = Vec::with_capacity(t.count() as usize);
for group in t.into_iter() {
let types = group?.into_types();
this.type_count += types.len() as u32;
for ty in types {
match ty.composite_type.inner {
CompositeInnerType::Func(ft) => {
this.type_params.push(ft.params().len());
this.max_func_params =
core::cmp::max(ft.params().len(), this.max_func_params);
this.max_func_results =
core::cmp::max(ft.results().len(), this.max_func_results);
}
CompositeInnerType::Array(_) | CompositeInnerType::Struct(_) => {
}
}
}
}
}
Payload::FunctionSection(section) => {
for a in section {
let type_index = a? as usize;
this.total_func_params +=
this.type_params.get(type_index).ok_or_else(|| {
VmError::static_validation_err(
"Wasm bytecode error: function uses unknown type index",
)
})?
}
}
Payload::Version { num, .. } => this.version = num,
Payload::ImportSection(i) => {
this.imports = i.into_iter().collect::<Result<Vec<_>, _>>()?;
}
Payload::TableSection(t) => {
this.tables = t
.into_iter()
.map(|r| r.map(|t| t.ty))
.collect::<Result<Vec<_>, _>>()?;
}
Payload::MemorySection(m) => {
this.memories = m.into_iter().collect::<Result<Vec<_>, _>>()?;
}
Payload::ExportSection(e) => {
this.exports = e.into_iter().collect::<Result<Vec<_>, _>>()?;
}
Payload::CustomSection(reader) if reader.name() == "cw_migrate_version" => {
let raw_version = str::from_utf8(reader.data())
.map_err(|err| VmError::static_validation_err(err.to_string()))?;
this.contract_migrate_version = Some(
raw_version
.parse::<u64>()
.map_err(|err| VmError::static_validation_err(err.to_string()))?,
);
}
_ => {} }
}
Ok(this)
}
pub fn validate_funcs(&mut self) -> VmResult<()> {
match self.func_validator {
FunctionValidator::Pending(OpaqueDebug(ref mut funcs)) => {
let result = (|| {
let mut allocations = <_>::default();
for (func, body) in mem::take(funcs) {
let mut validator = func.into_validator(allocations);
validator.validate(&body)?;
allocations = validator.into_allocations();
}
Ok(())
})();
self.func_validator = match result {
Ok(()) => FunctionValidator::Success,
Err(err) => FunctionValidator::Error(err),
};
self.validate_funcs()
}
FunctionValidator::Success => Ok(()),
FunctionValidator::Error(ref err) => Err(err.clone().into()),
}
}
}
#[cfg(test)]
mod test {
use super::ParsedWasm;
#[test]
fn read_migrate_version() {
let wasm_data =
wat::parse_str(r#"( module ( @custom "cw_migrate_version" "42" ) )"#).unwrap();
let parsed = ParsedWasm::parse(&wasm_data).unwrap();
assert_eq!(parsed.contract_migrate_version, Some(42));
}
#[test]
fn read_migrate_version_fails() {
let wasm_data =
wat::parse_str(r#"( module ( @custom "cw_migrate_version" "not a number" ) )"#)
.unwrap();
assert!(ParsedWasm::parse(&wasm_data).is_err());
}
#[test]
fn parsed_wasm_counts_functions_correctly() {
let wasm = wat::parse_str(r#"(module)"#).unwrap();
let module = ParsedWasm::parse(&wasm).unwrap();
assert_eq!(module.function_count, 0);
let wasm = wat::parse_str(
r#"(module
(type (func))
(func (type 0) nop)
(func (type 0) nop)
(export "foo" (func 0))
(export "bar" (func 0))
)"#,
)
.unwrap();
let module = ParsedWasm::parse(&wasm).unwrap();
assert_eq!(module.function_count, 2);
}
#[test]
fn parsed_wasm_counts_func_io_correctly() {
let wasm = wat::parse_str(r#"(module)"#).unwrap();
let module = ParsedWasm::parse(&wasm).unwrap();
assert_eq!(module.max_func_params, 0);
assert_eq!(module.max_func_results, 0);
let wasm = wat::parse_str(
r#"(module
(type (func (param i32 i32 i32) (result i32)))
(type (func (param i32) (result i32 i32)))
(func (type 1) i32.const 42 i32.const 42)
(func (type 0) i32.const 42)
)"#,
)
.unwrap();
let module = ParsedWasm::parse(&wasm).unwrap();
assert_eq!(module.max_func_params, 3);
assert_eq!(module.max_func_results, 2);
}
}