#include "storage/storage_manager.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/arrow/arrow.h"
#include "common/file_system/virtual_file_system.h"
#include "common/random_engine.h"
#include "common/serializer/in_mem_file_writer.h"
#include "main/attached_database.h"
#include "main/client_context.h"
#include "main/database.h"
#include "main/database_manager.h"
#include "main/db_config.h"
#include "storage/buffer_manager/buffer_manager.h"
#include "storage/buffer_manager/memory_manager.h"
#include "storage/checkpointer.h"
#include "storage/table/arrow_node_table.h"
#include "storage/table/arrow_rel_table.h"
#include "storage/table/arrow_table_support.h"
#include "storage/table/foreign_rel_table.h"
#include "storage/table/node_table.h"
#include "storage/table/parquet_node_table.h"
#include "storage/table/parquet_rel_table.h"
#include "storage/table/rel_table.h"
#include "storage/wal/wal_replayer.h"
#include "transaction/transaction.h"
#include <format>
using namespace lbug::catalog;
using namespace lbug::common;
using namespace lbug::transaction;
namespace lbug {
namespace storage {
StorageManager::StorageManager(const std::string& databasePath, bool readOnly, bool enableChecksums,
MemoryManager& memoryManager, bool enableCompression, VirtualFileSystem* vfs)
: databasePath{databasePath}, readOnly{readOnly}, dataFH{nullptr}, memoryManager{memoryManager},
enableCompression{enableCompression} {
wal = std::make_unique<WAL>(databasePath, readOnly, enableChecksums, vfs);
shadowFile =
std::make_unique<ShadowFile>(*memoryManager.getBufferManager(), vfs, this->databasePath);
inMemory = main::DBConfig::isDBPathInMemory(databasePath);
registerIndexType(PrimaryKeyIndex::getIndexType());
}
StorageManager::~StorageManager() = default;
void StorageManager::initDataFileHandle(VirtualFileSystem* vfs, main::ClientContext* context) {
if (inMemory) {
dataFH = memoryManager.getBufferManager()->getFileHandle(databasePath,
FileHandle::O_PERSISTENT_FILE_IN_MEM, vfs, context);
} else {
auto flag = readOnly ? FileHandle::O_PERSISTENT_FILE_READ_ONLY :
FileHandle::O_PERSISTENT_FILE_CREATE_NOT_EXISTS;
if (!readOnly) {
flag |= FileHandle::O_LOCKED_PERSISTENT_FILE;
}
dataFH = memoryManager.getBufferManager()->getFileHandle(databasePath, flag, vfs, context);
if (dataFH->getNumPages() == 0) {
if (!readOnly) {
dataFH->getPageManager()->allocatePage();
const auto* initialHeader = getOrInitDatabaseHeader(*context);
auto headerWriter =
std::make_shared<InMemFileWriter>(*MemoryManager::Get(*context));
Serializer headerSerializer(headerWriter);
initialHeader->serialize(headerSerializer);
dataFH->getFileInfo()->writeFile(headerWriter->getPage(0).data(), LBUG_PAGE_SIZE,
StorageConstants::DB_HEADER_PAGE_IDX);
dataFH->getFileInfo()->syncFile();
}
}
}
}
void StorageManager::closeFileHandle() {
if (dataFH != nullptr) {
dataFH->resetFileInfo();
}
}
Table* StorageManager::getTable(table_id_t tableID) {
std::shared_lock lck{mtx};
DASSERT(tables.contains(tableID));
return tables.at(tableID).get();
}
void StorageManager::recover(main::ClientContext& clientContext, bool throwOnWalReplayFailure,
bool enableChecksums) {
const auto walReplayer = std::make_unique<WALReplayer>(clientContext);
walReplayer->replay(throwOnWalReplayFailure, enableChecksums);
}
void StorageManager::createNodeTable(NodeTableCatalogEntry* entry) {
tableNameCache[entry->getTableID()] = entry->getName();
if (!entry->getStorage().empty()) {
if (entry->getStorage().substr(0, 8) == "arrow://") {
std::string arrowId = entry->getStorage().substr(8);
ArrowSchemaWrapper* schema = nullptr;
std::vector<ArrowArrayWrapper>* arrays = nullptr;
if (!ArrowTableSupport::getArrowData(arrowId, schema, arrays)) {
throw common::RuntimeException("Failed to retrieve Arrow data for ID: " + arrowId);
}
ArrowSchemaWrapper schemaCopy = createShallowCopy(*schema);
std::vector<ArrowArrayWrapper> arraysCopy;
arraysCopy.reserve(arrays->size());
for (const auto& arr : *arrays) {
arraysCopy.push_back(createShallowCopy(arr));
}
tables[entry->getTableID()] = std::make_unique<ArrowNodeTable>(this, entry,
&memoryManager, std::move(schemaCopy), std::move(arraysCopy), arrowId);
} else {
tables[entry->getTableID()] =
std::make_unique<ParquetNodeTable>(this, entry, &memoryManager);
}
} else {
tables[entry->getTableID()] = std::make_unique<NodeTable>(this, entry, &memoryManager);
}
}
void StorageManager::addRelTable(RelGroupCatalogEntry* entry, const RelTableCatalogInfo& info) {
if (entry->getScanFunction().has_value()) {
tables[info.oid] = std::make_unique<ForeignRelTable>(entry, info.nodePair.srcTableID,
info.nodePair.dstTableID, this, &memoryManager, *entry->getScanFunction(),
std::move(entry->getScanBindData().value()));
} else if (!entry->getStorage().empty()) {
if (entry->getStorage().substr(0, 8) == "arrow://") {
std::string arrowId = entry->getStorage().substr(8);
ArrowSchemaWrapper* schema = nullptr;
std::vector<ArrowArrayWrapper>* arrays = nullptr;
if (!ArrowTableSupport::getArrowData(arrowId, schema, arrays)) {
throw common::RuntimeException("Failed to retrieve Arrow data for ID: " + arrowId);
}
if (!tables.contains(info.nodePair.srcTableID) ||
!tables.contains(info.nodePair.dstTableID)) {
throw common::RuntimeException(
"Source or destination node table is not initialized for Arrow rel table");
}
auto* fromNodeTable = tables.at(info.nodePair.srcTableID)->ptrCast<NodeTable>();
auto* toNodeTable = tables.at(info.nodePair.dstTableID)->ptrCast<NodeTable>();
if (!fromNodeTable || !toNodeTable) {
throw common::RuntimeException(
"Arrow rel table currently supports only regular node tables");
}
ArrowSchemaWrapper schemaCopy = createShallowCopy(*schema);
std::vector<ArrowArrayWrapper> arraysCopy;
arraysCopy.reserve(arrays->size());
for (const auto& arr : *arrays) {
arraysCopy.push_back(createShallowCopy(arr));
}
tables[info.oid] = std::make_unique<ArrowRelTable>(entry, info.nodePair.srcTableID,
info.nodePair.dstTableID, this, &memoryManager, fromNodeTable, toNodeTable,
std::move(schemaCopy), std::move(arraysCopy), arrowId);
} else {
tables[info.oid] = std::make_unique<ParquetRelTable>(entry, info.nodePair.srcTableID,
info.nodePair.dstTableID, this, &memoryManager);
}
} else {
tables[info.oid] = std::make_unique<RelTable>(entry, info.nodePair.srcTableID,
info.nodePair.dstTableID, this, &memoryManager);
}
}
void StorageManager::createRelTableGroup(RelGroupCatalogEntry* entry) {
for (auto& info : entry->getRelEntryInfos()) {
addRelTable(entry, info);
}
}
void StorageManager::createTable(TableCatalogEntry* entry) {
std::unique_lock lck{mtx};
switch (entry->getType()) {
case CatalogEntryType::NODE_TABLE_ENTRY: {
createNodeTable(entry->ptrCast<NodeTableCatalogEntry>());
} break;
case CatalogEntryType::REL_GROUP_ENTRY: {
createRelTableGroup(entry->ptrCast<RelGroupCatalogEntry>());
} break;
default: {
UNREACHABLE_CODE;
}
}
}
WAL& StorageManager::getWAL() const {
DASSERT(wal);
return *wal;
}
ShadowFile& StorageManager::getShadowFile() const {
DASSERT(shadowFile);
return *shadowFile;
}
void StorageManager::reclaimDroppedTables(const Catalog& catalog) {
std::unique_lock lck{mtx};
std::vector<table_id_t> droppedTables;
for (const auto& [tableID, table] : tables) {
switch (table->getTableType()) {
case TableType::NODE: {
if (!catalog.containsTable(&DUMMY_CHECKPOINT_TRANSACTION, tableID, true)) {
table->reclaimStorage(*dataFH->getPageManager());
droppedTables.push_back(tableID);
}
} break;
case TableType::REL: {
auto& relTable = table->cast<RelTable>();
auto relGroupID = relTable.getRelGroupID();
if (!catalog.containsTable(&DUMMY_CHECKPOINT_TRANSACTION, relGroupID, true)) {
table->reclaimStorage(*dataFH->getPageManager());
droppedTables.push_back(tableID);
} else {
auto relGroupEntry =
catalog.getTableCatalogEntry(&DUMMY_CHECKPOINT_TRANSACTION, relGroupID);
if (!relGroupEntry->cast<RelGroupCatalogEntry>().getRelEntryInfo(
relTable.getFromNodeTableID(), relTable.getToNodeTableID())) {
table->reclaimStorage(*dataFH->getPageManager());
droppedTables.push_back(tableID);
}
}
}
default: {
}
}
}
for (auto tableID : droppedTables) {
tables.erase(tableID);
}
}
bool StorageManager::checkpoint(main::ClientContext* context, PageAllocator& pageAllocator) {
bool hasChanges = false;
const auto catalog = Catalog::Get(*context);
const auto nodeTableEntries = catalog->getNodeTableEntries(&DUMMY_CHECKPOINT_TRANSACTION);
const auto relGroupEntries = catalog->getRelGroupEntries(&DUMMY_CHECKPOINT_TRANSACTION);
std::shared_lock lck{mtx};
for (const auto entry : nodeTableEntries) {
if (!tables.contains(entry->getTableID())) {
throw RuntimeException(std::format(
"Checkpoint failed: table {} not found in storage manager.", entry->getName()));
}
hasChanges =
tables.at(entry->getTableID())->checkpoint(context, entry, pageAllocator) || hasChanges;
}
for (const auto entry : relGroupEntries) {
for (auto& info : entry->getRelEntryInfos()) {
if (!tables.contains(info.oid)) {
throw RuntimeException(std::format(
"Checkpoint failed: table {} not found in storage manager.", entry->getName()));
}
hasChanges =
tables.at(info.oid)->checkpoint(context, entry, pageAllocator) || hasChanges;
}
entry->vacuumColumnIDs(1);
}
lck.unlock();
reclaimDroppedTables(*catalog);
return hasChanges;
}
bool StorageManager::checkpoint(main::ClientContext* context, const Transaction& snapshotTxn,
PageAllocator& pageAllocator, const std::unordered_map<table_id_t, uint64_t>& epochWatermarks) {
bool hasChanges = false;
const auto catalog = Catalog::Get(*context);
const auto nodeTableEntries = catalog->getNodeTableEntries(&snapshotTxn);
const auto relGroupEntries = catalog->getRelGroupEntries(&snapshotTxn);
std::shared_lock lck{mtx};
for (const auto entry : nodeTableEntries) {
if (!tables.contains(entry->getTableID())) {
throw RuntimeException(std::format(
"Checkpoint failed: table {} not found in storage manager.", entry->getName()));
}
const auto watermarkIt = epochWatermarks.find(entry->getTableID());
const uint64_t watermark = watermarkIt != epochWatermarks.end() ? watermarkIt->second : 0;
hasChanges = tables.at(entry->getTableID())
->checkpoint(context, entry, pageAllocator, &snapshotTxn, watermark) ||
hasChanges;
}
for (const auto entry : relGroupEntries) {
for (auto& info : entry->getRelEntryInfos()) {
if (!tables.contains(info.oid)) {
throw RuntimeException(std::format(
"Checkpoint failed: table {} not found in storage manager.", entry->getName()));
}
const auto watermarkIt = epochWatermarks.find(info.oid);
const uint64_t watermark =
watermarkIt != epochWatermarks.end() ? watermarkIt->second : 0;
hasChanges = tables.at(info.oid)->checkpoint(context, entry, pageAllocator,
&snapshotTxn, watermark) ||
hasChanges;
}
entry->vacuumColumnIDs(1);
}
lck.unlock();
reclaimDroppedTables(*catalog);
return hasChanges;
}
std::unordered_map<table_id_t, uint64_t> StorageManager::captureChangeEpochs() const {
std::shared_lock lck{mtx};
std::unordered_map<table_id_t, uint64_t> epochs;
for (const auto& [id, table] : tables) {
epochs[id] = table->getChangeEpoch();
}
return epochs;
}
void StorageManager::finalizeCheckpoint() {
dataFH->getPageManager()->finalizeCheckpoint();
}
void StorageManager::rollbackCheckpoint(const Catalog& catalog) {
std::unique_lock lck{mtx};
const auto nodeTableEntries = catalog.getNodeTableEntries(&DUMMY_CHECKPOINT_TRANSACTION);
for (const auto tableEntry : nodeTableEntries) {
DASSERT(tables.contains(tableEntry->getTableID()));
tables.at(tableEntry->getTableID())->rollbackCheckpoint();
}
dataFH->getPageManager()->rollbackCheckpoint();
}
std::optional<std::reference_wrapper<const IndexType>> StorageManager::getIndexType(
const std::string& typeName) const {
for (auto& indexType : registeredIndexTypes) {
if (StringUtils::caseInsensitiveEquals(indexType.typeName, typeName)) {
return indexType;
}
}
return std::nullopt;
}
void StorageManager::serialize(const Catalog& catalog, Serializer& ser) {
std::shared_lock lck{mtx};
auto nodeTableEntries = catalog.getNodeTableEntries(&DUMMY_CHECKPOINT_TRANSACTION);
auto relGroupEntries = catalog.getRelGroupEntries(&DUMMY_CHECKPOINT_TRANSACTION);
std::sort(nodeTableEntries.begin(), nodeTableEntries.end(),
[](const auto& a, const auto& b) { return a->getTableID() < b->getTableID(); });
std::sort(relGroupEntries.begin(), relGroupEntries.end(),
[](const auto& a, const auto& b) { return a->getTableID() < b->getTableID(); });
ser.writeDebuggingInfo("num_node_tables");
ser.write<uint64_t>(nodeTableEntries.size());
for (const auto tableEntry : nodeTableEntries) {
DASSERT(tables.contains(tableEntry->getTableID()));
ser.writeDebuggingInfo("table_id");
ser.write<table_id_t>(tableEntry->getTableID());
tables.at(tableEntry->getTableID())->serialize(ser);
}
ser.writeDebuggingInfo("num_rel_groups");
ser.write<uint64_t>(relGroupEntries.size());
for (const auto entry : relGroupEntries) {
const auto& relGroupEntry = entry->cast<RelGroupCatalogEntry>();
ser.writeDebuggingInfo("rel_group_id");
ser.write<table_id_t>(relGroupEntry.getTableID());
ser.writeDebuggingInfo("num_inner_rel_tables");
ser.write<uint64_t>(relGroupEntry.getNumRelTables());
for (auto& info : relGroupEntry.getRelEntryInfos()) {
DASSERT(tables.contains(info.oid));
info.serialize(ser);
tables.at(info.oid)->serialize(ser);
}
}
}
void StorageManager::serialize(const Catalog& catalog, const Transaction& snapshotTxn,
Serializer& ser) {
auto nodeTableEntries = catalog.getNodeTableEntries(&snapshotTxn);
auto relGroupEntries = catalog.getRelGroupEntries(&snapshotTxn);
std::sort(nodeTableEntries.begin(), nodeTableEntries.end(),
[](const auto& a, const auto& b) { return a->getTableID() < b->getTableID(); });
std::sort(relGroupEntries.begin(), relGroupEntries.end(),
[](const auto& a, const auto& b) { return a->getTableID() < b->getTableID(); });
std::shared_lock lck{mtx};
ser.writeDebuggingInfo("num_node_tables");
ser.write<uint64_t>(nodeTableEntries.size());
for (const auto tableEntry : nodeTableEntries) {
DASSERT(tables.contains(tableEntry->getTableID()));
ser.writeDebuggingInfo("table_id");
ser.write<table_id_t>(tableEntry->getTableID());
tables.at(tableEntry->getTableID())->serialize(ser);
}
ser.writeDebuggingInfo("num_rel_groups");
ser.write<uint64_t>(relGroupEntries.size());
for (const auto entry : relGroupEntries) {
const auto& relGroupEntry = entry->cast<RelGroupCatalogEntry>();
ser.writeDebuggingInfo("rel_group_id");
ser.write<table_id_t>(relGroupEntry.getTableID());
ser.writeDebuggingInfo("num_inner_rel_tables");
ser.write<uint64_t>(relGroupEntry.getNumRelTables());
for (auto& info : relGroupEntry.getRelEntryInfos()) {
DASSERT(tables.contains(info.oid));
info.serialize(ser);
tables.at(info.oid)->serialize(ser);
}
}
}
void StorageManager::deserialize(main::ClientContext* context, const Catalog* catalog,
Deserializer& deSer) {
std::string key;
deSer.validateDebuggingInfo(key, "num_node_tables");
uint64_t numNodeTables = 0;
deSer.deserializeValue<uint64_t>(numNodeTables);
for (auto i = 0u; i < numNodeTables; i++) {
deSer.validateDebuggingInfo(key, "table_id");
table_id_t tableID = INVALID_TABLE_ID;
deSer.deserializeValue<table_id_t>(tableID);
if (!catalog->containsTable(&DUMMY_TRANSACTION, tableID)) {
throw RuntimeException(
std::format("Load table failed: table {} doesn't exist in catalog.", tableID));
}
DASSERT(!tables.contains(tableID));
auto tableEntry = catalog->getTableCatalogEntry(&DUMMY_TRANSACTION, tableID)
->ptrCast<NodeTableCatalogEntry>();
tableNameCache[tableID] = tableEntry->getName();
if (!tableEntry->getStorage().empty()) {
tables[tableID] = std::make_unique<ParquetNodeTable>(this, tableEntry, &memoryManager);
} else {
tables[tableID] = std::make_unique<NodeTable>(this, tableEntry, &memoryManager);
}
tables[tableID]->deserialize(context, this, deSer);
}
deSer.validateDebuggingInfo(key, "num_rel_groups");
uint64_t numRelGroups = 0;
deSer.deserializeValue<uint64_t>(numRelGroups);
for (auto i = 0u; i < numRelGroups; i++) {
deSer.validateDebuggingInfo(key, "rel_group_id");
table_id_t relGroupID = INVALID_TABLE_ID;
deSer.deserializeValue<table_id_t>(relGroupID);
if (!catalog->containsTable(&DUMMY_TRANSACTION, relGroupID)) {
throw RuntimeException(
std::format("Load table failed: table {} doesn't exist in catalog.", relGroupID));
}
deSer.validateDebuggingInfo(key, "num_inner_rel_tables");
uint64_t numInnerRelTables = 0;
deSer.deserializeValue<uint64_t>(numInnerRelTables);
auto relGroupEntry = catalog->getTableCatalogEntry(&DUMMY_TRANSACTION, relGroupID)
->ptrCast<RelGroupCatalogEntry>();
for (auto k = 0u; k < numInnerRelTables; k++) {
RelTableCatalogInfo info = RelTableCatalogInfo::deserialize(deSer);
DASSERT(!tables.contains(info.oid));
if (!relGroupEntry->getStorage().empty()) {
tables[info.oid] = std::make_unique<ParquetRelTable>(relGroupEntry,
info.nodePair.srcTableID, info.nodePair.dstTableID, this, &memoryManager);
} else {
tables[info.oid] = std::make_unique<RelTable>(relGroupEntry,
info.nodePair.srcTableID, info.nodePair.dstTableID, this, &memoryManager);
}
tables.at(info.oid)->deserialize(context, this, deSer);
}
}
}
common::uuid StorageManager::getOrInitDatabaseID(const main::ClientContext& clientContext) {
return getOrInitDatabaseHeader(clientContext)->databaseID;
}
const storage::DatabaseHeader* StorageManager::getOrInitDatabaseHeader(
const main::ClientContext& clientContext) {
if (databaseHeader == nullptr) {
DASSERT(std::nullopt == DatabaseHeader::readDatabaseHeader(*dataFH->getFileInfo()));
databaseHeader = std::make_unique<DatabaseHeader>(
DatabaseHeader::createInitialHeader(RandomEngine::Get(clientContext)));
}
return databaseHeader.get();
}
void StorageManager::setDatabaseHeader(std::unique_ptr<storage::DatabaseHeader> header) {
DASSERT(!databaseHeader || header->databaseID.value == databaseHeader->databaseID.value);
databaseHeader = std::move(header);
}
StorageManager* StorageManager::Get(const main::ClientContext& context) {
if (context.getAttachedDatabase()) {
return context.getAttachedDatabase()->getStorageManager();
}
auto dbManager = main::DatabaseManager::Get(context);
auto graphStorageManager = dbManager->getDefaultGraphStorageManager();
if (graphStorageManager != nullptr) {
return graphStorageManager;
}
return context.getDatabase()->getStorageManager();
}
} }