use dashmap::DashMap;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
pub type TableId = u64;
pub type RowId = u128;
pub type TxnId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IntentLock {
IntentShared,
IntentExclusive,
Shared,
Exclusive,
}
impl IntentLock {
pub fn is_compatible(&self, other: &IntentLock) -> bool {
use IntentLock::*;
matches!(
(self, other),
(IntentShared, IntentShared)
| (IntentShared, IntentExclusive)
| (IntentShared, Shared)
| (IntentExclusive, IntentShared)
| (IntentExclusive, IntentExclusive)
| (Shared, IntentShared)
| (Shared, Shared)
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LockMode {
Shared,
Exclusive,
}
impl LockMode {
pub fn is_compatible(&self, other: &LockMode) -> bool {
matches!((self, other), (LockMode::Shared, LockMode::Shared))
}
}
#[derive(Debug)]
struct TableLockEntry {
mode: IntentLock,
holders: Vec<TxnId>,
}
impl TableLockEntry {
fn new(mode: IntentLock, txn_id: TxnId) -> Self {
Self {
mode,
holders: vec![txn_id],
}
}
}
#[derive(Debug)]
struct RowLockEntry {
mode: LockMode,
holders: Vec<TxnId>,
}
impl RowLockEntry {
fn new(mode: LockMode, txn_id: TxnId) -> Self {
Self {
mode,
holders: vec![txn_id],
}
}
}
pub struct ShardedLockTable {
shards: [Mutex<HashMap<RowId, RowLockEntry>>; 256],
stats: LockTableStats,
}
impl Default for ShardedLockTable {
fn default() -> Self {
Self::new()
}
}
impl ShardedLockTable {
pub fn new() -> Self {
Self {
shards: std::array::from_fn(|_| Mutex::new(HashMap::new())),
stats: LockTableStats::default(),
}
}
#[inline]
fn shard_index(&self, row_id: RowId) -> usize {
((row_id >> 64) as usize ^ (row_id as usize)) % 256
}
pub fn try_lock(&self, row_id: RowId, mode: LockMode, txn_id: TxnId) -> LockResult {
let shard_idx = self.shard_index(row_id);
let mut shard = self.shards[shard_idx].lock();
if let Some(entry) = shard.get_mut(&row_id) {
if entry.holders.contains(&txn_id) {
if entry.mode == LockMode::Shared && mode == LockMode::Exclusive {
if entry.holders.len() == 1 {
entry.mode = LockMode::Exclusive;
self.stats.upgrades.fetch_add(1, Ordering::Relaxed);
return LockResult::Acquired;
} else {
self.stats.conflicts.fetch_add(1, Ordering::Relaxed);
return LockResult::WouldBlock;
}
}
return LockResult::AlreadyHeld;
}
if entry.mode.is_compatible(&mode) {
entry.holders.push(txn_id);
self.stats.shared_acquired.fetch_add(1, Ordering::Relaxed);
return LockResult::Acquired;
}
self.stats.conflicts.fetch_add(1, Ordering::Relaxed);
return LockResult::WouldBlock;
}
shard.insert(row_id, RowLockEntry::new(mode, txn_id));
match mode {
LockMode::Shared => self.stats.shared_acquired.fetch_add(1, Ordering::Relaxed),
LockMode::Exclusive => self
.stats
.exclusive_acquired
.fetch_add(1, Ordering::Relaxed),
};
LockResult::Acquired
}
pub fn unlock(&self, row_id: RowId, txn_id: TxnId) -> bool {
let shard_idx = self.shard_index(row_id);
let mut shard = self.shards[shard_idx].lock();
if let Some(entry) = shard.get_mut(&row_id)
&& let Some(pos) = entry.holders.iter().position(|&id| id == txn_id)
{
entry.holders.remove(pos);
self.stats.released.fetch_add(1, Ordering::Relaxed);
if entry.holders.is_empty() {
shard.remove(&row_id);
}
return true;
}
false
}
pub fn unlock_all(&self, txn_id: TxnId) -> usize {
let mut count = 0;
for shard in &self.shards {
let mut shard_guard = shard.lock();
let to_remove: Vec<RowId> = shard_guard
.iter()
.filter(|(_, entry)| entry.holders.contains(&txn_id))
.map(|(&row_id, _)| row_id)
.collect();
for row_id in to_remove {
if let Some(entry) = shard_guard.get_mut(&row_id)
&& let Some(pos) = entry.holders.iter().position(|&id| id == txn_id)
{
entry.holders.remove(pos);
count += 1;
if entry.holders.is_empty() {
shard_guard.remove(&row_id);
}
}
}
}
self.stats
.released
.fetch_add(count as u64, Ordering::Relaxed);
count
}
pub fn try_lock_tracked(
&self,
row_id: RowId,
mode: LockMode,
txn_id: TxnId,
lock_set: &mut TransactionLockSet,
) -> LockResult {
let result = self.try_lock(row_id, mode, txn_id);
if matches!(result, LockResult::Acquired) {
let shard_idx = self.shard_index(row_id);
lock_set.record(shard_idx, row_id);
}
result
}
pub fn unlock_all_tracked(&self, txn_id: TxnId, lock_set: &TransactionLockSet) -> usize {
let mut count = 0;
for &(shard_idx, row_id) in &lock_set.locks {
let mut shard = self.shards[shard_idx].lock();
if let Some(entry) = shard.get_mut(&row_id)
&& let Some(pos) = entry.holders.iter().position(|&id| id == txn_id)
{
entry.holders.remove(pos);
count += 1;
if entry.holders.is_empty() {
shard.remove(&row_id);
}
}
}
self.stats
.released
.fetch_add(count as u64, Ordering::Relaxed);
count
}
pub fn stats(&self) -> &LockTableStats {
&self.stats
}
}
#[derive(Debug, Default)]
pub struct TransactionLockSet {
locks: Vec<(usize, RowId)>,
}
impl TransactionLockSet {
pub fn new() -> Self {
Self {
locks: Vec::with_capacity(8),
}
}
fn record(&mut self, shard_idx: usize, row_id: RowId) {
self.locks.push((shard_idx, row_id));
}
pub fn len(&self) -> usize {
self.locks.len()
}
pub fn is_empty(&self) -> bool {
self.locks.is_empty()
}
pub fn clear(&mut self) {
self.locks.clear();
}
}
#[derive(Debug, Default)]
pub struct LockTableStats {
pub shared_acquired: AtomicU64,
pub exclusive_acquired: AtomicU64,
pub upgrades: AtomicU64,
pub conflicts: AtomicU64,
pub released: AtomicU64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LockResult {
Acquired,
AlreadyHeld,
WouldBlock,
Conflict,
}
pub struct LockManager {
table_locks: DashMap<TableId, TableLockEntry>,
row_locks: DashMap<TableId, Arc<ShardedLockTable>>,
epoch: AtomicU64,
stats: LockManagerStats,
}
impl Default for LockManager {
fn default() -> Self {
Self::new()
}
}
impl LockManager {
pub fn new() -> Self {
Self {
table_locks: DashMap::new(),
row_locks: DashMap::new(),
epoch: AtomicU64::new(0),
stats: LockManagerStats::default(),
}
}
pub fn lock_table(&self, table_id: TableId, mode: IntentLock, txn_id: TxnId) -> LockResult {
use dashmap::mapref::entry::Entry;
match self.table_locks.entry(table_id) {
Entry::Vacant(vacant) => {
vacant.insert(TableLockEntry::new(mode, txn_id));
self.stats
.table_locks_acquired
.fetch_add(1, Ordering::Relaxed);
LockResult::Acquired
}
Entry::Occupied(mut occupied) => {
let entry = occupied.get_mut();
if entry.holders.contains(&txn_id) {
return LockResult::AlreadyHeld;
}
if entry.mode.is_compatible(&mode) {
entry.holders.push(txn_id);
self.stats
.table_locks_acquired
.fetch_add(1, Ordering::Relaxed);
return LockResult::Acquired;
}
self.stats.table_conflicts.fetch_add(1, Ordering::Relaxed);
LockResult::WouldBlock
}
}
}
pub fn unlock_table(&self, table_id: TableId, txn_id: TxnId) -> bool {
if let Some(mut entry) = self.table_locks.get_mut(&table_id)
&& let Some(pos) = entry.holders.iter().position(|&id| id == txn_id)
{
entry.holders.remove(pos);
self.stats
.table_locks_released
.fetch_add(1, Ordering::Relaxed);
if entry.holders.is_empty() {
drop(entry);
self.table_locks.remove(&table_id);
}
return true;
}
false
}
fn get_row_lock_table(&self, table_id: TableId) -> Arc<ShardedLockTable> {
self.row_locks
.entry(table_id)
.or_insert_with(|| Arc::new(ShardedLockTable::new()))
.clone()
}
pub fn lock_row(
&self,
table_id: TableId,
row_id: RowId,
mode: LockMode,
txn_id: TxnId,
) -> LockResult {
let intent_mode = match mode {
LockMode::Shared => IntentLock::IntentShared,
LockMode::Exclusive => IntentLock::IntentExclusive,
};
match self.lock_table(table_id, intent_mode, txn_id) {
LockResult::Acquired | LockResult::AlreadyHeld => {}
result => return result,
}
let row_locks = self.get_row_lock_table(table_id);
row_locks.try_lock(row_id, mode, txn_id)
}
pub fn unlock_row(&self, table_id: TableId, row_id: RowId, txn_id: TxnId) -> bool {
if let Some(row_locks) = self.row_locks.get(&table_id) {
return row_locks.unlock(row_id, txn_id);
}
false
}
pub fn release_all(&self, txn_id: TxnId) -> usize {
let mut count = 0;
for entry in self.row_locks.iter() {
count += entry.value().unlock_all(txn_id);
}
let table_ids: Vec<TableId> = self
.table_locks
.iter()
.filter(|e| e.value().holders.contains(&txn_id))
.map(|e| *e.key())
.collect();
for table_id in table_ids {
if self.unlock_table(table_id, txn_id) {
count += 1;
}
}
count
}
pub fn enter_epoch(&self) -> u64 {
self.epoch.fetch_add(1, Ordering::AcqRel)
}
pub fn current_epoch(&self) -> u64 {
self.epoch.load(Ordering::Acquire)
}
pub fn stats(&self) -> &LockManagerStats {
&self.stats
}
}
#[derive(Debug, Default)]
pub struct LockManagerStats {
pub table_locks_acquired: AtomicU64,
pub table_locks_released: AtomicU64,
pub table_conflicts: AtomicU64,
}
pub struct OptimisticVersion {
version: AtomicU64,
}
impl Default for OptimisticVersion {
fn default() -> Self {
Self::new()
}
}
impl OptimisticVersion {
pub fn new() -> Self {
Self {
version: AtomicU64::new(0),
}
}
#[inline]
pub fn read_version(&self) -> u64 {
self.version.load(Ordering::Acquire)
}
#[inline]
pub fn is_stable(&self, version: u64) -> bool {
version & 1 == 0
}
#[inline]
pub fn validate(&self, read_version: u64) -> bool {
std::sync::atomic::fence(Ordering::Acquire);
self.version.load(Ordering::Relaxed) == read_version
}
pub fn try_write_begin(&self) -> Option<WriteGuard<'_>> {
let current = self.version.load(Ordering::Acquire);
if !self.is_stable(current) {
return None;
}
match self.version.compare_exchange(
current,
current + 1,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => Some(WriteGuard {
version: &self.version,
start_version: current,
}),
Err(_) => None,
}
}
}
pub struct WriteGuard<'a> {
version: &'a AtomicU64,
start_version: u64,
}
impl<'a> WriteGuard<'a> {
pub fn commit(self) {
self.version
.store(self.start_version + 2, Ordering::Release);
std::mem::forget(self); }
pub fn abort(self) {
self.version.store(self.start_version, Ordering::Release);
std::mem::forget(self);
}
}
impl<'a> Drop for WriteGuard<'a> {
fn drop(&mut self) {
self.version.store(self.start_version, Ordering::Release);
}
}
pub struct EpochGuard {
manager: Arc<EpochManager>,
epoch: u64,
}
impl Drop for EpochGuard {
fn drop(&mut self) {
self.manager.leave_epoch(self.epoch);
}
}
pub struct EpochManager {
global_epoch: AtomicU64,
epoch_counts: [AtomicUsize; 4],
retired: Mutex<Vec<(u64, Box<dyn Send>)>>,
}
impl Default for EpochManager {
fn default() -> Self {
Self::new()
}
}
impl EpochManager {
pub fn new() -> Self {
Self {
global_epoch: AtomicU64::new(0),
epoch_counts: std::array::from_fn(|_| AtomicUsize::new(0)),
retired: Mutex::new(Vec::new()),
}
}
pub fn pin(self: &Arc<Self>) -> EpochGuard {
let epoch = self.global_epoch.load(Ordering::Acquire);
self.epoch_counts[(epoch % 4) as usize].fetch_add(1, Ordering::AcqRel);
EpochGuard {
manager: self.clone(),
epoch,
}
}
fn leave_epoch(&self, epoch: u64) {
self.epoch_counts[(epoch % 4) as usize].fetch_sub(1, Ordering::AcqRel);
}
pub fn advance(&self) {
let current = self.global_epoch.load(Ordering::Acquire);
let old_epoch = (current + 2) % 4;
if self.epoch_counts[old_epoch as usize].load(Ordering::Acquire) == 0 {
self.global_epoch.fetch_add(1, Ordering::AcqRel);
self.reclaim(current.saturating_sub(2));
}
}
pub fn retire<T: Send + 'static>(&self, item: T) {
let epoch = self.global_epoch.load(Ordering::Acquire);
let mut retired = self.retired.lock();
retired.push((epoch, Box::new(item)));
}
fn reclaim(&self, safe_epoch: u64) {
let mut retired = self.retired.lock();
retired.retain(|(epoch, _)| *epoch > safe_epoch);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_intent_lock_compatibility() {
use IntentLock::*;
assert!(IntentShared.is_compatible(&IntentShared));
assert!(IntentShared.is_compatible(&IntentExclusive));
assert!(IntentShared.is_compatible(&Shared));
assert!(!IntentShared.is_compatible(&Exclusive));
assert!(IntentExclusive.is_compatible(&IntentShared));
assert!(IntentExclusive.is_compatible(&IntentExclusive));
assert!(!IntentExclusive.is_compatible(&Shared));
assert!(!IntentExclusive.is_compatible(&Exclusive));
assert!(Shared.is_compatible(&IntentShared));
assert!(!Shared.is_compatible(&IntentExclusive));
assert!(Shared.is_compatible(&Shared));
assert!(!Shared.is_compatible(&Exclusive));
assert!(!Exclusive.is_compatible(&IntentShared));
assert!(!Exclusive.is_compatible(&IntentExclusive));
assert!(!Exclusive.is_compatible(&Shared));
assert!(!Exclusive.is_compatible(&Exclusive));
}
#[test]
fn test_sharded_lock_table_basic() {
let table = ShardedLockTable::new();
assert_eq!(
table.try_lock(1, LockMode::Exclusive, 100),
LockResult::Acquired
);
assert_eq!(
table.try_lock(1, LockMode::Exclusive, 200),
LockResult::WouldBlock
);
assert_eq!(
table.try_lock(1, LockMode::Shared, 200),
LockResult::WouldBlock
);
assert_eq!(
table.try_lock(2, LockMode::Exclusive, 200),
LockResult::Acquired
);
assert!(table.unlock(1, 100));
assert_eq!(
table.try_lock(1, LockMode::Exclusive, 200),
LockResult::Acquired
);
}
#[test]
fn test_sharded_lock_table_shared() {
let table = ShardedLockTable::new();
assert_eq!(
table.try_lock(1, LockMode::Shared, 100),
LockResult::Acquired
);
assert_eq!(
table.try_lock(1, LockMode::Shared, 200),
LockResult::Acquired
);
assert_eq!(
table.try_lock(1, LockMode::Shared, 300),
LockResult::Acquired
);
assert_eq!(
table.try_lock(1, LockMode::Exclusive, 400),
LockResult::WouldBlock
);
assert!(table.unlock(1, 100));
assert!(table.unlock(1, 200));
assert!(table.unlock(1, 300));
assert_eq!(
table.try_lock(1, LockMode::Exclusive, 400),
LockResult::Acquired
);
}
#[test]
fn test_sharded_lock_upgrade() {
let table = ShardedLockTable::new();
assert_eq!(
table.try_lock(1, LockMode::Shared, 100),
LockResult::Acquired
);
assert_eq!(
table.try_lock(1, LockMode::Exclusive, 100),
LockResult::Acquired
);
assert_eq!(
table.try_lock(1, LockMode::Shared, 200),
LockResult::WouldBlock
);
}
#[test]
fn test_lock_manager_hierarchical() {
let manager = LockManager::new();
assert_eq!(
manager.lock_row(1, 100, LockMode::Exclusive, 1000),
LockResult::Acquired
);
assert_eq!(
manager.lock_row(1, 200, LockMode::Shared, 2000),
LockResult::Acquired
);
assert_eq!(
manager.lock_row(1, 100, LockMode::Exclusive, 2000),
LockResult::WouldBlock
);
let released = manager.release_all(1000);
assert!(released >= 1);
}
#[test]
fn test_optimistic_version() {
let version = OptimisticVersion::new();
let v = version.read_version();
assert!(version.is_stable(v));
assert!(version.validate(v));
{
let guard = version.try_write_begin().unwrap();
let v_during = version.read_version();
assert!(!version.is_stable(v_during)); guard.commit();
}
let v2 = version.read_version();
assert!(version.is_stable(v2));
assert_eq!(v2, 2);
}
#[test]
fn test_optimistic_concurrent() {
let version = Arc::new(OptimisticVersion::new());
let guard = version.try_write_begin().unwrap();
let version2 = version.clone();
let result = version2.try_write_begin();
assert!(result.is_none());
guard.commit();
let guard2 = version.try_write_begin().unwrap();
guard2.commit();
}
#[test]
fn test_epoch_manager() {
let manager = Arc::new(EpochManager::new());
let guard1 = manager.pin();
assert_eq!(guard1.epoch, 0);
manager.retire(vec![1, 2, 3]);
manager.advance();
let guard2 = manager.pin();
assert!(guard2.epoch >= guard1.epoch);
drop(guard1);
drop(guard2);
manager.advance();
manager.advance();
}
#[test]
fn test_sharded_distribution() {
let table = ShardedLockTable::new();
for i in 0..1000u128 {
assert_eq!(table.try_lock(i, LockMode::Shared, 1), LockResult::Acquired);
}
let mut non_empty_shards = 0;
for shard in &table.shards {
if !shard.lock().is_empty() {
non_empty_shards += 1;
}
}
assert!(
non_empty_shards > 100,
"Expected better distribution: {} shards used",
non_empty_shards
);
}
#[test]
fn test_unlock_all() {
let table = ShardedLockTable::new();
for i in 0..50u128 {
table.try_lock(i, LockMode::Exclusive, 100);
}
for i in 50..100u128 {
table.try_lock(i, LockMode::Exclusive, 200);
}
let released = table.unlock_all(100);
assert_eq!(released, 50);
assert_eq!(
table.try_lock(50, LockMode::Exclusive, 300),
LockResult::WouldBlock
);
assert_eq!(
table.try_lock(0, LockMode::Exclusive, 300),
LockResult::Acquired
);
}
#[test]
fn test_concurrent_locks() {
let table = Arc::new(ShardedLockTable::new());
let mut handles = vec![];
for txn_id in 0..16u64 {
let table = table.clone();
handles.push(thread::spawn(move || {
let start = txn_id as u128 * 100;
for i in 0..100 {
let result = table.try_lock(start + i, LockMode::Exclusive, txn_id);
assert_eq!(result, LockResult::Acquired);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(
table.stats().exclusive_acquired.load(Ordering::Relaxed),
1600
);
}
}