use crate::db::IsolationLevel;
use crate::error::FluxError;
use dashmap::{DashMap, DashSet};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
pub type TxId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionStatus {
Active,
Committed,
Aborted,
}
pub type Workspace<K, V> = HashMap<K, Option<Arc<V>>>;
pub struct TransactionManager<K: Eq + std::hash::Hash, V> {
next_txid: AtomicU64,
statuses: DashMap<TxId, TransactionStatus>,
active_transactions: DashMap<TxId, Arc<Transaction<K, V>>>,
pub min_retainable_txid: AtomicU64,
pub read_trackers: DashMap<K, DashSet<TxId>>,
_phantom: std::marker::PhantomData<(K, V)>,
}
impl<K, V> TransactionManager<K, V>
where
K: Eq + std::hash::Hash + Clone + Serialize + DeserializeOwned,
V: Clone + Serialize + DeserializeOwned,
{
pub fn new() -> Self {
Self {
next_txid: AtomicU64::new(1),
statuses: DashMap::new(),
active_transactions: DashMap::new(),
min_retainable_txid: AtomicU64::new(0), read_trackers: DashMap::new(), _phantom: std::marker::PhantomData,
}
}
pub fn new_txid(&self) -> TxId {
self.next_txid.fetch_add(1, Ordering::SeqCst)
}
pub fn begin(&self) -> Arc<Transaction<K, V>> {
let txid = self.new_txid();
self.statuses.insert(txid, TransactionStatus::Active);
let snapshot = self.create_snapshot(txid);
let tx = Arc::new(Transaction::new(txid, snapshot));
self.active_transactions.insert(txid, tx.clone());
tx
}
pub fn commit<F>(
&self,
tx: &Transaction<K, V>,
on_pre_commit: F,
isolation_level: IsolationLevel,
) -> Result<(), FluxError>
where
F: FnOnce() -> Result<(), FluxError>,
K: Ord + std::borrow::Borrow<str>,
{
if isolation_level == IsolationLevel::Serializable {
if tx.in_conflict.load(Ordering::Acquire) {
self.abort(tx);
return Err(FluxError::SerializationConflict);
}
for inserted_key in tx.insert_set.iter() {
for other_tx_entry in self.active_transactions.iter() {
let other_tx = other_tx_entry.value();
if other_tx.id == tx.id {
continue;
}
let other_ranges = other_tx.range_scans.read().unwrap();
for (start, end) in other_ranges.iter() {
if inserted_key.key() >= start && inserted_key.key() <= end {
other_tx.in_conflict.store(true, Ordering::Release);
break; }
}
if other_tx.in_conflict.load(Ordering::Relaxed) {
continue;
}
let other_prefixes = other_tx.prefix_scans.read().unwrap();
for prefix in other_prefixes.iter() {
if inserted_key.key().borrow().starts_with(prefix) {
other_tx.in_conflict.store(true, Ordering::Release);
break; }
}
}
}
for written_key in tx.write_set.iter() {
let key: &K = written_key.key();
if let Some(reader_tx_ids) = self.read_trackers.get(key.borrow()) {
for reader_tx_id in reader_tx_ids.iter() {
if *reader_tx_id == tx.id {
continue;
}
if let Some(reader_tx_entry) = self.active_transactions.get(&reader_tx_id) {
reader_tx_entry
.value()
.in_conflict
.store(true, Ordering::Release);
}
}
}
}
}
if let Err(e) = on_pre_commit() {
self.abort(tx);
return Err(e);
}
self.cleanup_read_trackers(tx);
self.statuses.insert(tx.id, TransactionStatus::Committed);
self.active_transactions.remove(&tx.id);
Ok(())
}
pub fn abort(&self, tx: &Transaction<K, V>) {
self.cleanup_read_trackers(tx);
self.statuses.insert(tx.id, TransactionStatus::Aborted);
self.active_transactions.remove(&tx.id);
}
pub fn get_active_txids(&self) -> HashSet<TxId> {
self.active_transactions
.iter()
.map(|entry| *entry.key())
.collect()
}
pub fn get_current_txid(&self) -> TxId {
self.next_txid.load(Ordering::SeqCst)
}
pub fn get_status(&self, txid: TxId) -> Option<TransactionStatus> {
self.statuses.get(&txid).map(|s| *s)
}
pub fn prune_statuses(&self) {
let min_txid = self.min_retainable_txid.load(Ordering::Acquire);
self.statuses.retain(|&txid, _| txid >= min_txid);
}
pub fn statuses_len(&self) -> usize {
self.statuses.len()
}
fn cleanup_read_trackers(&self, tx: &Transaction<K, V>) {
use dashmap::mapref::entry::Entry;
for read_key in tx.read_set.iter() {
if let Entry::Occupied(mut o) = self.read_trackers.entry(read_key.key().clone()) {
o.get_mut().remove(&tx.id);
if o.get().is_empty() {
o.remove();
}
}
}
}
fn create_snapshot(&self, txid: TxId) -> Snapshot {
let xmax = self.next_txid.load(Ordering::SeqCst);
let active_txids: HashSet<u64> = self
.active_transactions
.iter()
.map(|entry| *entry.key())
.collect();
let xmin = active_txids.iter().min().copied().unwrap_or(xmax);
Snapshot {
txid,
xmin,
xmax,
xip: active_txids,
}
}
}
#[derive(Debug, Clone)]
pub struct Snapshot {
pub txid: TxId,
pub xmin: TxId,
pub xmax: TxId,
pub xip: HashSet<TxId>,
}
impl Snapshot {
pub fn is_visible<K, V, F>(
&self,
version: &Version<V>,
tx_manager: &TransactionManager<K, F>,
) -> bool
where
K: Eq + std::hash::Hash + Clone + Serialize + DeserializeOwned,
F: Clone + Serialize + DeserializeOwned,
{
let creator_visible = if version.creator_txid == self.txid {
false
} else if version.creator_txid >= self.xmax {
false
} else if self.xip.contains(&version.creator_txid) {
false
} else {
match tx_manager.get_status(version.creator_txid) {
Some(status) => status == TransactionStatus::Committed,
None => version.creator_txid < self.xmin, }
};
if !creator_visible {
return false;
}
let expirer_id = version.expirer_txid.load(Ordering::Acquire);
if expirer_id == 0 {
return true; }
if expirer_id == self.txid {
return false; }
let expirer_visible = if expirer_id >= self.xmax {
false
} else if self.xip.contains(&expirer_id) {
false
} else {
match tx_manager.get_status(expirer_id) {
Some(status) => status == TransactionStatus::Committed,
None => expirer_id < self.xmin, }
};
!expirer_visible
}
}
#[derive(Debug)]
pub struct Transaction<K: Eq + std::hash::Hash, V> {
pub id: TxId,
pub snapshot: Snapshot,
pub read_set: DashMap<K, TxId>,
pub write_set: DashSet<K>,
pub insert_set: DashSet<K>,
pub range_scans: RwLock<Vec<(K, K)>>,
pub prefix_scans: RwLock<Vec<String>>,
pub in_conflict: AtomicBool,
pub workspace: RwLock<Workspace<K, V>>,
pub savepoints: RwLock<Vec<(String, Workspace<K, V>)>>,
_phantom: std::marker::PhantomData<V>,
}
impl<K, V> Transaction<K, V>
where
K: Eq + std::hash::Hash + Serialize + DeserializeOwned,
V: Serialize + DeserializeOwned,
{
pub fn new(id: TxId, snapshot: Snapshot) -> Self {
Self {
id,
snapshot,
read_set: DashMap::new(),
write_set: DashSet::new(),
insert_set: DashSet::new(),
range_scans: RwLock::new(Vec::new()),
prefix_scans: RwLock::new(Vec::new()),
in_conflict: AtomicBool::new(false),
workspace: RwLock::new(HashMap::new()), savepoints: RwLock::new(Vec::new()),
_phantom: std::marker::PhantomData,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Version<V> {
pub value: V,
pub creator_txid: TxId,
pub expirer_txid: AtomicU64,
}