use crate::{
compiler::{
block_fuel::compile_block_params,
compiled_expr::CompiledExpr,
func_builder::FuncBuilder,
snippets::Snippet,
translator::{InstructionTranslator, ReusableAllocations},
},
CompilationConfig, CompilationError, ConstructorParams, DataSegmentIdx, ElementSegmentIdx,
FuncIdx, FuncRef, GlobalIdx, GlobalVariable, ImportName, Opcode, RwasmModule, RwasmModuleInner,
TableIdx, DEFAULT_MEMORY_INDEX, SNIPPET_FUNC_IDX_UNRESOLVED,
};
use alloc::{boxed::Box, vec::Vec};
use core::{
mem::{replace, take},
ops::Range,
};
use hashbrown::HashMap;
use wasmparser::{
CustomSectionReader, DataKind, DataSectionReader, ElementItems, ElementKind,
ElementSectionReader, Encoding, ExportSectionReader, ExternalKind, FuncType, FunctionBody,
FunctionSectionReader, GlobalSectionReader, ImportSectionReader, MemorySectionReader, Parser,
Payload, TableSectionReader, Type, TypeRef, TypeSectionReader, ValType, Validator,
};
pub struct ModuleParser {
validator: Validator,
compiled_funcs: u32,
allocations: ReusableAllocations,
config: CompilationConfig,
}
impl ModuleParser {
pub fn new(config: CompilationConfig) -> Self {
Self {
validator: Validator::new_with_features(config.wasm_features()),
compiled_funcs: 0,
allocations: ReusableAllocations::default(),
config,
}
}
pub fn parse(&mut self, wasm_binary: &[u8]) -> Result<(), CompilationError> {
let parser = Parser::new(0);
let payloads = parser.parse_all(wasm_binary).collect::<Vec<_>>();
let mut func_bodies = Vec::new();
for payload in payloads {
match payload? {
Payload::CodeSectionEntry(func_body) => {
func_bodies.push(func_body);
}
Payload::End(offset) => {
for func_body in take(&mut func_bodies) {
self.process_code_entry(func_body)?;
}
self.process_end(offset)?;
}
payload => {
self.process_payload(payload)?;
}
}
}
Ok(())
}
pub fn parse_function_exports(
config: CompilationConfig,
wasm_binary: &[u8],
) -> Result<Vec<(Box<str>, FuncIdx, FuncType)>, CompilationError> {
let mut result = Vec::default();
let mut parser = ModuleParser::new(config);
parser.parse(wasm_binary)?;
for (k, v) in &parser.allocations.translation.exported_funcs {
let func_type_idx = parser.allocations.translation.resolve_func_type_index(*v);
let func_type = parser
.allocations
.translation
.func_type_registry
.resolve_original_func_type(func_type_idx)
.clone();
result.push((k.clone(), *v, func_type));
#[cfg(feature = "debug-print")]
println!("{}: func_idx={}, func_type_idx={}", k, v, func_type_idx);
}
Ok(result)
}
pub fn finalize(
mut self,
wasm_binary: &[u8],
) -> Result<(RwasmModule, ConstructorParams), CompilationError> {
if let Some(start_func) = self.allocations.translation.start_func {
if !self.config.allow_start_section {
return Err(CompilationError::StartSectionsAreNotAllowed);
}
self.allocations
.translation
.emit_function_call(start_func, true, false);
}
self.allocations
.translation
.segment_builder
.entrypoint_bytecode
.op_return();
let source_pc = self
.allocations
.translation
.segment_builder
.entrypoint_bytecode
.len() as u32;
if let Some(entrypoint_name) = self.config.entrypoint_name.as_ref() {
let func_idx = self
.allocations
.translation
.exported_funcs
.get(entrypoint_name)
.copied()
.ok_or(CompilationError::MissingEntrypoint)?;
self.allocations
.translation
.emit_function_call(func_idx, true, true);
} else if self.config.state_router.is_none() {
return Err(CompilationError::MissingEntrypoint);
}
self.emit_snippets();
self.emit_state_router()?;
self.allocations
.translation
.segment_builder
.entrypoint_bytecode
.finalize(true);
let mut code_section = self
.allocations
.translation
.segment_builder
.entrypoint_bytecode;
let entrypoint_length = code_section.len() as u32;
code_section.extend(self.allocations.translation.instruction_set.iter());
for instr in code_section.iter_mut() {
match instr {
Opcode::CallInternal(compiled_func)
| Opcode::ReturnCallInternal(compiled_func)
| Opcode::RefFunc(compiled_func) => {
if *compiled_func > 0 {
*compiled_func = self.allocations.translation.func_offsets
[*compiled_func as usize - 1]
+ entrypoint_length;
}
}
_ => continue,
}
}
let mut element_section = self
.allocations
.translation
.segment_builder
.global_element_section;
for elem in element_section.iter_mut() {
if *elem > 0 {
*elem = self.allocations.translation.func_offsets[*elem as usize - 1]
+ entrypoint_length;
}
}
let module = RwasmModuleInner {
code_section,
data_section: self
.allocations
.translation
.segment_builder
.global_memory_section,
elem_section: element_section,
hint_section: wasm_binary.to_vec(),
source_pc,
};
let constructor_params = self.allocations.translation.constructor_params;
Ok((RwasmModule::from(module), constructor_params))
}
pub fn emit_state_router(&mut self) -> Result<(), CompilationError> {
let allow_malformed_entrypoint_func_type = self.config.allow_malformed_entrypoint_func_type;
let Some(state_router) = &self.config.state_router else {
return Ok(());
};
if let Some(opcode) = &state_router.opcode {
self.allocations
.translation
.segment_builder
.entrypoint_bytecode
.push(*opcode);
}
for (entrypoint_name, state_value) in state_router.states.iter() {
let Some(func_idx) = self
.allocations
.translation
.exported_funcs
.get(entrypoint_name)
.copied()
else {
continue;
};
let func_type_idx = self
.allocations
.translation
.resolve_func_type_index(func_idx);
let is_empty_func_type = self
.allocations
.translation
.func_type_registry
.resolve_func_type_ref(func_type_idx, |func_type| {
func_type.params().is_empty() && func_type.results().is_empty()
});
if !is_empty_func_type && !allow_malformed_entrypoint_func_type {
return Err(CompilationError::MalformedFuncType);
}
let entrypoint_bytecode = &mut self
.allocations
.translation
.segment_builder
.entrypoint_bytecode;
entrypoint_bytecode.op_local_get(1u32);
entrypoint_bytecode.op_i32_const(*state_value);
entrypoint_bytecode.op_i32_eq();
entrypoint_bytecode.op_br_if_eqz(3);
entrypoint_bytecode.op_drop();
self.allocations
.translation
.emit_function_call(func_idx, true, true);
}
self.allocations
.translation
.segment_builder
.entrypoint_bytecode
.op_drop();
Ok(())
}
pub fn emit_snippets(&mut self) {
if !self.config.code_snippets {
return;
}
let mut emitted_snippets: HashMap<Snippet, FuncIdx> = HashMap::new();
let snippet_calls = self.allocations.translation.snippet_calls.clone();
for snippet_call in snippet_calls {
let snippet = snippet_call.snippet;
let snippet_func_idx = *emitted_snippets.entry(snippet).or_insert_with(|| {
let new_func_idx = self.next_func();
let alloc = &mut self.allocations.translation;
let func_offset = alloc.instruction_set.len() as u32;
alloc.func_offsets.push(func_offset);
alloc
.instruction_set
.op_stack_check(snippet.max_stack_height());
snippet.emit(&mut alloc.instruction_set);
alloc.instruction_set.op_return();
new_func_idx
});
let loc = snippet_call.loc;
let alloc = &mut self.allocations.translation;
let opcode = alloc.instruction_set.get_nth_mut(loc as usize)
.unwrap_or_else(|| panic!("expected snippet call at index {loc}, but instruction set length is smaller"));
match opcode {
Opcode::CallInternal(func_idx) => {
assert_eq!(*func_idx, SNIPPET_FUNC_IDX_UNRESOLVED);
*func_idx = snippet_func_idx + 1;
}
other => {
panic!("expected Opcode::CallInternal at index {loc}, but found {other:?}")
}
}
}
}
fn process_payload(&mut self, payload: Payload) -> Result<bool, CompilationError> {
match payload {
Payload::Version {
num,
encoding,
range,
} => self.process_version(num, encoding, range),
Payload::TypeSection(section) => self.process_types(section),
Payload::ImportSection(section) => self.process_imports(section),
Payload::InstanceSection(section) => self.process_instances(section),
Payload::FunctionSection(section) => self.process_functions(section),
Payload::TableSection(section) => self.process_tables(section),
Payload::MemorySection(section) => self.process_memories(section),
Payload::TagSection(section) => self.process_tags(section),
Payload::GlobalSection(section) => self.process_globals(section),
Payload::ExportSection(section) => self.process_exports(section),
Payload::StartSection { func, range } => self.process_start(func, range),
Payload::ElementSection(section) => self.process_element(section),
Payload::DataCountSection { count, range } => self.process_data_count(count, range),
Payload::DataSection(section) => self.process_data(section),
Payload::CustomSection(section) => self.process_custom_section(section),
Payload::CodeSectionStart { count, range, .. } => self.process_code_start(count, range),
Payload::CodeSectionEntry(func_body) => self.process_code_entry(func_body),
Payload::UnknownSection { id, range, .. } => self.process_unknown(id, range),
Payload::ModuleSection { parser: _, range } => {
self.process_unsupported_component_model(range)
}
Payload::CoreTypeSection(section) => {
self.process_unsupported_component_model(section.range())
}
Payload::ComponentSection { parser: _, range } => {
self.process_unsupported_component_model(range)
}
Payload::ComponentInstanceSection(section) => {
self.process_unsupported_component_model(section.range())
}
Payload::ComponentAliasSection(section) => {
self.process_unsupported_component_model(section.range())
}
Payload::ComponentTypeSection(section) => {
self.process_unsupported_component_model(section.range())
}
Payload::ComponentCanonicalSection(section) => {
self.process_unsupported_component_model(section.range())
}
Payload::ComponentStartSection { start: _, range } => {
self.process_unsupported_component_model(range)
}
Payload::ComponentImportSection(section) => {
self.process_unsupported_component_model(section.range())
}
Payload::ComponentExportSection(section) => {
self.process_unsupported_component_model(section.range())
}
Payload::End(offset) => {
self.process_end(offset)?;
return Ok(true);
}
}?;
Ok(false)
}
fn process_version(
&mut self,
num: u16,
encoding: Encoding,
range: Range<usize>,
) -> Result<(), CompilationError> {
self.validator
.version(num, encoding, &range)
.map_err(Into::into)
}
fn process_types(&mut self, section: TypeSectionReader) -> Result<(), CompilationError> {
self.validator.type_section(§ion)?;
for func_type in section.into_iter() {
let Type::Func(func_type) = func_type?;
self.allocations
.translation
.func_type_registry
.alloc_func_type(func_type)?;
}
Ok(())
}
fn process_imports(&mut self, section: ImportSectionReader) -> Result<(), CompilationError> {
self.validator.import_section(§ion)?;
for import in section.into_iter() {
let import = import?;
let func_type_index = match import.ty {
TypeRef::Func(func_type_index) => func_type_index,
TypeRef::Global(global_type) => {
let Some(default_value) = self.config.default_imported_global_value else {
return Err(CompilationError::NotSupportedImportType);
};
let global_index = self.allocations.translation.globals.len() as u32;
let global_variable = GlobalVariable::new(global_type, default_value);
self.allocations
.translation
.segment_builder
.add_global_variable(global_index, &global_variable)?;
self.allocations.translation.globals.push(global_variable);
continue;
}
_ => return Err(CompilationError::NotSupportedImportType),
};
let import_name = ImportName::new(import.module, import.name);
let Some(import_linker) = self.config.import_linker.as_ref() else {
return Err(CompilationError::UnresolvedImportFunction);
};
let import_linker_entity = import_linker
.resolve_by_import_name(&import_name)
.cloned()
.ok_or(CompilationError::UnresolvedImportFunction)?;
let func_type = self
.allocations
.translation
.func_type_registry
.resolve_original_func_type(func_type_index);
if !import_linker_entity.matches_func_type(func_type) {
return Err(CompilationError::MalformedImportFunctionType);
}
if !self.config.allow_func_ref_function_types {
for x in func_type.params().iter().chain(func_type.results()) {
if x == &ValType::FuncRef || x == &ValType::ExternRef {
return Err(CompilationError::MalformedImportFunctionType);
}
}
}
let func_idx = self.next_func();
self.allocations
.translation
.compiled_funcs
.push(func_type_index);
if let Some(intrinsic) = import_linker_entity.intrinsic {
self.allocations
.translation
.intrinsic_handler
.intrinsics
.push((func_idx, intrinsic));
}
let allocations = take(&mut self.allocations);
let mut translator = InstructionTranslator::new(
allocations.translation,
self.config.consume_fuel,
self.config.code_snippets,
self.config.consume_fuel_for_params_and_locals,
self.config.max_allowed_memory_pages,
);
translator.prepare(func_idx)?;
let signature_index = translator
.alloc
.func_type_registry
.resolve_func_type_signature(func_type_index);
translator.alloc.instruction_set.op_stack_check(u32::MAX);
if self.config.builtins_consume_fuel {
compile_block_params(
&mut translator.alloc.instruction_set,
import_linker_entity.syscall_fuel_param,
)
}
translator
.alloc
.instruction_set
.op_call(import_linker_entity.sys_func_idx);
translator.alloc.instruction_set.op_return();
translator.finish()?;
let _ = replace(
&mut self.allocations,
ReusableAllocations {
translation: take(&mut translator.alloc),
validation: allocations.validation,
},
);
}
Ok(())
}
fn process_instances(
&mut self,
section: wasmparser::InstanceSectionReader,
) -> Result<(), CompilationError> {
self.validator
.instance_section(§ion)
.map_err(Into::into)
}
fn process_functions(
&mut self,
section: FunctionSectionReader,
) -> Result<(), CompilationError> {
self.validator.function_section(§ion)?;
for func_type_index in section.into_iter() {
let func_type_index = func_type_index?;
self.allocations
.translation
.compiled_funcs
.push(func_type_index);
}
Ok(())
}
fn process_tables(&mut self, section: TableSectionReader) -> Result<(), CompilationError> {
self.validator.table_section(§ion)?;
for (table_idx, table_type) in section.into_iter().enumerate() {
let table_type = table_type?;
let table_idx = TableIdx::try_from(table_idx).unwrap();
self.allocations
.translation
.segment_builder
.emit_table_segment(table_idx, &table_type)?;
self.allocations.translation.tables.push(table_type);
}
Ok(())
}
fn process_memories(&mut self, section: MemorySectionReader) -> Result<(), CompilationError> {
self.validator.memory_section(§ion)?;
for memory_type in section.into_iter() {
let memory_type = memory_type?;
self.allocations.translation.memories.push(memory_type);
let initial_memory =
u32::try_from(memory_type.initial).expect("memory initial size too large");
self.allocations
.translation
.segment_builder
.add_memory_pages(initial_memory, self.config.max_allowed_memory_pages)?;
}
Ok(())
}
fn process_tags(
&mut self,
section: wasmparser::TagSectionReader,
) -> Result<(), CompilationError> {
self.validator.tag_section(§ion).map_err(Into::into)
}
fn process_globals(&mut self, section: GlobalSectionReader) -> Result<(), CompilationError> {
self.validator.global_section(§ion)?;
for global in section.into_iter() {
let global = global?;
let init_expr = CompiledExpr::new(global.init_expr);
let default_value = self.eval_const(init_expr)?;
let global_variable = GlobalVariable::new(global.ty, default_value);
let global_idx = GlobalIdx::from(self.allocations.translation.globals.len() as u32);
self.allocations
.translation
.segment_builder
.add_global_variable(global_idx, &global_variable)?;
self.allocations.translation.globals.push(global_variable);
}
Ok(())
}
fn process_exports(&mut self, section: ExportSectionReader) -> Result<(), CompilationError> {
self.validator.export_section(§ion)?;
for export in section.into_iter() {
let export = export?;
if export.kind == ExternalKind::Func {
let function_name: Box<str> = export.name.into();
self.allocations
.translation
.exported_funcs
.insert(function_name, FuncIdx::from(export.index));
}
}
Ok(())
}
fn process_start(&mut self, func: u32, range: Range<usize>) -> Result<(), CompilationError> {
self.validator.start_section(func, &range)?;
self.allocations.translation.start_func = Some(FuncIdx::from(func));
Ok(())
}
fn process_element(&mut self, section: ElementSectionReader) -> Result<(), CompilationError> {
self.validator.element_section(§ion)?;
for (element_segment_idx, element) in section.into_iter().enumerate() {
let element = element?;
let element_segment_idx = ElementSegmentIdx::from(element_segment_idx as u32);
let element_items_vec = match element.items {
ElementItems::Expressions(section) => section
.into_iter()
.map(|v| {
let compiled_expr = CompiledExpr::new(v?);
compiled_expr
.funcref()
.map(|v| v + 1)
.or_else(|| compiled_expr.eval_const().map(|v| v as i32 as u32))
.ok_or(CompilationError::ConstEvaluationFailed)
})
.collect::<Result<Vec<_>, _>>()?,
ElementItems::Functions(section) => section
.into_iter()
.map(|v| v.map(|v| v + 1).map_err(CompilationError::from))
.collect::<Result<Vec<_>, _>>()?,
};
match element.kind {
ElementKind::Active {
table_index,
offset_expr,
} => {
let compiled_expr = CompiledExpr::new(offset_expr);
let element_offset = u32::try_from(self.eval_const(compiled_expr)?)
.map_err(|_| CompilationError::TableOutOfBounds)?;
let table_idx = TableIdx::try_from(table_index).unwrap();
self.allocations
.translation
.segment_builder
.add_active_elements(
element_segment_idx,
element_offset,
table_idx,
element_items_vec,
);
}
ElementKind::Passive => self
.allocations
.translation
.segment_builder
.add_passive_elements(element_segment_idx, element_items_vec),
ElementKind::Declared => self
.allocations
.translation
.segment_builder
.add_passive_elements(element_segment_idx, []),
};
}
Ok(())
}
fn process_data_count(
&mut self,
count: u32,
range: Range<usize>,
) -> Result<(), CompilationError> {
self.validator
.data_count_section(count, &range)
.map_err(Into::into)
}
fn process_data(&mut self, section: DataSectionReader) -> Result<(), CompilationError> {
self.validator.data_section(§ion)?;
for (data_segment_idx, data) in section.into_iter().enumerate() {
let data = data?;
let data_segment_idx = DataSegmentIdx::from(data_segment_idx as u32);
match data.kind {
DataKind::Active {
memory_index,
offset_expr,
} => {
if memory_index != DEFAULT_MEMORY_INDEX {
return Err(CompilationError::NonDefaultMemoryIndex);
}
let compiled_expr = CompiledExpr::new(offset_expr);
let data_offset = u32::try_from(self.eval_const(compiled_expr)?)
.map_err(|_| CompilationError::MemoryOutOfBounds)?;
self.allocations
.translation
.segment_builder
.add_active_memory(data_segment_idx, data_offset, data.data);
}
DataKind::Passive => self
.allocations
.translation
.segment_builder
.add_passive_memory(data_segment_idx, data.data),
};
}
Ok(())
}
fn eval_const(&self, compiled_expr: CompiledExpr) -> Result<i64, CompilationError> {
compiled_expr
.eval_with_context(
|global_index| {
self.allocations
.translation
.globals
.get(global_index as usize)
.and_then(GlobalVariable::value)
},
|function_index| Some(FuncRef::new(function_index + 1)),
)
.ok_or(CompilationError::ConstEvaluationFailed)
}
fn process_custom_section(
&mut self,
reader: CustomSectionReader,
) -> Result<(), CompilationError> {
self.allocations
.translation
.constructor_params
.try_parse(reader);
Ok(())
}
fn process_code_start(
&mut self,
count: u32,
range: Range<usize>,
) -> Result<(), CompilationError> {
self.validator.code_section_start(count, &range)?;
Ok(())
}
fn next_func(&mut self) -> FuncIdx {
let compiled_func = self.compiled_funcs;
self.compiled_funcs += 1;
FuncIdx::from(compiled_func)
}
fn process_code_entry(&mut self, func_body: FunctionBody) -> Result<(), CompilationError> {
let func_idx = self.next_func();
let allocations = take(&mut self.allocations);
let validator = self.validator.code_section_entry(&func_body)?;
let func_validator = validator.into_validator(allocations.validation);
let allocations = FuncBuilder::new(
func_body,
func_validator,
func_idx,
allocations.translation,
self.config.consume_fuel,
self.config.code_snippets,
self.config.consume_fuel_for_params_and_locals,
self.config.max_allowed_memory_pages,
)
.translate()?;
let _ = replace(&mut self.allocations, allocations);
Ok(())
}
fn process_unknown(&mut self, id: u8, range: Range<usize>) -> Result<(), CompilationError> {
self.validator
.unknown_section(id, &range)
.map_err(Into::into)
}
fn process_unsupported_component_model(
&mut self,
range: Range<usize>,
) -> Result<(), CompilationError> {
panic!(
"rwasm does not support the `component-model` Wasm proposal: bytes[{}..{}]",
range.start, range.end
)
}
fn process_end(&mut self, offset: usize) -> Result<(), CompilationError> {
self.validator.end(offset)?;
Ok(())
}
}