use std::collections::HashMap;
use walrus::{FunctionId, FunctionKind, LocalFunction, Module};
use walrus::ir::{self, dfs_in_order, Instr, Visitor};
use crate::host_functions::HostFunction;
use super::{AnalyzedModule, HostCallSite};
#[derive(Default)]
pub struct FunctionBodyStats {
pub host_calls: Vec<HostCallSite>,
pub local_call_count: usize,
pub has_branches: bool,
pub has_loops: bool,
pub instruction_count: usize,
}
struct BodyAnalyzer<'a> {
module: &'a Module,
host_func_map: &'a HashMap<FunctionId, &'static HostFunction>,
host_calls: Vec<HostCallSite>,
local_call_count: usize,
has_branches: bool,
has_loops: bool,
instruction_count: usize,
}
impl<'instr> Visitor<'instr> for BodyAnalyzer<'_> {
fn visit_instr(
&mut self,
_instr: &'instr Instr,
_loc: &'instr ir::InstrLocId,
) {
self.instruction_count += 1;
}
fn visit_call(&mut self, call: &ir::Call) {
match &self.module.funcs.get(call.func).kind {
FunctionKind::Import(imp) => {
let import = self.module.imports.get(imp.import);
if let Some(hf) = self.host_func_map.get(&call.func) {
self.host_calls.push(HostCallSite {
semantic_module: hf.module.to_string(),
semantic_name: hf.name.to_string(),
raw_module: import.module.clone(),
raw_field: import.name.clone(),
});
} else {
self.host_calls.push(HostCallSite {
semantic_module: import.module.clone(),
semantic_name: import.name.clone(),
raw_module: import.module.clone(),
raw_field: import.name.clone(),
});
}
}
FunctionKind::Local(_) => {
self.local_call_count += 1;
}
FunctionKind::Uninitialized(_) => {}
}
}
fn visit_if_else(&mut self, _: &ir::IfElse) {
self.has_branches = true;
}
fn visit_br_if(&mut self, _: &ir::BrIf) {
self.has_branches = true;
}
fn visit_loop(&mut self, _: &ir::Loop) {
self.has_loops = true;
}
}
impl AnalyzedModule {
pub fn trace_to_impl(&self, func_id: FunctionId) -> FunctionId {
let mut current = func_id;
for _ in 0..5 {
let func = self.module.funcs.get(current);
let local_func = match &func.kind {
FunctionKind::Local(lf) => lf,
_ => break,
};
let unique_local_calls = collect_unique_local_calls(
&self.module, local_func,
);
let has_host_calls = has_any_host_call(
&self.module, &self.host_func_map, local_func,
);
if has_host_calls {
break;
}
let stats = self.analyze_function_body(current);
if stats.instruction_count > 10 || stats.has_branches || stats.has_loops {
break;
}
if unique_local_calls.len() == 1 {
current = unique_local_calls[0];
continue;
}
break;
}
current
}
pub fn analyze_function_body(&self, func_id: FunctionId) -> FunctionBodyStats {
let func = self.module.funcs.get(func_id);
let local_func = match &func.kind {
FunctionKind::Local(lf) => lf,
_ => return FunctionBodyStats::default(),
};
let mut visitor = BodyAnalyzer {
module: &self.module,
host_func_map: &self.host_func_map,
host_calls: Vec::new(),
local_call_count: 0,
has_branches: false,
has_loops: false,
instruction_count: 0,
};
dfs_in_order(&mut visitor, local_func, local_func.entry_block());
FunctionBodyStats {
host_calls: visitor.host_calls,
local_call_count: visitor.local_call_count,
has_branches: visitor.has_branches,
has_loops: visitor.has_loops,
instruction_count: visitor.instruction_count,
}
}
}
fn collect_unique_local_calls(
module: &Module,
func: &LocalFunction,
) -> Vec<FunctionId> {
struct Collector<'a> {
module: &'a Module,
calls: Vec<FunctionId>,
}
impl<'instr> Visitor<'instr> for Collector<'_> {
fn visit_call(&mut self, call: &ir::Call) {
let kind = &self.module.funcs.get(call.func).kind;
if matches!(kind, FunctionKind::Local(_))
&& !self.calls.contains(&call.func)
{
self.calls.push(call.func);
}
}
}
let mut collector = Collector {
module,
calls: Vec::new(),
};
dfs_in_order(&mut collector, func, func.entry_block());
collector.calls
}
fn has_any_host_call(
module: &Module,
host_func_map: &HashMap<FunctionId, &'static HostFunction>,
func: &LocalFunction,
) -> bool {
struct Checker<'a> {
module: &'a Module,
host_func_map: &'a HashMap<FunctionId, &'static HostFunction>,
found: bool,
}
impl<'instr> Visitor<'instr> for Checker<'_> {
fn visit_call(&mut self, call: &ir::Call) {
if self.found {
return;
}
let kind = &self.module.funcs.get(call.func).kind;
if let FunctionKind::Import(_) = kind {
if self.host_func_map.contains_key(&call.func) {
self.found = true;
}
}
}
}
let mut checker = Checker { module, host_func_map, found: false };
dfs_in_order(&mut checker, func, func.entry_block());
checker.found
}