#include "processor/operator/scan/count_rel_table.h"
#include "common/system_config.h"
#include "main/client_context.h"
#include "main/database.h"
#include "processor/execution_context.h"
#include "storage/buffer_manager/memory_manager.h"
#include "storage/local_storage/local_rel_table.h"
#include "storage/local_storage/local_storage.h"
#include "storage/table/column.h"
#include "storage/table/column_chunk_data.h"
#include "storage/table/csr_chunked_node_group.h"
#include "storage/table/csr_node_group.h"
#include "storage/table/rel_table_data.h"
#include "transaction/transaction.h"
using namespace lbug::common;
using namespace lbug::storage;
using namespace lbug::transaction;
namespace lbug {
namespace processor {
void CountRelTable::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* ) {
countVector = resultSet->getValueVector(countOutputPos).get();
hasExecuted = false;
totalCount = 0;
}
bool CountRelTable::getNextTuplesInternal(ExecutionContext* context) {
if (hasExecuted) {
return false;
}
auto transaction = Transaction::Get(*context->clientContext);
auto* memoryManager = context->clientContext->getDatabase()->getMemoryManager();
for (auto* relTable : relTables) {
auto* relTableData = relTable->getDirectedTableData(direction);
auto numNodeGroups = relTableData->getNumNodeGroups();
auto* csrLengthColumn = relTableData->getCSRLengthColumn();
for (node_group_idx_t nodeGroupIdx = 0; nodeGroupIdx < numNodeGroups; nodeGroupIdx++) {
auto* nodeGroup = relTableData->getNodeGroup(nodeGroupIdx);
if (!nodeGroup) {
continue;
}
auto& csrNodeGroup = nodeGroup->cast<CSRNodeGroup>();
if (auto* persistentGroup = csrNodeGroup.getPersistentChunkedGroup()) {
auto& csrPersistentGroup = persistentGroup->cast<ChunkedCSRNodeGroup>();
auto& csrHeader = csrPersistentGroup.getCSRHeader();
auto numNodes = csrHeader.length->getNumValues();
if (numNodes == 0) {
continue;
}
auto lengthChunk =
ColumnChunkFactory::createColumnChunkData(*memoryManager, LogicalType::UINT64(),
false , StorageConfig::NODE_GROUP_SIZE,
ResidencyState::IN_MEMORY, false );
ChunkState chunkState;
csrHeader.length->initializeScanState(chunkState, csrLengthColumn);
csrLengthColumn->scan(chunkState, lengthChunk.get(), 0 , numNodes);
auto* lengthData = reinterpret_cast<const uint64_t*>(lengthChunk->getData());
row_idx_t groupRelCount = 0;
for (offset_t i = 0; i < numNodes; ++i) {
groupRelCount += lengthData[i];
}
totalCount += groupRelCount;
if (persistentGroup->hasVersionInfo()) {
auto numDeletions =
persistentGroup->getNumDeletions(transaction, 0, groupRelCount);
totalCount -= numDeletions;
}
}
auto numChunkedGroups = csrNodeGroup.getNumChunkedGroups();
for (node_group_idx_t i = 0; i < numChunkedGroups; i++) {
auto* chunkedGroup = csrNodeGroup.getChunkedNodeGroup(i);
if (chunkedGroup) {
auto numRows = chunkedGroup->getNumRows();
totalCount += numRows;
if (chunkedGroup->hasVersionInfo()) {
auto numDeletions = chunkedGroup->getNumDeletions(transaction, 0, numRows);
totalCount -= numDeletions;
}
}
}
}
if (transaction->isWriteTransaction()) {
if (auto* localTable =
transaction->getLocalStorage()->getLocalTable(relTable->getTableID())) {
auto& localRelTable = localTable->cast<LocalRelTable>();
auto& csrIndex = localRelTable.getCSRIndex(direction);
for (const auto& [nodeOffset, rowIndices] : csrIndex) {
totalCount += rowIndices.size();
}
}
}
}
hasExecuted = true;
countVector->state->getSelVectorUnsafe().setToUnfiltered(1);
countVector->setValue<int64_t>(0, static_cast<int64_t>(totalCount));
return true;
}
} }