#include "storage/table/rel_table_data.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/enums/rel_direction.h"
#include "common/types/types.h"
#include "main/client_context.h"
#include "storage/storage_manager.h"
#include "storage/storage_utils.h"
#include "storage/table/node_group.h"
#include "storage/table/rel_table.h"
#include "transaction/transaction.h"
using namespace lbug::catalog;
using namespace lbug::common;
using namespace lbug::transaction;
namespace lbug {
namespace storage {
PersistentVersionRecordHandler::PersistentVersionRecordHandler(RelTableData* relTableData)
: relTableData(relTableData) {}
void PersistentVersionRecordHandler::applyFuncToChunkedGroups(version_record_handler_op_t func,
node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows,
transaction_t commitTS) const {
if (nodeGroupIdx < relTableData->getNumNodeGroups()) {
auto& nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx)->cast<CSRNodeGroup>();
if (auto* persistentChunkedGroup = nodeGroup.getPersistentChunkedGroup()) {
std::invoke(func, *persistentChunkedGroup, startRow, numRows, commitTS);
}
}
}
void PersistentVersionRecordHandler::rollbackInsert(main::ClientContext* context,
node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows) const {
VersionRecordHandler::rollbackInsert(context, nodeGroupIdx, startRow, numRows);
relTableData->rollbackGroupCollectionInsert(numRows, true);
}
InMemoryVersionRecordHandler::InMemoryVersionRecordHandler(RelTableData* relTableData)
: relTableData(relTableData) {}
void InMemoryVersionRecordHandler::applyFuncToChunkedGroups(version_record_handler_op_t func,
node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows,
transaction_t commitTS) const {
auto* nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx);
nodeGroup->applyFuncToChunkedGroups(func, startRow, numRows, commitTS);
}
void InMemoryVersionRecordHandler::rollbackInsert(main::ClientContext* context,
node_group_idx_t nodeGroupIdx, row_idx_t startRow, row_idx_t numRows) const {
VersionRecordHandler::rollbackInsert(context, nodeGroupIdx, startRow, numRows);
auto* nodeGroup = relTableData->getNodeGroupNoLock(nodeGroupIdx);
const auto numRowsToRollback = std::min(numRows, nodeGroup->getNumRows() - startRow);
nodeGroup->rollbackInsert(startRow);
relTableData->rollbackGroupCollectionInsert(numRowsToRollback, false);
}
RelTableData::RelTableData(FileHandle* dataFH, MemoryManager* mm, ShadowFile* shadowFile,
const RelGroupCatalogEntry& relGroupEntry, Table& table, RelDataDirection direction,
table_id_t nbrTableID, bool enableCompression)
: table{table}, mm{mm}, shadowFile{shadowFile}, enableCompression{enableCompression},
direction{direction}, multiplicity{relGroupEntry.getMultiplicity(direction)},
persistentVersionRecordHandler(this), inMemoryVersionRecordHandler(this) {
initCSRHeaderColumns(dataFH);
initPropertyColumns(relGroupEntry, nbrTableID, dataFH);
nodeGroups = std::make_unique<NodeGroupCollection>(*mm, getColumnTypes(), enableCompression,
ResidencyState::ON_DISK, &persistentVersionRecordHandler);
}
void RelTableData::initCSRHeaderColumns(FileHandle* dataFH) {
auto csrOffsetColumnName = StorageUtils::getColumnName("", StorageUtils::ColumnType::CSR_OFFSET,
RelDirectionUtils::relDirectionToString(direction));
csrHeaderColumns.offset = std::make_unique<Column>(csrOffsetColumnName, LogicalType::UINT64(),
dataFH, mm, shadowFile, enableCompression, false );
auto csrLengthColumnName = StorageUtils::getColumnName("", StorageUtils::ColumnType::CSR_LENGTH,
RelDirectionUtils::relDirectionToString(direction));
csrHeaderColumns.length = std::make_unique<Column>(csrLengthColumnName, LogicalType::UINT64(),
dataFH, mm, shadowFile, enableCompression, false );
}
void RelTableData::initPropertyColumns(const RelGroupCatalogEntry& relGroupEntry,
table_id_t nbrTableID, FileHandle* dataFH) {
const auto maxColumnID = relGroupEntry.getMaxColumnID();
columns.resize(maxColumnID + 1);
auto nbrIDColName = StorageUtils::getColumnName("NBR_ID", StorageUtils::ColumnType::DEFAULT,
RelDirectionUtils::relDirectionToString(direction));
auto nbrIDColumn =
std::make_unique<InternalIDColumn>(nbrIDColName, dataFH, mm, shadowFile, enableCompression);
columns[NBR_ID_COLUMN_ID] = std::move(nbrIDColumn);
for (auto& property : relGroupEntry.getProperties()) {
const auto columnID = relGroupEntry.getColumnID(property.getName());
const auto colName = StorageUtils::getColumnName(property.getName(),
StorageUtils::ColumnType::DEFAULT, RelDirectionUtils::relDirectionToString(direction));
columns[columnID] = ColumnFactory::createColumn(colName, property.getType().copy(), dataFH,
mm, shadowFile, enableCompression);
}
columns[NBR_ID_COLUMN_ID]->cast<InternalIDColumn>().setCommonTableID(nbrTableID);
columns[REL_ID_COLUMN_ID]->cast<InternalIDColumn>().setCommonTableID(table.getTableID());
}
bool RelTableData::update(Transaction* transaction, ValueVector& boundNodeIDVector,
const ValueVector& relIDVector, column_id_t columnID, const ValueVector& dataVector) const {
DASSERT(boundNodeIDVector.state->getSelVector().getSelSize() == 1);
DASSERT(relIDVector.state->getSelVector().getSelSize() == 1);
const auto boundNodePos = boundNodeIDVector.state->getSelVector()[0];
const auto relIDPos = relIDVector.state->getSelVector()[0];
if (boundNodeIDVector.isNull(boundNodePos) || relIDVector.isNull(relIDPos)) {
return false;
}
const auto [source, rowIdx] = findMatchingRow(transaction, boundNodeIDVector, relIDVector);
DASSERT(rowIdx != INVALID_ROW_IDX);
const auto boundNodeOffset = boundNodeIDVector.getValue<nodeID_t>(boundNodePos).offset;
const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeOffset);
auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast<CSRNodeGroup>();
csrNodeGroup.update(transaction, source, rowIdx, columnID, dataVector);
return true;
}
bool RelTableData::delete_(Transaction* transaction, ValueVector& boundNodeIDVector,
const ValueVector& relIDVector) {
const auto boundNodePos = boundNodeIDVector.state->getSelVector()[0];
const auto relIDPos = relIDVector.state->getSelVector()[0];
if (boundNodeIDVector.isNull(boundNodePos) || relIDVector.isNull(relIDPos)) {
return false;
}
const auto [source, rowIdx] = findMatchingRow(transaction, boundNodeIDVector, relIDVector);
if (rowIdx == INVALID_ROW_IDX) {
return false;
}
const auto boundNodeOffset = boundNodeIDVector.getValue<nodeID_t>(boundNodePos).offset;
const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeOffset);
auto& csrNodeGroup = getNodeGroup(nodeGroupIdx)->cast<CSRNodeGroup>();
bool isDeleted = csrNodeGroup.delete_(transaction, source, rowIdx);
if (isDeleted && transaction->shouldAppendToUndoBuffer()) {
transaction->pushDeleteInfo(nodeGroupIdx, rowIdx, 1, getVersionRecordHandler(source));
}
return isDeleted;
}
void RelTableData::addColumn(TableAddColumnState& addColumnState, PageAllocator& pageAllocator) {
auto& definition = addColumnState.propertyDefinition;
columns.push_back(ColumnFactory::createColumn(definition.getName(), definition.getType().copy(),
pageAllocator.getDataFH(), mm, shadowFile, enableCompression));
nodeGroups->addColumn(addColumnState, &pageAllocator);
}
std::pair<CSRNodeGroupScanSource, row_idx_t> RelTableData::findMatchingRow(Transaction* transaction,
ValueVector& boundNodeIDVector, const ValueVector& relIDVector) const {
DASSERT(boundNodeIDVector.state->getSelVector().getSelSize() == 1);
DASSERT(relIDVector.state->getSelVector().getSelSize() == 1);
const auto boundNodePos = boundNodeIDVector.state->getSelVector()[0];
const auto relIDPos = relIDVector.state->getSelVector()[0];
const auto boundNodeOffset = boundNodeIDVector.getValue<nodeID_t>(boundNodePos).offset;
const auto relOffset = relIDVector.getValue<nodeID_t>(relIDPos).offset;
const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(boundNodeOffset);
DataChunk scanChunk(1);
scanChunk.insert(0, std::make_shared<ValueVector>(LogicalType::INTERNAL_ID()));
std::vector columnIDs = {REL_ID_COLUMN_ID, ROW_IDX_COLUMN_ID};
std::vector<const Column*> columns{getColumn(REL_ID_COLUMN_ID), nullptr};
auto scanState = std::make_unique<RelTableScanState>(*mm, &boundNodeIDVector,
std::vector{&scanChunk.getValueVectorMutable(0)}, scanChunk.state, true );
scanState->setToTable(transaction, &table, columnIDs, {}, direction);
scanState->initState(transaction, getNodeGroup(nodeGroupIdx));
row_idx_t matchingRowIdx = INVALID_ROW_IDX;
auto source = CSRNodeGroupScanSource::NONE;
const auto scannedIDVector = scanState->outputVectors[0];
while (true) {
const auto scanResult = scanState->nodeGroup->scan(transaction, *scanState);
if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) {
break;
}
for (auto i = 0u; i < scanState->outState->getSelVector().getSelSize(); i++) {
const auto pos = scanState->outState->getSelVector()[i];
if (scannedIDVector->getValue<internalID_t>(pos).offset == relOffset) {
const auto rowIdxPos = scanState->rowIdxVector->state->getSelVector()[i];
matchingRowIdx = scanState->rowIdxVector->getValue<row_idx_t>(rowIdxPos);
source = scanState->nodeGroupScanState->cast<CSRNodeGroupScanState>().source;
break;
}
}
if (matchingRowIdx != INVALID_ROW_IDX) {
break;
}
}
return {source, matchingRowIdx};
}
bool RelTableData::checkIfNodeHasRels(Transaction* transaction,
ValueVector* srcNodeIDVector) const {
DASSERT(srcNodeIDVector->state->isFlat());
const auto nodeIDPos = srcNodeIDVector->state->getSelVector()[0];
const auto nodeOffset = srcNodeIDVector->getValue<nodeID_t>(nodeIDPos).offset;
const auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset);
if (nodeGroupIdx >= getNumNodeGroups()) {
return false;
}
DataChunk scanChunk(1);
scanChunk.insert(0, std::make_shared<ValueVector>(LogicalType::INTERNAL_ID()));
std::vector columnIDs = {REL_ID_COLUMN_ID};
std::vector<const Column*> columns{getColumn(REL_ID_COLUMN_ID)};
auto scanState = std::make_unique<RelTableScanState>(*mm, srcNodeIDVector,
std::vector{&scanChunk.getValueVectorMutable(0)}, scanChunk.state, true );
scanState->setToTable(transaction, &table, columnIDs, {}, direction);
scanState->initState(transaction, getNodeGroup(nodeGroupIdx));
while (true) {
const auto scanResult = scanState->nodeGroup->scan(transaction, *scanState);
if (scanResult == NODE_GROUP_SCAN_EMPTY_RESULT) {
break;
}
if (scanState->outState->getSelVector().getSelSize() > 0) {
return true;
}
}
return false;
}
void RelTableData::pushInsertInfo(const Transaction* transaction, const CSRNodeGroup& nodeGroup,
row_idx_t numRows_, CSRNodeGroupScanSource source) {
DASSERT(source != CSRNodeGroupScanSource::COMMITTED_PERSISTENT ||
!nodeGroup.getPersistentChunkedGroup() ||
nodeGroup.getPersistentChunkedGroup()->getNumRows() == 0);
const auto [startRow, shouldIncrementNumRows] =
(source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) ?
std::make_pair(static_cast<row_idx_t>(0), false) :
std::make_pair(nodeGroup.getNumRows(), true);
nodeGroups->pushInsertInfo(transaction, nodeGroup.getNodeGroupIdx(), startRow, numRows_,
getVersionRecordHandler(source), shouldIncrementNumRows);
}
void RelTableData::checkpoint(const std::vector<column_id_t>& columnIDs,
PageAllocator& pageAllocator, const Transaction* snapshotTxn) {
std::vector<std::unique_ptr<Column>> checkpointColumns;
for (auto i = 0u; i < columnIDs.size(); i++) {
const auto columnID = columnIDs[i];
checkpointColumns.push_back(std::move(columns[columnID]));
}
columns = std::move(checkpointColumns);
std::vector<Column*> checkpointColumnPtrs;
for (const auto& column : columns) {
checkpointColumnPtrs.push_back(column.get());
}
CSRNodeGroupCheckpointState state{columnIDs, std::move(checkpointColumnPtrs), pageAllocator, mm,
csrHeaderColumns.offset.get(), csrHeaderColumns.length.get()};
state.transaction = snapshotTxn;
nodeGroups->checkpoint(*mm, state);
}
void RelTableData::serialize(Serializer& serializer) const {
nodeGroups->serialize(serializer);
}
void RelTableData::deserialize(Deserializer& deSerializer, MemoryManager& memoryManager) {
nodeGroups->deserialize(deSerializer, memoryManager);
}
const VersionRecordHandler* RelTableData::getVersionRecordHandler(
CSRNodeGroupScanSource source) const {
if (source == CSRNodeGroupScanSource::COMMITTED_PERSISTENT) {
return &persistentVersionRecordHandler;
} else {
DASSERT(source == CSRNodeGroupScanSource::COMMITTED_IN_MEMORY);
return &inMemoryVersionRecordHandler;
}
}
void RelTableData::rollbackGroupCollectionInsert(row_idx_t numRows_, bool isPersistent) {
nodeGroups->rollbackInsert(numRows_, !isPersistent);
}
void RelTableData::reclaimStorage(PageAllocator& pageAllocator) const {
nodeGroups->reclaimStorage(pageAllocator);
}
} }