#![forbid(unsafe_code)]
use core::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::Duration;
use crate::error::{Error, LockKind, Result};
use crate::pager::page::{Page, PageId};
use crate::pager::{HeaderSnapshot, Pager, ReaderSnapshot};
use crate::platform::{FileBackend, FileHandle, ReaderLock, WriterLock};
pub const DEFAULT_BUSY_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug)]
pub struct TxnEnv<F: FileBackend = FileHandle> {
pager: Arc<Mutex<Pager<F>>>,
write_serialization: Arc<AtomicBool>,
lock_file: Option<Arc<FileHandle>>,
}
impl<F: FileBackend> TxnEnv<F> {
#[must_use]
pub fn new(pager: Pager<F>, lock_file: Option<Arc<FileHandle>>) -> Self {
Self {
pager: Arc::new(Mutex::new(pager)),
write_serialization: Arc::new(AtomicBool::new(false)),
lock_file,
}
}
#[must_use]
pub fn pager(&self) -> &Arc<Mutex<Pager<F>>> {
&self.pager
}
}
#[derive(Debug)]
#[must_use = "the gate is released when the WriteSerialGuard drops"]
pub struct WriteSerialGuard {
gate: Arc<AtomicBool>,
}
impl Drop for WriteSerialGuard {
fn drop(&mut self) {
self.gate.store(false, Ordering::Release);
}
}
#[derive(Debug)]
#[must_use = "a WriteAcquire holds the write locks until consumed by WriteTxn::from_acquire"]
pub struct WriteAcquire<F: FileBackend> {
write_guard: WriteSerialGuard,
writer_lock: Option<WriterLock>,
_backend: core::marker::PhantomData<fn() -> F>,
}
#[derive(Debug)]
pub struct WriteTxn<'db, F: FileBackend> {
env: &'db TxnEnv<F>,
write_guard: Option<WriteSerialGuard>,
writer_lock: Option<WriterLock>,
header_at_begin: Option<HeaderSnapshot>,
finished: bool,
}
impl<'db, F: FileBackend> WriteTxn<'db, F> {
pub fn begin(env: &'db TxnEnv<F>, timeout: Duration) -> Result<Self> {
Self::from_acquire(env, Self::acquire(env, timeout)?)
}
pub fn acquire(env: &'db TxnEnv<F>, timeout: Duration) -> Result<WriteAcquire<F>> {
let write_guard = acquire_write_serialization(&env.write_serialization, timeout)?;
let writer_lock = match env.lock_file.as_ref() {
Some(handle) => match handle.lock_writer(timeout) {
Ok(g) => Some(g),
Err(e) => {
drop(write_guard);
return Err(e);
}
},
None => None,
};
Ok(WriteAcquire {
write_guard,
writer_lock,
_backend: core::marker::PhantomData,
})
}
pub fn from_acquire(env: &'db TxnEnv<F>, acq: WriteAcquire<F>) -> Result<Self> {
let header_at_begin = {
let mut pager = env.pager.lock().map_err(|_| Error::Busy {
kind: LockKind::WriterInProcess,
})?;
pager.begin_txn();
pager.header_snapshot()
};
Ok(Self {
env,
write_guard: Some(acq.write_guard),
writer_lock: acq.writer_lock,
header_at_begin: Some(header_at_begin),
finished: false,
})
}
pub fn write_page(&self, id: PageId, page: &Page) -> Result<()> {
let mut pager = self.lock_pager()?;
pager.write_page(id, page)
}
pub fn read_page(&self, id: PageId) -> Result<Page> {
let mut pager = self.lock_pager()?;
let page_ref = pager.read_page(id)?;
Ok(page_ref.to_owned_page())
}
pub fn alloc_page(&self) -> Result<PageId> {
let mut pager = self.lock_pager()?;
pager.alloc_page()
}
pub fn lock_pager(&self) -> Result<MutexGuard<'_, Pager<F>>> {
self.env.pager.lock().map_err(|_| Error::Busy {
kind: LockKind::WriterInProcess,
})
}
#[must_use]
pub fn env(&self) -> &'db TxnEnv<F> {
self.env
}
pub fn commit(mut self) -> Result<()> {
{
let mut pager = self.lock_pager()?;
let _lsn = pager.commit()?;
pager.end_txn();
}
self.finished = true;
self.write_guard.take();
self.header_at_begin.take();
if let Some(lock) = self.writer_lock.take() {
lock.release()?;
}
Ok(())
}
pub fn rollback(mut self) -> Result<()> {
let snap = self.header_at_begin.take();
{
let mut pager = self.lock_pager()?;
rollback_pending(&mut pager);
if let Some(s) = snap {
pager.restore_header_snapshot(s)?;
}
pager.end_txn();
}
self.finished = true;
self.write_guard.take();
if let Some(lock) = self.writer_lock.take() {
lock.release()?;
}
Ok(())
}
}
impl<F: FileBackend> Drop for WriteTxn<'_, F> {
fn drop(&mut self) {
if self.finished {
return;
}
let snap = self.header_at_begin.take();
if let Ok(mut pager) = self.env.pager.lock() {
rollback_pending(&mut pager);
if let Some(s) = snap {
let _ = pager.restore_header_snapshot(s);
}
pager.end_txn();
}
#[cfg(feature = "tracing")]
tracing::debug!("WriteTxn dropped without commit/rollback; pending writes discarded");
}
}
fn rollback_pending<F: FileBackend>(pager: &mut Pager<F>) {
pager.rollback_pending_writes();
}
fn acquire_write_serialization(
gate: &Arc<AtomicBool>,
timeout: Duration,
) -> Result<WriteSerialGuard> {
let start = std::time::Instant::now();
let mut backoff = Duration::from_millis(1);
let max_backoff = Duration::from_millis(100);
let timeout_millis = u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX);
let mut iters: u64 = 0;
let max_iters = timeout_millis.saturating_add(64);
loop {
iters = iters.saturating_add(1);
if iters > max_iters {
return Err(Error::Busy {
kind: LockKind::WriterInProcess,
});
}
if gate
.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return Ok(WriteSerialGuard {
gate: Arc::clone(gate),
});
}
if start.elapsed() >= timeout {
return Err(Error::Busy {
kind: LockKind::WriterInProcess,
});
}
std::thread::sleep(backoff);
backoff = (backoff * 2).min(max_backoff);
}
}
#[derive(Debug)]
pub struct ReadTxn<'db, F: FileBackend> {
env: &'db TxnEnv<F>,
snapshot: ReaderSnapshot<F>,
_reader_lock: Option<ReaderLock>,
}
impl<'db, F: FileBackend> ReadTxn<'db, F> {
pub fn begin(env: &'db TxnEnv<F>) -> Result<Self> {
Self::begin_with_timeout(env, DEFAULT_BUSY_TIMEOUT)
}
pub fn begin_with_timeout(env: &'db TxnEnv<F>, timeout: Duration) -> Result<Self> {
let reader_lock = match env.lock_file.as_ref() {
Some(handle) => Some(handle.lock_reader(timeout)?),
None => None,
};
let snapshot = {
let mut pager = env.pager.lock().map_err(|_| Error::Busy {
kind: LockKind::WriterInProcess,
})?;
pager.reader_snapshot()?
};
Ok(Self {
env,
snapshot,
_reader_lock: reader_lock,
})
}
#[must_use]
pub fn pinned_lsn(&self) -> crate::wal::Lsn {
self.snapshot.pinned_lsn()
}
pub fn read_page(&self, id: PageId) -> Result<Page> {
let pager = self.env.pager.lock().map_err(|_| Error::Busy {
kind: LockKind::WriterInProcess,
})?;
Ok(self.snapshot.read_page(&pager, id)?.into_page())
}
#[must_use]
pub fn snapshot(&self) -> &ReaderSnapshot<F> {
&self.snapshot
}
#[must_use]
pub fn env(&self) -> &TxnEnv<F> {
self.env
}
pub fn end(self) {
drop(self);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pager::page::Page;
use crate::pager::Config;
use crate::platform::FileHandle;
use std::sync::Arc;
use std::thread;
use tempfile::TempDir;
fn build_env(dir: &TempDir) -> (TxnEnv<FileHandle>, PageId) {
let path = dir.path().join("txn.obj");
let mut pager = Pager::open(&path, Config::default()).expect("pager");
pager.begin_txn();
let a = pager.alloc_page().expect("alloc");
let mut page = Page::zeroed();
page.as_bytes_mut()[0] = 0;
pager.write_page(a, &page).expect("write");
let _ = pager.commit().expect("commit");
pager.end_txn();
let lock_path = crate::pager::lock_path_for(&path);
let lock_file = Arc::new(FileHandle::open_or_create(&lock_path).expect("lock file"));
lock_file.set_len(128).expect("lock sidecar len");
(TxnEnv::new(pager, Some(lock_file)), a)
}
#[test]
fn write_txn_commit_makes_writes_visible() {
let dir = TempDir::new().expect("tmp");
let (env, a) = build_env(&dir);
let tx = WriteTxn::begin(&env, Duration::from_millis(50)).expect("begin");
let mut page = Page::zeroed();
page.as_bytes_mut()[0] = 0x77;
tx.write_page(a, &page).expect("write");
tx.commit().expect("commit");
let rx = ReadTxn::begin(&env).expect("read");
let observed = rx.read_page(a).expect("read");
assert_eq!(observed.as_bytes()[0], 0x77);
}
#[test]
fn write_txn_rollback_drops_writes() {
let dir = TempDir::new().expect("tmp");
let (env, a) = build_env(&dir);
let tx = WriteTxn::begin(&env, Duration::from_millis(50)).expect("begin");
let mut page = Page::zeroed();
page.as_bytes_mut()[0] = 0x99;
tx.write_page(a, &page).expect("write");
tx.rollback().expect("rollback");
let rx = ReadTxn::begin(&env).expect("read");
let observed = rx.read_page(a).expect("read");
assert_eq!(observed.as_bytes()[0], 0, "rollback must discard writes");
}
#[test]
fn in_process_writers_serialize() {
let dir = TempDir::new().expect("tmp");
let (env, _a) = build_env(&dir);
let tx1 = WriteTxn::begin(&env, Duration::from_millis(50)).expect("tx1");
let err = WriteTxn::begin(&env, Duration::from_millis(10)).expect_err("tx2 busy");
assert!(matches!(
err,
Error::Busy {
kind: LockKind::WriterInProcess
}
));
tx1.commit().expect("commit");
let _tx3 = WriteTxn::begin(&env, Duration::from_millis(50)).expect("tx3");
}
#[test]
fn write_txn_is_send() {
fn assert_send<T: Send>() {}
assert_send::<WriteTxn<'_, FileHandle>>();
assert_send::<WriteAcquire<FileHandle>>();
assert_send::<WriteSerialGuard>();
assert_send::<ReadTxn<'_, FileHandle>>();
}
#[test]
fn panic_in_writer_releases_gate_and_rolls_back() {
let dir = TempDir::new().expect("tmp");
let (env, a) = build_env(&dir);
let env = Arc::new(env);
let env_for_panic = Arc::clone(&env);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let tx = WriteTxn::begin(&env_for_panic, Duration::from_millis(50)).expect("begin");
let mut page = Page::zeroed();
page.as_bytes_mut()[0] = 0xEE;
tx.write_page(a, &page).expect("write");
panic!("simulated mid-write crash");
}));
assert!(result.is_err(), "the closure must have panicked");
let tx2 = WriteTxn::begin(&env, Duration::from_millis(50))
.expect("gate must be free after a panicking writer");
tx2.commit().expect("commit");
let rx = ReadTxn::begin(&env).expect("read");
let observed = rx.read_page(a).expect("read");
assert_eq!(
observed.as_bytes()[0],
0,
"panicking writer's staged write must have been rolled back",
);
}
#[test]
fn n_writers_serialize_with_no_deadlock() {
let dir = TempDir::new().expect("tmp");
let path = dir.path().join("stress.obj");
let mut pager = Pager::open(&path, Config::default()).expect("pager");
pager.begin_txn();
let a = pager.alloc_page().expect("alloc");
let mut page = Page::zeroed();
page.as_bytes_mut()[0] = 0;
pager.write_page(a, &page).expect("write");
let _ = pager.commit().expect("commit");
pager.end_txn();
let lock_path = crate::pager::lock_path_for(&path);
let lock_file = Arc::new(FileHandle::open_or_create(&lock_path).expect("lock"));
lock_file.set_len(128).expect("lock sidecar len");
let env = Arc::new(TxnEnv::new(pager, Some(lock_file)));
let n_writers = 4usize;
let iters_per_writer = 250u32;
thread::scope(|scope| {
let mut handles = Vec::with_capacity(n_writers);
for w in 0..n_writers {
let env = Arc::clone(&env);
handles.push(scope.spawn(move || {
for i in 0..iters_per_writer {
let tx = WriteTxn::begin(&env, Duration::from_secs(30))
.expect("begin under load");
let mut p = Page::zeroed();
p.as_bytes_mut()[0] =
u8::try_from((w * 1000 + i as usize) % 250 + 1).expect("byte fits");
tx.write_page(a, &p).expect("write");
tx.commit().expect("commit");
}
}));
}
for h in handles {
h.join().expect("join");
}
});
let rx = ReadTxn::begin(&env).expect("read");
let p = rx.read_page(a).expect("read");
assert_ne!(p.as_bytes()[0], 0, "some writer's value must be visible");
}
#[test]
fn drop_without_commit_warns_and_rolls_back() {
let dir = TempDir::new().expect("tmp");
let (env, a) = build_env(&dir);
{
let tx = WriteTxn::begin(&env, Duration::from_millis(50)).expect("begin");
let mut page = Page::zeroed();
page.as_bytes_mut()[0] = 0xAB;
tx.write_page(a, &page).expect("write");
}
let rx = ReadTxn::begin(&env).expect("read");
let observed = rx.read_page(a).expect("read");
assert_eq!(
observed.as_bytes()[0],
0,
"drop-without-commit must roll back",
);
}
#[test]
fn read_txn_sees_consistent_snapshot() {
let dir = TempDir::new().expect("tmp");
let (env, a) = build_env(&dir);
let rx = ReadTxn::begin(&env).expect("read");
{
let tx = WriteTxn::begin(&env, Duration::from_millis(50)).expect("write");
let mut p = Page::zeroed();
p.as_bytes_mut()[0] = 0x55;
tx.write_page(a, &p).expect("write");
tx.commit().expect("commit");
}
let observed = rx.read_page(a).expect("read");
assert_eq!(
observed.as_bytes()[0],
0,
"snapshot must isolate reader from concurrent commits",
);
}
}