#ifndef wasm_ir_module_h
#define wasm_ir_module_h
#include "pass.h"
#include "support/unique_deferring_queue.h"
#include "wasm.h"
namespace wasm::ModuleUtils {
Function*
copyFunction(Function* func,
Module& out,
Name newName = Name(),
std::optional<std::vector<Index>> fileIndexMap = std::nullopt);
Global* copyGlobal(Global* global, Module& out);
Tag* copyTag(Tag* tag, Module& out);
ElementSegment* copyElementSegment(const ElementSegment* segment, Module& out);
Table* copyTable(const Table* table, Module& out);
Memory* copyMemory(const Memory* memory, Module& out);
DataSegment* copyDataSegment(const DataSegment* segment, Module& out);
void copyModuleItems(const Module& in, Module& out);
void copyModule(const Module& in, Module& out);
void clearModule(Module& wasm);
template<typename T> void renameFunctions(Module& wasm, T& map);
void renameFunction(Module& wasm, Name oldName, Name newName);
template<typename T> inline void iterImportedMemories(Module& wasm, T visitor) {
for (auto& import : wasm.memories) {
if (import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterDefinedMemories(Module& wasm, T visitor) {
for (auto& import : wasm.memories) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template<typename T>
inline void iterMemorySegments(Module& wasm, Name memory, T visitor) {
for (auto& segment : wasm.dataSegments) {
if (!segment->isPassive && segment->memory == memory) {
visitor(segment.get());
}
}
}
template<typename T>
inline void iterActiveDataSegments(Module& wasm, T visitor) {
for (auto& segment : wasm.dataSegments) {
if (!segment->isPassive) {
visitor(segment.get());
}
}
}
template<typename T> inline void iterImportedTables(Module& wasm, T visitor) {
for (auto& import : wasm.tables) {
if (import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterDefinedTables(Module& wasm, T visitor) {
for (auto& import : wasm.tables) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template<typename T>
inline void iterTableSegments(Module& wasm, Name table, T visitor) {
assert(table.is() && "Table name must not be null");
for (auto& segment : wasm.elementSegments) {
if (segment->table == table) {
visitor(segment.get());
}
}
}
template<typename T>
inline void iterActiveElementSegments(Module& wasm, T visitor) {
for (auto& segment : wasm.elementSegments) {
if (segment->table.is()) {
visitor(segment.get());
}
}
}
template<typename T> inline void iterImportedGlobals(Module& wasm, T visitor) {
for (auto& import : wasm.globals) {
if (import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterDefinedGlobals(Module& wasm, T visitor) {
for (auto& import : wasm.globals) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template<typename T>
inline void iterImportedFunctions(Module& wasm, T visitor) {
for (auto& import : wasm.functions) {
if (import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterDefinedFunctions(Module& wasm, T visitor) {
for (auto& import : wasm.functions) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterImportedTags(Module& wasm, T visitor) {
for (auto& import : wasm.tags) {
if (import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterDefinedTags(Module& wasm, T visitor) {
for (auto& import : wasm.tags) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template<typename T> inline void iterImports(Module& wasm, T visitor) {
iterImportedMemories(wasm, visitor);
iterImportedTables(wasm, visitor);
iterImportedGlobals(wasm, visitor);
iterImportedFunctions(wasm, visitor);
iterImportedTags(wasm, visitor);
}
template<typename T> inline void iterImportable(Module& wasm, T visitor) {
for (auto& curr : wasm.functions) {
if (curr->imported()) {
visitor(ExternalKind::Function, curr.get());
}
}
for (auto& curr : wasm.tables) {
if (curr->imported()) {
visitor(ExternalKind::Table, curr.get());
}
}
for (auto& curr : wasm.memories) {
if (curr->imported()) {
visitor(ExternalKind::Memory, curr.get());
}
}
for (auto& curr : wasm.globals) {
if (curr->imported()) {
visitor(ExternalKind::Global, curr.get());
}
}
for (auto& curr : wasm.tags) {
if (curr->imported()) {
visitor(ExternalKind::Tag, curr.get());
}
}
}
template<typename T> inline void iterModuleItems(Module& wasm, T visitor) {
for (auto& curr : wasm.functions) {
visitor(ModuleItemKind::Function, curr.get());
}
for (auto& curr : wasm.tables) {
visitor(ModuleItemKind::Table, curr.get());
}
for (auto& curr : wasm.memories) {
visitor(ModuleItemKind::Memory, curr.get());
}
for (auto& curr : wasm.globals) {
visitor(ModuleItemKind::Global, curr.get());
}
for (auto& curr : wasm.tags) {
visitor(ModuleItemKind::Tag, curr.get());
}
for (auto& curr : wasm.dataSegments) {
visitor(ModuleItemKind::DataSegment, curr.get());
}
for (auto& curr : wasm.elementSegments) {
visitor(ModuleItemKind::ElementSegment, curr.get());
}
}
template<typename K, typename V> using DefaultMap = std::map<K, V>;
template<typename T,
Mutability Mut = Immutable,
template<typename, typename> class MapT = DefaultMap>
struct ParallelFunctionAnalysis {
Module& wasm;
using Map = MapT<Function*, T>;
Map map;
using Func = std::function<void(Function*, T&)>;
ParallelFunctionAnalysis(Module& wasm, Func work) : wasm(wasm) {
for (auto& func : wasm.functions) {
map[func.get()];
}
doAnalysis(work);
}
void doAnalysis(Func work) {
for (auto& func : wasm.functions) {
if (func->imported()) {
work(func.get(), map[func.get()]);
}
}
struct Mapper : public WalkerPass<PostWalker<Mapper>> {
bool isFunctionParallel() override { return true; }
bool modifiesBinaryenIR() override { return Mut; }
Mapper(Module& module, Map& map, Func work)
: module(module), map(map), work(work) {}
std::unique_ptr<Pass> create() override {
return std::make_unique<Mapper>(module, map, work);
}
void doWalkFunction(Function* curr) {
assert(map.count(curr));
work(curr, map[curr]);
}
private:
Module& module;
Map& map;
Func work;
};
PassRunner runner(&wasm);
Mapper(wasm, map, work).run(&runner, &wasm);
}
};
template<typename T> struct CallGraphPropertyAnalysis {
Module& wasm;
struct FunctionInfo {
std::set<Function*> callsTo;
std::set<Function*> calledBy;
bool hasNonDirectCall = false;
};
using Map = std::map<Function*, T>;
Map map;
using Func = std::function<void(Function*, T&)>;
CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) {
ParallelFunctionAnalysis<T> analysis(wasm, [&](Function* func, T& info) {
work(func, info);
if (func->imported()) {
return;
}
struct Mapper : public PostWalker<Mapper> {
Mapper(Module* module, T& info, Func work)
: module(module), info(info), work(work) {}
void visitCall(Call* curr) {
info.callsTo.insert(module->getFunction(curr->target));
}
void visitCallIndirect(CallIndirect* curr) {
info.hasNonDirectCall = true;
}
void visitCallRef(CallRef* curr) { info.hasNonDirectCall = true; }
private:
Module* module;
T& info;
Func work;
} mapper(&wasm, info, work);
mapper.walk(func->body);
});
map.swap(analysis.map);
for (auto& [func, info] : map) {
for (auto* target : info.callsTo) {
map[target].calledBy.insert(func);
}
}
}
enum NonDirectCalls { IgnoreNonDirectCalls, NonDirectCallsHaveProperty };
void propagateBack(std::function<bool(const T&)> hasProperty,
std::function<bool(const T&)> canHaveProperty,
std::function<void(T&, Function*)> addProperty,
NonDirectCalls nonDirectCalls) {
UniqueDeferredQueue<Function*> work;
for (auto& func : wasm.functions) {
if (hasProperty(map[func.get()]) ||
(nonDirectCalls == NonDirectCallsHaveProperty &&
map[func.get()].hasNonDirectCall)) {
addProperty(map[func.get()], func.get());
work.push(func.get());
}
}
while (!work.empty()) {
auto* func = work.pop();
for (auto* caller : map[func].calledBy) {
if (!hasProperty(map[caller]) && canHaveProperty(map[caller])) {
addProperty(map[caller], func);
work.push(caller);
}
}
}
}
};
std::vector<HeapType> collectHeapTypes(Module& wasm);
std::vector<HeapType> getPublicHeapTypes(Module& wasm);
std::vector<HeapType> getPrivateHeapTypes(Module& wasm);
struct IndexedHeapTypes {
std::vector<HeapType> types;
std::unordered_map<HeapType, Index> indices;
};
IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm);
}
#endif