use crate::error::{KernelError, KernelResult, TransactionErrorKind};
use crate::wal::LogSequenceNumber;
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
pub type TransactionId = u64;
pub type Timestamp = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
#[default]
SnapshotIsolation,
Serializable,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
Active,
Preparing,
Committed,
Aborted,
}
#[derive(Debug)]
struct TransactionInfo {
id: TransactionId,
state: TransactionState,
snapshot_ts: Timestamp,
commit_ts: Option<Timestamp>,
isolation: IsolationLevel,
start_time: Instant,
last_lsn: Option<LogSequenceNumber>,
read_set: Vec<(u32, u64)>, write_set: Vec<(u32, u64)>, }
pub struct TxnManager {
next_txn_id: AtomicU64,
current_ts: AtomicU64,
active_txns: RwLock<HashMap<TransactionId, TransactionInfo>>,
timeout: Duration,
commit_lock: Mutex<()>,
}
impl Default for TxnManager {
fn default() -> Self {
Self::new()
}
}
impl TxnManager {
pub fn new() -> Self {
Self::with_timeout(Duration::from_secs(60))
}
pub fn with_timeout(timeout: Duration) -> Self {
Self {
next_txn_id: AtomicU64::new(1),
current_ts: AtomicU64::new(1),
active_txns: RwLock::new(HashMap::new()),
timeout,
commit_lock: Mutex::new(()),
}
}
pub fn begin(&self) -> TransactionId {
self.begin_with_isolation(IsolationLevel::default())
}
pub fn begin_with_isolation(&self, isolation: IsolationLevel) -> TransactionId {
let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
let snapshot_ts = self.current_ts.load(Ordering::SeqCst);
let info = TransactionInfo {
id: txn_id,
state: TransactionState::Active,
snapshot_ts,
commit_ts: None,
isolation,
start_time: Instant::now(),
last_lsn: None,
read_set: Vec::new(),
write_set: Vec::new(),
};
self.active_txns.write().insert(txn_id, info);
txn_id
}
pub fn commit(&self, txn_id: TransactionId) -> KernelResult<Timestamp> {
let _guard = self.commit_lock.lock();
let mut txns = self.active_txns.write();
let (current_state, isolation, read_set, write_set) = {
let info = txns.get(&txn_id).ok_or(KernelError::Transaction {
kind: TransactionErrorKind::NotFound(txn_id),
})?;
(
info.state,
info.isolation,
info.read_set.clone(),
info.write_set.clone(),
)
};
match current_state {
TransactionState::Active | TransactionState::Preparing => {
if isolation == IsolationLevel::Serializable {
self.check_serialization_conflicts_cloned(&read_set, &write_set)?;
}
let info = txns.get_mut(&txn_id).unwrap();
let commit_ts = self.current_ts.fetch_add(1, Ordering::SeqCst);
info.commit_ts = Some(commit_ts);
info.state = TransactionState::Committed;
Ok(commit_ts)
}
TransactionState::Committed => Err(KernelError::Transaction {
kind: TransactionErrorKind::AlreadyCommitted,
}),
TransactionState::Aborted => Err(KernelError::Transaction {
kind: TransactionErrorKind::AlreadyAborted,
}),
}
}
pub fn abort(&self, txn_id: TransactionId) -> KernelResult<()> {
let mut txns = self.active_txns.write();
let info = txns.get_mut(&txn_id).ok_or(KernelError::Transaction {
kind: TransactionErrorKind::NotFound(txn_id),
})?;
match info.state {
TransactionState::Active | TransactionState::Preparing => {
info.state = TransactionState::Aborted;
Ok(())
}
TransactionState::Committed => Err(KernelError::Transaction {
kind: TransactionErrorKind::AlreadyCommitted,
}),
TransactionState::Aborted => Ok(()), }
}
pub fn is_active(&self, txn_id: TransactionId) -> bool {
self.active_txns
.read()
.get(&txn_id)
.map(|info| info.state == TransactionState::Active)
.unwrap_or(false)
}
pub fn snapshot_ts(&self, txn_id: TransactionId) -> KernelResult<Timestamp> {
self.active_txns
.read()
.get(&txn_id)
.map(|info| info.snapshot_ts)
.ok_or(KernelError::Transaction {
kind: TransactionErrorKind::NotFound(txn_id),
})
}
pub fn record_read(&self, txn_id: TransactionId, table_id: u32, row_id: u64) {
if let Some(info) = self.active_txns.write().get_mut(&txn_id)
&& info.isolation == IsolationLevel::Serializable
{
info.read_set.push((table_id, row_id));
}
}
pub fn record_write(&self, txn_id: TransactionId, table_id: u32, row_id: u64) {
if let Some(info) = self.active_txns.write().get_mut(&txn_id) {
info.write_set.push((table_id, row_id));
}
}
pub fn set_last_lsn(&self, txn_id: TransactionId, lsn: LogSequenceNumber) {
if let Some(info) = self.active_txns.write().get_mut(&txn_id) {
info.last_lsn = Some(lsn);
}
}
pub fn min_active_snapshot(&self) -> Option<Timestamp> {
self.active_txns
.read()
.values()
.filter(|info| info.state == TransactionState::Active)
.map(|info| info.snapshot_ts)
.min()
}
pub fn active_count(&self) -> usize {
self.active_txns
.read()
.values()
.filter(|info| info.state == TransactionState::Active)
.count()
}
pub fn cleanup(&self, retention: Duration) {
let now = Instant::now();
self.active_txns.write().retain(|_, info| {
if info.state == TransactionState::Active {
return true;
}
now.duration_since(info.start_time) < retention
});
}
pub fn check_timeouts(&self) -> Vec<TransactionId> {
let now = Instant::now();
self.active_txns
.read()
.values()
.filter(|info| {
info.state == TransactionState::Active
&& now.duration_since(info.start_time) > self.timeout
})
.map(|info| info.id)
.collect()
}
#[allow(dead_code)]
fn check_serialization_conflicts(
&self,
txn: &TransactionInfo,
_all_txns: &HashMap<TransactionId, TransactionInfo>,
) -> KernelResult<()> {
let _ = txn;
Ok(())
}
fn check_serialization_conflicts_cloned(
&self,
_read_set: &[(u32, u64)],
_write_set: &[(u32, u64)],
) -> KernelResult<()> {
Ok(())
}
pub fn current_timestamp(&self) -> Timestamp {
self.current_ts.load(Ordering::SeqCst)
}
pub fn restore(&self, next_txn_id: TransactionId, current_ts: Timestamp) {
self.next_txn_id.store(next_txn_id, Ordering::SeqCst);
self.current_ts.store(current_ts, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_begin_commit() {
let mgr = TxnManager::new();
let txn1 = mgr.begin();
assert!(mgr.is_active(txn1));
let commit_ts = mgr.commit(txn1).unwrap();
assert!(!mgr.is_active(txn1));
assert!(commit_ts > 0);
}
#[test]
fn test_begin_abort() {
let mgr = TxnManager::new();
let txn1 = mgr.begin();
assert!(mgr.is_active(txn1));
mgr.abort(txn1).unwrap();
assert!(!mgr.is_active(txn1));
}
#[test]
fn test_snapshot_isolation() {
let mgr = TxnManager::new();
let txn1 = mgr.begin();
let ts1 = mgr.snapshot_ts(txn1).unwrap();
mgr.commit(txn1).unwrap();
let txn2 = mgr.begin();
let ts2 = mgr.snapshot_ts(txn2).unwrap();
assert!(ts2 >= ts1);
}
#[test]
fn test_double_commit_fails() {
let mgr = TxnManager::new();
let txn1 = mgr.begin();
mgr.commit(txn1).unwrap();
assert!(mgr.commit(txn1).is_err());
}
#[test]
fn test_min_active_snapshot() {
let mgr = TxnManager::new();
let txn1 = mgr.begin();
let txn2 = mgr.begin();
let min = mgr.min_active_snapshot().unwrap();
assert_eq!(min, mgr.snapshot_ts(txn1).unwrap());
mgr.commit(txn1).unwrap();
let min = mgr.min_active_snapshot().unwrap();
assert_eq!(min, mgr.snapshot_ts(txn2).unwrap());
}
}