use std::{
collections::{HashMap, HashSet},
sync::atomic::{AtomicU64, Ordering},
time::{Duration, Instant},
};
use dashmap::DashMap;
use parking_lot::RwLock;
use tracing::{debug, instrument, warn};
use crate::{SlabColumnValue, SlabRowId, Value};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TxPhase {
Active,
Committing,
Committed,
Aborting,
Aborted,
}
#[derive(Debug, Clone)]
pub(crate) struct IndexChange {
pub column: String,
pub old_value: Value,
pub new_value: Value,
}
#[derive(Debug, Clone)]
#[allow(clippy::enum_variant_names)]
pub(crate) enum UndoEntry {
InsertedRow {
table: String,
slab_row_id: SlabRowId,
row_id: u64,
index_entries: Vec<(String, Value)>,
},
UpdatedRow {
table: String,
slab_row_id: SlabRowId,
row_id: u64,
old_values: Vec<SlabColumnValue>,
index_changes: Vec<IndexChange>,
},
DeletedRow {
table: String,
slab_row_id: SlabRowId,
row_id: u64,
old_values: Vec<SlabColumnValue>,
index_entries: Vec<(String, Value)>,
},
}
#[derive(Debug)]
pub struct Transaction {
pub tx_id: u64,
pub phase: TxPhase,
pub started_at_ms: u64,
pub timeout_ms: u64,
pub(crate) undo_log: Vec<UndoEntry>,
pub affected_tables: HashSet<String>,
}
impl Transaction {
#[must_use]
pub fn new(tx_id: u64, timeout_ms: u64) -> Self {
Self {
tx_id,
phase: TxPhase::Active,
started_at_ms: now_epoch_millis(),
timeout_ms,
undo_log: Vec::new(),
affected_tables: HashSet::new(),
}
}
#[must_use]
pub fn is_active(&self) -> bool {
self.phase == TxPhase::Active
}
#[must_use]
pub fn is_expired(&self) -> bool {
now_epoch_millis().saturating_sub(self.started_at_ms) > self.timeout_ms
}
pub(crate) fn record_undo(&mut self, entry: UndoEntry) {
match &entry {
UndoEntry::InsertedRow { table, .. }
| UndoEntry::UpdatedRow { table, .. }
| UndoEntry::DeletedRow { table, .. } => {
self.affected_tables.insert(table.clone());
},
}
self.undo_log.push(entry);
}
}
#[derive(Debug, Clone)]
pub(crate) struct RowLock {
pub table: String,
pub row_id: u64,
pub tx_id: u64,
pub acquired_at_ms: u64,
pub timeout_ms: u64,
}
impl RowLock {
#[must_use]
pub fn is_expired(&self) -> bool {
now_epoch_millis().saturating_sub(self.acquired_at_ms) > self.timeout_ms
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct Deadline {
deadline: Option<Instant>,
}
impl Deadline {
#[must_use]
pub fn from_timeout_ms(timeout_ms: Option<u64>) -> Self {
Self {
deadline: timeout_ms.map(|ms| Instant::now() + Duration::from_millis(ms)),
}
}
#[must_use]
pub const fn never() -> Self {
Self { deadline: None }
}
#[must_use]
#[allow(clippy::trivially_copy_pass_by_ref)]
pub fn is_expired(&self) -> bool {
self.deadline.is_some_and(|d| Instant::now() >= d)
}
#[must_use]
#[allow(dead_code)] #[allow(clippy::cast_possible_truncation)]
#[allow(clippy::trivially_copy_pass_by_ref)]
pub fn remaining_ms(&self) -> Option<u64> {
self.deadline
.map(|d| d.saturating_duration_since(Instant::now()).as_millis() as u64)
}
}
impl Default for Deadline {
fn default() -> Self {
Self::never()
}
}
#[derive(Debug)]
pub(crate) struct RowLockManager {
locks: RwLock<HashMap<(String, u64), RowLock>>,
tx_locks: RwLock<HashMap<u64, Vec<(String, u64)>>>,
pub default_timeout: Duration,
}
impl Default for RowLockManager {
fn default() -> Self {
Self::new()
}
}
impl RowLockManager {
#[must_use]
#[instrument(skip_all)]
pub fn new() -> Self {
Self {
locks: RwLock::new(HashMap::new()),
tx_locks: RwLock::new(HashMap::new()),
default_timeout: Duration::from_secs(30),
}
}
#[must_use]
#[instrument(skip_all, fields(timeout_secs = timeout.as_secs()))]
pub fn with_default_timeout(timeout: Duration) -> Self {
Self {
locks: RwLock::new(HashMap::new()),
tx_locks: RwLock::new(HashMap::new()),
default_timeout: timeout,
}
}
#[allow(clippy::significant_drop_tightening)]
#[instrument(skip(self, rows), fields(tx_id, row_count = rows.len()))]
pub fn try_lock(
&self,
tx_id: u64,
rows: &[(String, u64)],
) -> std::result::Result<(), LockConflictInfo> {
let mut locks = self.locks.write();
let mut tx_locks = self.tx_locks.write();
for (table, row_id) in rows {
let key = (table.clone(), *row_id);
if let Some(existing) = locks.get(&key) {
if !existing.is_expired() && existing.tx_id != tx_id {
warn!(
tx_id,
table = %table,
row_id,
blocking_tx = existing.tx_id,
"lock conflict"
);
return Err(LockConflictInfo {
blocking_tx: existing.tx_id,
table: table.clone(),
row_id: *row_id,
});
}
}
}
let now_ms = now_epoch_millis();
#[allow(clippy::cast_possible_truncation)]
let timeout_ms = self.default_timeout.as_millis() as u64;
for (table, row_id) in rows {
let key = (table.clone(), *row_id);
locks.insert(
key.clone(),
RowLock {
table: table.clone(),
row_id: *row_id,
tx_id,
acquired_at_ms: now_ms,
timeout_ms,
},
);
tx_locks.entry(tx_id).or_default().push(key);
}
Ok(())
}
#[instrument(skip(self), fields(tx_id))]
pub fn release(&self, tx_id: u64) {
let mut locks = self.locks.write();
let mut tx_locks = self.tx_locks.write();
if let Some(keys) = tx_locks.remove(&tx_id) {
for key in keys {
if let Some(lock) = locks.get(&key) {
if lock.tx_id == tx_id {
locks.remove(&key);
}
}
}
}
}
#[must_use]
#[instrument(skip(self), fields(table = %table, row_id))]
pub fn is_locked(&self, table: &str, row_id: u64) -> bool {
let locks = self.locks.read();
let key = (table.to_string(), row_id);
locks.get(&key).is_some_and(|lock| !lock.is_expired())
}
#[must_use]
#[instrument(skip(self), fields(table = %table, row_id))]
pub fn lock_holder(&self, table: &str, row_id: u64) -> Option<u64> {
let locks = self.locks.read();
let key = (table.to_string(), row_id);
locks
.get(&key)
.filter(|lock| !lock.is_expired())
.map(|lock| lock.tx_id)
}
#[allow(clippy::significant_drop_tightening)]
#[instrument(skip(self))]
pub fn cleanup_expired(&self) -> usize {
let mut locks = self.locks.write();
let mut tx_locks = self.tx_locks.write();
let expired: Vec<_> = locks
.iter()
.filter(|(_, lock)| lock.is_expired())
.map(|(k, lock)| (k.clone(), lock.tx_id, lock.table.clone(), lock.row_id))
.collect();
for (key, tx_id, table, row_id) in &expired {
debug!(
tx_id,
table = %table,
row_id,
"expired lock removed"
);
locks.remove(key);
if let Some(tx_keys) = tx_locks.get_mut(tx_id) {
tx_keys.retain(|k| k != key);
}
}
let removed = expired.len();
if removed > 0 {
debug!(count = removed, "cleaned up expired locks");
}
removed
}
#[must_use]
#[instrument(skip(self))]
pub fn active_lock_count(&self) -> usize {
self.locks.read().len()
}
#[must_use]
#[instrument(skip(self), fields(tx_id))]
pub fn locks_held_by(&self, tx_id: u64) -> usize {
self.tx_locks.read().get(&tx_id).map_or(0, Vec::len)
}
}
#[derive(Debug, Clone)]
pub(crate) struct LockConflictInfo {
pub blocking_tx: u64,
pub table: String,
pub row_id: u64,
}
static TX_COUNTER: AtomicU64 = AtomicU64::new(1);
#[derive(Debug)]
pub struct TransactionManager {
transactions: DashMap<u64, Transaction>,
lock_manager: RowLockManager,
pub default_timeout: Duration,
}
impl Default for TransactionManager {
fn default() -> Self {
Self::new()
}
}
impl TransactionManager {
#[must_use]
#[instrument(skip_all)]
pub fn new() -> Self {
Self {
transactions: DashMap::new(),
lock_manager: RowLockManager::new(),
default_timeout: Duration::from_secs(60),
}
}
#[must_use]
#[instrument(skip_all, fields(timeout_secs = timeout.as_secs()))]
pub fn with_timeout(timeout: Duration) -> Self {
Self {
transactions: DashMap::new(),
lock_manager: RowLockManager::with_default_timeout(timeout),
default_timeout: timeout,
}
}
#[must_use]
#[instrument(skip_all, fields(tx_timeout_secs = tx_timeout.as_secs(), lock_timeout_secs = lock_timeout.as_secs()))]
pub fn with_timeouts(tx_timeout: Duration, lock_timeout: Duration) -> Self {
Self {
transactions: DashMap::new(),
lock_manager: RowLockManager::with_default_timeout(lock_timeout),
default_timeout: tx_timeout,
}
}
#[allow(clippy::cast_possible_truncation)]
#[instrument(skip(self))]
pub fn begin(&self) -> u64 {
let tx_id = TX_COUNTER.fetch_add(1, Ordering::Relaxed);
let tx = Transaction::new(tx_id, self.default_timeout.as_millis() as u64);
self.transactions.insert(tx_id, tx);
tx_id
}
#[must_use]
#[instrument(skip(self), fields(tx_id))]
pub fn get(&self, tx_id: u64) -> Option<TxPhase> {
self.transactions.get(&tx_id).map(|r| r.phase)
}
#[must_use]
#[instrument(skip(self), fields(tx_id))]
pub fn is_active(&self, tx_id: u64) -> bool {
self.transactions.get(&tx_id).is_some_and(|r| r.is_active())
}
#[instrument(skip(self), fields(tx_id, phase = ?phase))]
pub fn set_phase(&self, tx_id: u64, phase: TxPhase) -> bool {
if let Some(mut tx) = self.transactions.get_mut(&tx_id) {
tx.phase = phase;
true
} else {
false
}
}
#[allow(clippy::option_if_let_else)]
#[instrument(skip(self, entry), fields(tx_id))]
pub(crate) fn record_undo(&self, tx_id: u64, entry: UndoEntry) -> bool {
if let Some(mut tx) = self.transactions.get_mut(&tx_id) {
tx.record_undo(entry);
true
} else {
false
}
}
#[must_use]
#[instrument(skip(self), fields(tx_id))]
pub(crate) fn get_undo_log(&self, tx_id: u64) -> Option<Vec<UndoEntry>> {
self.transactions.get(&tx_id).map(|r| r.undo_log.clone())
}
#[instrument(skip(self), fields(tx_id))]
pub fn remove(&self, tx_id: u64) {
self.transactions.remove(&tx_id);
}
#[must_use]
pub(crate) const fn lock_manager(&self) -> &RowLockManager {
&self.lock_manager
}
#[instrument(skip(self), fields(tx_id))]
pub fn release_locks(&self, tx_id: u64) {
self.lock_manager.release(tx_id);
}
#[must_use]
#[instrument(skip(self))]
pub fn active_count(&self) -> usize {
self.transactions.iter().filter(|r| r.is_active()).count()
}
#[instrument(skip(self))]
pub fn cleanup_expired(&self) -> usize {
let expired_ids: Vec<u64> = self
.transactions
.iter()
.filter(|entry| entry.value().is_expired())
.map(|entry| *entry.key())
.collect();
for tx_id in &expired_ids {
self.lock_manager.release(*tx_id);
}
for tx_id in &expired_ids {
self.transactions.remove(tx_id);
}
if !expired_ids.is_empty() {
debug!(count = expired_ids.len(), "cleaned up expired transactions");
}
expired_ids.len()
}
#[instrument(skip(self))]
pub fn cleanup_expired_locks(&self) -> usize {
self.lock_manager.cleanup_expired()
}
#[must_use]
#[instrument(skip(self))]
pub fn active_lock_count(&self) -> usize {
self.lock_manager.active_lock_count()
}
#[must_use]
#[instrument(skip(self), fields(tx_id))]
pub fn locks_held_by(&self, tx_id: u64) -> usize {
self.lock_manager.locks_held_by(tx_id)
}
#[must_use]
#[instrument(skip(self), fields(table = %table, row_id))]
pub fn is_row_locked(&self, table: &str, row_id: u64) -> bool {
self.lock_manager.is_locked(table, row_id)
}
#[must_use]
#[instrument(skip(self), fields(table = %table, row_id))]
pub fn row_lock_holder(&self, table: &str, row_id: u64) -> Option<u64> {
self.lock_manager.lock_holder(table, row_id)
}
}
#[allow(clippy::cast_possible_truncation)]
fn now_epoch_millis() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or(std::time::Duration::ZERO)
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tx_phase_transitions() {
let tx = Transaction::new(1, 60000);
assert_eq!(tx.phase, TxPhase::Active);
assert!(tx.is_active());
}
#[test]
fn test_tx_expiration() {
let mut tx = Transaction::new(1, 0); tx.started_at_ms = 0; assert!(tx.is_expired());
}
#[test]
fn test_tx_record_undo() {
let mut tx = Transaction::new(1, 60000);
tx.record_undo(UndoEntry::InsertedRow {
table: "users".to_string(),
slab_row_id: SlabRowId::new(0),
row_id: 1,
index_entries: vec![],
});
assert_eq!(tx.undo_log.len(), 1);
assert!(tx.affected_tables.contains("users"));
}
#[test]
fn test_row_lock_manager_basic() {
let mgr = RowLockManager::new();
let result = mgr.try_lock(1, &[("users".to_string(), 1)]);
assert!(result.is_ok());
assert!(mgr.is_locked("users", 1));
assert_eq!(mgr.lock_holder("users", 1), Some(1));
let result = mgr.try_lock(2, &[("users".to_string(), 1)]);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.blocking_tx, 1);
let result = mgr.try_lock(1, &[("users".to_string(), 1)]);
assert!(result.is_ok());
}
#[test]
fn test_row_lock_manager_release() {
let mgr = RowLockManager::new();
mgr.try_lock(1, &[("users".to_string(), 1), ("users".to_string(), 2)])
.unwrap();
assert_eq!(mgr.locks_held_by(1), 2);
mgr.release(1);
assert!(!mgr.is_locked("users", 1));
assert!(!mgr.is_locked("users", 2));
assert_eq!(mgr.locks_held_by(1), 0);
}
#[test]
fn test_transaction_manager_lifecycle() {
let mgr = TransactionManager::new();
let tx_id = mgr.begin();
assert!(mgr.is_active(tx_id));
assert_eq!(mgr.get(tx_id), Some(TxPhase::Active));
mgr.set_phase(tx_id, TxPhase::Committing);
assert_eq!(mgr.get(tx_id), Some(TxPhase::Committing));
assert!(!mgr.is_active(tx_id));
mgr.set_phase(tx_id, TxPhase::Committed);
mgr.remove(tx_id);
assert_eq!(mgr.get(tx_id), None);
}
#[test]
fn test_transaction_manager_undo_log() {
let mgr = TransactionManager::new();
let tx_id = mgr.begin();
mgr.record_undo(
tx_id,
UndoEntry::InsertedRow {
table: "users".to_string(),
slab_row_id: SlabRowId::new(0),
row_id: 1,
index_entries: vec![],
},
);
let undo_log = mgr.get_undo_log(tx_id).unwrap();
assert_eq!(undo_log.len(), 1);
}
#[test]
fn test_lock_conflict_info() {
let info = LockConflictInfo {
blocking_tx: 1,
table: "users".to_string(),
row_id: 42,
};
assert_eq!(info.blocking_tx, 1);
assert_eq!(info.table, "users");
assert_eq!(info.row_id, 42);
}
#[test]
fn test_row_lock_expiration() {
let lock = RowLock {
table: "users".to_string(),
row_id: 1,
tx_id: 1,
acquired_at_ms: 0, timeout_ms: 1, };
assert!(lock.is_expired());
}
#[test]
fn test_row_lock_manager_multiple_rows() {
let mgr = RowLockManager::new();
let rows = vec![
("users".to_string(), 1),
("users".to_string(), 2),
("orders".to_string(), 1),
];
let result = mgr.try_lock(1, &rows);
assert!(result.is_ok());
assert!(mgr.is_locked("users", 1));
assert!(mgr.is_locked("users", 2));
assert!(mgr.is_locked("orders", 1));
assert!(!mgr.is_locked("users", 3));
}
#[test]
fn test_transaction_manager_active_count() {
let mgr = TransactionManager::new();
let tx1 = mgr.begin();
let tx2 = mgr.begin();
assert_eq!(mgr.active_count(), 2);
mgr.set_phase(tx1, TxPhase::Committed);
assert_eq!(mgr.active_count(), 1);
mgr.set_phase(tx2, TxPhase::Aborted);
assert_eq!(mgr.active_count(), 0);
}
#[test]
fn test_undo_entry_variants() {
let insert = UndoEntry::InsertedRow {
table: "t".to_string(),
slab_row_id: SlabRowId::new(0),
row_id: 1,
index_entries: vec![("col".to_string(), Value::Int(42))],
};
let update = UndoEntry::UpdatedRow {
table: "t".to_string(),
slab_row_id: SlabRowId::new(0),
row_id: 1,
old_values: vec![SlabColumnValue::Int(42)],
index_changes: vec![IndexChange {
column: "col".to_string(),
old_value: Value::Int(42),
new_value: Value::Int(43),
}],
};
let delete = UndoEntry::DeletedRow {
table: "t".to_string(),
slab_row_id: SlabRowId::new(0),
row_id: 1,
old_values: vec![SlabColumnValue::Int(42)],
index_entries: vec![("col".to_string(), Value::Int(42))],
};
match insert {
UndoEntry::InsertedRow { row_id, .. } => assert_eq!(row_id, 1),
_ => panic!("Wrong variant"),
}
match update {
UndoEntry::UpdatedRow { old_values, .. } => assert_eq!(old_values.len(), 1),
_ => panic!("Wrong variant"),
}
match delete {
UndoEntry::DeletedRow { index_entries, .. } => assert_eq!(index_entries.len(), 1),
_ => panic!("Wrong variant"),
}
}
#[test]
fn test_index_change() {
let change = IndexChange {
column: "age".to_string(),
old_value: Value::Int(30),
new_value: Value::Int(31),
};
assert_eq!(change.column, "age");
}
#[test]
fn test_lock_manager_default() {
let mgr = RowLockManager::default();
assert_eq!(mgr.active_lock_count(), 0);
}
#[test]
fn test_transaction_manager_default() {
let mgr = TransactionManager::default();
assert_eq!(mgr.active_count(), 0);
}
#[test]
fn test_cleanup_expired_locks() {
let mgr = RowLockManager::with_default_timeout(Duration::from_millis(1));
mgr.try_lock(1, &[("users".to_string(), 1)]).unwrap();
std::thread::sleep(Duration::from_millis(5));
let cleaned = mgr.cleanup_expired();
assert_eq!(cleaned, 1);
assert!(!mgr.is_locked("users", 1));
}
#[test]
fn test_now_epoch_millis_returns_reasonable_value() {
let now = now_epoch_millis();
assert!(now > 1_577_836_800_000, "epoch time should be after 2020");
assert!(now < 4_102_444_800_000, "epoch time should be before 2100");
}
#[test]
fn test_deadline_from_timeout_ms() {
let deadline = Deadline::from_timeout_ms(Some(1000));
assert!(!deadline.is_expired());
assert!(deadline.remaining_ms().is_some());
}
#[test]
fn test_deadline_never_expires() {
let deadline = Deadline::never();
assert!(!deadline.is_expired());
assert!(deadline.remaining_ms().is_none());
}
#[test]
fn test_deadline_is_expired() {
let deadline = Deadline::from_timeout_ms(Some(0));
std::thread::sleep(Duration::from_millis(1));
assert!(deadline.is_expired());
}
#[test]
fn test_deadline_default() {
let deadline = Deadline::default();
assert!(!deadline.is_expired());
assert!(deadline.remaining_ms().is_none());
}
#[test]
fn test_deadline_none_timeout() {
let deadline = Deadline::from_timeout_ms(None);
assert!(!deadline.is_expired());
assert!(deadline.remaining_ms().is_none());
}
#[test]
fn test_zero_timeout_immediate_expiry() {
let deadline = Deadline::from_timeout_ms(Some(0));
std::thread::sleep(Duration::from_millis(1));
assert!(deadline.is_expired());
}
}