#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/node_expression.h"
#include "binder/expression/node_rel_expression.h"
#include "binder/expression/rel_expression.h"
#include "binder/expression_binder.h"
#include "catalog/catalog.h"
#include "common/cast.h"
#include "common/exception/binder.h"
#include "common/string_utils.h"
#include "common/types/types.h"
#include "function/schema/vector_node_rel_functions.h"
#include "function/struct/vector_struct_functions.h"
#include "main/client_context.h"
#include "main/database_manager.h"
#include "parser/expression/parsed_property_expression.h"
#include "transaction/transaction.h"
#include <format>
using namespace lbug::common;
using namespace lbug::parser;
using namespace lbug::catalog;
namespace lbug {
namespace binder {
static bool isNodeOrRelPattern(const Expression& expression) {
return ExpressionUtil::isNodePattern(expression) || ExpressionUtil::isRelPattern(expression);
}
static bool isStructPattern(const Expression& expression) {
auto logicalTypeID = expression.getDataType().getLogicalTypeID();
return logicalTypeID == LogicalTypeID::NODE || logicalTypeID == LogicalTypeID::REL ||
logicalTypeID == LogicalTypeID::STRUCT;
}
static bool isAnyGraphNodeOrRel(const NodeOrRelExpression& nodeOrRel,
main::ClientContext* context) {
auto transaction = transaction::Transaction::Get(*context);
auto useInternal = context->useInternalCatalogEntry();
auto dbManager = main::DatabaseManager::Get(*context);
auto defaultGraphCatalog = dbManager->getDefaultGraphCatalog();
auto catalog = defaultGraphCatalog != nullptr ? defaultGraphCatalog : Catalog::Get(*context);
for (auto& entry : nodeOrRel.getEntries()) {
if (entry->getType() == CatalogEntryType::NODE_TABLE_ENTRY &&
catalog->containsTable(transaction, "_nodes", useInternal) &&
entry->getTableID() ==
catalog->getTableCatalogEntry(transaction, "_nodes", useInternal)->getTableID()) {
return true;
}
if (entry->getType() == CatalogEntryType::REL_GROUP_ENTRY &&
catalog->containsTable(transaction, "_edges", useInternal) &&
entry->getTableID() ==
catalog->getTableCatalogEntry(transaction, "_edges", useInternal)->getTableID()) {
return true;
}
}
return false;
}
expression_vector ExpressionBinder::bindPropertyStarExpression(
const parser::ParsedExpression& parsedExpression) {
auto child = bindExpression(*parsedExpression.getChild(0));
if (isNodeOrRelPattern(*child)) {
return bindNodeOrRelPropertyStarExpression(*child);
} else if (isStructPattern(*child)) {
return bindStructPropertyStarExpression(child);
} else {
throw BinderException(std::format("Cannot bind property for expression {} with type {}.",
child->toString(), ExpressionTypeUtil::toString(child->expressionType)));
}
}
expression_vector ExpressionBinder::bindNodeOrRelPropertyStarExpression(const Expression& child) {
expression_vector result;
auto& nodeOrRel = child.constCast<NodeOrRelExpression>();
for (auto& property : nodeOrRel.getPropertyExpressions()) {
if (Binder::reservedInPropertyLookup(property->getPropertyName())) {
continue;
}
result.push_back(property);
}
return result;
}
expression_vector ExpressionBinder::bindStructPropertyStarExpression(
const std::shared_ptr<Expression>& child) {
expression_vector result;
const auto& childType = child->getDataType();
for (auto& field : StructType::getFields(childType)) {
result.push_back(bindStructPropertyExpression(child, field.getName()));
}
return result;
}
std::shared_ptr<Expression> ExpressionBinder::bindPropertyExpression(
const ParsedExpression& parsedExpression) {
auto& propertyExpression = parsedExpression.constCast<ParsedPropertyExpression>();
if (propertyExpression.isStar()) {
throw BinderException(std::format("Cannot bind {} as a single property expression.",
parsedExpression.toString()));
}
auto propertyName = propertyExpression.getPropertyName();
auto child = bindExpression(*parsedExpression.getChild(0));
ExpressionUtil::validateDataType(*child,
std::vector<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::STRUCT,
LogicalTypeID::ANY});
if (config.bindOrderByAfterAggregate) {
return bindStructPropertyExpression(child, propertyName);
}
if (isNodeOrRelPattern(*child)) {
if (Binder::reservedInPropertyLookup(propertyName)) {
throw BinderException(
propertyName + " is reserved for system usage. External access is not allowed.");
}
return bindNodeOrRelPropertyExpression(*child, propertyName);
} else if (isStructPattern(*child)) {
return bindStructPropertyExpression(child, propertyName);
} else if (child->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) {
return createVariableExpression(LogicalType::ANY(), binder->getUniqueExpressionName(""));
} else {
throw BinderException(std::format("Cannot bind property for expression {} with type {}.",
child->toString(), ExpressionTypeUtil::toString(child->expressionType)));
}
}
std::shared_ptr<Expression> ExpressionBinder::bindNodeOrRelPropertyExpression(
const Expression& child, const std::string& propertyName) {
auto& nodeOrRel = child.constCast<NodeOrRelExpression>();
if (StringUtils::getUpper(propertyName) == function::RowIDFunction::name) {
std::shared_ptr<Expression> idExpr;
if (ExpressionUtil::isNodePattern(child)) {
auto& node = child.constCast<NodeExpression>();
idExpr = node.getInternalID()->copy();
} else {
auto& rel = child.constCast<RelExpression>();
idExpr = rel.getInternalID()->copy();
}
auto rowIDExpr = bindScalarFunctionExpression({idExpr}, function::OffsetFunction::name);
rowIDExpr->setAlias(std::format("{}.{}", child.toString(), propertyName));
return rowIDExpr;
}
if (propertyName == InternalKeyword::ID &&
child.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) {
auto& node = dynamic_cast_checked<const NodeExpression&>(child);
return node.getInternalID();
}
if (!nodeOrRel.hasPropertyExpression(propertyName)) {
if (isAnyGraphNodeOrRel(nodeOrRel, context)) {
table_id_map_t<SingleLabelPropertyInfo> infos;
for (auto& entry : nodeOrRel.getEntries()) {
infos.insert({entry->getTableID(),
SingleLabelPropertyInfo(true , false )});
}
return std::make_shared<PropertyExpression>(LogicalType::JSON(), propertyName,
nodeOrRel.getUniqueName(), nodeOrRel.getVariableName(), std::move(infos));
}
throw BinderException(
"Cannot find property " + propertyName + " for " + child.toString() + ".");
}
return nodeOrRel.getPropertyExpression(propertyName)->copy();
}
std::shared_ptr<Expression> ExpressionBinder::bindStructPropertyExpression(
std::shared_ptr<Expression> child, const std::string& propertyName) {
auto children = expression_vector{std::move(child), createLiteralExpression(propertyName)};
return bindScalarFunctionExpression(children, function::StructExtractFunctions::name);
}
} }