#include "optimizer/foreign_join_push_down_optimizer.h"
#include <algorithm>
#include <cctype>
#include "binder/expression/property_expression.h"
#include "binder/expression/variable_expression.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/exception/runtime.h"
#include "main/database_manager.h"
#include "planner/operator/extend/logical_extend.h"
#include "planner/operator/logical_filter.h"
#include "planner/operator/logical_flatten.h"
#include "planner/operator/logical_hash_join.h"
#include "planner/operator/logical_table_function_call.h"
#include "planner/operator/scan/logical_scan_node_table.h"
#include <format>
using namespace lbug::binder;
using namespace lbug::common;
using namespace lbug::planner;
using namespace lbug::catalog;
namespace lbug {
namespace optimizer {
void ForeignJoinPushDownOptimizer::rewrite(LogicalPlan* plan) {
visitOperator(plan->getLastOperator());
}
std::shared_ptr<LogicalOperator> ForeignJoinPushDownOptimizer::visitOperator(
const std::shared_ptr<LogicalOperator>& op) {
for (auto i = 0u; i < op->getNumChildren(); ++i) {
op->setChild(i, visitOperator(op->getChild(i)));
}
auto result = visitOperatorReplaceSwitch(op);
result->computeFlatSchema();
return result;
}
static bool isForeignTableFunctionCall(const LogicalOperator* op) {
if (op->getOperatorType() != LogicalOperatorType::TABLE_FUNCTION_CALL) {
return false;
}
auto& tableFuncCall = op->constCast<LogicalTableFunctionCall>();
return tableFuncCall.getTableFunc().supportsPushDownFunc();
}
static bool hasForeignScanFunction(const RelExpression* rel) {
if (rel->getNumEntries() != 1) {
return false;
}
auto relEntry = rel->getEntry(0)->ptrCast<RelGroupCatalogEntry>();
return relEntry && relEntry->getScanFunction().has_value();
}
static std::string getNodeForeignDatabaseName(const NodeExpression* node,
main::ClientContext* context) {
if (!node || node->getNumEntries() != 1) {
return "";
}
auto entry = node->getEntry(0);
if (!entry) {
return "";
}
std::string dbName;
if (entry->getType() == CatalogEntryType::NODE_TABLE_ENTRY) {
auto nodeEntry = entry->ptrCast<NodeTableCatalogEntry>();
if (!nodeEntry) {
return "";
}
dbName = nodeEntry->getForeignDatabaseName();
} else if (entry->getType() == CatalogEntryType::FOREIGN_TABLE_ENTRY) {
dbName = node->getDbName(entry);
}
if (dbName.empty()) {
return "";
}
auto dbManager = main::DatabaseManager::Get(*context);
auto attachedDB = dbManager->getAttachedDatabase(dbName);
if (!attachedDB) {
return "";
}
return std::format("{}({})", dbName, attachedDB->getDBType());
}
static std::string getRelForeignDatabaseName(const RelExpression* rel,
main::ClientContext* context) {
if (!rel || rel->getNumEntries() != 1) {
return "";
}
auto entry = rel->getEntry(0);
if (!entry) {
return "";
}
auto relEntry = entry->ptrCast<RelGroupCatalogEntry>();
if (!relEntry) {
return "";
}
auto storedName = relEntry->getForeignDatabaseName();
if (!storedName.empty()) {
return storedName;
}
auto storage = relEntry->getStorage();
auto dotPos = storage.find('.');
if (dotPos == std::string::npos) {
return "";
}
auto dbName = storage.substr(0, dotPos);
auto dbManager = main::DatabaseManager::Get(*context);
auto attachedDB = dbManager->getAttachedDatabase(dbName);
if (!attachedDB) {
return "";
}
return std::format("{}({})", dbName, attachedDB->getDBType());
}
struct ForeignJoinPatternInfo {
const LogicalExtend* extend = nullptr;
const LogicalTableFunctionCall* srcTableFunc = nullptr;
const LogicalTableFunctionCall* dstTableFunc = nullptr;
const LogicalHashJoin* outerHashJoin = nullptr;
const LogicalHashJoin* innerHashJoin = nullptr;
const LogicalFilter* relFilter = nullptr;
const Schema* outputSchema = nullptr;
std::string srcTable;
std::string dstTable;
std::string relTable;
std::string dbName; };
static std::string getUnqualifiedTableName(std::string tableName) {
auto dotPos = tableName.rfind('.');
if (dotPos != std::string::npos) {
tableName = tableName.substr(dotPos + 1);
}
if (tableName.size() >= 2 && tableName.front() == '"' && tableName.back() == '"') {
tableName = tableName.substr(1, tableName.size() - 2);
}
return tableName;
}
static std::optional<ForeignJoinPatternInfo> matchPattern(const LogicalOperator* op,
main::ClientContext* context) {
if (op == nullptr) {
return std::nullopt;
}
ForeignJoinPatternInfo info;
info.outputSchema = op->getSchema();
if (op->getOperatorType() != LogicalOperatorType::HASH_JOIN) {
return std::nullopt;
}
if (op->getNumChildren() < 2) {
return std::nullopt;
}
info.outerHashJoin = op->constPtrCast<LogicalHashJoin>();
if (info.outerHashJoin->getJoinType() != JoinType::INNER) {
return std::nullopt;
}
auto buildChild = op->getChild(1).get();
if (buildChild == nullptr || !isForeignTableFunctionCall(buildChild)) {
return std::nullopt;
}
info.dstTableFunc = buildChild->constPtrCast<LogicalTableFunctionCall>();
auto probeOp = op->getChild(0).get();
if (probeOp == nullptr) {
return std::nullopt;
}
if (probeOp->getOperatorType() == LogicalOperatorType::FLATTEN) {
if (probeOp->getNumChildren() < 1) {
return std::nullopt;
}
probeOp = probeOp->getChild(0).get();
if (probeOp == nullptr) {
return std::nullopt;
}
}
if (probeOp->getOperatorType() != LogicalOperatorType::HASH_JOIN) {
return std::nullopt;
}
if (probeOp->getNumChildren() < 2) {
return std::nullopt;
}
info.innerHashJoin = probeOp->constPtrCast<LogicalHashJoin>();
if (info.innerHashJoin->getJoinType() != JoinType::INNER) {
return std::nullopt;
}
auto innerBuildChild = probeOp->getChild(1).get();
if (innerBuildChild == nullptr || !isForeignTableFunctionCall(innerBuildChild)) {
return std::nullopt;
}
info.srcTableFunc = innerBuildChild->constPtrCast<LogicalTableFunctionCall>();
auto extendOp = probeOp->getChild(0).get();
if (extendOp != nullptr && extendOp->getOperatorType() == LogicalOperatorType::FILTER) {
info.relFilter = extendOp->constPtrCast<LogicalFilter>();
if (extendOp->getNumChildren() < 1) {
return std::nullopt;
}
extendOp = extendOp->getChild(0).get();
}
if (extendOp == nullptr || extendOp->getOperatorType() != LogicalOperatorType::EXTEND) {
return std::nullopt;
}
info.extend = extendOp->constPtrCast<LogicalExtend>();
if (extendOp->getNumChildren() < 1 || extendOp->getChild(0) == nullptr ||
extendOp->getChild(0)->getOperatorType() != LogicalOperatorType::SCAN_NODE_TABLE) {
return std::nullopt;
}
if (!hasForeignScanFunction(info.extend->getRel().get())) {
return std::nullopt;
}
auto srcDbName = getNodeForeignDatabaseName(info.extend->getBoundNode().get(), context);
auto dstDbName = getNodeForeignDatabaseName(info.extend->getNbrNode().get(), context);
auto relDbName = getRelForeignDatabaseName(info.extend->getRel().get(), context);
if (srcDbName.empty() || dstDbName.empty() || relDbName.empty()) {
return std::nullopt;
}
if (srcDbName != dstDbName || srcDbName != relDbName) {
return std::nullopt;
}
auto parenPos = srcDbName.find('(');
if (parenPos != std::string::npos) {
info.dbName = srcDbName.substr(0, parenPos);
} else {
info.dbName = srcDbName;
}
auto extractTableName = [](const std::string& desc) -> std::string {
auto fromPos = desc.find("FROM ");
if (fromPos == std::string::npos) {
return "";
}
auto tableName = desc.substr(fromPos + 5);
auto spacePos = tableName.find(' ');
if (spacePos != std::string::npos) {
tableName = tableName.substr(0, spacePos);
}
return tableName;
};
auto srcDesc = info.srcTableFunc->getBindData()->getDescription();
auto dstDesc = info.dstTableFunc->getBindData()->getDescription();
info.srcTable = extractTableName(srcDesc);
info.dstTable = extractTableName(dstDesc);
if (info.srcTable.empty() || info.dstTable.empty()) {
return std::nullopt;
}
auto rel = info.extend->getRel();
auto relEntry = rel->getEntry(0)->ptrCast<RelGroupCatalogEntry>();
std::string relStorage = relEntry->getStorage();
auto dotPos = relStorage.find('.');
if (dotPos != std::string::npos) {
auto srcDotPos = info.srcTable.find('.');
if (srcDotPos != std::string::npos) {
auto dbSchema = info.srcTable.substr(0, info.srcTable.rfind('.') + 1);
info.relTable = dbSchema + relStorage.substr(dotPos + 1);
} else {
info.relTable = relStorage.substr(dotPos + 1);
}
} else {
info.relTable = relStorage;
}
if (info.relTable.empty()) {
return std::nullopt;
}
return info;
}
static std::vector<std::string> getForeignTableColumnNames(const std::string& dbName,
const std::string& tableName, main::ClientContext* context) {
auto dbManager = main::DatabaseManager::Get(*context);
auto attachedDB = dbManager->getAttachedDatabase(dbName);
if (!attachedDB) {
return {};
}
return attachedDB->getTableColumnNames(tableName);
}
struct JoinQueryInfo {
std::string query;
std::vector<std::string> columnNames;
std::vector<std::string> displayNames;
};
static std::string sanitizeSQLAlias(std::string alias) {
for (auto& ch : alias) {
if (!std::isalnum(static_cast<unsigned char>(ch)) && ch != '_') {
ch = '_';
}
}
if (alias.empty() || std::isdigit(static_cast<unsigned char>(alias[0]))) {
alias = "col_" + alias;
}
return alias;
}
static JoinQueryInfo buildJoinQuery(const ForeignJoinPatternInfo& info,
const expression_vector& outputColumns, main::ClientContext* context) {
auto extend = info.extend;
auto srcNode = extend->getBoundNode();
auto dstNode = extend->getNbrNode();
auto rel = extend->getRel();
std::string srcAlias = srcNode->getVariableName();
std::string dstAlias = dstNode->getVariableName();
std::string relAlias = rel->getVariableName();
std::string srcJoinCol, dstJoinCol;
auto tableColumnNames =
getForeignTableColumnNames(info.dbName, getUnqualifiedTableName(info.relTable), context);
if (tableColumnNames.size() < 2) {
throw RuntimeException(std::format(
"Foreign join push down optimizer: unable to retrieve column names for table '{}.{}', "
"got {} columns but need at least 2 for join",
info.dbName, info.relTable, tableColumnNames.size()));
}
std::string firstCol = tableColumnNames[0];
std::string secondCol = tableColumnNames[1];
if (extend->getDirection() == ExtendDirection::FWD) {
srcJoinCol = firstCol;
dstJoinCol = secondCol;
} else {
srcJoinCol = secondCol;
dstJoinCol = firstCol;
}
auto getNodeIDColumn = [&](const std::string& tableName) {
auto columnNames =
getForeignTableColumnNames(info.dbName, getUnqualifiedTableName(tableName), context);
if (columnNames.empty()) {
return std::string{InternalKeyword::ID};
}
return columnNames[0];
};
auto srcIDCol = getNodeIDColumn(info.srcTable);
auto dstIDCol = getNodeIDColumn(info.dstTable);
std::vector<std::string> columnNames;
std::vector<std::string> displayNames;
for (auto& col : outputColumns) {
std::string colExpr;
std::string colName;
std::string displayName;
if (col->expressionType == ExpressionType::PROPERTY) {
auto& prop = col->constCast<PropertyExpression>();
auto rawVarName = prop.getRawVariableName();
auto propName = prop.getPropertyName();
if (propName == InternalKeyword::ID) {
auto idColumn = rawVarName == srcAlias ? srcIDCol : dstIDCol;
colExpr = std::format("{}.{}", rawVarName, idColumn);
} else {
colExpr = std::format("{}.{}", rawVarName, propName);
}
colName = sanitizeSQLAlias(std::format("{}_{}", rawVarName, propName));
displayName = std::format("{}.{}", rawVarName, propName);
} else {
auto uniqueName = col->getUniqueName();
auto dotPos = uniqueName.find('.');
if (dotPos != std::string::npos) {
auto prefix = uniqueName.substr(0, dotPos); auto colNamePart = uniqueName.substr(dotPos + 1);
auto underscorePos = prefix.find('_', 1); if (underscorePos != std::string::npos) {
auto rawVar = prefix.substr(underscorePos + 1); colExpr = std::format("{}.{}", rawVar, colNamePart);
colName = sanitizeSQLAlias(std::format("{}_{}", rawVar, colNamePart));
displayName = std::format("{}.{}", rawVar, colNamePart);
} else {
colExpr = std::format("{}.{}", prefix, colNamePart);
colName = sanitizeSQLAlias(uniqueName);
displayName = uniqueName;
}
} else {
colExpr = uniqueName;
colName = sanitizeSQLAlias(uniqueName);
displayName = uniqueName;
}
}
columnNames.push_back(std::format("{} AS {}", colExpr, colName));
displayNames.push_back(displayName);
}
std::string query = std::format("SELECT {{}} FROM {} {} "
"JOIN {} {} ON {}.{} = {}.{} "
"JOIN {} {} ON {}.{} = {}.{}",
info.srcTable, srcAlias, info.relTable, relAlias, srcAlias, srcIDCol, relAlias, srcJoinCol,
info.dstTable, dstAlias, relAlias, dstJoinCol, dstAlias, dstIDCol);
return {std::move(query), std::move(columnNames), std::move(displayNames)};
}
static std::shared_ptr<LogicalOperator> createJoinTableFunctionCall(
const ForeignJoinPatternInfo& info, const std::string& joinQuery,
const std::vector<std::string>& columnNames, const std::vector<std::string>& displayNames,
const expression_vector& outputColumns) {
auto tableFunc = info.srcTableFunc->getTableFunc();
expression_vector resultColumns;
for (size_t i = 0; i < outputColumns.size(); i++) {
auto& col = outputColumns[i];
auto dataType = col->getDataType().copy();
std::string uniqueName = col->getUniqueName();
auto alias = displayNames[i];
resultColumns.push_back(
std::make_shared<VariableExpression>(std::move(dataType), uniqueName, alias));
}
auto originalBindData = info.srcTableFunc->getBindData();
auto newBindData = originalBindData->copyWithQuery(joinQuery, resultColumns, columnNames);
if (!newBindData) {
return nullptr;
}
newBindData->setColumnPredicates({});
auto tableFuncCall =
std::make_shared<LogicalTableFunctionCall>(std::move(tableFunc), std::move(newBindData));
tableFuncCall->computeFlatSchema();
return tableFuncCall;
}
std::shared_ptr<LogicalOperator> ForeignJoinPushDownOptimizer::visitHashJoinReplace(
std::shared_ptr<LogicalOperator> op) {
auto patternInfo = matchPattern(op.get(), this->context);
if (!patternInfo.has_value()) {
return op;
}
auto& info = patternInfo.value();
auto allColumns = info.outputSchema->getExpressionsInScope();
expression_vector outputColumns;
std::unordered_set<std::string> outputColumnNames;
std::unordered_set<std::string> canonicalVarProps;
auto appendOutputColumn = [&](const std::shared_ptr<Expression>& column) {
if (!outputColumnNames.insert(column->getUniqueName()).second) {
return;
}
outputColumns.push_back(column);
};
auto extractCanonicalVarProp = [](const std::string& uniqueName) -> std::string {
if (uniqueName.empty() || uniqueName[0] != '_') {
return "";
}
auto dotPos = uniqueName.find('.');
if (dotPos == std::string::npos) {
return "";
}
auto prefix = uniqueName.substr(0, dotPos); auto secondUnderscore = prefix.find('_', 1);
if (secondUnderscore == std::string::npos || secondUnderscore + 1 >= prefix.size()) {
return "";
}
auto rawVar = prefix.substr(secondUnderscore + 1);
auto prop = uniqueName.substr(dotPos + 1);
if (rawVar.empty() || prop.empty()) {
return "";
}
return rawVar + "." + prop;
};
for (auto& col : allColumns) {
auto canonical = extractCanonicalVarProp(col->getUniqueName());
if (!canonical.empty()) {
canonicalVarProps.insert(canonical);
}
}
auto isCanonicalOrStandalone = [&](const std::string& uniqueName) {
if (uniqueName.empty() || uniqueName[0] == '_') {
return true;
}
return !canonicalVarProps.contains(uniqueName);
};
auto hasLowercaseID = [&](const std::string& uniqueName) {
static constexpr auto internalIDSuffix = "._ID";
static constexpr auto suffixLen = std::char_traits<char>::length(internalIDSuffix);
if (uniqueName.size() <= suffixLen) {
return false;
}
if (uniqueName.rfind(internalIDSuffix) != uniqueName.size() - suffixLen) {
return false;
}
auto lowercaseID = uniqueName.substr(0, uniqueName.size() - suffixLen);
lowercaseID += ".id";
for (auto& expr : allColumns) {
if (expr->getUniqueName() == lowercaseID) {
return true;
}
}
return false;
};
for (auto& col : allColumns) {
auto uniqueName = col->getUniqueName();
if (!isCanonicalOrStandalone(uniqueName)) {
continue;
}
if (hasLowercaseID(uniqueName)) {
continue;
}
appendOutputColumn(col);
}
auto appendPatternProperties = [&](const std::shared_ptr<NodeOrRelExpression>& pattern) {
for (auto& property : pattern->getPropertyExpressions()) {
if (property->getPropertyName().starts_with("_")) {
continue;
}
appendOutputColumn(property);
}
};
appendPatternProperties(info.extend->getBoundNode());
appendPatternProperties(info.extend->getRel());
appendPatternProperties(info.extend->getNbrNode());
if (outputColumns.empty()) {
for (auto& col : allColumns) {
appendOutputColumn(col);
}
}
auto joinQueryInfo = buildJoinQuery(info, outputColumns, this->context);
auto result = createJoinTableFunctionCall(info, joinQueryInfo.query, joinQueryInfo.columnNames,
joinQueryInfo.displayNames, outputColumns);
if (!result) {
return op;
}
if (info.relFilter != nullptr) {
result = std::make_shared<LogicalFilter>(info.relFilter->getPredicate(), std::move(result));
result->computeFlatSchema();
}
return result;
}
} }