mod map;
mod row_lock;
use std::cell::Cell;
use std::fmt::Debug;
use std::future::Future;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use futures::task::AtomicWaker;
use parking_lot::Mutex;
pub use map::LockMap;
pub use row_lock::{FullRowLock, RowLock};
pub struct LockGuard<LockType: RowLock, PrimaryKey: Hash + Eq + Debug + Clone> {
lock: Arc<Lock>,
lock_map: Arc<LockMap<LockType, PrimaryKey>>,
primary_key: PrimaryKey,
_not_sync: PhantomData<Cell<()>>,
}
impl<LockType, PrimaryKey> LockGuard<LockType, PrimaryKey>
where
LockType: RowLock,
PrimaryKey: Hash + Eq + Debug + Clone,
{
pub fn new(
lock: Arc<Lock>,
lock_map: Arc<LockMap<LockType, PrimaryKey>>,
primary_key: PrimaryKey,
) -> Self {
Self {
lock,
lock_map,
primary_key,
_not_sync: PhantomData,
}
}
pub fn unlock(self) {
self.lock.unlock();
self.lock_map.remove_with_lock_check(&self.primary_key);
}
}
impl<LockType, PrimaryKey> Drop for LockGuard<LockType, PrimaryKey>
where
LockType: RowLock,
PrimaryKey: Hash + Eq + Debug + Clone,
{
fn drop(&mut self) {
self.lock.unlock();
self.lock_map.remove_with_lock_check(&self.primary_key);
}
}
#[derive(Debug)]
pub struct Lock {
id: u16,
locked: Arc<AtomicBool>,
wakers: Mutex<Vec<Arc<AtomicWaker>>>,
}
impl PartialEq for Lock {
fn eq(&self, other: &Self) -> bool {
self.id.eq(&other.id)
}
}
impl Eq for Lock {}
impl Hash for Lock {
fn hash<H: Hasher>(&self, state: &mut H) {
Hash::hash(&self.id, state)
}
}
impl Drop for Lock {
fn drop(&mut self) {
self.unlock()
}
}
impl Lock {
pub fn new(id: u16) -> Self {
Self {
id,
locked: Arc::new(AtomicBool::from(true)),
wakers: Mutex::new(vec![]),
}
}
pub fn id(&self) -> u16 {
self.id
}
pub fn unlock(&self) {
self.locked.store(false, Ordering::Relaxed);
let guard = self.wakers.lock();
for w in guard.iter() {
w.wake()
}
}
pub fn lock(&self) {
self.locked.store(true, Ordering::Relaxed);
}
pub fn is_locked(&self) -> bool {
self.locked.load(Ordering::Relaxed)
}
pub fn wait(&self) -> LockWait {
let mut guard = self.wakers.lock();
let waker = Arc::new(AtomicWaker::new());
guard.push(waker.clone());
LockWait {
locked: self.locked.clone(),
waker,
}
}
}
#[derive(Debug)]
pub struct LockWait {
locked: Arc<AtomicBool>,
waker: Arc<AtomicWaker>,
}
impl Future for LockWait {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if !self.locked.load(Ordering::Acquire) {
return Poll::Ready(());
}
self.waker.register(cx.waker());
if self.locked.load(Ordering::Acquire) {
Poll::Pending
} else {
Poll::Ready(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::panic::AssertUnwindSafe;
#[test]
fn test_unlock_on_drop() {
let lock = Arc::new(Lock::new(1));
let lock_map: Arc<LockMap<FullRowLock, u64>> = Arc::new(LockMap::default());
let pk = 1u64;
assert!(lock.is_locked());
{
let _guard = LockGuard::<FullRowLock, u64>::new(lock.clone(), lock_map.clone(), pk);
assert!(lock.is_locked());
}
assert!(!lock.is_locked());
}
#[test]
fn test_explicit_unlock() {
let lock = Arc::new(Lock::new(1));
let lock_map: Arc<LockMap<FullRowLock, u64>> = Arc::new(LockMap::default());
let pk = 1u64;
assert!(lock.is_locked());
let guard = LockGuard::<FullRowLock, u64>::new(lock.clone(), lock_map.clone(), pk);
assert!(lock.is_locked());
guard.unlock();
assert!(!lock.is_locked());
}
#[test]
fn test_panic_releases_lock() {
let lock = Arc::new(Lock::new(1));
let lock_map: Arc<LockMap<FullRowLock, u64>> = Arc::new(LockMap::default());
let pk = 1u64;
assert!(lock.is_locked());
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
let _guard = LockGuard::<FullRowLock, u64>::new(lock.clone(), lock_map.clone(), pk);
panic!("test panic");
}));
assert!(result.is_err());
assert!(!lock.is_locked());
}
#[test]
fn test_multiple_guards_can_be_held() {
let lock1 = Arc::new(Lock::new(1));
let lock2 = Arc::new(Lock::new(2));
let lock3 = Arc::new(Lock::new(3));
let lock_map: Arc<LockMap<FullRowLock, u64>> = Arc::new(LockMap::default());
assert!(lock1.is_locked());
assert!(lock2.is_locked());
assert!(lock3.is_locked());
{
let _guard1 = LockGuard::<FullRowLock, u64>::new(lock1.clone(), lock_map.clone(), 1u64);
let _guard2 = LockGuard::<FullRowLock, u64>::new(lock2.clone(), lock_map.clone(), 2u64);
let _guard3 = LockGuard::<FullRowLock, u64>::new(lock3.clone(), lock_map.clone(), 3u64);
assert!(lock1.is_locked());
assert!(lock2.is_locked());
assert!(lock3.is_locked());
}
assert!(!lock1.is_locked());
assert!(!lock2.is_locked());
assert!(!lock3.is_locked());
}
#[test]
fn test_guard_is_send() {
fn assert_send<T: Send>() {}
assert_send::<LockGuard<FullRowLock, u64>>();
}
#[tokio::test]
async fn test_lock_cleanup_on_guard_drop() {
use crate::lock::FullRowLock;
use crate::lock::RowLock;
let lock_map: Arc<LockMap<FullRowLock, u64>> = Arc::new(LockMap::default());
let pk = 42u64;
let (lock_type, lock) = FullRowLock::with_lock(lock_map.next_id());
let rw_lock = Arc::new(tokio::sync::RwLock::new(lock_type));
lock_map.insert(pk, rw_lock);
assert!(lock_map.get(&pk).is_some());
{
let _guard = LockGuard::new(lock, lock_map.clone(), pk);
}
assert!(lock_map.get(&pk).is_none());
}
}