use std::io::Cursor;
use anyhow::{anyhow, bail, Context, Result};
use walrus::{
ir::Value, ConstExpr, ConstOp, ElementItems, FunctionBuilder, FunctionId, FunctionKind,
GlobalId, GlobalKind, MemoryId, Module, RawCustomSection, ValType,
};
use wasmparser::BinaryReader;
pub fn get_memory(module: &Module) -> Result<MemoryId> {
let mut memories = module.memories.iter().map(|m| m.id());
let memory = memories.next();
if memories.next().is_some() {
bail!(
"expected a single memory, found multiple; multiple memories \
currently not supported"
);
}
memory.ok_or_else(|| {
anyhow!(
"module does not have a memory; must have a memory \
to transform return pointers into Wasm multi-value"
)
})
}
pub fn get_stack_pointer(module: &Module) -> Option<GlobalId> {
if let Some(g) = module
.globals
.iter()
.find(|g| matches!(g.name.as_deref(), Some("__stack_pointer")))
{
return Some(g.id());
}
let candidates = module
.globals
.iter()
.filter(|g| g.ty == ValType::I32)
.filter(|g| g.mutable)
.filter(|g| match g.kind {
GlobalKind::Local(ConstExpr::Value(Value::I32(n))) => n != 0,
_ => false,
})
.collect::<Vec<_>>();
match candidates.len() {
0 => None,
1 => Some(candidates[0].id()),
2 => {
log::warn!("Unable to accurately determine the location of `__stack_pointer`");
Some(candidates[0].id())
}
_ => None,
}
}
pub fn get_tls_base(module: &Module) -> Option<GlobalId> {
let candidates = module
.exports
.iter()
.filter(|ex| ex.name == "__tls_base")
.filter_map(|ex| match ex.item {
walrus::ExportItem::Global(id) => Some(id),
_ => None,
})
.filter(|id| {
let global = module.globals.get(*id);
global.ty == ValType::I32
})
.collect::<Vec<_>>();
match candidates.len() {
1 => Some(candidates[0]),
_ => None,
}
}
fn evaluate_const_expr(expr: &ConstExpr, module: &Module) -> Option<Value> {
match expr {
ConstExpr::Value(v) => Some(*v),
ConstExpr::Global(g) => {
match &module.globals.get(*g).kind {
GlobalKind::Local(inner) => evaluate_const_expr(inner, module),
_ => None,
}
}
ConstExpr::Extended(ops) => {
let mut stack: Vec<Value> = Vec::new();
for op in ops {
match op {
ConstOp::I32Const(n) => stack.push(Value::I32(*n)),
ConstOp::I64Const(n) => stack.push(Value::I64(*n)),
ConstOp::F32Const(n) => stack.push(Value::F32(*n)),
ConstOp::F64Const(n) => stack.push(Value::F64(*n)),
ConstOp::GlobalGet(g) => {
let v = match &module.globals.get(*g).kind {
GlobalKind::Local(inner) => evaluate_const_expr(inner, module)?,
_ => return None,
};
stack.push(v);
}
ConstOp::I32Add => {
let (Value::I32(b), Value::I32(a)) = (stack.pop()?, stack.pop()?) else {
return None;
};
stack.push(Value::I32(a.wrapping_add(b)));
}
ConstOp::I32Sub => {
let (Value::I32(b), Value::I32(a)) = (stack.pop()?, stack.pop()?) else {
return None;
};
stack.push(Value::I32(a.wrapping_sub(b)));
}
ConstOp::I32Mul => {
let (Value::I32(b), Value::I32(a)) = (stack.pop()?, stack.pop()?) else {
return None;
};
stack.push(Value::I32(a.wrapping_mul(b)));
}
ConstOp::I64Add => {
let (Value::I64(b), Value::I64(a)) = (stack.pop()?, stack.pop()?) else {
return None;
};
stack.push(Value::I64(a.wrapping_add(b)));
}
ConstOp::I64Sub => {
let (Value::I64(b), Value::I64(a)) = (stack.pop()?, stack.pop()?) else {
return None;
};
stack.push(Value::I64(a.wrapping_sub(b)));
}
ConstOp::I64Mul => {
let (Value::I64(b), Value::I64(a)) = (stack.pop()?, stack.pop()?) else {
return None;
};
stack.push(Value::I64(a.wrapping_mul(b)));
}
_ => return None,
}
}
if stack.len() == 1 {
stack.pop()
} else {
None
}
}
_ => None,
}
}
pub fn get_function_table_entry(module: &Module, idx: u32) -> Result<FunctionId> {
let table = module
.tables
.main_function_table()?
.ok_or_else(|| anyhow!("no function table found in module"))?;
let table = module.tables.get(table);
for &segment in table.elem_segments.iter() {
let segment = module.elements.get(segment);
let offset = match &segment.kind {
walrus::ElementKind::Active { offset, .. } => {
match evaluate_const_expr(offset, module) {
Some(Value::I32(n)) => n as u32,
_ => continue,
}
}
_ => continue,
};
let local_idx = match idx.checked_sub(offset) {
Some(i) => i as usize,
None => continue,
};
let slot = match &segment.items {
ElementItems::Functions(items) => items.get(local_idx).map(Some),
ElementItems::Expressions(_, items) => items.get(local_idx).map(|item| {
if let ConstExpr::RefFunc(target) = item {
Some(target)
} else {
None
}
}),
};
match slot {
Some(slot) => {
return slot.copied().context("function table entry wasn't filled");
}
None => continue,
}
}
bail!("failed to find `{idx}` in function table");
}
pub fn get_start(module: &mut Module) -> Result<FunctionId, Option<FunctionId>> {
match module.start {
Some(start) => match module.funcs.get_mut(start).kind {
FunctionKind::Import(_) => Err(Some(start)),
FunctionKind::Local(_) => Ok(start),
FunctionKind::Uninitialized(_) => unimplemented!(),
},
None => Err(None),
}
}
pub fn get_or_insert_start_builder(module: &mut Module) -> &mut FunctionBuilder {
let prev_start = get_start(module);
let id = match prev_start {
Ok(id) => id,
Err(prev_start) => {
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
if let Some(prev_start) = prev_start {
builder.func_body().call(prev_start);
}
let id = builder.finish(Vec::new(), &mut module.funcs);
module.start = Some(id);
id
}
};
module
.funcs
.get_mut(id)
.kind
.unwrap_local_mut()
.builder_mut()
}
pub fn target_feature(module: &Module, feature: &str) -> Result<bool> {
anyhow::ensure!(feature.len() <= 100_000, "feature name too long");
let section = module
.customs
.iter()
.find(|(_, custom)| custom.name() == "target_features");
if let Some((_, section)) = section {
let section: &RawCustomSection = section
.as_any()
.downcast_ref()
.context("failed to read section")?;
let mut reader = BinaryReader::new(§ion.data, 0);
let count = reader.read_var_u32()?;
for _ in 0..count {
let prefix = reader.read_u8()?;
let length = reader.read_var_u32()?;
let this_feature = reader.read_bytes(length as usize)?;
if this_feature == feature.as_bytes() {
if prefix == b'-' {
return Ok(false);
}
return Ok(true);
}
}
Ok(false)
} else {
Ok(false)
}
}
pub fn insert_target_feature(module: &mut Module, new_feature: &str) -> Result<()> {
anyhow::ensure!(new_feature.len() <= 100_000, "feature name too long");
let section = module
.customs
.iter_mut()
.find(|(_, custom)| custom.name() == "target_features");
let section = if let Some((_, section)) = section {
let section: &mut RawCustomSection = section
.as_any_mut()
.downcast_mut()
.context("failed to read section")?;
let mut reader = BinaryReader::new(§ion.data, 0);
let count = reader.read_var_u32()?;
for _ in 0..count {
let prefix_index = reader.current_position();
let prefix = reader.read_u8()?;
let length = reader.read_var_u32()?;
let feature = reader.read_bytes(length as usize)?;
if feature == new_feature.as_bytes() {
if prefix == b'-' {
section.data[prefix_index] = b'+';
}
return Ok(());
}
}
section
} else {
let mut data = Vec::new();
leb128::write::unsigned(&mut data, 0).unwrap();
let id = module.customs.add(RawCustomSection {
name: String::from("target_features"),
data,
});
module.customs.get_mut(id).unwrap()
};
let mut data = Cursor::new(§ion.data);
let count = leb128::read::unsigned(&mut data).unwrap();
let mut new_count = Vec::new();
leb128::write::unsigned(&mut new_count, count + 1).unwrap();
section.data.splice(0..data.position() as usize, new_count);
section.data.push(b'+');
leb128::write::unsigned(&mut section.data, new_feature.len() as u64).unwrap();
section.data.extend(new_feature.as_bytes());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use walrus::{
ConstOp, ElementItems, ElementKind, FunctionBuilder, Module, ModuleConfig, RefType,
};
fn make_module_with_segment(offset: ConstExpr) -> (Module, FunctionId) {
let mut config = ModuleConfig::new();
config.generate_producers_section(false);
let mut module = Module::with_config(config);
let func_id =
FunctionBuilder::new(&mut module.types, &[], &[]).finish(vec![], &mut module.funcs);
module.exports.add("f", func_id);
let table_id = module.tables.add_local(false, 64, None, RefType::FUNCREF);
let elem_id = module.elements.add(
ElementKind::Active {
table: table_id,
offset,
},
ElementItems::Functions(vec![func_id]),
);
module
.tables
.get_mut(table_id)
.elem_segments
.insert(elem_id);
(module, func_id)
}
#[test]
fn evaluate_immediate_i32() {
let config = ModuleConfig::new();
let module = Module::with_config(config);
let expr = ConstExpr::Value(Value::I32(42));
assert!(matches!(
evaluate_const_expr(&expr, &module),
Some(Value::I32(42))
));
}
#[test]
fn evaluate_global_offset() {
let mut config = ModuleConfig::new();
config.generate_producers_section(false);
let mut module = Module::with_config(config);
let g =
module
.globals
.add_local(ValType::I32, false, false, ConstExpr::Value(Value::I32(5)));
let expr = ConstExpr::Global(g);
assert!(matches!(
evaluate_const_expr(&expr, &module),
Some(Value::I32(5))
));
}
#[test]
fn evaluate_extended_global_plus_const() {
let mut config = ModuleConfig::new();
config.generate_producers_section(false);
let mut module = Module::with_config(config);
let g =
module
.globals
.add_local(ValType::I32, false, false, ConstExpr::Value(Value::I32(1)));
let expr = ConstExpr::Extended(vec![
ConstOp::GlobalGet(g),
ConstOp::I32Const(7),
ConstOp::I32Add,
]);
assert!(matches!(
evaluate_const_expr(&expr, &module),
Some(Value::I32(8))
));
}
#[test]
fn evaluate_extended_returns_none_for_unknown_op() {
let config = ModuleConfig::new();
let module = Module::with_config(config);
let expr = ConstExpr::Extended(vec![ConstOp::RefNull(walrus::RefType::FUNCREF)]);
assert!(evaluate_const_expr(&expr, &module).is_none());
}
#[test]
fn lookup_with_immediate_i32_offset() {
let (module, func_id) = make_module_with_segment(ConstExpr::Value(Value::I32(1)));
let result = get_function_table_entry(&module, 1);
assert_eq!(result.unwrap(), func_id);
}
#[test]
fn lookup_with_global_offset() {
let mut config = ModuleConfig::new();
config.generate_producers_section(false);
let mut module = Module::with_config(config);
let g =
module
.globals
.add_local(ValType::I32, false, false, ConstExpr::Value(Value::I32(1)));
module.exports.add("__table_base", g);
let func_id =
FunctionBuilder::new(&mut module.types, &[], &[]).finish(vec![], &mut module.funcs);
module.exports.add("f", func_id);
let table_id = module.tables.add_local(false, 4, None, RefType::FUNCREF);
let elem_id = module.elements.add(
ElementKind::Active {
table: table_id,
offset: ConstExpr::Global(g),
},
ElementItems::Functions(vec![func_id]),
);
module
.tables
.get_mut(table_id)
.elem_segments
.insert(elem_id);
let result = get_function_table_entry(&module, 1);
assert_eq!(result.unwrap(), func_id);
}
#[test]
fn lookup_with_extended_offset() {
let mut config = ModuleConfig::new();
config.generate_producers_section(false);
let mut module = Module::with_config(config);
let g =
module
.globals
.add_local(ValType::I32, false, false, ConstExpr::Value(Value::I32(1)));
module.exports.add("__table_base", g);
let func_id =
FunctionBuilder::new(&mut module.types, &[], &[]).finish(vec![], &mut module.funcs);
module.exports.add("f", func_id);
let table_id = module.tables.add_local(false, 16, None, RefType::FUNCREF);
let elem_id = module.elements.add(
ElementKind::Active {
table: table_id,
offset: ConstExpr::Extended(vec![
ConstOp::GlobalGet(g),
ConstOp::I32Const(4),
ConstOp::I32Add,
]),
},
ElementItems::Functions(vec![func_id]),
);
module
.tables
.get_mut(table_id)
.elem_segments
.insert(elem_id);
let result = get_function_table_entry(&module, 5);
assert_eq!(result.unwrap(), func_id);
}
#[test]
fn lookup_fails_gracefully_when_index_not_in_any_segment() {
let (module, _) = make_module_with_segment(ConstExpr::Value(Value::I32(1)));
assert!(get_function_table_entry(&module, 99).is_err());
}
#[test]
fn lookup_multi_segment_no_underflow() {
let mut config = ModuleConfig::new();
config.generate_producers_section(false);
let mut module = Module::with_config(config);
let func_a =
FunctionBuilder::new(&mut module.types, &[], &[]).finish(vec![], &mut module.funcs);
module.exports.add("func_a", func_a);
let func_b =
FunctionBuilder::new(&mut module.types, &[], &[]).finish(vec![], &mut module.funcs);
module.exports.add("func_b", func_b);
let table_id = module.tables.add_local(false, 256, None, RefType::FUNCREF);
let seg_a = module.elements.add(
ElementKind::Active {
table: table_id,
offset: ConstExpr::Value(Value::I32(0)),
},
ElementItems::Functions(vec![func_a]),
);
module.tables.get_mut(table_id).elem_segments.insert(seg_a);
let seg_b = module.elements.add(
ElementKind::Active {
table: table_id,
offset: ConstExpr::Value(Value::I32(128)),
},
ElementItems::Functions(vec![func_b]),
);
module.tables.get_mut(table_id).elem_segments.insert(seg_b);
assert_eq!(get_function_table_entry(&module, 0).unwrap(), func_a);
assert_eq!(get_function_table_entry(&module, 128).unwrap(), func_b);
}
}