#include "optimizer/top_k_optimizer.h"
#include "planner/operator/logical_limit.h"
#include "planner/operator/logical_order_by.h"
using namespace lbug::planner;
using namespace lbug::common;
namespace lbug {
namespace optimizer {
void TopKOptimizer::rewrite(planner::LogicalPlan* plan) {
plan->setLastOperator(visitOperator(plan->getLastOperator()));
}
std::shared_ptr<LogicalOperator> TopKOptimizer::visitOperator(
const std::shared_ptr<LogicalOperator>& op) {
for (auto i = 0u; i < op->getNumChildren(); ++i) {
op->setChild(i, visitOperator(op->getChild(i)));
}
auto result = visitOperatorReplaceSwitch(op);
result->computeFlatSchema();
return result;
}
std::shared_ptr<LogicalOperator> TopKOptimizer::visitLimitReplace(
std::shared_ptr<LogicalOperator> op) {
auto limit = op->ptrCast<LogicalLimit>();
if (!limit->hasLimitNum()) {
return op; }
auto multiplicityReducer = limit->getChild(0);
DASSERT(multiplicityReducer->getOperatorType() == LogicalOperatorType::MULTIPLICITY_REDUCER);
auto projectionOrOrderBy = multiplicityReducer->getChild(0);
std::shared_ptr<LogicalOrderBy> orderBy;
if (projectionOrOrderBy->getOperatorType() == LogicalOperatorType::PROJECTION) {
if (projectionOrOrderBy->getChild(0)->getOperatorType() != LogicalOperatorType::ORDER_BY) {
return op;
}
orderBy = std::static_pointer_cast<LogicalOrderBy>(projectionOrOrderBy->getChild(0));
} else if (projectionOrOrderBy->getOperatorType() == LogicalOperatorType::ORDER_BY) {
orderBy = std::static_pointer_cast<LogicalOrderBy>(projectionOrOrderBy);
} else {
return op;
}
DASSERT(orderBy != nullptr);
if (limit->hasLimitNum()) {
orderBy->setLimitNum(limit->getLimitNum());
}
if (limit->hasSkipNum()) {
orderBy->setSkipNum(limit->getSkipNum());
}
return projectionOrOrderBy;
}
} }