use std::io;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use super::{Page, PageType, Pager, PagerConfig};
use crate::storage::wal::{
CheckpointError, CheckpointMode, CheckpointResult, Checkpointer, Transaction,
TransactionManager, TxError,
};
#[derive(Debug, Clone)]
pub struct DatabaseConfig {
pub cache_size: usize,
pub read_only: bool,
pub create: bool,
pub checkpoint_mode: CheckpointMode,
pub auto_checkpoint_threshold: u32,
pub verify_checksums: bool,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
cache_size: 10_000,
read_only: false,
create: true,
checkpoint_mode: CheckpointMode::Full,
auto_checkpoint_threshold: 1000,
verify_checksums: true,
}
}
}
#[derive(Debug)]
pub enum DatabaseError {
Io(io::Error),
Pager(String),
LockPoisoned(&'static str),
Transaction(TxError),
Checkpoint(CheckpointError),
ReadOnly,
Closed,
}
impl std::fmt::Display for DatabaseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "I/O error: {}", e),
Self::Pager(msg) => write!(f, "Pager error: {}", msg),
Self::LockPoisoned(name) => write!(f, "Lock poisoned: {}", name),
Self::Transaction(e) => write!(f, "Transaction error: {}", e),
Self::Checkpoint(e) => write!(f, "Checkpoint error: {}", e),
Self::ReadOnly => write!(f, "Database is read-only"),
Self::Closed => write!(f, "Database is closed"),
}
}
}
impl std::error::Error for DatabaseError {}
impl From<io::Error> for DatabaseError {
fn from(e: io::Error) -> Self {
Self::Io(e)
}
}
impl From<TxError> for DatabaseError {
fn from(e: TxError) -> Self {
Self::Transaction(e)
}
}
impl From<CheckpointError> for DatabaseError {
fn from(e: CheckpointError) -> Self {
Self::Checkpoint(e)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DbState {
Open,
Closed,
}
pub struct Database {
path: PathBuf,
wal_path: PathBuf,
pager: Arc<Pager>,
tx_manager: Arc<TransactionManager>,
config: DatabaseConfig,
state: RwLock<DbState>,
pages_since_checkpoint: RwLock<u32>,
#[allow(dead_code)]
bgwriter: Option<crate::storage::cache::bgwriter::BgWriterHandle>,
}
impl Database {
fn state_read(&self) -> Result<RwLockReadGuard<'_, DbState>, DatabaseError> {
self.state
.read()
.map_err(|_| DatabaseError::LockPoisoned("database state"))
}
fn state_write(&self) -> Result<RwLockWriteGuard<'_, DbState>, DatabaseError> {
self.state
.write()
.map_err(|_| DatabaseError::LockPoisoned("database state"))
}
fn pages_since_checkpoint_read(&self) -> Result<RwLockReadGuard<'_, u32>, DatabaseError> {
self.pages_since_checkpoint
.read()
.map_err(|_| DatabaseError::LockPoisoned("pages since checkpoint"))
}
fn pages_since_checkpoint_write(&self) -> Result<RwLockWriteGuard<'_, u32>, DatabaseError> {
self.pages_since_checkpoint
.write()
.map_err(|_| DatabaseError::LockPoisoned("pages since checkpoint"))
}
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, DatabaseError> {
Self::open_with_config(path, DatabaseConfig::default())
}
pub fn bgwriter_stats(&self) -> Option<crate::storage::cache::bgwriter::BgWriterStatsSnapshot> {
self.bgwriter.as_ref().map(|h| h.stats.snapshot())
}
pub fn open_with_config<P: AsRef<Path>>(
path: P,
config: DatabaseConfig,
) -> Result<Self, DatabaseError> {
let path = path.as_ref().to_path_buf();
let wal_path = path.with_extension("rdb-wal");
let pager_config = PagerConfig {
cache_size: config.cache_size,
read_only: config.read_only,
create: config.create,
verify_checksums: config.verify_checksums,
double_write: true,
encryption: None,
};
let pager =
Pager::open(&path, pager_config).map_err(|e| DatabaseError::Pager(e.to_string()))?;
let pager = Arc::new(pager);
if wal_path.exists() && !config.read_only {
let recovery_result = Checkpointer::recover(&pager, &wal_path)?;
if recovery_result.pages_checkpointed > 0 {
tracing::info!(
transactions = recovery_result.transactions_processed,
pages = recovery_result.pages_checkpointed,
"WAL recovery applied"
);
}
}
let tx_manager = Arc::new(
TransactionManager::new(Arc::clone(&pager), &wal_path).map_err(DatabaseError::Io)?,
);
let bgwriter = if config.read_only
|| !matches!(
std::env::var("REDDB_BGWRITER").ok().as_deref(),
Some("1") | Some("true") | Some("on")
) {
None
} else {
let flusher = std::sync::Arc::new(
crate::storage::cache::bgwriter::PagerDirtyFlusher::new(Arc::downgrade(&pager)),
);
Some(crate::storage::cache::bgwriter::spawn(
flusher,
crate::storage::cache::bgwriter::BgWriterConfig::default(),
))
};
Ok(Self {
path,
wal_path,
pager,
tx_manager,
config,
state: RwLock::new(DbState::Open),
pages_since_checkpoint: RwLock::new(0),
bgwriter,
})
}
fn check_open(&self) -> Result<(), DatabaseError> {
if *self.state_read()? == DbState::Closed {
return Err(DatabaseError::Closed);
}
Ok(())
}
pub fn begin(&self) -> Result<Transaction, DatabaseError> {
self.check_open()?;
Ok(self.tx_manager.begin()?)
}
pub fn pager(&self) -> &Arc<Pager> {
&self.pager
}
pub fn tx_manager(&self) -> &Arc<TransactionManager> {
&self.tx_manager
}
pub fn allocate_page(&self, page_type: PageType) -> Result<Page, DatabaseError> {
self.check_open()?;
if self.config.read_only {
return Err(DatabaseError::ReadOnly);
}
self.pager
.allocate_page(page_type)
.map_err(|e| DatabaseError::Pager(e.to_string()))
}
pub fn read_page(&self, page_id: u32) -> Result<Page, DatabaseError> {
self.check_open()?;
self.pager
.read_page(page_id)
.map_err(|e| DatabaseError::Pager(e.to_string()))
}
pub fn checkpoint(&self) -> Result<CheckpointResult, DatabaseError> {
self.check_open()?;
if self.config.read_only {
return Err(DatabaseError::ReadOnly);
}
let checkpointer = Checkpointer::new(self.config.checkpoint_mode);
let result = checkpointer.checkpoint(&self.pager, &self.wal_path)?;
*self.pages_since_checkpoint_write()? = 0;
Ok(result)
}
pub fn maybe_auto_checkpoint(&self) -> Result<Option<CheckpointResult>, DatabaseError> {
if self.config.auto_checkpoint_threshold == 0 {
return Ok(None);
}
let pages = *self.pages_since_checkpoint_read()?;
if pages >= self.config.auto_checkpoint_threshold {
Ok(Some(self.checkpoint()?))
} else {
Ok(None)
}
}
pub fn increment_page_count(&self, count: u32) {
let mut pages = self
.pages_since_checkpoint
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*pages = pages.saturating_add(count);
}
pub fn sync(&self) -> Result<(), DatabaseError> {
self.check_open()?;
self.pager
.sync()
.map_err(|e| DatabaseError::Pager(e.to_string()))?;
self.tx_manager.sync_wal()?;
Ok(())
}
pub fn close(self) -> Result<(), DatabaseError> {
*self.state_write()? = DbState::Closed;
if self.tx_manager.has_active_transactions() {
tracing::warn!("closing database with active transactions");
}
if !self.config.read_only {
let checkpointer = Checkpointer::new(CheckpointMode::Truncate);
let _ = checkpointer.checkpoint(&self.pager, &self.wal_path);
}
let _ = self.pager.sync();
Ok(())
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn wal_path(&self) -> &Path {
&self.wal_path
}
pub fn is_read_only(&self) -> bool {
self.config.read_only
}
pub fn page_count(&self) -> u32 {
self.pager.page_count().unwrap_or(0)
}
pub fn file_size(&self) -> Result<u64, DatabaseError> {
self.pager
.file_size()
.map_err(|e| DatabaseError::Pager(e.to_string()))
}
pub fn cache_stats(&self) -> super::page_cache::CacheStats {
self.pager.cache_stats()
}
}
impl Drop for Database {
fn drop(&mut self) {
let state = self
.state
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if *state == DbState::Open {
drop(state);
let mut state = self
.state
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*state = DbState::Closed;
drop(state);
if !self.config.read_only {
let checkpointer = Checkpointer::new(CheckpointMode::Full);
let _ = checkpointer.checkpoint(&self.pager, &self.wal_path);
}
let _ = self.pager.sync();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_db_path() -> PathBuf {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
std::env::temp_dir().join(format!("reddb_engine_test_{}.rdb", timestamp))
}
fn cleanup(path: &Path) {
let _ = fs::remove_file(path);
let wal_path = path.with_extension("rdb-wal");
let _ = fs::remove_file(wal_path);
}
#[test]
fn test_database_open_create() {
let path = temp_db_path();
cleanup(&path);
{
let db = Database::open(&path).unwrap();
assert!(!db.is_read_only());
assert_eq!(db.page_count(), 3); }
{
let db = Database::open(&path).unwrap();
assert_eq!(db.page_count(), 3);
}
cleanup(&path);
}
#[test]
fn test_database_transaction() {
let path = temp_db_path();
cleanup(&path);
{
let db = Database::open(&path).unwrap();
let page = db.allocate_page(PageType::BTreeLeaf).unwrap();
let page_id = page.page_id();
let mut tx = db.begin().unwrap();
let mut page = Page::new(PageType::BTreeLeaf, page_id);
page.as_bytes_mut()[100] = 0xAB;
tx.write_page(page_id, page).unwrap();
tx.commit().unwrap();
let read_page = db.read_page(page_id).unwrap();
assert_eq!(read_page.as_bytes()[100], 0xAB);
}
cleanup(&path);
}
#[test]
fn test_database_crash_recovery() {
let path = temp_db_path();
cleanup(&path);
let page_id;
{
let db = Database::open(&path).unwrap();
let page = db.allocate_page(PageType::BTreeLeaf).unwrap();
page_id = page.page_id();
let mut tx = db.begin().unwrap();
let mut page = Page::new(PageType::BTreeLeaf, page_id);
page.as_bytes_mut()[100] = 0xCD;
tx.write_page(page_id, page).unwrap();
tx.commit().unwrap();
db.sync().unwrap();
}
{
let db = Database::open(&path).unwrap();
let read_page = db.read_page(page_id).unwrap();
assert_eq!(read_page.as_bytes()[100], 0xCD);
}
cleanup(&path);
}
#[test]
fn test_database_checkpoint() {
let path = temp_db_path();
cleanup(&path);
{
let db = Database::open(&path).unwrap();
let page1 = db.allocate_page(PageType::BTreeLeaf).unwrap();
let page2 = db.allocate_page(PageType::BTreeLeaf).unwrap();
let mut tx1 = db.begin().unwrap();
let mut p1 = Page::new(PageType::BTreeLeaf, page1.page_id());
p1.as_bytes_mut()[100] = 0x11;
tx1.write_page(page1.page_id(), p1).unwrap();
tx1.commit().unwrap();
let mut tx2 = db.begin().unwrap();
let mut p2 = Page::new(PageType::BTreeLeaf, page2.page_id());
p2.as_bytes_mut()[100] = 0x22;
tx2.write_page(page2.page_id(), p2).unwrap();
tx2.commit().unwrap();
let result = db.checkpoint().unwrap();
assert_eq!(result.transactions_processed, 2);
assert!(result.pages_checkpointed >= 2);
db.close().unwrap();
}
{
let db = Database::open(&path).unwrap();
assert!(db.page_count() >= 3); }
cleanup(&path);
}
#[test]
fn test_database_read_only() {
let path = temp_db_path();
cleanup(&path);
{
let db = Database::open(&path).unwrap();
let page = db.allocate_page(PageType::BTreeLeaf).unwrap();
db.close().unwrap();
}
{
let config = DatabaseConfig {
read_only: true,
..Default::default()
};
let db = Database::open_with_config(&path, config).unwrap();
assert!(db.is_read_only());
assert!(db.allocate_page(PageType::BTreeLeaf).is_err());
}
cleanup(&path);
}
#[test]
fn test_database_multiple_transactions() {
let path = temp_db_path();
cleanup(&path);
{
let db = Database::open(&path).unwrap();
let page1 = db.allocate_page(PageType::BTreeLeaf).unwrap();
let page2 = db.allocate_page(PageType::BTreeLeaf).unwrap();
let mut tx1 = db.begin().unwrap();
let mut tx2 = db.begin().unwrap();
let mut p1 = Page::new(PageType::BTreeLeaf, page1.page_id());
p1.as_bytes_mut()[100] = 0x11;
tx1.write_page(page1.page_id(), p1).unwrap();
let mut p2 = Page::new(PageType::BTreeLeaf, page2.page_id());
p2.as_bytes_mut()[100] = 0x22;
tx2.write_page(page2.page_id(), p2).unwrap();
tx1.commit().unwrap();
tx2.commit().unwrap();
let r1 = db.read_page(page1.page_id()).unwrap();
let r2 = db.read_page(page2.page_id()).unwrap();
assert_eq!(r1.as_bytes()[100], 0x11);
assert_eq!(r2.as_bytes()[100], 0x22);
}
cleanup(&path);
}
#[test]
fn test_begin_returns_structured_error_when_state_lock_is_poisoned() {
let path = temp_db_path();
cleanup(&path);
{
let db = Arc::new(Database::open(&path).unwrap());
let poison_target = Arc::clone(&db);
let _ = std::thread::spawn(move || {
let _guard = poison_target
.state
.write()
.expect("state lock should be acquired");
panic!("poison database state");
})
.join();
match db.begin() {
Ok(_) => panic!("begin should fail after state lock poisoning"),
Err(err) => {
assert!(matches!(err, DatabaseError::LockPoisoned("database state")))
}
}
}
cleanup(&path);
}
}