#include "ir/hashed.h"
#include "ir/manipulation.h"
#include "ir/module-utils.h"
#include "ir/names.h"
#include "ir/utils.h"
#include "opt-utils.h"
#include "pass.h"
#include "support/hash.h"
#include "support/utilities.h"
#include "wasm.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <iostream>
#include <map>
#include <memory>
#include <ostream>
#include <variant>
#include <vector>
namespace wasm {
using ConstDiff = std::variant<Literals, std::vector<Name>>;
struct ParamInfo {
ConstDiff values;
std::vector<Expression**> uses;
ParamInfo(ConstDiff values, std::vector<Expression**> uses)
: values(std::move(values)), uses(uses) {}
Type getValueType(Module* module) const {
if (const auto literals = std::get_if<Literals>(&values)) {
return (*literals)[0].type;
} else if (auto callees = std::get_if<std::vector<Name>>(&values)) {
auto* callee = module->getFunction((*callees)[0]);
return Type(callee->type, NonNullable);
} else {
WASM_UNREACHABLE("unexpected const value type");
}
}
Expression*
lowerToExpression(Builder& builder, Module* module, size_t index) const {
if (const auto literals = std::get_if<Literals>(&values)) {
return builder.makeConst((*literals)[index]);
} else if (auto callees = std::get_if<std::vector<Name>>(&values)) {
auto fnName = (*callees)[index];
auto heapType = module->getFunction(fnName)->type;
return builder.makeRefFunc(fnName, heapType);
} else {
WASM_UNREACHABLE("unexpected const value type");
}
}
};
struct EquivalentClass {
Function* primaryFunction;
std::vector<Function*> functions;
EquivalentClass(Function* primaryFunction, std::vector<Function*> functions)
: primaryFunction(primaryFunction), functions(functions) {}
bool isEligibleToMerge() { return this->functions.size() >= 2; }
void merge(Module* module, const std::vector<ParamInfo>& params);
bool hasMergeBenefit(Module* module, const std::vector<ParamInfo>& params);
Function* createShared(Module* module, const std::vector<ParamInfo>& params);
Function* replaceWithThunk(Builder& builder,
Function* target,
Function* shared,
const std::vector<ParamInfo>& params,
const std::vector<Expression*>& extraArgs);
bool deriveParams(Module* module,
std::vector<ParamInfo>& params,
bool isIndirectionEnabled);
};
struct MergeSimilarFunctions : public Pass {
bool invalidatesDWARF() override { return true; }
void run(Module* module) override {
std::vector<EquivalentClass> classes;
collectEquivalentClasses(classes, module);
std::sort(
classes.begin(), classes.end(), [](const auto& left, const auto& right) {
return left.primaryFunction->name < right.primaryFunction->name;
});
for (auto& clazz : classes) {
if (!clazz.isEligibleToMerge()) {
continue;
}
std::vector<ParamInfo> params;
if (!clazz.deriveParams(
module, params, isCallIndirectionEnabled(module))) {
continue;
}
if (!clazz.hasMergeBenefit(module, params)) {
continue;
}
clazz.merge(module, params);
}
}
bool isCallIndirectionEnabled(Module* module) const {
return module->features.hasReferenceTypes() && module->features.hasGC();
}
bool areInEquvalentClass(Function* lhs, Function* rhs, Module* module);
void collectEquivalentClasses(std::vector<EquivalentClass>& classes,
Module* module);
};
bool MergeSimilarFunctions::areInEquvalentClass(Function* lhs,
Function* rhs,
Module* module) {
if (lhs->imported() || rhs->imported()) {
return false;
}
if (lhs->type != rhs->type) {
return false;
}
if (lhs->getNumVars() != rhs->getNumVars()) {
return false;
}
ExpressionAnalyzer::ExprComparer comparer = [&](Expression* lhsExpr,
Expression* rhsExpr) {
if (lhsExpr->_id != rhsExpr->_id) {
return false;
}
if (lhsExpr->type != rhsExpr->type) {
return false;
}
if (lhsExpr->is<Call>()) {
if (!this->isCallIndirectionEnabled(module)) {
return false;
}
auto lhsCast = lhsExpr->dynCast<Call>();
auto rhsCast = rhsExpr->dynCast<Call>();
if (lhsCast->operands.size() != rhsCast->operands.size()) {
return false;
}
if (lhsCast->type != rhsCast->type) {
return false;
}
auto* lhsCallee = module->getFunction(lhsCast->target);
auto* rhsCallee = module->getFunction(rhsCast->target);
if (lhsCallee->type != rhsCallee->type) {
return false;
}
for (Index i = 0; i < lhsCast->operands.size(); i++) {
if (!ExpressionAnalyzer::flexibleEqual(
lhsCast->operands[i], rhsCast->operands[i], comparer)) {
return false;
}
}
return true;
}
if (lhsExpr->is<Const>()) {
auto lhsCast = lhsExpr->dynCast<Const>();
auto rhsCast = rhsExpr->dynCast<Const>();
if (lhsCast->value.type != rhsCast->value.type) {
return false;
}
return true;
}
return false;
};
if (!ExpressionAnalyzer::flexibleEqual(lhs->body, rhs->body, comparer)) {
return false;
}
return true;
}
void MergeSimilarFunctions::collectEquivalentClasses(
std::vector<EquivalentClass>& classes, Module* module) {
auto hashes = FunctionHasher::createMap(module);
PassRunner runner(module);
std::function<bool(Expression*, size_t&)> ignoringConsts =
[&](Expression* expr, size_t& digest) {
if (expr->is<Const>()) {
return true;
}
if (auto* call = expr->dynCast<Call>()) {
for (auto operand : call->operands) {
rehash(digest,
ExpressionAnalyzer::flexibleHash(operand, ignoringConsts));
}
rehash(digest, call->isReturn);
return true;
}
return false;
};
FunctionHasher(&hashes, ignoringConsts).run(&runner, module);
std::map<size_t, std::vector<Function*>> hashGroups;
ModuleUtils::iterDefinedFunctions(
*module, [&](Function* func) { hashGroups[hashes[func]].push_back(func); });
for (auto& [_, hashGroup] : hashGroups) {
if (hashGroup.size() < 2) {
continue;
}
std::vector<EquivalentClass> classesInGroup = {
EquivalentClass(hashGroup[0], {hashGroup[0]})};
for (Index i = 1; i < hashGroup.size(); i++) {
auto* func = hashGroup[i];
bool found = false;
for (auto& newClass : classesInGroup) {
if (areInEquvalentClass(newClass.primaryFunction, func, module)) {
newClass.functions.push_back(func);
found = true;
break;
}
}
if (!found) {
classesInGroup.push_back(EquivalentClass(func, {func}));
}
}
std::copy(classesInGroup.begin(),
classesInGroup.end(),
std::back_inserter(classes));
}
}
bool EquivalentClass::deriveParams(Module* module,
std::vector<ParamInfo>& params,
bool isCallIndirectionEnabled) {
struct DeepValueIterator {
SmallVector<Expression**, 10> tasks;
DeepValueIterator(Expression** root) { tasks.push_back(root); }
void operator++() {
ChildIterator it(*tasks.back());
tasks.pop_back();
for (Expression*& child : it) {
tasks.push_back(&child);
}
}
Expression*& operator*() {
assert(!empty());
return *tasks.back();
}
bool empty() { return tasks.empty(); }
};
if (primaryFunction->imported()) {
return false;
}
DeepValueIterator primaryIt(&primaryFunction->body);
std::vector<DeepValueIterator> siblingIterators;
assert(functions.size() >= 2);
for (auto func = functions.begin() + 1; func != functions.end(); ++func) {
siblingIterators.emplace_back(&(*func)->body);
}
for (; !primaryIt.empty(); ++primaryIt) {
Expression*& primary = *primaryIt;
ConstDiff diff;
Literals values;
std::vector<Name> names;
bool isAllSame = true;
if (auto* primaryConst = primary->dynCast<Const>()) {
values.push_back(primaryConst->value);
for (auto& it : siblingIterators) {
Expression*& sibling = *it;
++it;
if (auto* siblingConst = sibling->dynCast<Const>()) {
isAllSame &= primaryConst->value == siblingConst->value;
values.push_back(siblingConst->value);
} else {
WASM_UNREACHABLE(
"all sibling functions should have the same instruction type");
}
}
diff = values;
} else if (isCallIndirectionEnabled && primary->is<Call>()) {
auto* primaryCall = primary->dynCast<Call>();
names.push_back(primaryCall->target);
for (auto& it : siblingIterators) {
Expression*& sibling = *it;
++it;
if (auto* siblingCall = sibling->dynCast<Call>()) {
isAllSame &= primaryCall->target == siblingCall->target;
names.push_back(siblingCall->target);
} else {
WASM_UNREACHABLE(
"all sibling functions should have the same instruction type");
}
}
diff = names;
} else {
for (auto& it : siblingIterators) {
assert((*it)->_id == primary->_id);
++it;
}
continue;
}
if (isAllSame) {
continue;
}
bool paramReused = false;
for (auto& param : params) {
if (param.values == diff) {
param.uses.push_back(&primary);
paramReused = true;
break;
}
}
if (!paramReused) {
params.push_back(ParamInfo(diff, {&primary}));
}
}
return true;
}
void EquivalentClass::merge(Module* module,
const std::vector<ParamInfo>& params) {
Function* sharedFn = createShared(module, params);
for (size_t i = 0; i < functions.size(); ++i) {
Builder builder(*module);
auto* func = functions[i];
std::vector<Expression*> extraArgs;
for (auto& param : params) {
extraArgs.push_back(param.lowerToExpression(builder, module, i));
}
replaceWithThunk(builder, func, sharedFn, params, extraArgs);
}
return;
}
bool EquivalentClass::hasMergeBenefit(Module* module,
const std::vector<ParamInfo>& params) {
size_t funcCount = functions.size();
Index exprSize = Measurer::measure(primaryFunction->body);
size_t thunkCount = funcCount;
size_t removedInstrs = (funcCount - 1) * exprSize;
size_t addedInstrsPerThunk =
thunkCount * (
1 +
primaryFunction->getParams().size() + params.size());
constexpr size_t INSTR_WEIGHT = 1;
constexpr size_t CODE_SEC_LOCALS_WEIGHT = 1;
constexpr size_t CODE_SEC_ENTRY_WEIGHT = 2;
constexpr size_t FUNC_SEC_ENTRY_WEIGHT = 2;
size_t negativeScore =
addedInstrsPerThunk * INSTR_WEIGHT +
thunkCount * (
(params.size() * CODE_SEC_LOCALS_WEIGHT) +
CODE_SEC_ENTRY_WEIGHT) +
(thunkCount * FUNC_SEC_ENTRY_WEIGHT);
size_t positiveScore = INSTR_WEIGHT * removedInstrs;
return negativeScore < positiveScore;
}
Function* EquivalentClass::createShared(Module* module,
const std::vector<ParamInfo>& params) {
Name fnName = Names::getValidFunctionName(*module,
std::string("byn$mgfn-shared$") +
primaryFunction->name.toString());
Builder builder(*module);
std::vector<Type> sigParams;
Index extraParamBase = primaryFunction->getNumParams();
Index newVarBase = primaryFunction->getNumParams() + params.size();
for (const auto& param : primaryFunction->getParams()) {
sigParams.push_back(param);
}
for (const auto& param : params) {
sigParams.push_back(param.getValueType(module));
}
Signature sig(Type(sigParams), primaryFunction->getResults());
ExpressionManipulator::CustomCopier copier =
[&](Expression* expr) -> Expression* {
if (!expr) {
return nullptr;
}
for (Index paramIdx = 0; paramIdx < params.size(); paramIdx++) {
for (auto& use : params[paramIdx].uses) {
if (*use != expr) {
continue;
}
auto* paramExpr = builder.makeLocalGet(
extraParamBase + paramIdx, params[paramIdx].getValueType(module));
if (expr->is<Const>()) {
return paramExpr;
} else if (auto* call = expr->cast<Call>()) {
ExpressionList operands(module->allocator);
for (auto* operand : call->operands) {
operands.push_back(
ExpressionManipulator::flexibleCopy(operand, *module, copier));
}
auto returnType = module->getFunction(call->target)->getResults();
return builder.makeCallRef(
paramExpr, operands, returnType, call->isReturn);
}
}
}
if (auto* localGet = expr->dynCast<LocalGet>()) {
if (primaryFunction->isVar(localGet->index)) {
localGet->index =
newVarBase + (localGet->index - primaryFunction->getNumParams());
localGet->finalize();
return localGet;
}
}
if (auto* localSet = expr->dynCast<LocalSet>()) {
if (primaryFunction->isVar(localSet->index)) {
auto operand =
ExpressionManipulator::flexibleCopy(localSet->value, *module, copier);
localSet->index =
newVarBase + (localSet->index - primaryFunction->getNumParams());
localSet->value = operand;
localSet->finalize();
return localSet;
}
}
return nullptr;
};
Expression* body =
ExpressionManipulator::flexibleCopy(primaryFunction->body, *module, copier);
auto vars = primaryFunction->vars;
std::unique_ptr<Function> f =
builder.makeFunction(fnName, sig, std::move(vars), body);
return module->addFunction(std::move(f));
}
Function*
EquivalentClass::replaceWithThunk(Builder& builder,
Function* target,
Function* shared,
const std::vector<ParamInfo>& params,
const std::vector<Expression*>& extraArgs) {
std::vector<Expression*> callOperands;
Type targetParams = target->getParams();
for (Index i = 0; i < targetParams.size(); i++) {
callOperands.push_back(builder.makeLocalGet(i, targetParams[i]));
}
for (const auto& value : extraArgs) {
callOperands.push_back(value);
}
auto ret = builder.makeCall(shared->name, callOperands, target->getResults());
target->vars.clear();
target->body = ret;
return target;
}
Pass* createMergeSimilarFunctionsPass() { return new MergeSimilarFunctions(); }
}