#ifndef wasm_ir_module_h
#define wasm_ir_module_h
#include "ir/find_all.h"
#include "ir/manipulation.h"
#include "ir/properties.h"
#include "pass.h"
#include "support/unique_deferring_queue.h"
#include "wasm.h"
namespace wasm {
namespace ModuleUtils {
inline Function* copyFunction(Function* func, Module& out) {
auto* ret = new Function();
ret->name = func->name;
ret->sig = func->sig;
ret->vars = func->vars;
ret->localNames = func->localNames;
ret->localIndices = func->localIndices;
ret->debugLocations = func->debugLocations;
ret->body = ExpressionManipulator::copy(func->body, out);
ret->module = func->module;
ret->base = func->base;
assert(!func->stackIR);
out.addFunction(ret);
return ret;
}
inline Global* copyGlobal(Global* global, Module& out) {
auto* ret = new Global();
ret->name = global->name;
ret->type = global->type;
ret->mutable_ = global->mutable_;
ret->module = global->module;
ret->base = global->base;
if (global->imported()) {
ret->init = nullptr;
} else {
ret->init = ExpressionManipulator::copy(global->init, out);
}
out.addGlobal(ret);
return ret;
}
inline Event* copyEvent(Event* event, Module& out) {
auto* ret = new Event();
ret->name = event->name;
ret->attribute = event->attribute;
ret->sig = event->sig;
out.addEvent(ret);
return ret;
}
inline void copyModule(const Module& in, Module& out) {
for (auto& curr : in.exports) {
out.addExport(new Export(*curr));
}
for (auto& curr : in.functions) {
copyFunction(curr.get(), out);
}
for (auto& curr : in.globals) {
copyGlobal(curr.get(), out);
}
for (auto& curr : in.events) {
copyEvent(curr.get(), out);
}
out.table = in.table;
for (auto& segment : out.table.segments) {
segment.offset = ExpressionManipulator::copy(segment.offset, out);
}
out.memory = in.memory;
for (auto& segment : out.memory.segments) {
segment.offset = ExpressionManipulator::copy(segment.offset, out);
}
out.start = in.start;
out.userSections = in.userSections;
out.debugInfoFileNames = in.debugInfoFileNames;
}
inline void clearModule(Module& wasm) {
wasm.~Module();
new (&wasm) Module;
}
template<typename T> inline void renameFunctions(Module& wasm, T& map) {
for (auto& pair : map) {
if (Function* F = wasm.getFunctionOrNull(pair.first)) {
assert(!wasm.getFunctionOrNull(pair.second) || F->name == pair.second);
F->name = pair.second;
}
}
wasm.updateMaps();
auto maybeUpdate = [&](Name& name) {
auto iter = map.find(name);
if (iter != map.end()) {
name = iter->second;
}
};
maybeUpdate(wasm.start);
for (auto& segment : wasm.table.segments) {
for (auto& name : segment.data) {
maybeUpdate(name);
}
}
for (auto& exp : wasm.exports) {
if (exp->kind == ExternalKind::Function) {
maybeUpdate(exp->value);
}
}
for (auto& func : wasm.functions) {
if (!func->imported()) {
FindAll<Call> calls(func->body);
for (auto* call : calls.list) {
maybeUpdate(call->target);
}
}
}
}
inline void renameFunction(Module& wasm, Name oldName, Name newName) {
std::map<Name, Name> map;
map[oldName] = newName;
renameFunctions(wasm, map);
}
template<typename T> inline void iterImportedMemories(Module& wasm, T visitor) {
if (wasm.memory.exists && wasm.memory.imported()) {
visitor(&wasm.memory);
}
}
template<typename T> inline void iterDefinedMemories(Module& wasm, T visitor) {
if (wasm.memory.exists && !wasm.memory.imported()) {
visitor(&wasm.memory);
}
}
template<typename T> inline void iterImportedTables(Module& wasm, T visitor) {
if (wasm.table.exists && wasm.table.imported()) {
visitor(&wasm.table);
}
}
template<typename T> inline void iterDefinedTables(Module& wasm, T visitor) {
if (wasm.table.exists && !wasm.table.imported()) {
visitor(&wasm.table);
}
}
template<typename T> inline void iterImportedGlobals(Module& wasm, T visitor) {
for (auto& import : wasm.globals) {
if (import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterDefinedGlobals(Module& wasm, T visitor) {
for (auto& import : wasm.globals) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template<typename T>
inline void iterImportedFunctions(Module& wasm, T visitor) {
for (auto& import : wasm.functions) {
if (import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterDefinedFunctions(Module& wasm, T visitor) {
for (auto& import : wasm.functions) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterImportedEvents(Module& wasm, T visitor) {
for (auto& import : wasm.events) {
if (import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterDefinedEvents(Module& wasm, T visitor) {
for (auto& import : wasm.events) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterImports(Module& wasm, T visitor) {
iterImportedMemories(wasm, visitor);
iterImportedTables(wasm, visitor);
iterImportedGlobals(wasm, visitor);
iterImportedFunctions(wasm, visitor);
iterImportedEvents(wasm, visitor);
}
template<typename T> struct ParallelFunctionAnalysis {
Module& wasm;
typedef std::map<Function*, T> Map;
Map map;
typedef std::function<void(Function*, T&)> Func;
ParallelFunctionAnalysis(Module& wasm, Func work) : wasm(wasm) {
for (auto& func : wasm.functions) {
map[func.get()];
}
for (auto& func : wasm.functions) {
if (func->imported()) {
work(func.get(), map[func.get()]);
}
}
struct Mapper : public WalkerPass<PostWalker<Mapper>> {
bool isFunctionParallel() override { return true; }
bool modifiesBinaryenIR() override { return false; }
Mapper(Module& module, Map& map, Func work)
: module(module), map(map), work(work) {}
Mapper* create() override { return new Mapper(module, map, work); }
void doWalkFunction(Function* curr) {
assert(map.count(curr));
work(curr, map[curr]);
}
private:
Module& module;
Map& map;
Func work;
};
PassRunner runner(&wasm);
Mapper(wasm, map, work).run(&runner, &wasm);
}
};
template<typename T> struct CallGraphPropertyAnalysis {
Module& wasm;
struct FunctionInfo {
std::set<Function*> callsTo;
std::set<Function*> calledBy;
bool hasIndirectCall = false;
};
typedef std::map<Function*, T> Map;
Map map;
typedef std::function<void(Function*, T&)> Func;
CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) {
ParallelFunctionAnalysis<T> analysis(wasm, [&](Function* func, T& info) {
work(func, info);
if (func->imported()) {
return;
}
struct Mapper : public PostWalker<Mapper> {
Mapper(Module* module, T& info, Func work)
: module(module), info(info), work(work) {}
void visitCall(Call* curr) {
info.callsTo.insert(module->getFunction(curr->target));
}
void visitCallIndirect(CallIndirect* curr) {
info.hasIndirectCall = true;
}
private:
Module* module;
T& info;
Func work;
} mapper(&wasm, info, work);
mapper.walk(func->body);
});
map.swap(analysis.map);
for (auto& pair : map) {
auto* func = pair.first;
auto& info = pair.second;
for (auto* target : info.callsTo) {
map[target].calledBy.insert(func);
}
}
}
enum IndirectCalls { IgnoreIndirectCalls, IndirectCallsHaveProperty };
void propagateBack(std::function<bool(const T&)> hasProperty,
std::function<bool(const T&)> canHaveProperty,
std::function<void(T&, Function*)> addProperty,
IndirectCalls indirectCalls) {
UniqueDeferredQueue<Function*> work;
for (auto& func : wasm.functions) {
if (hasProperty(map[func.get()]) ||
(indirectCalls == IndirectCallsHaveProperty &&
map[func.get()].hasIndirectCall)) {
addProperty(map[func.get()], func.get());
work.push(func.get());
}
}
while (!work.empty()) {
auto* func = work.pop();
for (auto* caller : map[func].calledBy) {
if (!hasProperty(map[caller]) && canHaveProperty(map[caller])) {
addProperty(map[caller], func);
work.push(caller);
}
}
}
}
};
inline void
collectSignatures(Module& wasm,
std::vector<Signature>& signatures,
std::unordered_map<Signature, Index>& sigIndices) {
using Counts = std::unordered_map<Signature, size_t>;
auto updateCounts = [&](Function* func, Counts& counts) {
if (func->imported()) {
return;
}
struct TypeCounter
: PostWalker<TypeCounter, UnifiedExpressionVisitor<TypeCounter>> {
Counts& counts;
TypeCounter(Counts& counts) : counts(counts) {}
void visitExpression(Expression* curr) {
if (auto* call = curr->dynCast<CallIndirect>()) {
counts[call->sig]++;
} else if (Properties::isControlFlowStructure(curr)) {
if (curr->type.isTuple()) {
counts[Signature(Type::none, curr->type)]++;
}
}
}
};
TypeCounter(counts).walk(func->body);
};
ModuleUtils::ParallelFunctionAnalysis<Counts> analysis(wasm, updateCounts);
Counts counts;
for (auto& curr : wasm.functions) {
counts[curr->sig]++;
}
for (auto& curr : wasm.events) {
counts[curr->sig]++;
}
for (auto& pair : analysis.map) {
Counts& functionCounts = pair.second;
for (auto& innerPair : functionCounts) {
counts[innerPair.first] += innerPair.second;
}
}
std::vector<std::pair<Signature, size_t>> sorted(counts.begin(),
counts.end());
std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
if (a.second != b.second) {
return a.second > b.second;
}
return a.first < b.first;
});
for (Index i = 0; i < sorted.size(); ++i) {
sigIndices[sorted[i].first] = i;
signatures.push_back(sorted[i].first);
}
}
}
}
#endif