#![deny(missing_docs)]
use crate::wasm_conventions;
use anyhow::{bail, ensure};
use std::collections::{BTreeMap, HashMap, HashSet};
use walrus::ir::InstrSeqId;
use walrus::{ExportId, FunctionId, GlobalId, GlobalKind, LocalFunction, LocalId, Module};
#[derive(Default)]
pub struct Interpreter {
describe_id: Option<FunctionId>,
describe_cast_id: Option<FunctionId>,
mem: Vec<i32>,
scratch: Vec<i32>,
stack_pointer: Option<GlobalId>,
stack_pointer_initial: i32,
globals: HashMap<GlobalId, i32>,
descriptor: Vec<u32>,
skip_interpret: Option<ExportId>,
skip_calls: HashSet<FunctionId>,
stopped: bool,
}
fn skip_calls(module: &Module, id: FunctionId) -> HashSet<FunctionId> {
use walrus::ir::*;
let func = module.funcs.get(id);
let local = match &func.kind {
walrus::FunctionKind::Local(l) => l,
_ => panic!("can only call locally defined functions"),
};
let entry = local.entry_block();
let block = local.block(entry);
block
.instrs
.iter()
.filter_map(|(instr, _)| match instr {
Instr::Call(Call { func }) | Instr::ReturnCall(ReturnCall { func }) => Some(*func),
_ => None,
})
.collect()
}
impl Interpreter {
pub fn new(module: &Module) -> Result<Interpreter, anyhow::Error> {
let mut ret = Interpreter {
mem: module
.memories
.iter()
.next()
.map_or(vec![], |m| vec![0; m.initial as usize * 65536 / 4]),
..Default::default()
};
for global in module.globals.iter() {
if let GlobalKind::Local(walrus::ConstExpr::Value(walrus::ir::Value::I32(n))) =
global.kind
{
ret.globals.insert(global.id(), n);
}
}
if let Some(sp) = wasm_conventions::get_stack_pointer(module) {
ret.stack_pointer = Some(sp);
}
for import in module.imports.iter() {
let id = match import.kind {
walrus::ImportKind::Function(id) => id,
_ => continue,
};
if import.module != "__wbindgen_placeholder__" {
continue;
}
if import.name == "__wbindgen_describe" {
ret.describe_id = Some(id);
} else if import.name == "__wbindgen_describe_cast" {
ret.describe_cast_id = Some(id);
}
}
if let Some(export) = module
.exports
.iter()
.find(|export| export.name == "__wbindgen_skip_interpret_calls")
{
let id = match export.item {
walrus::ExportItem::Function(id) => id,
_ => panic!("__wbindgen_skip_interpret_calls must be an export function"),
};
ret.skip_interpret = Some(export.id());
ret.skip_calls = skip_calls(module, id);
}
Ok(ret)
}
pub fn interpret_descriptor(&mut self, id: FunctionId, module: &Module) -> &[u32] {
self.descriptor.truncate(0);
self.stopped = false;
if let Some(sp) = self.stack_pointer {
self.stack_pointer_initial = self.globals[&sp];
}
let func = module.funcs.get(id);
let ty = module.types.get(func.ty());
self.call(id, module, &vec![0; ty.params().len()]);
if let Some(sp) = self.stack_pointer {
assert_eq!(self.globals[&sp], self.stack_pointer_initial);
}
&self.descriptor
}
pub fn describe_cast_id(&self) -> Option<FunctionId> {
self.describe_cast_id
}
pub fn skip_interpret(&self) -> Option<ExportId> {
self.skip_interpret
}
fn call(&mut self, id: FunctionId, module: &Module, args: &[i32]) {
let func = module.funcs.get(id);
log::trace!("starting a call of {id:?} {:?}", func.name);
log::trace!("arguments {args:?}");
let local = match &func.kind {
walrus::FunctionKind::Local(l) => l,
_ => panic!("can only call locally defined functions"),
};
let mut frame = Frame {
module,
func: local,
interp: self,
locals: BTreeMap::new(),
};
assert_eq!(local.args.len(), args.len());
for (arg, val) in local.args.iter().zip(args) {
frame.locals.insert(*arg, *val);
}
frame.eval(local.entry_block()).unwrap_or_else(|err| {
if let Some(name) = &module.funcs.get(id).name {
panic!("{name}: {err}")
} else {
panic!("{err}")
}
})
}
}
struct Frame<'a> {
module: &'a Module,
func: &'a LocalFunction,
interp: &'a mut Interpreter,
locals: BTreeMap<LocalId, i32>,
}
impl Frame<'_> {
fn eval(&mut self, seq: InstrSeqId) -> anyhow::Result<()> {
use walrus::ir::*;
for (instr, _) in self.func.block(seq).iter() {
let stack = &mut self.interp.scratch;
match instr {
Instr::Const(c) => match c.value {
Value::I32(n) => stack.push(n),
_ => bail!("non-i32 constant"),
},
Instr::LocalGet(e) => stack.push(self.locals.get(&e.local).cloned().unwrap_or(0)),
Instr::LocalSet(e) => {
let val = stack.pop().unwrap();
self.locals.insert(e.local, val);
}
Instr::LocalTee(e) => {
let val = *stack.last().unwrap();
self.locals.insert(e.local, val);
}
Instr::GlobalGet(e) => {
let val = *self.interp.globals.get(&e.global).unwrap_or_else(|| {
panic!(
"global {:?} not found, this is a bug in wasm-bindgen",
e.global
)
});
stack.push(val);
}
Instr::GlobalSet(e) => {
let val = stack.pop().unwrap();
self.interp.globals.insert(e.global, val);
}
Instr::Binop(e) => {
let rhs = stack.pop().unwrap();
let lhs = stack.pop().unwrap();
stack.push(match e.op {
BinaryOp::I32Sub => lhs - rhs,
BinaryOp::I32Add => lhs + rhs,
op => bail!("invalid binary op {op:?}"),
});
}
Instr::Load(e) => {
let address = stack.pop().unwrap();
let address = address as u32 + e.arg.offset as u32;
ensure!(
address > 0,
"Read a negative or zero address value from the stack. Did we run out of memory?"
);
let width = e.kind.width();
ensure!(address % width == 0);
let val = self.interp.mem[address as usize / 4];
if width == 4 {
stack.push(val)
} else if width == 1 {
let result = val.to_le_bytes()[(address % 4) as usize];
let LoadKind::I32_8 { kind } = e.kind else {
panic!("Unhandled load kind {:?}", e.kind)
};
match kind {
ExtendedLoad::SignExtend => {
stack.push(result as i8 as i32);
}
ExtendedLoad::ZeroExtend | ExtendedLoad::ZeroExtendAtomic => {
stack.push(result as i32);
}
};
} else {
panic!("Unhandled load width {width}");
}
}
Instr::Store(e) => {
let value = stack.pop().unwrap();
let address = stack.pop().unwrap();
let address = address as u32 + e.arg.offset as u32;
ensure!(
address > 0,
"Read a negative or zero address value from the stack. Did we run out of memory?"
);
let width = e.kind.width();
ensure!(address % width == 0);
let index = address as usize / 4;
if width == 8 {
self.interp.mem[index] = value;
self.interp.mem[index + 1] = 0;
} else if width == 4 {
self.interp.mem[index] = value;
} else if width == 1 {
let mut bytes = self.interp.mem[index].to_le_bytes();
bytes[(address % 4) as usize] = value as u8;
self.interp.mem[index] = i32::from_le_bytes(bytes);
} else {
panic!("Unhandled store width {width}");
}
}
Instr::Return(_) => {
log::trace!("return");
break;
}
Instr::Drop(_) => {
log::trace!("drop");
stack.pop().unwrap();
}
Instr::Call(Call { func }) | Instr::ReturnCall(ReturnCall { func }) => {
let func = *func;
if Some(func) == self.interp.describe_id {
let val = stack.pop().unwrap();
log::trace!("__wbindgen_describe({val})");
self.interp.descriptor.push(val as u32);
} else if Some(func) == self.interp.describe_cast_id {
log::trace!("__wbindgen_describe_cast()");
if let Some(sp) = self.interp.stack_pointer {
self.interp
.globals
.insert(sp, self.interp.stack_pointer_initial);
}
self.interp.stopped = true;
break;
} else {
if self.interp.skip_calls.contains(&func) {
continue;
}
if self
.module
.funcs
.get(func)
.name
.as_ref()
.is_some_and(|name| {
name.starts_with("__llvm_profile_init")
|| name.starts_with("__llvm_profile_register_function")
|| name.starts_with("__llvm_profile_register_function")
})
{
continue;
}
let ty = self.module.types.get(self.module.funcs.get(func).ty());
let mut args = (0..ty.params().len())
.map(|_| stack.pop().unwrap())
.collect::<Vec<_>>();
args.reverse();
self.interp.call(func, self.module, &args);
}
if let Instr::ReturnCall(_) = instr {
log::trace!("return_call");
break;
}
}
Instr::Block(block) => {
self.eval(block.seq)?;
if self.interp.stopped {
break;
}
}
Instr::Try(block) => {
self.eval(block.seq)?;
if self.interp.stopped {
break;
}
}
Instr::TryTable(block) => {
self.eval(block.seq)?;
if self.interp.stopped {
break;
}
}
s => bail!("unknown instruction {s:?}"),
}
}
Ok(())
}
}
#[cfg(test)]
mod smoke_tests;