#include "function/gds/gds.h"
#include "binder/binder.h"
#include "binder/expression/rel_expression.h"
#include "binder/query/reading_clause/bound_table_function_call.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/exception/binder.h"
#include "function/table/bind_input.h"
#include "graph/graph_entry_set.h"
#include "graph/on_disk_graph.h"
#include "parser/parser.h"
#include "planner/operator/logical_table_function_call.h"
#include "planner/operator/sip/logical_semi_masker.h"
#include "planner/planner.h"
#include "processor/operator/table_function_call.h"
#include "processor/plan_mapper.h"
#include <format>
using namespace lbug::catalog;
using namespace lbug::common;
using namespace lbug::binder;
using namespace lbug::main;
using namespace lbug::graph;
using namespace lbug::processor;
using namespace lbug::planner;
namespace lbug {
namespace function {
void GDSFuncSharedState::setGraphNodeMask(std::unique_ptr<NodeOffsetMaskMap> maskMap) {
auto onDiskGraph = dynamic_cast_checked<OnDiskGraph*>(graph.get());
onDiskGraph->setNodeOffsetMask(maskMap.get());
graphNodeMask = std::move(maskMap);
}
static expression_vector getResultColumns(const std::string& cypher, ClientContext* context) {
auto parsedStatements = parser::Parser::parseQuery(cypher);
DASSERT(parsedStatements.size() == 1);
auto binder = Binder(context);
auto boundStatement = binder.bind(*parsedStatements[0]);
return boundStatement->getStatementResult()->getColumns();
}
static void validateNodeProjected(const table_id_set_t& connectedNodeTableIDSet,
const table_id_set_t& projectedNodeIDSet, const std::string& relName, Catalog* catalog,
transaction::Transaction* transaction) {
for (auto id : connectedNodeTableIDSet) {
if (!projectedNodeIDSet.contains(id)) {
auto entryName = catalog->getTableCatalogEntry(transaction, id)->getName();
throw BinderException(
std::format("{} is connected to {} but not projected.", entryName, relName));
}
}
}
static void validateRelSrcDstNodeAreProjected(const TableCatalogEntry& entry,
const table_id_set_t& projectedNodeIDSet, Catalog* catalog,
transaction::Transaction* transaction) {
auto& relEntry = entry.constCast<RelGroupCatalogEntry>();
validateNodeProjected(relEntry.getSrcNodeTableIDSet(), projectedNodeIDSet, relEntry.getName(),
catalog, transaction);
validateNodeProjected(relEntry.getDstNodeTableIDSet(), projectedNodeIDSet, relEntry.getName(),
catalog, transaction);
}
NativeGraphEntry GDSFunction::bindGraphEntry(ClientContext& context, const std::string& name) {
auto set = GraphEntrySet::Get(context);
set->validateGraphExist(name);
auto entry = set->getEntry(name);
if (entry->type != GraphEntryType::NATIVE) {
throw BinderException("AA");
}
return bindGraphEntry(context, entry->cast<ParsedNativeGraphEntry>());
}
static NativeGraphEntryTableInfo bindNodeEntry(ClientContext& context, const std::string& tableName,
const std::string& predicate) {
auto catalog = Catalog::Get(context);
auto transaction = transaction::Transaction::Get(context);
auto nodeEntry = catalog->getTableCatalogEntry(transaction, tableName);
if (nodeEntry->getType() != CatalogEntryType::NODE_TABLE_ENTRY) {
throw BinderException(std::format("{} is not a NODE table.", tableName));
}
if (!predicate.empty()) {
auto cypher = std::format("MATCH (n:`{}`) RETURN n, {}", nodeEntry->getName(), predicate);
auto columns = getResultColumns(cypher, &context);
DASSERT(columns.size() == 2);
return {nodeEntry, columns[0], columns[1]};
} else {
auto cypher = std::format("MATCH (n:`{}`) RETURN n", nodeEntry->getName());
auto columns = getResultColumns(cypher, &context);
DASSERT(columns.size() == 1);
return {nodeEntry, columns[0], nullptr };
}
}
static NativeGraphEntryTableInfo bindRelEntry(ClientContext& context, const std::string& tableName,
const std::string& predicate) {
auto catalog = Catalog::Get(context);
auto transaction = transaction::Transaction::Get(context);
auto relEntry = catalog->getTableCatalogEntry(transaction, tableName);
if (relEntry->getType() != CatalogEntryType::REL_GROUP_ENTRY) {
throw BinderException(
std::format("{} has catalog entry type. REL entry was expected.", tableName));
}
if (!predicate.empty()) {
auto cypher =
std::format("MATCH ()-[r:`{}`]->() RETURN r, {}", relEntry->getName(), predicate);
auto columns = getResultColumns(cypher, &context);
DASSERT(columns.size() == 2);
return {relEntry, columns[0], columns[1]};
} else {
auto cypher = std::format("MATCH ()-[r:`{}`]->() RETURN r", relEntry->getName());
auto columns = getResultColumns(cypher, &context);
DASSERT(columns.size() == 1);
return {relEntry, columns[0], nullptr };
}
}
NativeGraphEntry GDSFunction::bindGraphEntry(ClientContext& context,
const ParsedNativeGraphEntry& entry) {
auto catalog = Catalog::Get(context);
auto transaction = transaction::Transaction::Get(context);
auto result = NativeGraphEntry();
table_id_set_t projectedNodeTableIDSet;
for (auto& nodeInfo : entry.nodeInfos) {
auto boundInfo = bindNodeEntry(context, nodeInfo.tableName, nodeInfo.predicate);
projectedNodeTableIDSet.insert(boundInfo.entry->getTableID());
result.nodeInfos.push_back(std::move(boundInfo));
}
for (auto& relInfo : entry.relInfos) {
if (catalog->containsTable(transaction, relInfo.tableName)) {
auto boundInfo = bindRelEntry(context, relInfo.tableName, relInfo.predicate);
validateRelSrcDstNodeAreProjected(*boundInfo.entry, projectedNodeTableIDSet, catalog,
transaction);
result.relInfos.push_back(std::move(boundInfo));
} else {
throw BinderException(std::format("{} is not a REL table.", relInfo.tableName));
}
}
return result;
}
std::shared_ptr<binder::Expression> GDSFunction::bindRelOutput(const TableFuncBindInput& bindInput,
const std::vector<catalog::TableCatalogEntry*>& relEntries,
std::shared_ptr<NodeExpression> srcNode, std::shared_ptr<NodeExpression> dstNode,
const std::optional<std::string>& name, const std::optional<uint64_t>& yieldVariableIdx) {
std::string relColumnName = name.value_or(REL_COLUMN_NAME);
StringUtils::toLower(relColumnName);
if (!bindInput.yieldVariables.empty()) {
relColumnName =
bindColumnName(bindInput.yieldVariables[yieldVariableIdx.value_or(0)], relColumnName);
}
auto rel = bindInput.binder->createNonRecursiveQueryRel(relColumnName, relEntries, srcNode,
dstNode, RelDirectionType::SINGLE, {});
bindInput.binder->addToScope(REL_COLUMN_NAME, rel);
return rel;
}
std::shared_ptr<Expression> GDSFunction::bindNodeOutput(const TableFuncBindInput& bindInput,
const std::vector<TableCatalogEntry*>& nodeEntries, const std::optional<std::string>& name,
const std::optional<uint64_t>& yieldVariableIdx) {
std::string nodeColumnName = name.value_or(NODE_COLUMN_NAME);
StringUtils::toLower(nodeColumnName);
if (!bindInput.yieldVariables.empty()) {
nodeColumnName =
bindColumnName(bindInput.yieldVariables[yieldVariableIdx.value_or(0)], nodeColumnName);
}
auto node = bindInput.binder->createQueryNode(nodeColumnName, nodeEntries);
bindInput.binder->addToScope(nodeColumnName, node);
return node;
}
std::string GDSFunction::bindColumnName(const parser::YieldVariable& yieldVariable,
std::string expressionName) {
if (yieldVariable.name != expressionName) {
throw common::BinderException{
std::format("Unknown variable name: {}.", yieldVariable.name)};
}
if (yieldVariable.hasAlias()) {
return yieldVariable.alias;
}
return expressionName;
}
std::unique_ptr<TableFuncSharedState> GDSFunction::initSharedState(
const TableFuncInitSharedStateInput& input) {
auto bindData = input.bindData->constPtrCast<GDSBindData>();
auto graph =
std::make_unique<OnDiskGraph>(input.context->clientContext, bindData->graphEntry.copy());
return std::make_unique<GDSFuncSharedState>(bindData->getResultTable(), std::move(graph));
}
std::vector<std::shared_ptr<LogicalOperator>> getNodeMaskPlanRoots(const GDSBindData& bindData,
Planner* planner) {
std::vector<std::shared_ptr<LogicalOperator>> nodeMaskPlanRoots;
for (auto& nodeInfo : bindData.graphEntry.nodeInfos) {
if (nodeInfo.predicate == nullptr) {
continue;
}
auto& node = nodeInfo.nodeOrRel->constCast<NodeExpression>();
planner->getCardinliatyEstimatorUnsafe().init(node);
auto p = planner->getNodeSemiMaskPlan(SemiMaskTargetType::GDS_GRAPH_NODE, node,
nodeInfo.predicate);
nodeMaskPlanRoots.push_back(p.getLastOperator());
}
return nodeMaskPlanRoots;
};
void GDSFunction::getLogicalPlan(Planner* planner, const BoundReadingClause& readingClause,
expression_vector predicates, LogicalPlan& plan) {
auto& call = readingClause.constCast<BoundTableFunctionCall>();
auto bindData = call.getBindData()->constPtrCast<GDSBindData>();
auto op = std::make_shared<LogicalTableFunctionCall>(call.getTableFunc(), bindData->copy());
for (auto root : getNodeMaskPlanRoots(*bindData, planner)) {
op->addChild(root);
}
op->computeFactorizedSchema();
planner->planReadOp(std::move(op), predicates, plan);
auto nodeOutput = bindData->output[0]->ptrCast<NodeExpression>();
DASSERT(nodeOutput != nullptr);
planner->getCardinliatyEstimatorUnsafe().init(*nodeOutput);
auto scanPlan = planner->getNodePropertyScanPlan(*nodeOutput);
if (scanPlan.isEmpty()) {
return;
}
expression_vector joinConditions;
joinConditions.push_back(nodeOutput->getInternalID());
planner->appendHashJoin(joinConditions, JoinType::INNER, plan, scanPlan, plan);
}
std::unique_ptr<PhysicalOperator> GDSFunction::getPhysicalPlan(PlanMapper* planMapper,
const LogicalOperator* logicalOp) {
auto logicalCall = logicalOp->constPtrCast<LogicalTableFunctionCall>();
auto bindData = logicalCall->getBindData()->copy();
auto columns = bindData->columns;
auto tableSchema = PlanMapper::createFlatFTableSchema(columns, *logicalCall->getSchema());
auto table = std::make_shared<FactorizedTable>(
storage::MemoryManager::Get(*planMapper->clientContext), tableSchema.copy());
bindData->cast<GDSBindData>().setResultFTable(table);
auto info = TableFunctionCallInfo();
info.function = logicalCall->getTableFunc();
info.bindData = std::move(bindData);
auto initInput =
TableFuncInitSharedStateInput(info.bindData.get(), planMapper->executionContext);
auto sharedState = info.function.initSharedStateFunc(initInput);
auto printInfo =
std::make_unique<TableFunctionCallPrintInfo>(info.function.name, info.bindData->columns);
auto call = std::make_unique<TableFunctionCall>(std::move(info), sharedState,
planMapper->getOperatorID(), std::move(printInfo));
if (logicalCall->getNumChildren() > 0u) {
const auto funcSharedState = sharedState->ptrCast<GDSFuncSharedState>();
funcSharedState->setGraphNodeMask(std::make_unique<NodeOffsetMaskMap>());
auto maskMap = funcSharedState->getGraphNodeMaskMap();
planMapper->addOperatorMapping(logicalOp, call.get());
for (auto logicalRoot : logicalCall->getChildren()) {
DASSERT(logicalRoot->getNumChildren() == 1);
auto child = logicalRoot->getChild(0);
DASSERT(child->getOperatorType() == LogicalOperatorType::SEMI_MASKER);
auto logicalSemiMasker = child->ptrCast<LogicalSemiMasker>();
logicalSemiMasker->addTarget(logicalOp);
for (auto tableID : logicalSemiMasker->getNodeTableIDs()) {
maskMap->addMask(tableID, planMapper->createSemiMask(tableID));
}
auto root = planMapper->mapOperator(logicalRoot.get());
call->addChild(std::move(root));
}
planMapper->eraseOperatorMapping(logicalOp);
}
planMapper->addOperatorMapping(logicalOp, call.get());
physical_op_vector_t children;
auto dummySink = std::make_unique<DummySink>(std::move(call), planMapper->getOperatorID());
dummySink->setDescriptor(std::make_unique<ResultSetDescriptor>(logicalCall->getSchema()));
children.push_back(std::move(dummySink));
return planMapper->createFTableScanAligned(columns, logicalCall->getSchema(), table,
DEFAULT_VECTOR_CAPACITY, std::move(children));
}
} }