#include "optimizer/count_rel_table_optimizer.h"
#include "binder/expression/aggregate_function_expression.h"
#include "binder/expression/node_expression.h"
#include "catalog/catalog_entry/node_table_id_pair.h"
#include "function/aggregate/count_star.h"
#include "main/client_context.h"
#include "planner/operator/extend/logical_extend.h"
#include "planner/operator/logical_aggregate.h"
#include "planner/operator/logical_projection.h"
#include "planner/operator/scan/logical_count_rel_table.h"
#include "planner/operator/scan/logical_scan_node_table.h"
using namespace lbug::common;
using namespace lbug::planner;
using namespace lbug::binder;
using namespace lbug::catalog;
namespace lbug {
namespace optimizer {
void CountRelTableOptimizer::rewrite(LogicalPlan* plan) {
visitOperator(plan->getLastOperator());
}
std::shared_ptr<LogicalOperator> CountRelTableOptimizer::visitOperator(
const std::shared_ptr<LogicalOperator>& op) {
for (auto i = 0u; i < op->getNumChildren(); ++i) {
op->setChild(i, visitOperator(op->getChild(i)));
}
auto result = visitOperatorReplaceSwitch(op);
result->computeFlatSchema();
return result;
}
bool CountRelTableOptimizer::isSimpleCountStar(LogicalOperator* op) const {
if (op->getOperatorType() != LogicalOperatorType::AGGREGATE) {
return false;
}
auto& aggregate = op->constCast<LogicalAggregate>();
if (aggregate.hasKeys()) {
return false;
}
auto aggregates = aggregate.getAggregates();
if (aggregates.size() != 1) {
return false;
}
auto& aggExpr = aggregates[0];
if (aggExpr->expressionType != ExpressionType::AGGREGATE_FUNCTION) {
return false;
}
auto& aggFuncExpr = aggExpr->constCast<AggregateFunctionExpression>();
if (aggFuncExpr.getFunction().name != function::CountStarFunction::name) {
return false;
}
if (aggFuncExpr.isDistinct()) {
return false;
}
return true;
}
bool CountRelTableOptimizer::canOptimize(LogicalOperator* aggregate) const {
auto* current = aggregate->getChild(0).get();
while (current->getOperatorType() == LogicalOperatorType::PROJECTION) {
auto& proj = current->constCast<LogicalProjection>();
if (!proj.getExpressionsToProject().empty()) {
for (auto& expr : proj.getExpressionsToProject()) {
if (expr->expressionType != ExpressionType::AGGREGATE_FUNCTION) {
return false;
}
}
}
current = current->getChild(0).get();
}
if (current->getOperatorType() != LogicalOperatorType::EXTEND) {
return false;
}
auto& extend = current->constCast<LogicalExtend>();
if (extend.getDirection() == ExtendDirection::BOTH) {
return false;
}
auto rel = extend.getRel();
if (rel->isMultiLabeled()) {
return false;
}
if (!extend.getProperties().empty()) {
return false;
}
auto* extendChild = current->getChild(0).get();
if (extendChild->getOperatorType() != LogicalOperatorType::SCAN_NODE_TABLE) {
return false;
}
auto& scanNode = extendChild->constCast<LogicalScanNodeTable>();
if (!scanNode.getProperties().empty()) {
return false;
}
return true;
}
std::shared_ptr<LogicalOperator> CountRelTableOptimizer::visitAggregateReplace(
std::shared_ptr<LogicalOperator> op) {
if (!isSimpleCountStar(op.get())) {
return op;
}
if (!canOptimize(op.get())) {
return op;
}
auto* current = op->getChild(0).get();
while (current->getOperatorType() == LogicalOperatorType::PROJECTION) {
current = current->getChild(0).get();
}
DASSERT(current->getOperatorType() == LogicalOperatorType::EXTEND);
auto& extend = current->constCast<LogicalExtend>();
auto rel = extend.getRel();
auto boundNode = extend.getBoundNode();
auto nbrNode = extend.getNbrNode();
DASSERT(rel->getNumEntries() == 1);
auto* relGroupEntry = rel->getEntry(0)->ptrCast<RelGroupCatalogEntry>();
auto boundNodeTableIDs = boundNode->getTableIDsSet();
auto nbrNodeTableIDs = nbrNode->getTableIDsSet();
std::vector<table_id_t> relTableIDs;
for (auto& info : relGroupEntry->getRelEntryInfos()) {
table_id_t srcTableID = info.nodePair.srcTableID;
table_id_t dstTableID = info.nodePair.dstTableID;
bool matches = false;
if (extend.extendFromSourceNode()) {
matches =
boundNodeTableIDs.contains(srcTableID) && nbrNodeTableIDs.contains(dstTableID);
} else {
matches =
boundNodeTableIDs.contains(dstTableID) && nbrNodeTableIDs.contains(srcTableID);
}
if (matches) {
relTableIDs.push_back(info.oid);
}
}
if (relTableIDs.empty()) {
return op;
}
auto& aggregate = op->constCast<LogicalAggregate>();
auto countExpr = aggregate.getAggregates()[0];
std::vector<table_id_t> boundNodeTableIDsVec(boundNodeTableIDs.begin(),
boundNodeTableIDs.end());
auto countRelTable =
std::make_shared<LogicalCountRelTable>(relGroupEntry, std::move(relTableIDs),
std::move(boundNodeTableIDsVec), boundNode, extend.getDirection(), countExpr);
countRelTable->computeFlatSchema();
return countRelTable;
}
} }