use crate::{Error, Result, Storable, serialization::helpers, traits::KeyType};
use parking_lot::{Mutex, RwLock};
use rocksdb::{BoundColumnFamily, WriteBatch as RocksWriteBatch};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::path::Path;
use std::sync::Arc;
use tracing::{debug, error, info, instrument, warn};
#[derive(Clone)]
pub struct Database {
pub(crate) inner: Arc<DatabaseInner>,
}
pub(crate) struct DatabaseInner {
pub(crate) db: Arc<rocksdb::DB>,
shutdown: Arc<RwLock<bool>>,
}
impl Database {
pub(crate) fn new(db: rocksdb::DB) -> Self {
Self {
inner: Arc::new(DatabaseInner {
db: Arc::new(db),
shutdown: Arc::new(RwLock::new(false)),
}),
}
}
#[instrument(skip(self))]
pub fn collection<T: Storable>(&self, name: &str) -> Result<Collection<T>> {
let shutdown_guard = self.inner.shutdown.read();
if *shutdown_guard {
return Err(Error::Database("Database has been shut down".to_string()));
}
self.inner.db.cf_handle(name).ok_or_else(|| {
error!("Column family '{}' not found", name);
Error::Database(format!(
"Column family '{}' does not exist. Ensure it was declared in DatabaseConfig::add_column_family() before opening the database.",
name
))
})?;
debug!("Created collection for column family '{}'", name);
drop(shutdown_guard);
Ok(Collection::new(
Arc::clone(&self.inner.db),
name,
Arc::clone(&self.inner.shutdown),
))
}
pub fn list_collections(&self) -> Result<Vec<String>> {
let _guard = self.check_shutdown()?;
rocksdb::DB::list_cf(&rocksdb::Options::default(), self.inner.db.path())
.map_err(|e| Error::Database(format!("Failed to list collections: {}", e)))
}
#[instrument(skip(self))]
pub fn flush(&self) -> Result<()> {
info!("Flushing database");
self.inner.db.flush().map_err(|e| {
error!("Flush failed: {}", e);
Error::Database(format!("Flush failed: {}", e))
})
}
#[instrument(skip(self))]
pub fn compact_all(&self) -> Result<()> {
info!("Compacting entire database");
self.inner.db.compact_range::<&[u8], &[u8]>(None, None);
Ok(())
}
#[instrument(skip(self, backup_path))]
pub fn backup<P: AsRef<Path>>(&self, backup_path: P) -> Result<()> {
use rocksdb::backup::{BackupEngine, BackupEngineOptions};
let _guard = self.check_shutdown()?;
let path = backup_path.as_ref();
info!("Creating backup at {:?}", path);
let backup_opts = BackupEngineOptions::new(path).map_err(|e| {
error!("Failed to create backup options: {}", e);
Error::Database(format!("Failed to create backup options: {}", e))
})?;
let mut backup_engine =
BackupEngine::open(&backup_opts, &rocksdb::Env::new()?).map_err(|e| {
error!("Failed to open backup engine: {}", e);
Error::Database(format!("Failed to open backup engine: {}", e))
})?;
backup_engine
.create_new_backup(&self.inner.db)
.map_err(|e| {
error!("Failed to create backup: {}", e);
Error::Database(format!("Failed to create backup: {}", e))
})?;
info!("Backup created successfully");
Ok(())
}
#[inline]
fn check_shutdown(&self) -> Result<parking_lot::RwLockReadGuard<'_, bool>> {
let guard = self.inner.shutdown.read();
if *guard {
return Err(Error::Database("Database has been shut down".to_string()));
}
Ok(guard)
}
pub fn restore_from_backup<P: AsRef<Path>>(backup_path: P, restore_path: P) -> Result<()> {
use rocksdb::backup::{BackupEngine, BackupEngineOptions, RestoreOptions};
let backup_path = backup_path.as_ref();
let restore_path = restore_path.as_ref();
info!(
"Restoring from backup {:?} to {:?}",
backup_path, restore_path
);
let backup_opts = BackupEngineOptions::new(backup_path).map_err(|e| {
error!("Failed to create backup options: {}", e);
Error::Database(format!("Failed to create backup options: {}", e))
})?;
let mut backup_engine =
BackupEngine::open(&backup_opts, &rocksdb::Env::new()?).map_err(|e| {
error!("Failed to open backup engine: {}", e);
Error::Database(format!("Failed to open backup engine: {}", e))
})?;
let restore_opts = RestoreOptions::default();
backup_engine
.restore_from_latest_backup(restore_path, restore_path, &restore_opts)
.map_err(|e| {
error!("Failed to restore backup: {}", e);
Error::Database(format!("Failed to restore backup: {}", e))
})?;
info!("Backup restored successfully");
Ok(())
}
pub fn list_backups<P: AsRef<Path>>(backup_path: P) -> Result<Vec<BackupInfo>> {
use rocksdb::backup::{BackupEngine, BackupEngineOptions};
let path = backup_path.as_ref();
let backup_opts = BackupEngineOptions::new(path)
.map_err(|e| Error::Database(format!("Failed to create backup options: {}", e)))?;
let backup_engine = BackupEngine::open(&backup_opts, &rocksdb::Env::new()?)
.map_err(|e| Error::Database(format!("Failed to open backup engine: {}", e)))?;
let infos = backup_engine.get_backup_info();
Ok(infos
.iter()
.map(|info| BackupInfo {
backup_id: info.backup_id,
timestamp: info.timestamp,
size: info.size,
})
.collect())
}
#[instrument(skip(self))]
pub fn transaction(&self) -> Result<Transaction> {
let shutdown = self.inner.shutdown.read();
if *shutdown {
return Err(Error::Database("Database has been shut down".to_string()));
}
Ok(Transaction::new(
Arc::clone(&self.inner.db),
Arc::clone(&self.inner.shutdown),
))
}
#[instrument(skip(self))]
pub fn shutdown(&self) -> Result<()> {
info!("Shutting down database");
let mut shutdown_guard = self.inner.shutdown.write();
let flush_result = self.flush();
if flush_result.is_ok() {
*shutdown_guard = true;
info!("Database shutdown complete");
} else {
error!("Shutdown failed: flush error, database remains operational");
}
flush_result
}
}
unsafe impl Send for Database {}
unsafe impl Sync for Database {}
#[derive(Debug, Clone)]
pub struct BackupInfo {
pub backup_id: u32,
pub timestamp: i64,
pub size: u64,
}
#[derive(Debug)]
pub struct Collection<T: Storable> {
db: Arc<rocksdb::DB>,
cf_name: String,
shutdown: Arc<RwLock<bool>>,
_phantom: PhantomData<T>,
}
impl<T: Storable> Collection<T> {
fn new(db: Arc<rocksdb::DB>, name: &str, shutdown: Arc<RwLock<bool>>) -> Self {
Self {
db,
cf_name: name.to_string(),
shutdown,
_phantom: PhantomData,
}
}
fn cf<'a>(&'a self) -> Result<Arc<BoundColumnFamily<'a>>> {
self.db
.cf_handle(&self.cf_name)
.ok_or_else(|| Error::Database(format!("Column family '{}' not found", self.cf_name)))
}
#[inline]
fn check_shutdown(&self) -> Result<parking_lot::RwLockReadGuard<'_, bool>> {
let guard = self.shutdown.read();
if *guard {
return Err(Error::Database("Database has been shut down".to_string()));
}
Ok(guard)
}
#[instrument(skip(self, value))]
pub fn put(&self, value: &T) -> Result<()> {
let _guard = self.check_shutdown()?;
value.validate()?;
let key = value.key();
let key_bytes = key.to_bytes()?;
let value_bytes = helpers::serialize(value)?;
debug!("Putting value in collection '{}'", self.cf_name);
let cf = self.cf()?;
self.db.put_cf(&cf, key_bytes, value_bytes).map_err(|e| {
error!("Failed to put value: {}", e);
Error::Database(format!("Failed to put value: {}", e))
})?;
value.on_stored();
Ok(())
}
#[instrument(skip(self))]
pub fn get(&self, key: &T::Key) -> Result<Option<T>> {
let _guard = self.check_shutdown()?;
let key_bytes = key.to_bytes()?;
let cf = self.cf()?;
match self.db.get_cf(&cf, key_bytes)? {
Some(value_bytes) => {
let value: T = helpers::deserialize(&value_bytes)?;
Ok(Some(value))
}
None => Ok(None),
}
}
#[instrument(skip(self, db))]
pub fn get_with_refs(&self, key: &T::Key, db: &crate::Database) -> Result<Option<T>>
where
T: crate::Referable,
{
let _guard = self.check_shutdown()?;
let key_bytes = key.to_bytes()?;
let cf = self.cf()?;
match self.db.get_cf(&cf, key_bytes)? {
Some(value_bytes) => {
let value: T = helpers::deserialize(&value_bytes)?;
value.resolve_all(db)?;
Ok(Some(value))
}
None => Ok(None),
}
}
#[instrument(skip(self, keys))]
pub fn get_many(&self, keys: &[T::Key]) -> Result<Vec<Option<T>>> {
let _guard = self.check_shutdown()?;
if keys.is_empty() {
return Ok(Vec::new());
}
let key_bytes: Result<Vec<Vec<u8>>> = keys.iter().map(|k| k.to_bytes()).collect();
let key_bytes = key_bytes?;
let cf = self.cf()?;
let cf_refs: Vec<_> = key_bytes.iter().map(|k| (&cf, k.as_slice())).collect();
let results = self.db.multi_get_cf(cf_refs);
let mut output = Vec::with_capacity(keys.len());
for result in results {
match result {
Ok(Some(value_bytes)) => {
let value: T = helpers::deserialize(&value_bytes)?;
output.push(Some(value));
}
Ok(None) => output.push(None),
Err(e) => {
return Err(Error::Database(format!("Multi-get failed: {}", e)));
}
}
}
Ok(output)
}
#[instrument(skip(self, keys, db))]
pub fn get_many_with_refs(
&self,
keys: &[T::Key],
db: &crate::Database,
) -> Result<Vec<Option<T>>>
where
T: crate::Referable,
{
let _guard = self.check_shutdown()?;
if keys.is_empty() {
return Ok(Vec::new());
}
let key_bytes: Result<Vec<Vec<u8>>> = keys.iter().map(|k| k.to_bytes()).collect();
let key_bytes = key_bytes?;
let cf = self.cf()?;
let cf_refs: Vec<_> = key_bytes.iter().map(|k| (&cf, k.as_slice())).collect();
let results = self.db.multi_get_cf(cf_refs);
let mut output = Vec::with_capacity(keys.len());
for result in results {
match result {
Ok(Some(value_bytes)) => {
let value: T = helpers::deserialize(&value_bytes)?;
value.resolve_all(db)?;
output.push(Some(value));
}
Ok(None) => output.push(None),
Err(e) => {
return Err(Error::Database(format!("Multi-get failed: {}", e)));
}
}
}
Ok(output)
}
#[instrument(skip(self))]
pub fn delete(&self, key: &T::Key) -> Result<()> {
let _guard = self.check_shutdown()?;
let key_bytes = key.to_bytes()?;
debug!("Deleting key from collection '{}'", self.cf_name);
let cf = self.cf()?;
self.db.delete_cf(&cf, key_bytes).map_err(|e| {
error!("Failed to delete: {}", e);
Error::Database(format!("Failed to delete: {}", e))
})
}
#[instrument(skip(self))]
pub fn exists(&self, key: &T::Key) -> Result<bool> {
let _guard = self.check_shutdown()?;
Ok(self.get(key)?.is_some())
}
pub fn batch(&self) -> Batch<T> {
Batch::new(Arc::clone(&self.db), self.cf_name.clone())
}
pub fn snapshot(&self) -> Snapshot<T> {
Snapshot::new(Arc::clone(&self.db), self.cf_name.clone())
}
pub fn iter(&self) -> Result<Iterator<T>> {
let _guard = self.check_shutdown()?;
Ok(Iterator::new(
Arc::clone(&self.db),
self.cf_name.clone(),
IteratorMode::Start,
Arc::clone(&self.shutdown),
))
}
pub fn iter_from(&self, key: &T::Key) -> Result<Iterator<T>> {
let _guard = self.check_shutdown()?;
let key_bytes = key.to_bytes()?;
Ok(Iterator::new(
Arc::clone(&self.db),
self.cf_name.clone(),
IteratorMode::From(key_bytes),
Arc::clone(&self.shutdown),
))
}
pub fn estimate_num_keys(&self) -> Result<u64> {
let cf = self.cf()?;
self.db
.property_int_value_cf(&cf, "rocksdb.estimate-num-keys")
.map(|v| v.unwrap_or(0))
.map_err(|e| Error::Database(format!("Failed to get estimate: {}", e)))
}
#[instrument(skip(self))]
pub fn flush(&self) -> Result<()> {
info!("Flushing collection '{}'", self.cf_name);
let cf = self.cf()?;
self.db.flush_cf(&cf).map_err(|e| {
error!("Flush failed: {}", e);
Error::Database(format!("Flush failed: {}", e))
})
}
#[instrument(skip(self, start, end))]
pub fn compact_range(&self, start: Option<&T::Key>, end: Option<&T::Key>) -> Result<()> {
let start_bytes = start.map(|k| k.to_bytes()).transpose()?;
let end_bytes = end.map(|k| k.to_bytes()).transpose()?;
info!("Compacting range in collection '{}'", self.cf_name);
let cf = self.cf()?;
self.db
.compact_range_cf(&cf, start_bytes.as_deref(), end_bytes.as_deref());
Ok(())
}
pub fn name(&self) -> &str {
&self.cf_name
}
}
unsafe impl<T: Storable> Send for Collection<T> {}
unsafe impl<T: Storable> Sync for Collection<T> {}
pub struct Batch<T: Storable> {
db: Arc<rocksdb::DB>,
cf_name: String,
batch: RocksWriteBatch,
_phantom: PhantomData<T>,
}
impl<T: Storable> Batch<T> {
fn new(db: Arc<rocksdb::DB>, cf_name: String) -> Self {
Self {
db,
cf_name,
batch: RocksWriteBatch::default(),
_phantom: PhantomData,
}
}
pub fn put(&mut self, value: &T) -> Result<()> {
value.validate()?;
let key = value.key();
let key_bytes = key.to_bytes()?;
let value_bytes = helpers::serialize(value)?;
let cf = self.db.cf_handle(&self.cf_name).ok_or_else(|| {
Error::Database(format!("Column family '{}' not found", self.cf_name))
})?;
self.batch.put_cf(&cf, &key_bytes, &value_bytes);
Ok(())
}
pub fn delete(&mut self, key: &T::Key) -> Result<()> {
let key_bytes = key.to_bytes()?;
let cf = self.db.cf_handle(&self.cf_name).ok_or_else(|| {
Error::Database(format!("Column family '{}' not found", self.cf_name))
})?;
self.batch.delete_cf(&cf, &key_bytes);
Ok(())
}
pub fn clear(&mut self) {
self.batch.clear();
}
pub fn len(&self) -> usize {
self.batch.len()
}
pub fn is_empty(&self) -> bool {
self.batch.is_empty()
}
#[instrument(skip(self))]
pub fn commit(self) -> Result<()> {
let op_count = self.batch.len();
debug!(
"Committing batch with {} operations to '{}'",
op_count, self.cf_name
);
self.db.write(self.batch).map_err(|e| {
error!("Batch commit failed: {}", e);
Error::Database(format!("Batch commit failed: {}", e))
})
}
}
pub struct Snapshot<T: Storable> {
db: Arc<rocksdb::DB>,
snapshot_ptr: *const rocksdb::SnapshotWithThreadMode<'static, rocksdb::DB>,
cf_name: String,
_phantom: PhantomData<T>,
}
impl<T: Storable> Snapshot<T> {
fn new(db: Arc<rocksdb::DB>, cf_name: String) -> Self {
let snapshot_ptr = unsafe {
let snapshot = db.snapshot();
let static_snapshot: rocksdb::SnapshotWithThreadMode<'static, rocksdb::DB> =
std::mem::transmute(snapshot);
let boxed = Box::new(static_snapshot);
Box::into_raw(boxed) as *const _
};
Self {
db,
snapshot_ptr,
cf_name,
_phantom: PhantomData,
}
}
fn snapshot(&self) -> &rocksdb::SnapshotWithThreadMode<'_, rocksdb::DB> {
unsafe { &*self.snapshot_ptr }
}
fn cf<'a>(&'a self) -> Result<Arc<BoundColumnFamily<'a>>> {
self.db
.cf_handle(&self.cf_name)
.ok_or_else(|| Error::Database(format!("Column family '{}' not found", self.cf_name)))
}
pub fn get(&self, key: &T::Key) -> Result<Option<T>> {
let key_bytes = key.to_bytes()?;
let cf = self.cf()?;
match self.snapshot().get_cf(&cf, key_bytes)? {
Some(value_bytes) => {
let value: T = helpers::deserialize(&value_bytes)?;
Ok(Some(value))
}
None => Ok(None),
}
}
pub fn exists(&self, key: &T::Key) -> Result<bool> {
Ok(self.get(key)?.is_some())
}
}
impl<T: Storable> Drop for Snapshot<T> {
fn drop(&mut self) {
unsafe {
let _ = Box::from_raw(
self.snapshot_ptr as *mut rocksdb::SnapshotWithThreadMode<'static, rocksdb::DB>,
);
}
}
}
enum IteratorMode {
Start,
From(Vec<u8>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IterationStatus {
Completed,
StoppedEarly,
}
pub struct Iterator<T: Storable> {
db: Arc<rocksdb::DB>,
cf_name: String,
mode: IteratorMode,
shutdown: Arc<RwLock<bool>>,
_phantom: PhantomData<T>,
}
impl<T: Storable> Iterator<T> {
fn new(
db: Arc<rocksdb::DB>,
cf_name: String,
mode: IteratorMode,
shutdown: Arc<RwLock<bool>>,
) -> Self {
Self {
db,
cf_name,
mode,
shutdown,
_phantom: PhantomData,
}
}
fn cf<'a>(&'a self) -> Result<Arc<BoundColumnFamily<'a>>> {
self.db
.cf_handle(&self.cf_name)
.ok_or_else(|| Error::Database(format!("Column family '{}' not found", self.cf_name)))
}
#[inline]
fn check_shutdown(&self) -> Result<parking_lot::RwLockReadGuard<'_, bool>> {
let guard = self.shutdown.read();
if *guard {
return Err(Error::Database("Database has been shut down".to_string()));
}
Ok(guard)
}
pub fn collect_all(&self) -> Result<Vec<T>> {
let _guard = self.check_shutdown()?;
let mut results = Vec::new();
let cf = self.cf()?;
let iter = match &self.mode {
IteratorMode::Start => self.db.iterator_cf(&cf, rocksdb::IteratorMode::Start),
IteratorMode::From(key) => self.db.iterator_cf(
&cf,
rocksdb::IteratorMode::From(key, rocksdb::Direction::Forward),
),
};
for item in iter {
let (_key, value_bytes) =
item.map_err(|e| Error::IteratorError(format!("Iterator error: {}", e)))?;
let value: T = helpers::deserialize(&value_bytes)?;
results.push(value);
}
Ok(results)
}
pub fn for_each<F>(&self, mut f: F) -> Result<IterationStatus>
where
F: FnMut(T) -> bool,
{
let _guard = self.check_shutdown()?;
let cf = self.cf()?;
let iter = match &self.mode {
IteratorMode::Start => self.db.iterator_cf(&cf, rocksdb::IteratorMode::Start),
IteratorMode::From(key) => self.db.iterator_cf(
&cf,
rocksdb::IteratorMode::From(key, rocksdb::Direction::Forward),
),
};
for item in iter {
let (_key, value_bytes) =
item.map_err(|e| Error::IteratorError(format!("Iterator error: {}", e)))?;
let value: T = helpers::deserialize(&value_bytes)?;
if !f(value) {
return Ok(IterationStatus::StoppedEarly);
}
}
Ok(IterationStatus::Completed)
}
pub fn count(&self) -> Result<usize> {
let _guard = self.check_shutdown()?;
let cf = self.cf()?;
let iter = match &self.mode {
IteratorMode::Start => self.db.iterator_cf(&cf, rocksdb::IteratorMode::Start),
IteratorMode::From(key) => self.db.iterator_cf(
&cf,
rocksdb::IteratorMode::From(key, rocksdb::Direction::Forward),
),
};
let mut count = 0;
for item in iter {
item.map_err(|e| Error::IteratorError(format!("Iterator error: {}", e)))?;
count += 1;
}
Ok(count)
}
}
pub struct Transaction {
db: Arc<rocksdb::DB>,
batch: Mutex<RocksWriteBatch>,
cache: Mutex<TransactionCache>,
shutdown: Arc<RwLock<bool>>,
}
struct TransactionCache {
data: HashMap<(String, Vec<u8>), Option<Vec<u8>>>,
operation_count: usize,
total_bytes: usize,
}
impl TransactionCache {
fn new() -> Self {
Self {
data: HashMap::new(),
operation_count: 0,
total_bytes: 0,
}
}
fn insert(&mut self, key: (String, Vec<u8>), value: Option<Vec<u8>>) -> Result<()> {
const MAX_OPERATIONS: usize = 100_000;
const MAX_BYTES: usize = 100 * 1024 * 1024; const HASHMAP_OVERHEAD: usize = 32;
if self.operation_count >= MAX_OPERATIONS {
return Err(Error::Database(format!(
"Transaction limit exceeded: maximum {} operations allowed",
MAX_OPERATIONS
)));
}
let entry_size = key.0.len()
+ key.1.len()
+ value.as_ref().map(|v| v.len()).unwrap_or(0)
+ HASHMAP_OVERHEAD;
let size_delta = if let Some(old_value) = self.data.get(&key) {
let old_size = key.0.len()
+ key.1.len()
+ old_value.as_ref().map(|v| v.len()).unwrap_or(0)
+ HASHMAP_OVERHEAD;
entry_size as i64 - old_size as i64
} else {
entry_size as i64
};
let new_total = (self.total_bytes as i64 + size_delta) as usize;
if new_total > MAX_BYTES {
return Err(Error::Database(format!(
"Transaction memory limit exceeded: maximum {}MB allowed",
MAX_BYTES / (1024 * 1024)
)));
}
let is_new_entry = !self.data.contains_key(&key);
if is_new_entry {
self.operation_count += 1;
}
self.total_bytes = new_total;
self.data.insert(key, value);
Ok(())
}
fn get(&self, key: &(String, Vec<u8>)) -> Option<&Option<Vec<u8>>> {
self.data.get(key)
}
fn clear(&mut self) {
self.data.clear();
self.operation_count = 0;
self.total_bytes = 0;
}
}
impl Transaction {
fn new(db: Arc<rocksdb::DB>, shutdown: Arc<RwLock<bool>>) -> Self {
Self {
db,
batch: Mutex::new(RocksWriteBatch::default()),
cache: Mutex::new(TransactionCache::new()),
shutdown,
}
}
#[inline]
fn check_shutdown(&self) -> Result<parking_lot::RwLockReadGuard<'_, bool>> {
let guard = self.shutdown.read();
if *guard {
return Err(Error::Database("Database has been shut down".to_string()));
}
Ok(guard)
}
#[instrument(skip(self))]
pub fn collection<'txn, T: Storable>(
&'txn self,
name: &str,
) -> Result<TransactionCollection<'txn, T>> {
let _guard = self.check_shutdown()?;
self.db.cf_handle(name).ok_or_else(|| {
error!("Column family '{}' not found", name);
Error::Database(format!("Column family '{}' not found", name))
})?;
debug!("Created transaction collection for '{}'", name);
Ok(TransactionCollection::new(
Arc::clone(&self.db),
name.to_string(),
&self.batch,
&self.cache,
))
}
#[instrument(skip(self))]
pub fn commit(self) -> Result<()> {
let guard = self.check_shutdown()?;
drop(guard);
let db = self.db;
let batch = self.batch.into_inner();
let op_count = batch.len();
info!("Committing transaction with {} operations", op_count);
db.write(batch).map_err(|e| {
error!("Failed to commit transaction: {}", e);
Error::Database(format!("Failed to commit transaction: {}", e))
})
}
#[instrument(skip(self))]
pub fn rollback(self) -> Result<()> {
let op_count = self.batch.lock().len();
warn!("Rolling back transaction with {} operations", op_count);
Ok(())
}
pub fn clear(&self) -> Result<()> {
self.batch.lock().clear();
self.cache.lock().clear();
Ok(())
}
pub fn len(&self) -> Result<usize> {
Ok(self.batch.lock().len())
}
pub fn is_empty(&self) -> Result<bool> {
Ok(self.batch.lock().is_empty())
}
}
unsafe impl Send for Transaction {}
unsafe impl Sync for Transaction {}
pub struct TransactionCollection<'txn, T: Storable> {
db: Arc<rocksdb::DB>,
cf_name: String,
batch: &'txn Mutex<RocksWriteBatch>,
cache: &'txn Mutex<TransactionCache>,
_phantom: PhantomData<T>,
}
impl<'txn, T: Storable> TransactionCollection<'txn, T> {
fn new(
db: Arc<rocksdb::DB>,
cf_name: String,
batch: &'txn Mutex<RocksWriteBatch>,
cache: &'txn Mutex<TransactionCache>,
) -> Self {
Self {
db,
cf_name,
batch,
cache,
_phantom: PhantomData,
}
}
fn cf<'a>(&'a self) -> Result<Arc<BoundColumnFamily<'a>>> {
self.db
.cf_handle(&self.cf_name)
.ok_or_else(|| Error::Database(format!("Column family '{}' not found", self.cf_name)))
}
#[instrument(skip(self, value))]
pub fn put(&self, value: &T) -> Result<()> {
value.validate()?;
let key = value.key();
let key_bytes = key.to_bytes()?;
let value_bytes = helpers::serialize(value)?;
debug!("Transaction put in collection '{}'", self.cf_name);
let mut batch = self.batch.lock();
let mut cache = self.cache.lock();
let cf = self.cf()?;
batch.put_cf(&cf, &key_bytes, &value_bytes);
cache.insert((self.cf_name.clone(), key_bytes), Some(value_bytes))?;
value.on_stored();
Ok(())
}
#[instrument(skip(self))]
pub fn get(&self, key: &T::Key) -> Result<Option<T>> {
let key_bytes = key.to_bytes()?;
let cache_key = (self.cf_name.clone(), key_bytes.clone());
let cached_value = self.cache.lock().get(&cache_key).cloned();
if let Some(cached) = cached_value {
debug!("Transaction cache hit for key in '{}'", self.cf_name);
return match cached {
Some(value_bytes) => {
let value: T = helpers::deserialize(&value_bytes)?;
Ok(Some(value))
}
None => Ok(None), };
}
let cf = self.cf()?;
match self.db.get_cf(&cf, key_bytes)? {
Some(value_bytes) => {
let value: T = helpers::deserialize(&value_bytes)?;
Ok(Some(value))
}
None => Ok(None),
}
}
#[instrument(skip(self))]
pub fn delete(&self, key: &T::Key) -> Result<()> {
let key_bytes = key.to_bytes()?;
debug!("Transaction delete in collection '{}'", self.cf_name);
let mut batch = self.batch.lock();
let mut cache = self.cache.lock();
let cf = self.cf()?;
batch.delete_cf(&cf, &key_bytes);
cache.insert((self.cf_name.clone(), key_bytes), None)?;
Ok(())
}
pub fn exists(&self, key: &T::Key) -> Result<bool> {
Ok(self.get(key)?.is_some())
}
#[instrument(skip(self, keys))]
pub fn get_many(&self, keys: &[T::Key]) -> Result<Vec<Option<T>>> {
if keys.is_empty() {
return Ok(Vec::new());
}
let key_bytes: Vec<Vec<u8>> = keys
.iter()
.map(|k| k.to_bytes())
.collect::<Result<Vec<Vec<u8>>>>()?;
let mut results: Vec<Option<T>> = (0..keys.len()).map(|_| None).collect();
let mut uncached_indices = Vec::new();
let mut uncached_keys = Vec::new();
{
let cache = self.cache.lock();
for (i, kb) in key_bytes.iter().enumerate() {
let cache_key = (self.cf_name.clone(), kb.clone());
if let Some(cached) = cache.get(&cache_key) {
results[i] = match cached {
Some(value_bytes) => Some(helpers::deserialize(value_bytes)?),
None => None, };
} else {
uncached_indices.push(i);
uncached_keys.push(kb.clone());
}
}
}
if !uncached_keys.is_empty() {
let cf = self.cf()?;
let cf_refs: Vec<_> = uncached_keys.iter().map(|k| (&cf, k.as_slice())).collect();
let db_results = self.db.multi_get_cf(cf_refs);
debug_assert_eq!(
db_results.len(),
uncached_keys.len(),
"RocksDB multi_get violated contract: got {} results but expected {}",
db_results.len(),
uncached_keys.len()
);
for (result_idx, db_result) in db_results.into_iter().enumerate() {
let original_idx = uncached_indices[result_idx];
results[original_idx] = match db_result {
Ok(Some(value_bytes)) => Some(helpers::deserialize(&value_bytes)?),
Ok(None) => None,
Err(e) => return Err(Error::Database(format!("Multi-get failed: {}", e))),
};
}
}
Ok(results)
}
}
unsafe impl<'txn, T: Storable> Send for TransactionCollection<'txn, T> {}
unsafe impl<'txn, T: Storable> Sync for TransactionCollection<'txn, T> {}
#[cfg(test)]
mod tests {
use borsh::{BorshDeserialize, BorshSerialize};
use super::*;
use crate::DatabaseConfig;
#[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)]
struct TestItem {
id: u64,
data: String,
}
impl Storable for TestItem {
type Key = u64;
fn key(&self) -> Self::Key {
self.id
}
}
fn create_test_db() -> Database {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
let path = std::env::temp_dir().join(format!("ngdb_test_{}", id));
let _ = std::fs::remove_dir_all(&path);
DatabaseConfig::new(&path)
.create_if_missing(true)
.add_column_family("test")
.open()
.expect("Failed to create test database")
}
#[test]
fn test_collection_put_and_get() {
let db = create_test_db();
let collection = db.collection::<TestItem>("test").unwrap();
let item = TestItem {
id: 1,
data: "test".to_string(),
};
collection.put(&item).unwrap();
let retrieved = collection.get(&1).unwrap();
assert_eq!(Some(item), retrieved);
}
#[test]
fn test_collection_delete() {
let db = create_test_db();
let collection = db.collection::<TestItem>("test").unwrap();
let item = TestItem {
id: 1,
data: "test".to_string(),
};
collection.put(&item).unwrap();
collection.delete(&1).unwrap();
assert_eq!(None, collection.get(&1).unwrap());
}
#[test]
fn test_batch() {
let db = create_test_db();
let collection = db.collection::<TestItem>("test").unwrap();
let mut batch = collection.batch();
for i in 0..10 {
batch
.put(&TestItem {
id: i,
data: format!("item_{}", i),
})
.unwrap();
}
batch.commit().unwrap();
for i in 0..10 {
let item = collection.get(&i).unwrap().unwrap();
assert_eq!(i, item.id);
}
}
#[test]
fn test_iterator() {
let db = create_test_db();
let collection = db.collection::<TestItem>("test").unwrap();
for i in 0..5 {
collection
.put(&TestItem {
id: i,
data: format!("item_{}", i),
})
.unwrap();
}
let items = collection.iter().unwrap().collect_all().unwrap();
assert_eq!(5, items.len());
}
#[test]
fn test_get_many() {
let db = create_test_db();
let collection = db.collection::<TestItem>("test").unwrap();
for i in 0..10 {
collection
.put(&TestItem {
id: i,
data: format!("item_{}", i),
})
.unwrap();
}
let keys = vec![1, 3, 5, 99]; let results = collection.get_many(&keys).unwrap();
assert_eq!(4, results.len());
assert!(results[0].is_some());
assert_eq!(1, results[0].as_ref().unwrap().id);
assert!(results[1].is_some());
assert_eq!(3, results[1].as_ref().unwrap().id);
assert!(results[2].is_some());
assert_eq!(5, results[2].as_ref().unwrap().id);
assert!(results[3].is_none());
}
#[test]
fn test_transaction() {
let db = create_test_db();
let txn = db.transaction().unwrap();
let collection = txn.collection::<TestItem>("test").unwrap();
collection
.put(&TestItem {
id: 1,
data: "test".to_string(),
})
.unwrap();
assert!(collection.get(&1).unwrap().is_some());
txn.commit().unwrap();
let regular_collection = db.collection::<TestItem>("test").unwrap();
assert!(regular_collection.get(&1).unwrap().is_some());
}
#[test]
fn test_transaction_get_many() {
let db = create_test_db();
let collection = db.collection::<TestItem>("test").unwrap();
collection
.put(&TestItem {
id: 1,
data: "one".to_string(),
})
.unwrap();
collection
.put(&TestItem {
id: 2,
data: "two".to_string(),
})
.unwrap();
collection
.put(&TestItem {
id: 5,
data: "five".to_string(),
})
.unwrap();
let txn = db.transaction().unwrap();
let txn_collection = txn.collection::<TestItem>("test").unwrap();
txn_collection
.put(&TestItem {
id: 3,
data: "three".to_string(),
})
.unwrap();
txn_collection
.put(&TestItem {
id: 4,
data: "four".to_string(),
})
.unwrap();
txn_collection.delete(&5).unwrap();
let keys = vec![1, 2, 3, 4, 5, 6];
let results = txn_collection.get_many(&keys).unwrap();
assert!(results[0].is_some()); assert_eq!(results[0].as_ref().unwrap().data, "one");
assert!(results[1].is_some()); assert_eq!(results[1].as_ref().unwrap().data, "two");
assert!(results[2].is_some()); assert_eq!(results[2].as_ref().unwrap().data, "three");
assert!(results[3].is_some()); assert_eq!(results[3].as_ref().unwrap().data, "four");
assert!(results[4].is_none()); assert!(results[5].is_none());
let committed_results = collection.get_many(&keys).unwrap();
assert!(committed_results[2].is_none()); assert!(committed_results[3].is_none()); assert!(committed_results[4].is_some());
txn.commit().unwrap();
let final_results = collection.get_many(&keys).unwrap();
assert!(final_results[2].is_some()); assert!(final_results[3].is_some()); assert!(final_results[4].is_none()); }
#[test]
fn test_transaction_limits() {
let db = create_test_db();
let txn = db.transaction().unwrap();
let collection = txn.collection::<TestItem>("test").unwrap();
for i in 0..100_001 {
let result = collection.put(&TestItem {
id: i,
data: format!("item_{}", i),
});
if i < 100_000 {
assert!(result.is_ok());
} else {
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("limit exceeded"));
break;
}
}
}
#[test]
fn test_shutdown_prevents_operations() {
let db = create_test_db();
let collection = db.collection::<TestItem>("test").unwrap();
let item = TestItem {
id: 1,
data: "test".to_string(),
};
assert!(collection.put(&item).is_ok());
db.shutdown().unwrap();
let item2 = TestItem {
id: 2,
data: "test2".to_string(),
};
let result = collection.put(&item2);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("shut down"));
let result = db.collection::<TestItem>("test");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("shut down"));
}
#[test]
fn test_iterator_checks_shutdown() {
let db = create_test_db();
let collection = db.collection::<TestItem>("test").unwrap();
for i in 0..5 {
collection
.put(&TestItem {
id: i,
data: format!("item_{}", i),
})
.unwrap();
}
let iter = collection.iter().unwrap();
db.shutdown().unwrap();
let result = iter.collect_all();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("shut down"));
}
#[test]
fn test_shutdown_lock_is_released() {
use std::sync::Arc;
use std::thread;
use std::time::Duration;
let db = Arc::new(create_test_db());
let collection = db.collection::<TestItem>("test").unwrap();
collection
.put(&TestItem {
id: 1,
data: "test".to_string(),
})
.unwrap();
let db_clone = Arc::clone(&db);
let shutdown_handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
db_clone.shutdown()
});
thread::sleep(Duration::from_millis(100));
let shutdown_result = shutdown_handle.join().unwrap();
assert!(shutdown_result.is_ok());
let result = collection.put(&TestItem {
id: 2,
data: "test2".to_string(),
});
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("shut down"));
let result = db.collection::<TestItem>("another");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("shut down"));
}
#[test]
fn test_list_collections_checks_shutdown() {
let db = create_test_db();
let collections = db.list_collections();
assert!(collections.is_ok());
db.shutdown().unwrap();
let result = db.list_collections();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("shut down"));
}
}