use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, PoisonError};
use crate::error::{Result, TxnError};
use crate::store::{MemoryStore, VersionStore, WriteEntry};
use crate::timestamp::Timestamp;
use crate::txn::{Snapshot, Transaction};
pub(crate) struct Inner<S: VersionStore> {
pub(crate) store: S,
next_ts: AtomicU64,
last_committed: AtomicU64,
commit_lock: Mutex<()>,
}
impl<S: VersionStore> Inner<S> {
fn new(store: S) -> Self {
Inner {
store,
next_ts: AtomicU64::new(1),
last_committed: AtomicU64::new(Timestamp::ZERO.get()),
commit_lock: Mutex::new(()),
}
}
#[inline]
fn read_ts(&self) -> Timestamp {
Timestamp::from_raw(self.last_committed.load(Ordering::Acquire))
}
pub(crate) fn commit_writes(
&self,
read_ts: Timestamp,
writes: std::collections::HashMap<Arc<[u8]>, Option<Arc<[u8]>>>,
) -> Result<Timestamp> {
let _guard = self
.commit_lock
.lock()
.unwrap_or_else(PoisonError::into_inner);
for key in writes.keys() {
if let Some(latest) = self.store.latest_commit_ts(key)? {
if latest > read_ts {
return Err(TxnError::conflict(key.len()));
}
}
}
let commit_ts = Timestamp::from_raw(self.next_ts.fetch_add(1, Ordering::Relaxed));
let batch: Vec<WriteEntry> = writes.into_iter().collect();
self.store.apply(commit_ts, batch)?;
self.last_committed
.store(commit_ts.get(), Ordering::Release);
Ok(commit_ts)
}
}
pub struct Db<S: VersionStore = MemoryStore> {
inner: Arc<Inner<S>>,
}
impl Db<MemoryStore> {
#[must_use]
pub fn new() -> Self {
Db::with_store(MemoryStore::new())
}
}
impl Default for Db<MemoryStore> {
fn default() -> Self {
Db::new()
}
}
impl<S: VersionStore> Db<S> {
#[must_use]
pub fn with_store(store: S) -> Self {
Db {
inner: Arc::new(Inner::new(store)),
}
}
pub fn begin(&self) -> Transaction<S> {
Transaction::new(Arc::clone(&self.inner), self.inner.read_ts())
}
#[must_use]
pub fn snapshot(&self) -> Snapshot<S> {
Snapshot::new(Arc::clone(&self.inner), self.inner.read_ts())
}
#[must_use]
pub fn last_committed(&self) -> Timestamp {
self.inner.read_ts()
}
}
impl<S: VersionStore> Clone for Db<S> {
fn clone(&self) -> Self {
Db {
inner: Arc::clone(&self.inner),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_new_database_is_empty_at_zero() {
let db = Db::new();
assert_eq!(db.last_committed(), Timestamp::ZERO);
assert_eq!(db.begin().get(b"k").unwrap(), None);
}
#[test]
fn test_commit_makes_writes_visible_to_later_transactions() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v".to_vec());
let ts = tx.commit().unwrap();
assert!(ts > Timestamp::ZERO);
assert_eq!(db.begin().get(b"k").unwrap().as_deref(), Some(&b"v"[..]));
}
#[test]
fn test_snapshot_is_isolated_from_later_commits() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v1".to_vec());
let _ = tx.commit().unwrap();
let snap = db.snapshot();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v2".to_vec());
let _ = tx.commit().unwrap();
assert_eq!(snap.get(b"k").unwrap().as_deref(), Some(&b"v1"[..]));
}
#[test]
fn test_write_write_conflict_aborts_later_committer() {
let db = Db::new();
let mut a = db.begin();
let mut b = db.begin();
a.put(b"k".to_vec(), b"a".to_vec());
b.put(b"k".to_vec(), b"b".to_vec());
assert!(a.commit().is_ok());
let err = b.commit().expect_err("second committer must lose");
assert!(err.is_retryable());
assert_eq!(db.begin().get(b"k").unwrap().as_deref(), Some(&b"a"[..]));
}
#[test]
fn test_disjoint_keys_do_not_conflict() {
let db = Db::new();
let mut a = db.begin();
let mut b = db.begin();
a.put(b"a".to_vec(), b"1".to_vec());
b.put(b"b".to_vec(), b"2".to_vec());
assert!(a.commit().is_ok());
assert!(b.commit().is_ok());
}
#[test]
fn test_read_only_commit_returns_snapshot_timestamp() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v".to_vec());
let ts = tx.commit().unwrap();
let ro = db.begin();
assert_eq!(ro.commit().unwrap(), ts);
}
#[test]
fn test_rollback_discards_writes() {
let db = Db::new();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v".to_vec());
tx.rollback();
assert_eq!(db.begin().get(b"k").unwrap(), None);
}
#[test]
fn test_clone_shares_state() {
let db = Db::new();
let db2 = db.clone();
let mut tx = db.begin();
tx.put(b"k".to_vec(), b"v".to_vec());
let _ = tx.commit().unwrap();
assert_eq!(db2.begin().get(b"k").unwrap().as_deref(), Some(&b"v"[..]));
}
}