#include <ir/local-graph.h>
#include <ir/properties.h>
#include <pass.h>
#include <wasm-builder.h>
#include <wasm.h>
namespace wasm {
static bool canReplaceWithReinterpret(Load* load) {
return load->type != Type::unreachable &&
load->bytes == load->type.getByteSize();
}
static Load* getSingleLoad(LocalGraph* localGraph,
LocalGet* get,
const PassOptions& passOptions,
Module& module) {
std::set<LocalGet*> seen;
seen.insert(get);
while (1) {
auto& sets = localGraph->getSetses[get];
if (sets.size() != 1) {
return nullptr;
}
auto* set = *sets.begin();
if (!set) {
return nullptr;
}
auto* value = Properties::getFallthrough(set->value, passOptions, module);
if (auto* parentGet = value->dynCast<LocalGet>()) {
if (seen.emplace(parentGet).second) {
get = parentGet;
continue;
}
return nullptr;
}
if (auto* load = value->dynCast<Load>()) {
return load;
}
return nullptr;
}
}
static bool isReinterpret(Unary* curr) {
return curr->op == ReinterpretInt32 || curr->op == ReinterpretInt64 ||
curr->op == ReinterpretFloat32 || curr->op == ReinterpretFloat64;
}
struct AvoidReinterprets : public WalkerPass<PostWalker<AvoidReinterprets>> {
bool isFunctionParallel() override { return true; }
std::unique_ptr<Pass> create() override {
return std::make_unique<AvoidReinterprets>();
}
struct Info {
bool reinterpreted;
Index ptrLocal;
Index reinterpretedLocal;
};
std::map<Load*, Info> infos;
LocalGraph* localGraph;
void doWalkFunction(Function* func) {
LocalGraph localGraph_(func);
localGraph = &localGraph_;
PostWalker<AvoidReinterprets>::doWalkFunction(func);
optimize(func);
}
void visitUnary(Unary* curr) {
if (isReinterpret(curr)) {
if (auto* get = Properties::getFallthrough(
curr->value, getPassOptions(), *getModule())
->dynCast<LocalGet>()) {
if (auto* load =
getSingleLoad(localGraph, get, getPassOptions(), *getModule())) {
auto& info = infos[load];
info.reinterpreted = true;
}
}
}
}
void optimize(Function* func) {
std::set<Load*> unoptimizables;
for (auto& [load, info] : infos) {
if (info.reinterpreted && canReplaceWithReinterpret(load)) {
auto mem = getModule()->getMemory(load->memory);
info.ptrLocal = Builder::addVar(func, mem->indexType);
info.reinterpretedLocal =
Builder::addVar(func, load->type.reinterpret());
} else {
unoptimizables.insert(load);
}
}
for (auto* load : unoptimizables) {
infos.erase(load);
}
struct FinalOptimizer : public PostWalker<FinalOptimizer> {
std::map<Load*, Info>& infos;
LocalGraph* localGraph;
Module* module;
const PassOptions& passOptions;
FinalOptimizer(std::map<Load*, Info>& infos,
LocalGraph* localGraph,
Module* module,
const PassOptions& passOptions)
: infos(infos), localGraph(localGraph), module(module),
passOptions(passOptions) {}
void visitUnary(Unary* curr) {
if (isReinterpret(curr)) {
auto* value = curr->value;
if (auto* load = value->dynCast<Load>()) {
if (canReplaceWithReinterpret(load)) {
replaceCurrent(makeReinterpretedLoad(load, load->ptr));
}
} else if (auto* get = value->dynCast<LocalGet>()) {
if (auto* load =
getSingleLoad(localGraph, get, passOptions, *module)) {
auto iter = infos.find(load);
if (iter != infos.end()) {
auto& info = iter->second;
Builder builder(*module);
replaceCurrent(builder.makeLocalGet(info.reinterpretedLocal,
load->type.reinterpret()));
}
}
}
}
}
void visitLoad(Load* curr) {
auto iter = infos.find(curr);
if (iter != infos.end()) {
auto& info = iter->second;
Builder builder(*module);
auto* ptr = curr->ptr;
auto mem = getModule()->getMemory(curr->memory);
auto indexType = mem->indexType;
curr->ptr = builder.makeLocalGet(info.ptrLocal, indexType);
replaceCurrent(builder.makeBlock(
{builder.makeLocalSet(info.ptrLocal, ptr),
builder.makeLocalSet(
info.reinterpretedLocal,
makeReinterpretedLoad(
curr, builder.makeLocalGet(info.ptrLocal, indexType))),
curr}));
}
}
Load* makeReinterpretedLoad(Load* load, Expression* ptr) {
Builder builder(*module);
return builder.makeLoad(load->bytes,
false,
load->offset,
load->align,
ptr,
load->type.reinterpret(),
load->memory);
}
} finalOptimizer(infos, localGraph, getModule(), getPassOptions());
finalOptimizer.setModule(getModule());
finalOptimizer.walk(func->body);
}
};
Pass* createAvoidReinterpretsPass() { return new AvoidReinterprets(); }
}