use std::sync::Arc;
use crate::error::Result;
use crate::oracle::Oracle;
use crate::store::{MemoryStore, VersionStore, WriteEntry};
use crate::timestamp::Timestamp;
use crate::txn::{Snapshot, Transaction};
pub(crate) struct Inner<S: VersionStore> {
pub(crate) store: S,
oracle: Oracle,
#[cfg(feature = "durability")]
log: Option<crate::durable::CommitLog>,
}
impl<S: VersionStore> Inner<S> {
fn new(store: S) -> Self {
Inner {
store,
oracle: Oracle::new(),
#[cfg(feature = "durability")]
log: None,
}
}
#[inline]
fn read_ts(&self) -> Timestamp {
self.oracle.read_ts()
}
#[inline]
pub(crate) fn begin_reader(&self) -> Timestamp {
self.oracle.begin_reader()
}
#[inline]
pub(crate) fn end_reader(&self, read_ts: Timestamp) {
self.oracle.end_reader(read_ts);
}
fn collect_garbage(&self) -> usize {
self.store.collect_garbage(self.oracle.low_watermark())
}
pub(crate) fn commit_writes(
&self,
read_ts: Timestamp,
writes: Vec<WriteEntry>,
reads: &[Arc<[u8]>],
) -> Result<Timestamp> {
let commit_ts = self.oracle.alloc_commit_ts();
#[cfg(feature = "durability")]
let record = self
.log
.as_ref()
.map(|_| crate::durable::encode_for_log(commit_ts, &writes));
let outcome = self.store.try_commit(read_ts, commit_ts, writes, reads);
#[cfg(feature = "durability")]
if outcome.is_ok() {
if let (Some(log), Some(record)) = (self.log.as_ref(), record) {
if let Err(err) = log.append_committed(&record) {
self.oracle.commit_done(commit_ts);
return Err(err);
}
}
}
self.oracle.commit_done(commit_ts);
outcome.map(|()| commit_ts)
}
#[cfg(feature = "durability")]
fn recovered(store: S, oracle: Oracle, log: crate::durable::CommitLog) -> Self {
Inner {
store,
oracle,
log: Some(log),
}
}
}
pub struct Db<S: VersionStore = MemoryStore> {
inner: Arc<Inner<S>>,
}
impl Db<MemoryStore> {
#[must_use]
pub fn new() -> Self {
Db::with_store(MemoryStore::new())
}
#[cfg(feature = "durability")]
#[cfg_attr(docsrs, doc(cfg(feature = "durability")))]
pub fn open(path: impl AsRef<std::path::Path>) -> Result<Db<MemoryStore>> {
let (log, mut recovered) = crate::durable::CommitLog::open(path)?;
recovered.sort_by_key(|commit| commit.commit_ts);
let store = MemoryStore::new();
let mut highest = Timestamp::ZERO;
for commit in recovered {
highest = highest.max(commit.commit_ts);
store.install_recovered(commit.commit_ts, commit.writes);
}
Ok(Db {
inner: Arc::new(Inner::recovered(store, Oracle::recovered(highest), log)),
})
}
}
impl Default for Db<MemoryStore> {
fn default() -> Self {
Db::new()
}
}
impl<S: VersionStore> Db<S> {
#[must_use]
pub fn with_store(store: S) -> Self {
Db {
inner: Arc::new(Inner::new(store)),
}
}
pub fn begin(&self) -> Transaction<S> {
Transaction::new(Arc::clone(&self.inner), false)
}
#[cfg(feature = "serializable")]
#[cfg_attr(docsrs, doc(cfg(feature = "serializable")))]
pub fn begin_serializable(&self) -> Transaction<S> {
Transaction::new(Arc::clone(&self.inner), true)
}
pub fn snapshot(&self) -> Snapshot<S> {
Snapshot::new(Arc::clone(&self.inner))
}
#[must_use]
pub fn last_committed(&self) -> Timestamp {
self.inner.read_ts()
}
pub fn collect_garbage(&self) -> usize {
self.inner.collect_garbage()
}
}
impl<S: VersionStore> Clone for Db<S> {
fn clone(&self) -> Self {
Db {
inner: Arc::clone(&self.inner),
}
}
}
#[cfg(all(test, not(loom)))]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_new_database_is_empty_at_zero() {
let db = Db::new();
assert_eq!(db.last_committed(), Timestamp::ZERO);
assert_eq!(db.begin().get(b"k").unwrap(), None);
}
#[test]
fn test_commit_makes_writes_visible_to_later_transactions() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v".to_vec());
let ts = tx.commit().unwrap();
assert!(ts > Timestamp::ZERO);
assert_eq!(db.begin().get(b"k").unwrap().as_deref(), Some(&b"v"[..]));
}
#[test]
fn test_snapshot_is_isolated_from_later_commits() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v1".to_vec());
let _ = tx.commit().unwrap();
let snap = db.snapshot();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v2".to_vec());
let _ = tx.commit().unwrap();
assert_eq!(snap.get(b"k").unwrap().as_deref(), Some(&b"v1"[..]));
}
#[test]
fn test_write_write_conflict_aborts_later_committer() {
let db = Db::new();
let mut a = db.begin();
let mut b = db.begin();
a.put(b"k".to_vec(), b"a".to_vec());
b.put(b"k".to_vec(), b"b".to_vec());
assert!(a.commit().is_ok());
let err = b.commit().expect_err("second committer must lose");
assert!(err.is_retryable());
assert_eq!(db.begin().get(b"k").unwrap().as_deref(), Some(&b"a"[..]));
}
#[test]
fn test_disjoint_keys_do_not_conflict() {
let db = Db::new();
let mut a = db.begin();
let mut b = db.begin();
a.put(b"a".to_vec(), b"1".to_vec());
b.put(b"b".to_vec(), b"2".to_vec());
assert!(a.commit().is_ok());
assert!(b.commit().is_ok());
}
#[test]
fn test_read_only_commit_returns_snapshot_timestamp() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v".to_vec());
let ts = tx.commit().unwrap();
let ro = db.begin();
assert_eq!(ro.commit().unwrap(), ts);
}
#[test]
fn test_rollback_discards_writes() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v".to_vec());
tx.rollback();
assert_eq!(db.begin().get(b"k").unwrap(), None);
}
#[test]
fn test_gc_reclaims_when_no_reader_is_held() {
let db = Db::new();
for v in 0..5u8 {
let mut tx = db.begin();
tx.put(b"k".to_vec(), vec![v]);
let _ = tx.commit().unwrap();
}
let reclaimed = db.collect_garbage();
assert!(reclaimed > 0);
assert_eq!(db.begin().get(b"k").unwrap().as_deref(), Some(&[4u8][..]));
}
#[test]
fn test_held_snapshot_pins_gc() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), vec![1]);
let _ = tx.commit().unwrap();
let snap = db.snapshot();
let mut tx = db.begin();
tx.put(b"k".to_vec(), vec![2]);
let _ = tx.commit().unwrap();
let _ = db.collect_garbage();
assert_eq!(snap.get(b"k").unwrap().as_deref(), Some(&[1u8][..]));
drop(snap);
let reclaimed = db.collect_garbage();
assert!(reclaimed > 0);
assert_eq!(db.begin().get(b"k").unwrap().as_deref(), Some(&[2u8][..]));
}
#[test]
fn test_clone_shares_state() {
let db = Db::new();
let db2 = db.clone();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v".to_vec());
let _ = tx.commit().unwrap();
assert_eq!(db2.begin().get(b"k").unwrap().as_deref(), Some(&b"v"[..]));
}
#[cfg(feature = "serializable")]
#[test]
fn test_serializable_rejects_write_skew() {
let db = Db::new();
let mut seed = db.begin();
seed.put(b"x".to_vec(), vec![1]);
seed.put(b"y".to_vec(), vec![1]);
let _ = seed.commit().unwrap();
let mut t1 = db.begin_serializable();
let mut t2 = db.begin_serializable();
let _ = t1.get(b"x").unwrap();
let _ = t1.get(b"y").unwrap();
let _ = t2.get(b"x").unwrap();
let _ = t2.get(b"y").unwrap();
t1.put(b"x".to_vec(), vec![0]);
t2.put(b"y".to_vec(), vec![0]);
assert!(t1.commit().is_ok());
let err = t2.commit().expect_err("write skew must be rejected");
assert!(err.is_retryable());
}
#[cfg(feature = "serializable")]
#[test]
fn test_snapshot_txn_allows_write_skew() {
let db = Db::new();
let mut seed = db.begin();
seed.put(b"x".to_vec(), vec![1]);
seed.put(b"y".to_vec(), vec![1]);
let _ = seed.commit().unwrap();
let mut t1 = db.begin();
let mut t2 = db.begin();
let _ = t1.get(b"x").unwrap();
let _ = t1.get(b"y").unwrap();
let _ = t2.get(b"x").unwrap();
let _ = t2.get(b"y").unwrap();
t1.put(b"x".to_vec(), vec![0]);
t2.put(b"y".to_vec(), vec![0]);
assert!(t1.commit().is_ok());
assert!(t2.commit().is_ok());
}
}