use super::lock::{LockManager, LockMode, LockResult};
use super::log::{TransactionLog, WalConfig};
use super::savepoint::TxnSavepoints;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::time::{Duration, Instant};
pub type TxnId = u64;
pub type Timestamp = u64;
fn tx_lock_error(context: &'static str) -> TxnError {
TxnError::Internal(format!("{context} lock poisoned"))
}
fn read_guard_or_err<'a, T>(
lock: &'a RwLock<T>,
context: &'static str,
) -> Result<RwLockReadGuard<'a, T>, TxnError> {
lock.read().map_err(|_| tx_lock_error(context))
}
fn write_guard_or_err<'a, T>(
lock: &'a RwLock<T>,
context: &'static str,
) -> Result<RwLockWriteGuard<'a, T>, TxnError> {
lock.write().map_err(|_| tx_lock_error(context))
}
fn recover_read_guard<'a, T>(lock: &'a RwLock<T>) -> RwLockReadGuard<'a, T> {
match lock.read() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
fn recover_write_guard<'a, T>(lock: &'a RwLock<T>) -> RwLockWriteGuard<'a, T> {
match lock.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
#[derive(Debug, Clone)]
pub enum TxnError {
NotFound(TxnId),
AlreadyCommitted(TxnId),
AlreadyAborted(TxnId),
WriteConflict { key: Vec<u8>, holder: TxnId },
Deadlock(Vec<TxnId>),
LockLimitExceeded { limit: usize },
LockTimeout { key: Vec<u8>, timeout: Duration },
ValidationFailed {
key: Vec<u8>,
expected_ts: Timestamp,
actual_ts: Timestamp,
},
LogError(String),
SavepointNotFound(String),
Timeout(TxnId),
Internal(String),
}
impl std::fmt::Display for TxnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TxnError::NotFound(id) => write!(f, "Transaction {} not found", id),
TxnError::AlreadyCommitted(id) => write!(f, "Transaction {} already committed", id),
TxnError::AlreadyAborted(id) => write!(f, "Transaction {} already aborted", id),
TxnError::WriteConflict { key, holder } => {
write!(f, "Write conflict on {:?}, held by txn {}", key, holder)
}
TxnError::Deadlock(cycle) => write!(f, "Deadlock detected: {:?}", cycle),
TxnError::LockLimitExceeded { limit } => {
write!(f, "Lock limit exceeded: max {}", limit)
}
TxnError::LockTimeout { key, timeout } => {
write!(f, "Lock timeout on {:?} after {:?}", key, timeout)
}
TxnError::ValidationFailed {
key,
expected_ts,
actual_ts,
} => {
write!(
f,
"Validation failed for {:?}: expected ts {}, actual {}",
key, expected_ts, actual_ts
)
}
TxnError::LogError(msg) => write!(f, "WAL error: {}", msg),
TxnError::SavepointNotFound(name) => write!(f, "Savepoint '{}' not found", name),
TxnError::Timeout(id) => write!(f, "Transaction {} timed out", id),
TxnError::Internal(msg) => write!(f, "Internal error: {}", msg),
}
}
}
impl std::error::Error for TxnError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TxnState {
Active,
Preparing,
Committed,
Aborted,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
#[default]
SnapshotIsolation,
Serializable,
}
#[derive(Debug, Clone)]
pub struct TxnConfig {
pub isolation_level: IsolationLevel,
pub lock_timeout: Duration,
pub txn_timeout: Duration,
pub optimistic: bool,
pub wal_enabled: bool,
pub wal_sync: WalSyncMode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WalSyncMode {
EveryCommit,
Periodic(Duration),
None,
}
impl TxnConfig {
pub fn new() -> Self {
Self {
isolation_level: IsolationLevel::SnapshotIsolation,
lock_timeout: Duration::from_secs(30),
txn_timeout: Duration::from_secs(300),
optimistic: true,
wal_enabled: true,
wal_sync: WalSyncMode::EveryCommit,
}
}
pub fn with_isolation(mut self, level: IsolationLevel) -> Self {
self.isolation_level = level;
self
}
pub fn with_lock_timeout(mut self, timeout: Duration) -> Self {
self.lock_timeout = timeout;
self
}
pub fn with_optimistic(mut self, enabled: bool) -> Self {
self.optimistic = enabled;
self
}
}
impl Default for TxnConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TxnHandle {
pub id: TxnId,
pub start_ts: Timestamp,
pub isolation: IsolationLevel,
}
impl TxnHandle {
pub fn id(&self) -> TxnId {
self.id
}
pub fn start_ts(&self) -> Timestamp {
self.start_ts
}
}
struct TransactionState {
handle: TxnHandle,
state: TxnState,
start_time: Instant,
read_set: Vec<(Vec<u8>, Timestamp)>,
write_set: Vec<WriteEntry>,
savepoints: TxnSavepoints,
locks_held: Vec<Vec<u8>>,
}
#[derive(Debug, Clone)]
struct WriteEntry {
key: Vec<u8>,
old_value: Option<Vec<u8>>,
new_value: Option<Vec<u8>>,
timestamp: Timestamp,
}
pub struct Transaction {
id: TxnId,
coordinator: Arc<TransactionManager>,
}
impl Transaction {
pub fn id(&self) -> TxnId {
self.id
}
pub fn record_read(&self, key: &[u8], read_ts: Timestamp) {
self.coordinator.record_read(self.id, key, read_ts);
}
pub fn record_write(&self, key: &[u8], old_value: Option<&[u8]>, new_value: Option<&[u8]>) {
self.coordinator
.record_write(self.id, key, old_value, new_value);
}
pub fn savepoint(&self, name: &str) -> Result<(), TxnError> {
self.coordinator.create_savepoint(self.id, name)
}
pub fn rollback_to(&self, name: &str) -> Result<(), TxnError> {
self.coordinator.rollback_to_savepoint(self.id, name)
}
pub fn commit(self) -> Result<(), TxnError> {
self.coordinator.commit(self.id)
}
pub fn abort(self) -> Result<(), TxnError> {
self.coordinator.abort(self.id)
}
}
pub struct TransactionManager {
config: TxnConfig,
next_id: AtomicU64,
current_ts: AtomicU64,
transactions: RwLock<HashMap<TxnId, TransactionState>>,
lock_manager: LockManager,
log: Option<TransactionLog>,
committed_ts: RwLock<HashMap<Vec<u8>, Timestamp>>,
}
impl TransactionManager {
pub fn new(config: TxnConfig) -> Self {
let log = if config.wal_enabled {
Some(TransactionLog::new(WalConfig::default()))
} else {
None
};
Self {
config,
next_id: AtomicU64::new(1),
current_ts: AtomicU64::new(1),
transactions: RwLock::new(HashMap::new()),
lock_manager: LockManager::with_defaults(),
log: log.and_then(|r| r.ok()),
committed_ts: RwLock::new(HashMap::new()),
}
}
pub fn with_default_config() -> Self {
Self::new(TxnConfig::default())
}
pub fn config(&self) -> &TxnConfig {
&self.config
}
fn next_timestamp(&self) -> Timestamp {
self.current_ts.fetch_add(1, Ordering::SeqCst)
}
pub fn begin(&self) -> TxnHandle {
self.begin_with_isolation(self.config.isolation_level)
}
pub fn begin_with_isolation(&self, isolation: IsolationLevel) -> TxnHandle {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let start_ts = self.next_timestamp();
let handle = TxnHandle {
id,
start_ts,
isolation,
};
let state = TransactionState {
handle: handle.clone(),
state: TxnState::Active,
start_time: Instant::now(),
read_set: Vec::new(),
write_set: Vec::new(),
savepoints: TxnSavepoints::new(id),
locks_held: Vec::new(),
};
if let Some(ref log) = self.log {
let _ = log.log_begin(id);
}
recover_write_guard(&self.transactions).insert(id, state);
handle
}
pub fn begin_transaction(self: &Arc<Self>) -> Transaction {
let handle = self.begin();
Transaction {
id: handle.id,
coordinator: Arc::clone(self),
}
}
pub fn record_read(&self, txn_id: TxnId, key: &[u8], read_ts: Timestamp) {
let mut txns = recover_write_guard(&self.transactions);
if let Some(state) = txns.get_mut(&txn_id) {
if state.state == TxnState::Active {
state.read_set.push((key.to_vec(), read_ts));
}
}
}
pub fn record_write(
&self,
txn_id: TxnId,
key: &[u8],
old_value: Option<&[u8]>,
new_value: Option<&[u8]>,
) {
let timestamp = self.next_timestamp();
let mut txns = recover_write_guard(&self.transactions);
if let Some(state) = txns.get_mut(&txn_id) {
if state.state == TxnState::Active {
let entry = WriteEntry {
key: key.to_vec(),
old_value: old_value.map(|v| v.to_vec()),
new_value: new_value.map(|v| v.to_vec()),
timestamp,
};
if let Some(ref log) = self.log {
if let Some(old) = old_value {
if let Some(new) = new_value {
let _ =
log.log_update(txn_id, key.to_vec(), old.to_vec(), new.to_vec());
} else {
let _ = log.log_delete(txn_id, key.to_vec(), old.to_vec());
}
} else if let Some(new) = new_value {
let _ = log.log_insert(txn_id, key.to_vec(), new.to_vec());
}
}
state.write_set.push(entry);
}
}
}
pub fn acquire_lock(&self, txn_id: TxnId, key: &[u8], mode: LockMode) -> Result<(), TxnError> {
{
let txns = read_guard_or_err(&self.transactions, "transaction manager state")?;
let state = txns.get(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
if state.state != TxnState::Active {
return Err(TxnError::AlreadyAborted(txn_id));
}
}
match self
.lock_manager
.acquire_with_timeout(txn_id, key, mode, self.config.lock_timeout)
{
LockResult::Granted | LockResult::Upgraded | LockResult::AlreadyHeld => {
let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
if let Some(state) = txns.get_mut(&txn_id) {
if !state.locks_held.contains(&key.to_vec()) {
state.locks_held.push(key.to_vec());
}
}
Ok(())
}
LockResult::Waiting => {
Err(TxnError::Internal(
"Lock returned Waiting unexpectedly".to_string(),
))
}
LockResult::Timeout => Err(TxnError::LockTimeout {
key: key.to_vec(),
timeout: self.config.lock_timeout,
}),
LockResult::Deadlock(cycle) => Err(TxnError::Deadlock(cycle)),
LockResult::LockLimitExceeded => Err(TxnError::LockLimitExceeded {
limit: self.lock_manager.config().max_locks_per_txn,
}),
LockResult::TxnNotFound => Err(TxnError::NotFound(txn_id)),
}
}
fn release_locks(&self, txn_id: TxnId) {
let locks = {
let txns = recover_read_guard(&self.transactions);
txns.get(&txn_id)
.map(|s| s.locks_held.clone())
.unwrap_or_default()
};
for key in locks {
self.lock_manager.release(txn_id, &key);
}
}
fn validate(&self, txn_id: TxnId) -> Result<(), TxnError> {
let txns = read_guard_or_err(&self.transactions, "transaction manager state")?;
let state = txns.get(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
if !self.config.optimistic {
return Ok(());
}
let committed = read_guard_or_err(&self.committed_ts, "transaction manager committed_ts")?;
for (key, read_ts) in &state.read_set {
if let Some(&commit_ts) = committed.get(key) {
if commit_ts > *read_ts && commit_ts > state.handle.start_ts {
return Err(TxnError::ValidationFailed {
key: key.clone(),
expected_ts: *read_ts,
actual_ts: commit_ts,
});
}
}
}
Ok(())
}
pub fn commit(&self, txn_id: TxnId) -> Result<(), TxnError> {
self.validate(txn_id)?;
let commit_ts = self.next_timestamp();
{
let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
let state = txns.get_mut(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
match state.state {
TxnState::Active | TxnState::Preparing => {
state.state = TxnState::Committed;
}
TxnState::Committed => return Err(TxnError::AlreadyCommitted(txn_id)),
TxnState::Aborted => return Err(TxnError::AlreadyAborted(txn_id)),
}
let mut committed =
write_guard_or_err(&self.committed_ts, "transaction manager committed_ts")?;
for entry in &state.write_set {
committed.insert(entry.key.clone(), commit_ts);
}
}
if let Some(ref log) = self.log {
let _ = log.log_commit(txn_id);
if matches!(self.config.wal_sync, WalSyncMode::EveryCommit) {
let _ = log.flush();
}
}
self.release_locks(txn_id);
Ok(())
}
pub fn abort(&self, txn_id: TxnId) -> Result<(), TxnError> {
{
let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
let state = txns.get_mut(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
match state.state {
TxnState::Active | TxnState::Preparing => {
state.state = TxnState::Aborted;
}
TxnState::Committed => return Err(TxnError::AlreadyCommitted(txn_id)),
TxnState::Aborted => return Err(TxnError::AlreadyAborted(txn_id)),
}
}
if let Some(ref log) = self.log {
let _ = log.log_abort(txn_id);
}
self.release_locks(txn_id);
Ok(())
}
pub fn create_savepoint(&self, txn_id: TxnId, name: &str) -> Result<(), TxnError> {
let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
let state = txns.get_mut(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
if state.state != TxnState::Active {
return Err(TxnError::AlreadyAborted(txn_id));
}
let write_set_index = state.write_set.len();
let lock_count = state.locks_held.len();
state
.savepoints
.create(name.to_string(), 0, lock_count, write_set_index);
Ok(())
}
pub fn rollback_to_savepoint(&self, txn_id: TxnId, name: &str) -> Result<(), TxnError> {
let mut txns = write_guard_or_err(&self.transactions, "transaction manager state")?;
let state = txns.get_mut(&txn_id).ok_or(TxnError::NotFound(txn_id))?;
if state.state != TxnState::Active {
return Err(TxnError::AlreadyAborted(txn_id));
}
let savepoint = state
.savepoints
.get(name)
.ok_or_else(|| TxnError::SavepointNotFound(name.to_string()))?;
state.write_set.truncate(savepoint.write_set_index);
state.savepoints.release(name);
Ok(())
}
pub fn get_state(&self, txn_id: TxnId) -> Option<TxnState> {
self.transactions
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.get(&txn_id)
.map(|s| s.state)
}
pub fn is_active(&self, txn_id: TxnId) -> bool {
self.get_state(txn_id) == Some(TxnState::Active)
}
pub fn active_count(&self) -> usize {
self.transactions
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.values()
.filter(|s| s.state == TxnState::Active)
.count()
}
pub fn oldest_active_ts(&self) -> Option<Timestamp> {
self.transactions
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.values()
.filter(|s| s.state == TxnState::Active)
.map(|s| s.handle.start_ts)
.min()
}
pub fn cleanup(&self, max_age: Duration) {
let mut txns = recover_write_guard(&self.transactions);
let now = Instant::now();
txns.retain(|_, state| {
if state.state == TxnState::Active {
true
} else {
now.duration_since(state.start_time) < max_age
}
});
}
}
impl Default for TransactionManager {
fn default() -> Self {
Self::with_default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_begin_commit() {
let tm = TransactionManager::with_default_config();
let handle = tm.begin();
assert!(tm.is_active(handle.id));
tm.commit(handle.id).unwrap();
assert!(!tm.is_active(handle.id));
assert_eq!(tm.get_state(handle.id), Some(TxnState::Committed));
}
#[test]
fn test_begin_abort() {
let tm = TransactionManager::with_default_config();
let handle = tm.begin();
assert!(tm.is_active(handle.id));
tm.abort(handle.id).unwrap();
assert!(!tm.is_active(handle.id));
assert_eq!(tm.get_state(handle.id), Some(TxnState::Aborted));
}
#[test]
fn test_double_commit() {
let tm = TransactionManager::with_default_config();
let handle = tm.begin();
tm.commit(handle.id).unwrap();
assert!(matches!(
tm.commit(handle.id),
Err(TxnError::AlreadyCommitted(_))
));
}
#[test]
fn test_transaction_wrapper() {
let tm = Arc::new(TransactionManager::with_default_config());
let txn = tm.begin_transaction();
let id = txn.id();
txn.record_write(b"key1", None, Some(b"value1"));
txn.commit().unwrap();
assert!(!tm.is_active(id));
}
#[test]
fn test_savepoints() {
let tm = TransactionManager::with_default_config();
let handle = tm.begin();
tm.record_write(handle.id, b"key1", None, Some(b"v1"));
tm.create_savepoint(handle.id, "sp1").unwrap();
tm.record_write(handle.id, b"key2", None, Some(b"v2"));
tm.record_write(handle.id, b"key3", None, Some(b"v3"));
tm.rollback_to_savepoint(handle.id, "sp1").unwrap();
tm.commit(handle.id).unwrap();
}
#[test]
fn test_isolation_levels() {
let tm = TransactionManager::with_default_config();
let h1 = tm.begin_with_isolation(IsolationLevel::ReadCommitted);
let h2 = tm.begin_with_isolation(IsolationLevel::SnapshotIsolation);
assert_eq!(h1.isolation, IsolationLevel::ReadCommitted);
assert_eq!(h2.isolation, IsolationLevel::SnapshotIsolation);
tm.abort(h1.id).unwrap();
tm.abort(h2.id).unwrap();
}
#[test]
fn test_active_count() {
let tm = TransactionManager::with_default_config();
assert_eq!(tm.active_count(), 0);
let h1 = tm.begin();
let h2 = tm.begin();
assert_eq!(tm.active_count(), 2);
tm.commit(h1.id).unwrap();
assert_eq!(tm.active_count(), 1);
tm.abort(h2.id).unwrap();
assert_eq!(tm.active_count(), 0);
}
#[test]
fn test_oldest_active_ts() {
let tm = TransactionManager::with_default_config();
let h1 = tm.begin();
let ts1 = h1.start_ts;
let _h2 = tm.begin();
assert_eq!(tm.oldest_active_ts(), Some(ts1));
}
#[test]
fn test_config() {
let config = TxnConfig::new()
.with_isolation(IsolationLevel::Serializable)
.with_lock_timeout(Duration::from_secs(10))
.with_optimistic(false);
assert_eq!(config.isolation_level, IsolationLevel::Serializable);
assert_eq!(config.lock_timeout, Duration::from_secs(10));
assert!(!config.optimistic);
}
}