#pragma once
#include "common/constants.h"
#include "common/enums/extend_direction.h"
#include "common/enums/query_rel_type.h"
#include "function/gds/rec_joins.h"
#include "node_expression.h"
namespace lbug {
namespace binder {
enum class RelDirectionType : uint8_t {
SINGLE = 0,
BOTH = 1,
UNKNOWN = 2,
};
class RelExpression;
struct RecursiveInfo {
std::shared_ptr<NodeExpression> node = nullptr;
std::shared_ptr<NodeExpression> nodeCopy = nullptr;
std::shared_ptr<RelExpression> rel = nullptr;
std::shared_ptr<Expression> nodePredicate = nullptr;
std::shared_ptr<Expression> relPredicate = nullptr;
expression_vector nodeProjectionList;
expression_vector relProjectionList;
std::unique_ptr<function::RJAlgorithm> function;
std::unique_ptr<function::RJBindData> bindData;
};
class LBUG_API RelExpression final : public NodeOrRelExpression {
public:
RelExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName,
std::vector<catalog::TableCatalogEntry*> entries, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType,
common::QueryRelType relType)
: NodeOrRelExpression{std::move(dataType), std::move(uniqueName), std::move(variableName),
std::move(entries)},
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directionType{directionType},
relType{relType} {}
bool isRecursive() const {
return dataType.getLogicalTypeID() == common::LogicalTypeID::RECURSIVE_REL;
}
bool isMultiLabeled() const override;
bool isBoundByMultiLabeledNode() const {
return srcNode->isMultiLabeled() || dstNode->isMultiLabeled();
}
std::shared_ptr<NodeExpression> getSrcNode() const { return srcNode; }
std::string getSrcNodeName() const { return srcNode->getUniqueName(); }
void setDstNode(std::shared_ptr<NodeExpression> node) { dstNode = std::move(node); }
std::shared_ptr<NodeExpression> getDstNode() const { return dstNode; }
std::string getDstNodeName() const { return dstNode->getUniqueName(); }
void setLeftNode(std::shared_ptr<NodeExpression> node) { leftNode = std::move(node); }
std::shared_ptr<NodeExpression> getLeftNode() const { return leftNode; }
void setRightNode(std::shared_ptr<NodeExpression> node) { rightNode = std::move(node); }
std::shared_ptr<NodeExpression> getRightNode() const { return rightNode; }
common::QueryRelType getRelType() const { return relType; }
void setDirectionExpr(std::shared_ptr<Expression> expr) { directionExpr = std::move(expr); }
bool hasDirectionExpr() const { return directionExpr != nullptr; }
std::shared_ptr<Expression> getDirectionExpr() const { return directionExpr; }
RelDirectionType getDirectionType() const { return directionType; }
std::shared_ptr<PropertyExpression> getInternalID() const override {
return getPropertyExpression(common::InternalKeyword::ID);
}
void setRecursiveInfo(std::unique_ptr<RecursiveInfo> recursiveInfo_) {
recursiveInfo = std::move(recursiveInfo_);
}
const RecursiveInfo* getRecursiveInfo() const { return recursiveInfo.get(); }
std::shared_ptr<Expression> getLengthExpression() const {
DASSERT(recursiveInfo != nullptr);
return recursiveInfo->bindData->lengthExpr;
}
bool isSelfLoop() const { return *srcNode == *dstNode; }
std::string detailsToString() const;
std::vector<common::ExtendDirection> getExtendDirections() const;
std::vector<common::table_id_t> getInnerRelTableIDs() const;
private:
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
std::shared_ptr<NodeExpression> leftNode;
std::shared_ptr<NodeExpression> rightNode;
RelDirectionType directionType;
std::shared_ptr<Expression> directionExpr;
common::QueryRelType relType;
std::unique_ptr<RecursiveInfo> recursiveInfo;
};
} }