lbug 0.16.1

An in-process property graph database management system built for query speed and scalability
Documentation
#include "storage/local_storage/local_rel_table.h"

#include <algorithm>
#include <numeric>

#include "common/enums/rel_direction.h"
#include "storage/table/rel_table.h"
#include "transaction/transaction.h"

using namespace lbug::common;
using namespace lbug::transaction;

namespace lbug {
namespace storage {

static std::vector<LogicalType> getTypesForLocalRelTable(const catalog::TableCatalogEntry& table) {
    std::vector<LogicalType> types;
    types.reserve(table.getNumProperties() + 2);
    // Pre-append src and dst node ID columns.
    types.push_back(LogicalType::INTERNAL_ID());
    types.push_back(LogicalType::INTERNAL_ID());
    for (auto& property : table.getProperties()) {
        types.push_back(property.getType().copy());
    }
    return types;
}

LocalRelTable::LocalRelTable(const catalog::TableCatalogEntry* tableEntry, const Table& table,
    MemoryManager& mm)
    : LocalTable{table} {
    localNodeGroup = std::make_unique<NodeGroup>(mm, 0, false,
        getTypesForLocalRelTable(*tableEntry), INVALID_ROW_IDX);
    const auto& relTable = table.cast<RelTable>();
    for (auto relDirection : relTable.getStorageDirections()) {
        directedIndices.emplace_back(relDirection);
    }
}

bool LocalRelTable::insert(Transaction*, TableInsertState& state) {
    const auto& insertState = state.cast<RelTableInsertState>();
    const auto relIDVector = insertState.propertyVectors[0];
    DASSERT(relIDVector->dataType.getPhysicalType() == PhysicalTypeID::INTERNAL_ID);
    const auto numRowsToAppend = relIDVector->state->getSelVector().getSelSize();

    const auto getPosToRead = [](const ValueVector& vector, uint64_t i) -> sel_t {
        const auto& selVector = vector.state->getSelVector();
        if (vector.state->isFlat()) {
            return selVector[0];
        }
        return selVector[i];
    };

    for (auto i = 0u; i < numRowsToAppend; i++) {
        for (auto& directedIndex : directedIndices) {
            const auto& nodeIDVector = insertState.getBoundNodeIDVector(directedIndex.direction);
            const auto nodePos = getPosToRead(nodeIDVector, i);
            if (nodeIDVector.isNull(nodePos)) {
                return false;
            }
        }
    }

    const auto numRowsInLocalTable = localNodeGroup->getNumRows();
    for (auto i = 0u; i < numRowsToAppend; i++) {
        const auto relIDPos = getPosToRead(*relIDVector, i);
        const auto relOffset = StorageConstants::MAX_NUM_ROWS_IN_TABLE + numRowsInLocalTable + i;
        relIDVector->setValue<internalID_t>(relIDPos, internalID_t{relOffset, table.getTableID()});
        relIDVector->setNull(relIDPos, false);
    }

    std::vector<ValueVector*> insertVectors;
    insertVectors.push_back(&insertState.srcNodeIDVector);
    insertVectors.push_back(&insertState.dstNodeIDVector);
    for (auto i = 0u; i < insertState.propertyVectors.size(); i++) {
        insertVectors.push_back(insertState.propertyVectors[i]);
    }
    localNodeGroup->append(&DUMMY_TRANSACTION, insertVectors, 0, numRowsToAppend);

    for (auto i = 0u; i < numRowsToAppend; i++) {
        const auto rowToInsert = static_cast<row_idx_t>(numRowsInLocalTable + i);
        for (auto& directedIndex : directedIndices) {
            const auto& nodeIDVector = insertState.getBoundNodeIDVector(directedIndex.direction);
            const auto nodePos = getPosToRead(nodeIDVector, i);
            const auto nodeOffset = nodeIDVector.readNodeOffset(nodePos);
            directedIndex.index[nodeOffset].push_back(rowToInsert);
        }
    }

    return true;
}

bool LocalRelTable::update(Transaction* transaction, TableUpdateState& state) {
    DASSERT(transaction->isDummy());
    const auto& updateState = state.cast<RelTableUpdateState>();

    std::vector<row_idx_vec_t*> rowIndicesToUpdate;
    for (auto& directedIndex : directedIndices) {
        const auto& nodeIDVector = updateState.getBoundNodeIDVector(directedIndex.direction);
        DASSERT(nodeIDVector.state->getSelVector().getSelSize() == 1);
        auto nodePos = nodeIDVector.state->getSelVector()[0];
        if (nodeIDVector.isNull(nodePos)) {
            return false;
        }
        auto nodeOffset = nodeIDVector.readNodeOffset(nodePos);
        rowIndicesToUpdate.push_back(&directedIndex.index[nodeOffset]);
    }

    const auto relIDPos = updateState.relIDVector.state->getSelVector()[0];
    if (updateState.relIDVector.isNull(relIDPos)) {
        return false;
    }
    const auto relOffset = updateState.relIDVector.readNodeOffset(relIDPos);
    const auto matchedRow = findMatchingRow(transaction, rowIndicesToUpdate, relOffset);
    if (matchedRow == INVALID_ROW_IDX) {
        return false;
    }
    DASSERT(updateState.columnID != NBR_ID_COLUMN_ID);
    localNodeGroup->update(transaction, matchedRow,
        rewriteLocalColumnID(RelDataDirection::FWD /* This is a dummy direction */,
            updateState.columnID),
        updateState.propertyVector);
    return true;
}

bool LocalRelTable::delete_(Transaction* transaction, TableDeleteState& state) {
    const auto& deleteState = state.cast<RelTableDeleteState>();

    std::vector<row_idx_vec_t*> rowIndicesToDeleteFrom;
    auto& directedIndex =
        directedIndices[RelDirectionUtils::relDirectionToKeyIdx(deleteState.detachDeleteDirection)];
    auto& reverseDirectedIndex = directedIndices[RelDirectionUtils::relDirectionToKeyIdx(
        RelDirectionUtils::getOppositeDirection(deleteState.detachDeleteDirection))];
    std::vector<std::pair<DirectedCSRIndex&, ValueVector&>> directedIndicesAndNodeIDVectors;
    auto directedIndexPos =
        RelDirectionUtils::relDirectionToKeyIdx(deleteState.detachDeleteDirection);
    if (directedIndexPos < directedIndices.size()) {
        directedIndicesAndNodeIDVectors.emplace_back(directedIndex, deleteState.srcNodeIDVector);
    }
    auto reverseDirectedIndexPos = RelDirectionUtils::relDirectionToKeyIdx(
        RelDirectionUtils::getOppositeDirection(deleteState.detachDeleteDirection));
    if (reverseDirectedIndexPos < directedIndices.size()) {
        directedIndicesAndNodeIDVectors.emplace_back(reverseDirectedIndex,
            deleteState.dstNodeIDVector);
    }
    for (auto& [csrIndex, nodeIDVector] : directedIndicesAndNodeIDVectors) {
        DASSERT(nodeIDVector.state->getSelVector().getSelSize() == 1);
        auto nodePos = nodeIDVector.state->getSelVector()[0];
        if (nodeIDVector.isNull(nodePos)) {
            return false;
        }
        auto nodeOffset = nodeIDVector.readNodeOffset(nodePos);
        DASSERT(csrIndex.index.contains(nodeOffset));
        rowIndicesToDeleteFrom.push_back(&csrIndex.index[nodeOffset]);
    }

    const auto relIDPos = deleteState.relIDVector.state->getSelVector()[0];
    if (deleteState.relIDVector.isNull(relIDPos)) {
        return false;
    }
    const auto relOffset = deleteState.relIDVector.readNodeOffset(relIDPos);
    const auto matchedRow = findMatchingRow(transaction, rowIndicesToDeleteFrom, relOffset);
    if (matchedRow == INVALID_ROW_IDX) {
        return false;
    }

    for (auto* rowIndexToDeleteFrom : rowIndicesToDeleteFrom) {
        std::erase(*rowIndexToDeleteFrom, matchedRow);
    }
    return true;
}

bool LocalRelTable::addColumn(TableAddColumnState& addColumnState) {
    localNodeGroup->addColumn(addColumnState, nullptr /* FileHandle */,
        nullptr /* newColumnStats */);
    return true;
}

bool LocalRelTable::checkIfNodeHasRels(ValueVector* srcNodeIDVector,
    RelDataDirection direction) const {
    DASSERT(srcNodeIDVector->state->isFlat());
    const auto nodeIDPos = srcNodeIDVector->state->getSelVector()[0];
    const auto nodeOffset = srcNodeIDVector->getValue<nodeID_t>(nodeIDPos).offset;
    const auto& directedIndex =
        directedIndices[RelDirectionUtils::relDirectionToKeyIdx(direction)].index;
    return (directedIndex.contains(nodeOffset) && !directedIndex.at(nodeOffset).empty());
}

void LocalRelTable::initializeScan(TableScanState& state) {
    auto& relScanState = state.cast<RelTableScanState>();
    DASSERT(relScanState.source == TableScanSource::UNCOMMITTED);
    DASSERT(relScanState.localTableScanState);
    auto& localScanState = *relScanState.localTableScanState;
    localScanState.rowIndices.clear();
    localScanState.nextRowToScan = 0;
}

std::vector<column_id_t> LocalRelTable::rewriteLocalColumnIDs(RelDataDirection direction,
    const std::vector<column_id_t>& columnIDs) {
    std::vector<column_id_t> localColumnIDs;
    localColumnIDs.reserve(columnIDs.size());
    for (auto i = 0u; i < columnIDs.size(); i++) {
        const auto columnID = columnIDs[i];
        localColumnIDs.push_back(rewriteLocalColumnID(direction, columnID));
    }
    return localColumnIDs;
}

column_id_t LocalRelTable::rewriteLocalColumnID(RelDataDirection direction, column_id_t columnID) {
    return columnID == NBR_ID_COLUMN_ID ? direction == RelDataDirection::FWD ?
                                          LOCAL_NBR_NODE_ID_COLUMN_ID :
                                          LOCAL_BOUND_NODE_ID_COLUMN_ID :
                                          columnID + 1;
}

bool LocalRelTable::scan(const Transaction* transaction, TableScanState& state) const {
    auto& relScanState = state.cast<RelTableScanState>();
    DASSERT(relScanState.localTableScanState);
    auto& localScanState = *relScanState.localTableScanState;
    while (true) {
        if (relScanState.currBoundNodeIdx >= relScanState.cachedBoundNodeSelVector.getSelSize()) {
            return false;
        }
        const auto boundNodePos =
            relScanState.cachedBoundNodeSelVector[relScanState.currBoundNodeIdx];
        const auto boundNodeOffset = relScanState.nodeIDVector->readNodeOffset(boundNodePos);
        auto& localCSRIndex =
            directedIndices[RelDirectionUtils::relDirectionToKeyIdx(relScanState.direction)].index;
        if (localScanState.rowIndices.empty() && localCSRIndex.contains(boundNodeOffset)) {
            localScanState.rowIndices = localCSRIndex.at(boundNodeOffset);
            localScanState.nextRowToScan = 0;
            DASSERT(
                std::is_sorted(localScanState.rowIndices.begin(), localScanState.rowIndices.end()));
        }
        DASSERT(localScanState.rowIndices.size() >= localScanState.nextRowToScan);
        const auto numToScan =
            std::min(localScanState.rowIndices.size() - localScanState.nextRowToScan,
                DEFAULT_VECTOR_CAPACITY);
        if (numToScan == 0) {
            relScanState.currBoundNodeIdx++;
            localScanState.nextRowToScan = 0;
            localScanState.rowIndices.clear();
            continue;
        }
        for (auto i = 0u; i < numToScan; i++) {
            localScanState.rowIdxVector->setValue<row_idx_t>(i,
                localScanState.rowIndices[localScanState.nextRowToScan + i]);
        }
        localScanState.outState->setToUnflat();
        localScanState.rowIdxVector->state->getSelVectorUnsafe().setToUnfiltered(numToScan);
        [[maybe_unused]] auto lookupRes =
            localNodeGroup->lookupMultiple(transaction, localScanState);
        localScanState.nextRowToScan += numToScan;
        relScanState.setNodeIDVectorToFlat(
            relScanState.cachedBoundNodeSelVector[relScanState.currBoundNodeIdx]);
        return true;
    }
}

static std::unique_ptr<RelTableScanState> setupLocalTableScanState(DataChunk& scanChunk,
    std::span<row_idx_t> intersectRows) {
    const std::vector columnIDs{LOCAL_REL_ID_COLUMN_ID};
    auto scanState = std::make_unique<RelTableScanState>(nullptr,
        std::vector{&scanChunk.getValueVectorMutable(0)}, scanChunk.state);
    scanState->columnIDs = columnIDs;
    scanState->nodeGroupScanState->chunkStates.resize(columnIDs.size());
    scanChunk.state->getSelVectorUnsafe().setSelSize(intersectRows.size());
    for (uint64_t i = 0; i < intersectRows.size(); i++) {
        scanState->rowIdxVector->setValue<row_idx_t>(i, intersectRows[i]);
    }
    return scanState;
}

row_idx_t LocalRelTable::findMatchingRow(const Transaction* transaction,
    const std::vector<row_idx_vec_t*>& rowIndicesToCheck, offset_t relOffset) const {
    for (auto* rowIndex : rowIndicesToCheck) {
        std::sort(rowIndex->begin(), rowIndex->end());
    }
    std::vector<row_idx_t> intersectRows =
        std::accumulate(rowIndicesToCheck.begin(), rowIndicesToCheck.end(), *rowIndicesToCheck[0],
            [](row_idx_vec_t curIntersection, row_idx_vec_t* rowIndex) -> row_idx_vec_t {
                row_idx_vec_t ret;
                std::set_intersection(curIntersection.begin(), curIntersection.end(),
                    rowIndex->begin(), rowIndex->end(), std::back_inserter(ret));
                return ret;
            });
    // Loop over relID column chunks to find the relID.
    const auto numVectorsToScan =
        ceilDiv(static_cast<uint64_t>(intersectRows.size()), DEFAULT_VECTOR_CAPACITY);
    for (uint64_t vectorIdx = 0; vectorIdx < numVectorsToScan; ++vectorIdx) {
        DataChunk scanChunk(1);
        scanChunk.insert(0, std::make_shared<ValueVector>(LogicalType::INTERNAL_ID()));

        const uint64_t startRowToScan = vectorIdx * DEFAULT_VECTOR_CAPACITY;
        const auto endRowToScan = std::min(startRowToScan + DEFAULT_VECTOR_CAPACITY,
            static_cast<uint64_t>(intersectRows.size()));
        std::span currentRowsToCheck{intersectRows.begin() + startRowToScan,
            intersectRows.begin() + endRowToScan};
        const auto scanState = setupLocalTableScanState(scanChunk, currentRowsToCheck);

        [[maybe_unused]] auto lookupRes = localNodeGroup->lookupMultiple(transaction, *scanState);
        const auto scannedRelIDVector = scanState->outputVectors[0];
        DASSERT(
            scannedRelIDVector->state->getSelVector().getSelSize() == currentRowsToCheck.size());
        for (auto i = 0u; i < currentRowsToCheck.size(); i++) {
            if (scannedRelIDVector->getValue<internalID_t>(i).offset == relOffset) {
                return currentRowsToCheck[i];
            }
        }
    }
    return INVALID_ROW_IDX;
}

} // namespace storage
} // namespace lbug