#include <libsolidity/analysis/ControlFlowRevertPruner.h>
#include <libsolutil/Algorithms.h>
#include <range/v3/algorithm/remove.hpp>
namespace solidity::frontend
{
namespace
{
ContractDefinition const* findScopeContract(FunctionDefinition const& _function, ContractDefinition const* _callingContract)
{
if (auto const* functionContract = _function.annotation().contract)
{
if (_callingContract && _callingContract->derivesFrom(*functionContract))
return _callingContract;
else
return functionContract;
}
return nullptr;
}
}
void ControlFlowRevertPruner::run()
{
for (auto& [pair, flow]: m_cfg.allFunctionFlows())
m_functions[pair] = RevertState::Unknown;
findRevertStates();
modifyFunctionFlows();
}
void ControlFlowRevertPruner::findRevertStates()
{
std::set<CFG::FunctionContractTuple> pendingFunctions = util::keys(m_functions);
std::map<CFG::FunctionContractTuple, std::set<CFG::FunctionContractTuple>> wakeUp;
while (!pendingFunctions.empty())
{
CFG::FunctionContractTuple item = *pendingFunctions.begin();
pendingFunctions.erase(pendingFunctions.begin());
if (m_functions[item] != RevertState::Unknown)
continue;
bool foundExit = false;
bool foundUnknown = false;
FunctionFlow const& functionFlow = m_cfg.functionFlow(*item.function, item.contract);
solidity::util::BreadthFirstSearch<CFGNode*>{{functionFlow.entry}}.run(
[&](CFGNode* _node, auto&& _addChild) {
if (_node == functionFlow.exit)
foundExit = true;
if (auto const* functionCall = _node->functionCall)
{
auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.contract);
if (resolvedFunction && resolvedFunction->isImplemented())
{
CFG::FunctionContractTuple calledFunctionTuple{
findScopeContract(*resolvedFunction, item.contract),
resolvedFunction
};
switch (m_functions.at(calledFunctionTuple))
{
case RevertState::Unknown:
wakeUp[calledFunctionTuple].insert(item);
foundUnknown = true;
return;
case RevertState::AllPathsRevert:
return;
case RevertState::HasNonRevertingPath:
break;
}
}
}
for (CFGNode* exit: _node->exits)
_addChild(exit);
}
);
auto& revertState = m_functions[item];
if (foundExit)
revertState = RevertState::HasNonRevertingPath;
else if (!foundUnknown)
revertState = RevertState::AllPathsRevert;
if (revertState != RevertState::Unknown && wakeUp.count(item))
{
for (CFG::FunctionContractTuple const& nextItem: wakeUp[item])
if (m_functions.at(nextItem) == RevertState::Unknown)
pendingFunctions.insert(nextItem);
wakeUp.erase(item);
}
}
}
void ControlFlowRevertPruner::modifyFunctionFlows()
{
for (auto& item: m_functions)
{
FunctionFlow const& functionFlow = m_cfg.functionFlow(*item.first.function, item.first.contract);
solidity::util::BreadthFirstSearch<CFGNode*>{{functionFlow.entry}}.run(
[&](CFGNode* _node, auto&& _addChild) {
if (auto const* functionCall = _node->functionCall)
{
auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.first.contract);
if (resolvedFunction && resolvedFunction->isImplemented())
switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction}))
{
case RevertState::Unknown:
[[fallthrough]];
case RevertState::AllPathsRevert:
for (CFGNode * node: _node->exits)
ranges::remove(node->entries, _node);
_node->exits = {functionFlow.revert};
functionFlow.revert->entries.push_back(_node);
return;
default:
break;
}
}
for (CFGNode* exit: _node->exits)
_addChild(exit);
});
}
}
}