#include <memory>
#include <pass.h>
#include <wasm.h>
namespace wasm {
struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals>> {
bool isFunctionParallel() override { return true; }
bool requiresNonNullableLocalFixups() override { return false; }
std::unique_ptr<Pass> create() override {
return std::make_unique<ReorderLocals>();
}
std::vector<Index> counts;
std::vector<Index> firstUses;
Index firstUseIndex = 1;
enum { Unseen = 0 };
void doWalkFunction(Function* curr) {
if (curr->getNumVars() == 0) {
return; }
Index num = curr->getNumLocals();
counts.clear();
counts.resize(num);
firstUses.clear();
firstUses.resize(num, Unseen);
walk(curr->body);
std::vector<Index> newToOld(num);
for (size_t i = 0; i < num; i++) {
newToOld[i] = i;
}
sort(
newToOld.begin(), newToOld.end(), [this, curr](Index a, Index b) -> bool {
if (curr->isParam(a) && !curr->isParam(b)) {
return true;
}
if (curr->isParam(b) && !curr->isParam(a)) {
return false;
}
if (curr->isParam(b) && curr->isParam(a)) {
return a < b;
}
if (counts[a] == counts[b]) {
if (counts[a] == 0) {
return a < b;
}
return firstUses[a] < firstUses[b];
}
return counts[a] > counts[b];
});
size_t numParams = curr->getParams().size();
for (size_t i = 0; i < numParams; i++) {
assert(newToOld[i] < numParams);
newToOld[i] = i;
}
std::vector<Type> oldVars;
std::swap(oldVars, curr->vars);
for (size_t i = curr->getVarIndexBase(); i < newToOld.size(); i++) {
Index index = newToOld[i];
if (counts[index] > 0) {
curr->vars.push_back(oldVars[index - curr->getVarIndexBase()]);
} else {
newToOld.resize(i);
break;
}
}
counts.clear();
std::vector<Index> oldToNew;
oldToNew.resize(num);
for (size_t i = 0; i < newToOld.size(); i++) {
if (curr->isParam(i)) {
oldToNew[i] = i;
} else {
oldToNew[newToOld[i]] = i;
}
}
struct ReIndexer : public PostWalker<ReIndexer> {
Function* func;
std::vector<Index>& oldToNew;
ReIndexer(Function* func, std::vector<Index>& oldToNew)
: func(func), oldToNew(oldToNew) {}
void visitLocalGet(LocalGet* curr) {
curr->index = oldToNew[curr->index];
}
void visitLocalSet(LocalSet* curr) {
curr->index = oldToNew[curr->index];
}
};
ReIndexer reIndexer(curr, oldToNew);
reIndexer.walk(curr->body);
auto oldLocalNames = curr->localNames;
auto oldLocalIndices = curr->localIndices;
curr->localNames.clear();
curr->localIndices.clear();
for (size_t i = 0; i < newToOld.size(); i++) {
auto iter = oldLocalNames.find(newToOld[i]);
if (iter != oldLocalNames.end()) {
auto old = iter->second;
curr->localNames[i] = old;
curr->localIndices[old] = i;
}
}
}
void visitLocalGet(LocalGet* curr) {
counts[curr->index]++;
if (firstUses[curr->index] == Unseen) {
firstUses[curr->index] = firstUseIndex++;
}
}
void visitLocalSet(LocalSet* curr) {
counts[curr->index]++;
if (firstUses[curr->index] == Unseen) {
firstUses[curr->index] = firstUseIndex++;
}
}
};
Pass* createReorderLocalsPass() { return new ReorderLocals(); }
}