use crate::{
ids::CodeId,
message::{DispatchKind, WasmEntryPoint},
pages::{PageNumber, PageU32Size, WasmPage},
};
use alloc::{collections::BTreeSet, vec, vec::Vec};
use gear_wasm_instrument::{
parity_wasm::{
self,
builder::ModuleBuilder,
elements::{ExportEntry, GlobalEntry, GlobalType, InitExpr, Instruction, Internal, Module},
},
wasm_instrument::{
self,
gas_metering::{ConstantCostRules, Rules},
},
STACK_END_EXPORT_NAME,
};
use scale_info::{
scale::{Decode, Encode},
TypeInfo,
};
pub const MAX_WASM_PAGE_COUNT: u16 = 512;
pub const STATE_EXPORTS: [&str; 2] = ["state", "metahash"];
fn get_exports(
module: &Module,
reject_unnecessary: bool,
) -> Result<BTreeSet<DispatchKind>, CodeError> {
let mut exports = BTreeSet::<DispatchKind>::new();
for entry in module
.export_section()
.ok_or(CodeError::ExportSectionNotFound)?
.entries()
.iter()
{
if let Internal::Function(_) = entry.internal() {
if let Some(kind) = DispatchKind::try_from_entry(entry.field()) {
exports.insert(kind);
} else if !STATE_EXPORTS.contains(&entry.field()) && reject_unnecessary {
return Err(CodeError::NonGearExportFnFound);
}
}
}
Ok(exports)
}
fn get_export_entry<'a>(module: &'a Module, name: &str) -> Option<&'a ExportEntry> {
module
.export_section()?
.entries()
.iter()
.find(|export| export.field() == name)
}
fn get_export_entry_mut<'a>(module: &'a mut Module, name: &str) -> Option<&'a mut ExportEntry> {
module
.export_section_mut()?
.entries_mut()
.iter_mut()
.find(|export| export.field() == name)
}
fn get_export_global_index<'a>(module: &'a Module, name: &str) -> Option<&'a u32> {
match get_export_entry(module, name)?.internal() {
Internal::Global(index) => Some(index),
_ => None,
}
}
fn get_export_global_index_mut<'a>(module: &'a mut Module, name: &str) -> Option<&'a mut u32> {
match get_export_entry_mut(module, name)?.internal_mut() {
Internal::Global(index) => Some(index),
_ => None,
}
}
fn get_init_expr_const_i32(init_expr: &InitExpr) -> Option<i32> {
let init_code = init_expr.code();
if init_code.len() != 2 {
return None;
}
match (&init_code[0], &init_code[1]) {
(Instruction::I32Const(const_i32), Instruction::End) => Some(*const_i32),
_ => None,
}
}
fn get_global_entry(module: &Module, global_index: u32) -> Option<&GlobalEntry> {
module
.global_section()?
.entries()
.get(global_index as usize)
}
fn get_global_init_const_i32(module: &Module, global_index: u32) -> Result<i32, CodeError> {
let init_expr = get_global_entry(module, global_index)
.ok_or(CodeError::IncorrectGlobalIndex)?
.init_expr();
get_init_expr_const_i32(init_expr).ok_or(CodeError::StackEndInitialization)
}
fn check_and_canonize_gear_stack_end(module: &mut Module) -> Result<(), CodeError> {
let Some(&stack_end_global_index) = get_export_global_index(module, STACK_END_EXPORT_NAME)
else {
return Ok(());
};
let stack_end_offset = get_global_init_const_i32(module, stack_end_global_index)?;
if let Some(data_section) = module.data_section() {
for data_segment in data_section.entries() {
let offset = data_segment
.offset()
.as_ref()
.and_then(get_init_expr_const_i32)
.ok_or(CodeError::DataSegmentInitialization)?;
if offset < stack_end_offset {
return Err(CodeError::StackEndOverlaps);
}
}
};
if get_global_entry(module, stack_end_global_index)
.ok_or(CodeError::IncorrectGlobalIndex)?
.global_type()
.is_mutable()
{
let global_section = module
.global_section_mut()
.unwrap_or_else(|| unreachable!("Cannot find global section"));
let new_global_index = u32::try_from(global_section.entries().len())
.map_err(|_| CodeError::IncorrectGlobalIndex)?;
global_section.entries_mut().push(GlobalEntry::new(
GlobalType::new(parity_wasm::elements::ValueType::I32, false),
InitExpr::new(vec![
Instruction::I32Const(stack_end_offset),
Instruction::End,
]),
));
get_export_global_index_mut(module, STACK_END_EXPORT_NAME)
.map(|global_index| *global_index = new_global_index)
.unwrap_or_else(|| unreachable!("Cannot find stack end export"))
}
Ok(())
}
#[derive(Debug, PartialEq, Eq, derive_more::Display)]
pub enum CodeError {
#[display(fmt = "Import section not found")]
ImportSectionNotFound,
#[display(fmt = "Memory entry not found")]
MemoryEntryNotFound,
#[display(fmt = "Export section not found")]
ExportSectionNotFound,
#[display(fmt = "Required export function `init` or `handle` not found")]
RequiredExportFnNotFound,
#[display(fmt = "Unnecessary function exports found")]
NonGearExportFnFound,
#[display(fmt = "Wasm validation failed")]
Validation,
#[display(fmt = "The wasm bytecode is failed to be decoded")]
Decode,
#[display(fmt = "Failed to inject instructions for gas metrics: may be in case \
program contains unsupported instructions (floats, memory grow, etc.)")]
GasInjection,
#[display(fmt = "Failed to set stack height limits")]
StackLimitInjection,
#[display(fmt = "Failed to encode instrumented program")]
Encode,
#[display(fmt = "Start section is not allowed for smart contracts")]
StartSectionExists,
#[display(fmt = "The wasm bytecode has invalid count of static pages")]
InvalidStaticPageCount,
#[display(fmt = "Unsupported initialization of gear stack end global variable")]
StackEndInitialization,
#[display(fmt = "Unsupported initialization of data segment")]
DataSegmentInitialization,
#[display(fmt = "Pointer to the stack end overlaps data segment")]
StackEndOverlaps,
#[display(fmt = "Global index in export is incorrect")]
IncorrectGlobalIndex,
#[display(fmt = "Program cannot have mutable globals in export section")]
MutGlobalExport,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Code {
code: Vec<u8>,
original_code: Vec<u8>,
exports: BTreeSet<DispatchKind>,
static_pages: WasmPage,
instruction_weights_version: u32,
}
fn check_mut_global_exports(module: &Module) -> Result<(), CodeError> {
let global_exports_indexes = module
.export_section()
.iter()
.flat_map(|export_section| export_section.entries().iter())
.filter_map(|export| match export.internal() {
Internal::Global(index) => Some(*index as usize),
_ => None,
})
.collect::<Vec<_>>();
if global_exports_indexes.is_empty() {
return Ok(());
}
if let Some(globals_section) = module.global_section() {
for index in global_exports_indexes {
if globals_section
.entries()
.get(index)
.ok_or(CodeError::IncorrectGlobalIndex)?
.global_type()
.is_mutable()
{
return Err(CodeError::MutGlobalExport);
}
}
}
Ok(())
}
fn check_start_section(module: &Module) -> Result<(), CodeError> {
if module.start_section().is_some() {
log::debug!("Found start section in contract code, which is not allowed");
Err(CodeError::StartSectionExists)
} else {
Ok(())
}
}
fn export_stack_height(module: Module) -> Module {
let globals = module
.global_section()
.expect("Global section must be create by `inject_stack_limiter` before")
.entries()
.len();
ModuleBuilder::new()
.with_module(module)
.export()
.field("__gear_stack_height")
.internal()
.global(globals as u32 - 1)
.build()
.build()
}
pub struct TryNewCodeConfig {
pub version: u32,
pub stack_height: Option<u32>,
pub export_stack_height: bool,
pub check_exports: bool,
pub check_and_canonize_stack_end: bool,
pub check_mut_global_exports: bool,
pub check_start_section: bool,
pub make_validation: bool,
}
impl Default for TryNewCodeConfig {
fn default() -> Self {
Self {
version: 1,
stack_height: None,
export_stack_height: false,
check_exports: true,
check_and_canonize_stack_end: true,
check_mut_global_exports: true,
check_start_section: true,
make_validation: true,
}
}
}
impl TryNewCodeConfig {
pub fn new_no_exports_check() -> Self {
Self {
check_exports: false,
..Default::default()
}
}
}
impl Code {
fn try_new_internal<R, GetRulesFn>(
original_code: Vec<u8>,
get_gas_rules: Option<GetRulesFn>,
config: TryNewCodeConfig,
) -> Result<Self, CodeError>
where
R: Rules,
GetRulesFn: FnMut(&Module) -> R,
{
if config.make_validation {
wasmparser::validate(&original_code).map_err(|err| {
log::trace!("Wasm validation failed: {err}");
CodeError::Validation
})?;
}
let mut module: Module =
parity_wasm::deserialize_buffer(&original_code).map_err(|err| {
log::trace!("The wasm bytecode is failed to be decoded: {err}");
CodeError::Decode
})?;
if config.check_and_canonize_stack_end {
check_and_canonize_gear_stack_end(&mut module)?;
}
if config.check_mut_global_exports {
check_mut_global_exports(&module)?;
}
if config.check_start_section {
check_start_section(&module)?;
}
let static_pages = module
.import_section()
.ok_or(CodeError::ImportSectionNotFound)?
.entries()
.iter()
.find_map(|entry| match entry.external() {
parity_wasm::elements::External::Memory(mem_ty) => Some(mem_ty.limits().initial()),
_ => None,
})
.map(WasmPage::new)
.ok_or(CodeError::MemoryEntryNotFound)?
.map_err(|_| CodeError::InvalidStaticPageCount)?;
if static_pages.raw() > MAX_WASM_PAGE_COUNT as u32 {
return Err(CodeError::InvalidStaticPageCount);
}
let exports = get_exports(&module, config.check_exports)?;
if config.check_exports
&& !(exports.contains(&DispatchKind::Init) || exports.contains(&DispatchKind::Handle))
{
return Err(CodeError::RequiredExportFnNotFound);
}
if let Some(stack_limit) = config.stack_height {
let globals = config.export_stack_height.then(|| module.globals_space());
module = wasm_instrument::inject_stack_limiter(module, stack_limit).map_err(|err| {
log::trace!("Failed to inject stack height limits: {err}");
CodeError::StackLimitInjection
})?;
if let Some(globals_before) = globals {
let globals_after = module.globals_space();
assert_eq!(globals_after, globals_before + 1);
module = export_stack_height(module);
}
}
if let Some(mut get_gas_rules) = get_gas_rules {
let gas_rules = get_gas_rules(&module);
module = gear_wasm_instrument::inject(module, &gas_rules, "env")
.map_err(|_| CodeError::GasInjection)?;
}
let code = parity_wasm::elements::serialize(module).map_err(|err| {
log::trace!("Failed to encode instrumented program: {err}");
CodeError::Encode
})?;
Ok(Self {
code,
original_code,
exports,
static_pages,
instruction_weights_version: config.version,
})
}
pub fn try_new<R, GetRulesFn>(
original_code: Vec<u8>,
version: u32,
get_gas_rules: GetRulesFn,
stack_height: Option<u32>,
) -> Result<Self, CodeError>
where
R: Rules,
GetRulesFn: FnMut(&Module) -> R,
{
Self::try_new_internal(
original_code,
Some(get_gas_rules),
TryNewCodeConfig {
version,
stack_height,
..Default::default()
},
)
}
pub fn try_new_mock_const_or_no_rules(
original_code: Vec<u8>,
const_rules: bool,
config: TryNewCodeConfig,
) -> Result<Self, CodeError> {
let get_gas_rules = const_rules.then_some(|_module: &Module| ConstantCostRules::default());
Self::try_new_internal(original_code, get_gas_rules, config)
}
pub fn try_new_mock_with_rules<R, GetRulesFn>(
original_code: Vec<u8>,
get_gas_rules: GetRulesFn,
config: TryNewCodeConfig,
) -> Result<Self, CodeError>
where
R: Rules,
GetRulesFn: FnMut(&Module) -> R,
{
Self::try_new_internal(original_code, Some(get_gas_rules), config)
}
pub fn original_code(&self) -> &[u8] {
&self.original_code
}
pub fn code(&self) -> &[u8] {
&self.code
}
pub fn exports(&self) -> &BTreeSet<DispatchKind> {
&self.exports
}
pub fn instruction_weights_version(&self) -> u32 {
self.instruction_weights_version
}
pub fn static_pages(&self) -> WasmPage {
self.static_pages
}
pub fn into_parts(self) -> (InstrumentedCode, Vec<u8>) {
let original_code_len = self.original_code.len() as u32;
(
InstrumentedCode {
code: self.code,
original_code_len,
exports: self.exports,
static_pages: self.static_pages,
version: self.instruction_weights_version,
},
self.original_code,
)
}
}
#[derive(Clone, Debug)]
pub struct CodeAndId {
code: Code,
code_id: CodeId,
}
impl CodeAndId {
pub fn new(code: Code) -> Self {
let code_id = CodeId::generate(code.original_code());
Self { code, code_id }
}
pub fn from_parts_unchecked(code: Code, code_id: CodeId) -> Self {
debug_assert_eq!(code_id, CodeId::generate(code.original_code()));
Self { code, code_id }
}
pub fn code_id(&self) -> CodeId {
self.code_id
}
pub fn code(&self) -> &Code {
&self.code
}
pub fn into_parts(self) -> (Code, CodeId) {
(self.code, self.code_id)
}
}
#[derive(Clone, Debug, Decode, Encode, TypeInfo)]
pub struct InstrumentedCode {
code: Vec<u8>,
original_code_len: u32,
exports: BTreeSet<DispatchKind>,
static_pages: WasmPage,
version: u32,
}
impl InstrumentedCode {
pub fn code(&self) -> &[u8] {
&self.code
}
pub fn original_code_len(&self) -> u32 {
self.original_code_len
}
pub fn instruction_weights_version(&self) -> u32 {
self.version
}
pub fn exports(&self) -> &BTreeSet<DispatchKind> {
&self.exports
}
pub fn static_pages(&self) -> WasmPage {
self.static_pages
}
pub fn into_code(self) -> Vec<u8> {
self.code
}
}
#[derive(Clone, Debug, Decode, Encode)]
pub struct InstrumentedCodeAndId {
code: InstrumentedCode,
code_id: CodeId,
}
impl InstrumentedCodeAndId {
pub fn code(&self) -> &InstrumentedCode {
&self.code
}
pub fn code_id(&self) -> CodeId {
self.code_id
}
pub fn into_parts(self) -> (InstrumentedCode, CodeId) {
(self.code, self.code_id)
}
}
impl From<CodeAndId> for InstrumentedCodeAndId {
fn from(code_and_id: CodeAndId) -> Self {
let (code, code_id) = code_and_id.into_parts();
let (code, _) = code.into_parts();
Self { code, code_id }
}
}
#[cfg(test)]
mod tests {
use crate::code::{Code, CodeError};
use alloc::vec::Vec;
use gear_wasm_instrument::wasm_instrument::gas_metering::ConstantCostRules;
fn wat2wasm(s: &str) -> Vec<u8> {
wabt::Wat2Wasm::new()
.validate(true)
.convert(s)
.unwrap()
.as_ref()
.to_vec()
}
#[test]
fn reject_unknown_exports() {
const WAT: &str = r#"
(module
(import "env" "memory" (memory 1))
(export "this_import_is_unknown" (func $test))
(func $test)
)
"#;
let original_code = wat2wasm(WAT);
assert_eq!(
Code::try_new(original_code, 1, |_| ConstantCostRules::default(), None),
Err(CodeError::NonGearExportFnFound)
);
}
#[test]
fn required_fn_not_found() {
const WAT: &str = r#"
(module
(import "env" "memory" (memory 1))
(export "handle_signal" (func $handle_signal))
(func $handle_signal)
)
"#;
let original_code = wat2wasm(WAT);
assert_eq!(
Code::try_new(original_code, 1, |_| ConstantCostRules::default(), None),
Err(CodeError::RequiredExportFnNotFound)
);
}
#[test]
fn stack_limit_injection_works() {
const WAT: &str = r#"
(module
(import "env" "memory" (memory 1))
(export "init" (func $init))
(func $init)
)
"#;
let original_code = wat2wasm(WAT);
let _ = Code::try_new(
original_code,
1,
|_| ConstantCostRules::default(),
Some(16 * 1024),
)
.unwrap();
}
}