use crate::txn::version_store::TransactionId;
use crate::types::RowId;
use crate::{Result, StorageError};
use dashmap::DashMap;
use parking_lot::{Condvar, Mutex, RwLock};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LockMode {
Shared,
Exclusive,
}
#[derive(Debug)]
struct LockWaiter {
_txn_id: TransactionId,
_mode: LockMode,
_condvar: Arc<Condvar>,
_granted: Arc<Mutex<bool>>,
}
struct LockEntry {
holders: RwLock<Vec<(TransactionId, LockMode)>>,
_waiters: Mutex<VecDeque<LockWaiter>>,
}
impl LockEntry {
fn new() -> Self {
Self {
holders: RwLock::new(Vec::new()),
_waiters: Mutex::new(VecDeque::new()),
}
}
fn can_grant(&self, mode: LockMode, txn_id: TransactionId) -> bool {
let holders = self.holders.read();
match mode {
LockMode::Shared => {
!holders.iter().any(|(tid, m)| *m == LockMode::Exclusive && *tid != txn_id)
}
LockMode::Exclusive => {
holders.is_empty() || (holders.len() == 1 && holders[0].0 == txn_id)
}
}
}
fn grant(&self, txn_id: TransactionId, mode: LockMode) {
let mut holders = self.holders.write();
holders.retain(|(tid, _)| *tid != txn_id);
holders.push((txn_id, mode));
}
fn release(&self, txn_id: TransactionId) {
let mut holders = self.holders.write();
holders.retain(|(tid, _)| *tid != txn_id);
}
fn holds_lock(&self, txn_id: TransactionId) -> Option<LockMode> {
let holders = self.holders.read();
holders.iter()
.find(|(tid, _)| *tid == txn_id)
.map(|(_, mode)| *mode)
}
}
pub struct LockManager {
locks: DashMap<RowId, Arc<LockEntry>>,
txn_locks: Arc<Mutex<HashMap<TransactionId, HashSet<RowId>>>>,
wait_for: Arc<Mutex<HashMap<TransactionId, HashSet<TransactionId>>>>,
_deadlock_timeout: Duration,
}
impl LockManager {
pub fn new() -> Self {
Self {
locks: DashMap::new(),
txn_locks: Arc::new(Mutex::new(HashMap::new())),
wait_for: Arc::new(Mutex::new(HashMap::new())),
_deadlock_timeout: Duration::from_secs(5),
}
}
pub fn acquire_shared(&self, txn_id: TransactionId, row_id: RowId) -> Result<()> {
self.acquire_lock(txn_id, row_id, LockMode::Shared)
}
pub fn acquire_exclusive(&self, txn_id: TransactionId, row_id: RowId) -> Result<()> {
self.acquire_lock(txn_id, row_id, LockMode::Exclusive)
}
#[allow(dead_code)]
fn has_cycle(&self, start_txn: TransactionId, current_txn: TransactionId, visited: &mut HashSet<TransactionId>) -> bool {
if current_txn == start_txn && !visited.is_empty() {
return true;
}
if visited.contains(¤t_txn) {
return false;
}
visited.insert(current_txn);
let wait_for = self.wait_for.lock();
if let Some(waiting_for) = wait_for.get(¤t_txn) {
for &next_txn in waiting_for {
if self.has_cycle(start_txn, next_txn, visited) {
return true;
}
}
}
false
}
#[allow(dead_code)]
fn detect_deadlock(&self, txn_id: TransactionId) -> bool {
let mut visited = HashSet::new();
self.has_cycle(txn_id, txn_id, &mut visited)
}
#[allow(dead_code)]
fn add_wait_for(&self, waiter: TransactionId, holders: &[TransactionId]) {
let mut wait_for = self.wait_for.lock();
let entry = wait_for.entry(waiter).or_default();
for &holder in holders {
if holder != waiter {
entry.insert(holder);
}
}
}
fn remove_wait_for(&self, txn_id: TransactionId) {
let mut wait_for = self.wait_for.lock();
wait_for.remove(&txn_id);
}
fn acquire_lock(
&self,
txn_id: TransactionId,
row_id: RowId,
mode: LockMode,
) -> Result<()> {
let entry = self.locks.entry(row_id)
.or_insert_with(|| Arc::new(LockEntry::new()))
.clone();
if let Some(current_mode) = entry.holds_lock(txn_id) {
if current_mode == LockMode::Shared && mode == LockMode::Exclusive {
return self.upgrade_lock(txn_id, row_id, entry);
}
return Ok(());
}
if entry.can_grant(mode, txn_id) {
entry.grant(txn_id, mode);
let mut txn_locks = self.txn_locks.lock();
txn_locks.entry(txn_id).or_default().insert(row_id);
return Ok(());
}
Err(StorageError::Transaction(format!(
"Lock conflict: txn {} cannot acquire {:?} lock on row {}",
txn_id, mode, row_id
)))
}
fn upgrade_lock(&self, txn_id: TransactionId, row_id: RowId, entry: Arc<LockEntry>) -> Result<()> {
if entry.can_grant(LockMode::Exclusive, txn_id) {
entry.grant(txn_id, LockMode::Exclusive);
return Ok(());
}
Err(StorageError::Transaction(format!(
"Cannot upgrade lock: txn {} on row {}, other transactions hold locks",
txn_id, row_id
)))
}
pub fn release_locks(&self, txn_id: TransactionId) -> Result<()> {
self.remove_wait_for(txn_id);
let locked_rows = {
let mut txn_locks = self.txn_locks.lock();
txn_locks.remove(&txn_id).unwrap_or_default()
};
for row_id in locked_rows {
if let Some(entry) = self.locks.get(&row_id) {
entry.release(txn_id);
}
}
Ok(())
}
pub fn stats(&self) -> LockManagerStats {
let txn_locks = self.txn_locks.lock();
LockManagerStats {
total_locks: self.locks.len() as u64,
active_transactions: txn_locks.len() as u64,
total_locked_rows: txn_locks.values().map(|s| s.len() as u64).sum(),
}
}
}
impl Default for LockManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct LockManagerStats {
pub total_locks: u64,
pub active_transactions: u64,
pub total_locked_rows: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_shared_lock_compatibility() {
let lm = LockManager::new();
lm.acquire_shared(1, 100).unwrap();
lm.acquire_shared(2, 100).unwrap();
lm.acquire_shared(3, 100).unwrap();
let stats = lm.stats();
assert_eq!(stats.active_transactions, 3);
}
#[test]
fn test_exclusive_lock_blocks() {
let lm = LockManager::new();
lm.acquire_exclusive(1, 100).unwrap();
assert!(lm.acquire_shared(2, 100).is_err());
assert!(lm.acquire_exclusive(2, 100).is_err());
}
#[test]
fn test_exclusive_blocks_shared() {
let lm = LockManager::new();
lm.acquire_exclusive(1, 100).unwrap();
assert!(lm.acquire_shared(2, 100).is_err());
}
#[test]
fn test_shared_blocks_exclusive() {
let lm = LockManager::new();
lm.acquire_shared(1, 100).unwrap();
assert!(lm.acquire_exclusive(2, 100).is_err());
}
#[test]
fn test_lock_release() {
let lm = LockManager::new();
lm.acquire_exclusive(1, 100).unwrap();
lm.release_locks(1).unwrap();
lm.acquire_exclusive(2, 100).unwrap();
}
#[test]
fn test_lock_upgrade() {
let lm = LockManager::new();
lm.acquire_shared(1, 100).unwrap();
lm.release_locks(1).unwrap();
lm.acquire_shared(1, 100).unwrap();
lm.acquire_exclusive(1, 100).unwrap(); }
#[test]
fn test_multiple_row_locks() {
let lm = LockManager::new();
lm.acquire_exclusive(1, 100).unwrap();
lm.acquire_exclusive(1, 200).unwrap();
lm.acquire_exclusive(1, 300).unwrap();
let stats = lm.stats();
assert_eq!(stats.total_locked_rows, 3);
lm.release_locks(1).unwrap();
let stats = lm.stats();
assert_eq!(stats.active_transactions, 0);
}
#[test]
fn test_concurrent_shared_locks() {
let lm = Arc::new(LockManager::new());
let mut handles = vec![];
for i in 0..5 {
let lm = lm.clone();
let handle = thread::spawn(move || {
lm.acquire_shared(i, 100).unwrap();
thread::sleep(Duration::from_millis(10));
lm.release_locks(i).unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let stats = lm.stats();
assert_eq!(stats.active_transactions, 0);
}
#[test]
fn test_lock_statistics() {
let lm = LockManager::new();
lm.acquire_exclusive(1, 100).unwrap();
lm.acquire_exclusive(2, 200).unwrap();
lm.acquire_shared(3, 300).unwrap();
let stats = lm.stats();
assert_eq!(stats.active_transactions, 3);
assert_eq!(stats.total_locked_rows, 3);
assert!(stats.total_locks > 0);
}
}