#include "optimizer/foreign_join_push_down_optimizer.h"
#include <algorithm>
#include "binder/expression/property_expression.h"
#include "binder/expression/variable_expression.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/exception/runtime.h"
#include "main/database_manager.h"
#include "planner/operator/extend/logical_extend.h"
#include "planner/operator/logical_flatten.h"
#include "planner/operator/logical_hash_join.h"
#include "planner/operator/logical_table_function_call.h"
#include "planner/operator/scan/logical_scan_node_table.h"
#include <format>
using namespace lbug::binder;
using namespace lbug::common;
using namespace lbug::planner;
using namespace lbug::catalog;
namespace lbug {
namespace optimizer {
void ForeignJoinPushDownOptimizer::rewrite(LogicalPlan* plan) {
visitOperator(plan->getLastOperator());
}
std::shared_ptr<LogicalOperator> ForeignJoinPushDownOptimizer::visitOperator(
const std::shared_ptr<LogicalOperator>& op) {
for (auto i = 0u; i < op->getNumChildren(); ++i) {
op->setChild(i, visitOperator(op->getChild(i)));
}
auto result = visitOperatorReplaceSwitch(op);
result->computeFlatSchema();
return result;
}
static bool isForeignTableFunctionCall(const LogicalOperator* op) {
if (op->getOperatorType() != LogicalOperatorType::TABLE_FUNCTION_CALL) {
return false;
}
auto& tableFuncCall = op->constCast<LogicalTableFunctionCall>();
return tableFuncCall.getTableFunc().supportsPushDownFunc();
}
static bool hasForeignScanFunction(const RelExpression* rel) {
if (rel->getNumEntries() != 1) {
return false;
}
auto relEntry = rel->getEntry(0)->ptrCast<RelGroupCatalogEntry>();
return relEntry && relEntry->getScanFunction().has_value();
}
static std::string getNodeForeignDatabaseName(const NodeExpression* node,
main::ClientContext* context) {
if (!node || node->getNumEntries() != 1) {
return "";
}
auto entry = node->getEntry(0);
if (!entry) {
return "";
}
std::string dbName;
if (entry->getType() == CatalogEntryType::NODE_TABLE_ENTRY) {
auto nodeEntry = entry->ptrCast<NodeTableCatalogEntry>();
if (!nodeEntry) {
return "";
}
dbName = nodeEntry->getForeignDatabaseName();
} else if (entry->getType() == CatalogEntryType::FOREIGN_TABLE_ENTRY) {
dbName = node->getDbName(entry);
}
if (dbName.empty()) {
return "";
}
auto dbManager = main::DatabaseManager::Get(*context);
auto attachedDB = dbManager->getAttachedDatabase(dbName);
if (!attachedDB) {
return "";
}
return std::format("{}({})", dbName, attachedDB->getDBType());
}
static std::string getRelForeignDatabaseName(const RelExpression* rel,
main::ClientContext* context) {
if (!rel || rel->getNumEntries() != 1) {
return "";
}
auto entry = rel->getEntry(0);
if (!entry) {
return "";
}
auto relEntry = entry->ptrCast<RelGroupCatalogEntry>();
if (!relEntry) {
return "";
}
auto storedName = relEntry->getForeignDatabaseName();
if (!storedName.empty()) {
return storedName;
}
auto storage = relEntry->getStorage();
auto dotPos = storage.find('.');
if (dotPos == std::string::npos) {
return "";
}
auto dbName = storage.substr(0, dotPos);
auto dbManager = main::DatabaseManager::Get(*context);
auto attachedDB = dbManager->getAttachedDatabase(dbName);
if (!attachedDB) {
return "";
}
return std::format("{}({})", dbName, attachedDB->getDBType());
}
struct ForeignJoinPatternInfo {
const LogicalExtend* extend = nullptr;
const LogicalTableFunctionCall* srcTableFunc = nullptr;
const LogicalTableFunctionCall* dstTableFunc = nullptr;
const LogicalHashJoin* outerHashJoin = nullptr;
const LogicalHashJoin* innerHashJoin = nullptr;
const Schema* outputSchema = nullptr;
std::string srcTable;
std::string dstTable;
std::string relTable;
std::string dbName; };
static std::optional<ForeignJoinPatternInfo> matchPattern(const LogicalOperator* op,
main::ClientContext* context) {
if (op == nullptr) {
return std::nullopt;
}
ForeignJoinPatternInfo info;
info.outputSchema = op->getSchema();
if (op->getOperatorType() != LogicalOperatorType::HASH_JOIN) {
return std::nullopt;
}
if (op->getNumChildren() < 2) {
return std::nullopt;
}
info.outerHashJoin = op->constPtrCast<LogicalHashJoin>();
if (info.outerHashJoin->getJoinType() != JoinType::INNER) {
return std::nullopt;
}
auto buildChild = op->getChild(1).get();
if (buildChild == nullptr || !isForeignTableFunctionCall(buildChild)) {
return std::nullopt;
}
info.dstTableFunc = buildChild->constPtrCast<LogicalTableFunctionCall>();
auto probeOp = op->getChild(0).get();
if (probeOp == nullptr) {
return std::nullopt;
}
if (probeOp->getOperatorType() == LogicalOperatorType::FLATTEN) {
if (probeOp->getNumChildren() < 1) {
return std::nullopt;
}
probeOp = probeOp->getChild(0).get();
if (probeOp == nullptr) {
return std::nullopt;
}
}
if (probeOp->getOperatorType() != LogicalOperatorType::HASH_JOIN) {
return std::nullopt;
}
if (probeOp->getNumChildren() < 2) {
return std::nullopt;
}
info.innerHashJoin = probeOp->constPtrCast<LogicalHashJoin>();
if (info.innerHashJoin->getJoinType() != JoinType::INNER) {
return std::nullopt;
}
auto innerBuildChild = probeOp->getChild(1).get();
if (innerBuildChild == nullptr || !isForeignTableFunctionCall(innerBuildChild)) {
return std::nullopt;
}
info.srcTableFunc = innerBuildChild->constPtrCast<LogicalTableFunctionCall>();
auto extendOp = probeOp->getChild(0).get();
if (extendOp == nullptr || extendOp->getOperatorType() != LogicalOperatorType::EXTEND) {
return std::nullopt;
}
info.extend = extendOp->constPtrCast<LogicalExtend>();
if (extendOp->getNumChildren() < 1 || extendOp->getChild(0) == nullptr ||
extendOp->getChild(0)->getOperatorType() != LogicalOperatorType::SCAN_NODE_TABLE) {
return std::nullopt;
}
if (!hasForeignScanFunction(info.extend->getRel().get())) {
return std::nullopt;
}
auto srcDbName = getNodeForeignDatabaseName(info.extend->getBoundNode().get(), context);
auto dstDbName = getNodeForeignDatabaseName(info.extend->getNbrNode().get(), context);
auto relDbName = getRelForeignDatabaseName(info.extend->getRel().get(), context);
if (srcDbName.empty() || dstDbName.empty() || relDbName.empty()) {
return std::nullopt;
}
if (srcDbName != dstDbName || srcDbName != relDbName) {
return std::nullopt;
}
auto parenPos = srcDbName.find('(');
if (parenPos != std::string::npos) {
info.dbName = srcDbName.substr(0, parenPos);
} else {
info.dbName = srcDbName;
}
auto extractTableName = [](const std::string& desc) -> std::string {
auto fromPos = desc.find("FROM ");
if (fromPos == std::string::npos) {
return "";
}
auto tableName = desc.substr(fromPos + 5);
auto spacePos = tableName.find(' ');
if (spacePos != std::string::npos) {
tableName = tableName.substr(0, spacePos);
}
auto dotPos = tableName.find('.');
if (dotPos != std::string::npos) {
tableName = tableName.substr(dotPos + 1);
}
return tableName;
};
auto srcDesc = info.srcTableFunc->getBindData()->getDescription();
auto dstDesc = info.dstTableFunc->getBindData()->getDescription();
info.srcTable = extractTableName(srcDesc);
info.dstTable = extractTableName(dstDesc);
if (info.srcTable.empty() || info.dstTable.empty()) {
return std::nullopt;
}
auto rel = info.extend->getRel();
auto relEntry = rel->getEntry(0)->ptrCast<RelGroupCatalogEntry>();
std::string relStorage = relEntry->getStorage();
auto dotPos = relStorage.find('.');
if (dotPos != std::string::npos) {
auto srcDotPos = info.srcTable.find('.');
if (srcDotPos != std::string::npos) {
auto dbSchema = info.srcTable.substr(0, info.srcTable.rfind('.') + 1);
info.relTable = dbSchema + relStorage.substr(dotPos + 1);
} else {
info.relTable = relStorage.substr(dotPos + 1);
}
} else {
info.relTable = relStorage;
}
auto relDotPos = info.relTable.find('.');
if (relDotPos != std::string::npos) {
info.relTable = info.relTable.substr(relDotPos + 1);
}
if (info.relTable.empty()) {
return std::nullopt;
}
return info;
}
static std::vector<std::string> getForeignTableColumnNames(const std::string& dbName,
const std::string& tableName, main::ClientContext* context) {
auto dbManager = main::DatabaseManager::Get(*context);
auto attachedDB = dbManager->getAttachedDatabase(dbName);
if (!attachedDB) {
return {};
}
return attachedDB->getTableColumnNames(tableName);
}
static std::pair<std::string, std::vector<std::string>> buildJoinQuery(
const ForeignJoinPatternInfo& info, const expression_vector& outputColumns,
main::ClientContext* context) {
auto extend = info.extend;
auto srcNode = extend->getBoundNode();
auto dstNode = extend->getNbrNode();
auto rel = extend->getRel();
std::string srcAlias = srcNode->getVariableName();
std::string dstAlias = dstNode->getVariableName();
std::string relAlias = rel->getVariableName();
std::string srcJoinCol, dstJoinCol;
auto tableColumnNames = getForeignTableColumnNames(info.dbName, info.relTable, context);
if (tableColumnNames.size() < 2) {
throw RuntimeException(std::format(
"Foreign join push down optimizer: unable to retrieve column names for table '{}.{}', "
"got {} columns but need at least 2 for join",
info.dbName, info.relTable, tableColumnNames.size()));
}
std::string firstCol = tableColumnNames[0];
std::string secondCol = tableColumnNames[1];
if (extend->getDirection() == ExtendDirection::FWD) {
srcJoinCol = firstCol;
dstJoinCol = secondCol;
} else {
srcJoinCol = secondCol;
dstJoinCol = firstCol;
}
std::string selectClause = "SELECT ";
std::vector<std::string> columnNames;
bool first = true;
for (auto& col : outputColumns) {
if (!first) {
selectClause += ", ";
}
first = false;
std::string colExpr;
std::string colName;
if (col->expressionType == ExpressionType::PROPERTY) {
auto& prop = col->constCast<PropertyExpression>();
auto rawVarName = prop.getRawVariableName();
auto propName = prop.getPropertyName();
auto uniqueName = col->getUniqueName();
if (propName == InternalKeyword::ID) {
colExpr = std::format("{}.id", rawVarName);
} else {
colExpr = std::format("{}.{}", rawVarName, propName);
}
colName = uniqueName;
} else {
auto uniqueName = col->getUniqueName();
auto dotPos = uniqueName.find('.');
if (dotPos != std::string::npos) {
auto prefix = uniqueName.substr(0, dotPos); auto colNamePart = uniqueName.substr(dotPos + 1);
auto underscorePos = prefix.find('_', 1); if (underscorePos != std::string::npos) {
auto rawVar = prefix.substr(underscorePos + 1); colExpr = std::format("{}.{}", rawVar, colNamePart);
} else {
colExpr = std::format("{}.{}", prefix, colNamePart);
}
colName = uniqueName;
std::replace(colName.begin(), colName.end(), '.', '_');
} else {
colExpr = uniqueName;
colName = uniqueName;
}
}
std::replace(colName.begin(), colName.end(), '.', '_');
selectClause += std::format("{} AS {}", colExpr, colName);
columnNames.push_back(colName);
}
std::string query = std::format("{} FROM {} {} "
"JOIN {} {} ON {}.id = {}.{} "
"JOIN {} {} ON {}.{} = {}.id",
selectClause, info.srcTable, srcAlias, info.relTable, relAlias, srcAlias, relAlias,
srcJoinCol, info.dstTable, dstAlias, relAlias, dstJoinCol, dstAlias);
return {query, columnNames};
}
static std::shared_ptr<LogicalOperator> createJoinTableFunctionCall(
const ForeignJoinPatternInfo& info, const std::string& joinQuery,
const std::vector<std::string>& columnNames, const expression_vector& outputColumns) {
auto tableFunc = info.srcTableFunc->getTableFunc();
expression_vector resultColumns;
for (size_t i = 0; i < outputColumns.size(); i++) {
auto& col = outputColumns[i];
auto dataType = col->getDataType().copy();
std::string uniqueName = col->getUniqueName();
auto alias = columnNames[i];
resultColumns.push_back(
std::make_shared<VariableExpression>(std::move(dataType), uniqueName, alias));
}
auto originalBindData = info.srcTableFunc->getBindData();
auto newBindData = originalBindData->copyWithQuery(joinQuery, resultColumns, columnNames);
if (!newBindData) {
return nullptr;
}
newBindData->setColumnPredicates({});
auto tableFuncCall =
std::make_shared<LogicalTableFunctionCall>(std::move(tableFunc), std::move(newBindData));
tableFuncCall->computeFlatSchema();
return tableFuncCall;
}
std::shared_ptr<LogicalOperator> ForeignJoinPushDownOptimizer::visitHashJoinReplace(
std::shared_ptr<LogicalOperator> op) {
auto patternInfo = matchPattern(op.get(), this->context);
if (!patternInfo.has_value()) {
return op;
}
auto& info = patternInfo.value();
auto allColumns = info.outputSchema->getExpressionsInScope();
expression_vector outputColumns;
std::unordered_set<std::string> canonicalVarProps;
auto extractCanonicalVarProp = [](const std::string& uniqueName) -> std::string {
if (uniqueName.empty() || uniqueName[0] != '_') {
return "";
}
auto dotPos = uniqueName.find('.');
if (dotPos == std::string::npos) {
return "";
}
auto prefix = uniqueName.substr(0, dotPos); auto secondUnderscore = prefix.find('_', 1);
if (secondUnderscore == std::string::npos || secondUnderscore + 1 >= prefix.size()) {
return "";
}
auto rawVar = prefix.substr(secondUnderscore + 1);
auto prop = uniqueName.substr(dotPos + 1);
if (rawVar.empty() || prop.empty()) {
return "";
}
return rawVar + "." + prop;
};
for (auto& col : allColumns) {
auto canonical = extractCanonicalVarProp(col->getUniqueName());
if (!canonical.empty()) {
canonicalVarProps.insert(canonical);
}
}
auto isCanonicalOrStandalone = [&](const std::string& uniqueName) {
if (uniqueName.empty() || uniqueName[0] == '_') {
return true;
}
return !canonicalVarProps.contains(uniqueName);
};
auto hasLowercaseID = [&](const std::string& uniqueName) {
static constexpr auto internalIDSuffix = "._ID";
static constexpr auto suffixLen = std::char_traits<char>::length(internalIDSuffix);
if (uniqueName.size() <= suffixLen) {
return false;
}
if (uniqueName.rfind(internalIDSuffix) != uniqueName.size() - suffixLen) {
return false;
}
auto lowercaseID = uniqueName.substr(0, uniqueName.size() - suffixLen);
lowercaseID += ".id";
for (auto& expr : allColumns) {
if (expr->getUniqueName() == lowercaseID) {
return true;
}
}
return false;
};
for (auto& col : allColumns) {
auto uniqueName = col->getUniqueName();
if (!isCanonicalOrStandalone(uniqueName)) {
continue;
}
if (hasLowercaseID(uniqueName)) {
continue;
}
outputColumns.push_back(col);
}
if (outputColumns.empty()) {
for (auto& col : allColumns) {
outputColumns.push_back(col);
}
}
auto [joinQuery, columnNames] = buildJoinQuery(info, outputColumns, this->context);
auto result = createJoinTableFunctionCall(info, joinQuery, columnNames, outputColumns);
if (!result) {
return op;
}
return result;
}
} }