use std::{
any::type_name,
cell::UnsafeCell,
collections::{HashMap, HashSet, VecDeque},
fmt::{self, Debug, Display},
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex, MutexGuard,
},
task::{Context, Poll, Waker},
};
use futures::executor::block_on;
use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
use crate::engine::{
structures::Table,
types::{AsUuid, Name, PoolIdentifier},
};
pub struct SharedGuard<G>(G);
impl<T, G> Deref for SharedGuard<G>
where
G: Deref<Target = T>,
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<G> Drop for SharedGuard<G> {
fn drop(&mut self) {
trace!("Dropping shared lock on {}", type_name::<G>());
}
}
pub struct ExclusiveGuard<G>(G);
impl<T, G> Deref for ExclusiveGuard<G>
where
G: Deref<Target = T>,
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<G> DerefMut for ExclusiveGuard<G>
where
G: DerefMut,
{
fn deref_mut(&mut self) -> &mut Self::Target {
#[allow(clippy::explicit_auto_deref)]
&mut *self.0
}
}
impl<G> Drop for ExclusiveGuard<G> {
fn drop(&mut self) {
trace!("Dropping exclusive lock on {}", type_name::<G>());
}
}
#[derive(Debug)]
pub struct Lockable<T>(T);
impl<T> Lockable<Arc<RwLock<T>>> {
pub fn new_shared(t: T) -> Self {
Lockable(Arc::new(RwLock::new(t)))
}
}
impl<T> Lockable<Arc<RwLock<T>>>
where
T: ?Sized,
{
pub async fn read(&self) -> SharedGuard<OwnedRwLockReadGuard<T>> {
trace!("Acquiring shared lock on {}", type_name::<Self>());
let lock = SharedGuard(Arc::clone(&self.0).read_owned().await);
trace!("Acquired shared lock on {}", type_name::<Self>());
lock
}
pub fn blocking_read(&self) -> SharedGuard<OwnedRwLockReadGuard<T>> {
block_on(self.read())
}
pub async fn write(&self) -> ExclusiveGuard<OwnedRwLockWriteGuard<T>> {
trace!("Acquiring exclusive lock on {}", type_name::<Self>());
let lock = ExclusiveGuard(Arc::clone(&self.0).write_owned().await);
trace!("Acquired exclusive lock on {}", type_name::<Self>());
lock
}
pub fn blocking_write(&self) -> ExclusiveGuard<OwnedRwLockWriteGuard<T>> {
block_on(self.write())
}
}
impl<T> Clone for Lockable<Arc<T>>
where
T: ?Sized,
{
fn clone(&self) -> Self {
Lockable(Arc::clone(&self.0))
}
}
#[derive(Debug)]
struct LockRecord<U, T> {
all_read_locked: u64,
all_write_locked: bool,
read_locked: HashMap<U, u64>,
write_locked: HashSet<U>,
inner: UnsafeCell<Table<U, T>>,
waiting: VecDeque<Waiter<U>>,
woken: HashMap<u64, WaitType<U>>,
next_idx: u64,
}
impl<U, T> LockRecord<U, T>
where
U: AsUuid,
{
fn woken_or_new(&mut self, wait_type: Option<&WaitType<U>>, idx: u64) {
if self.woken.contains_key(&idx) {
let woken = self.woken.remove(&idx);
if let Some(w) = wait_type {
assert_eq!(woken.as_ref(), Some(w));
}
}
}
fn assert(&mut self, wait_type: &WaitType<U>, idx: u64) {
self.woken_or_new(Some(wait_type), idx);
assert!(!self.conflicts_with_woken(wait_type));
}
fn get_by_lock_key(&self, lock_key: &PoolIdentifier<U>) -> Option<(U, Name)> {
match lock_key {
PoolIdentifier::Name(ref n) => unsafe { self.inner.get().as_ref() }
.and_then(|i| i.get_by_name(n).map(|(u, _)| (u, n.clone()))),
PoolIdentifier::Uuid(u) => unsafe { self.inner.get().as_ref() }
.and_then(|i| i.get_by_uuid(*u).map(|(n, _)| (*u, n))),
}
}
fn add_read_lock(&mut self, uuid: U, idx: Option<u64>) {
match self.read_locked.get_mut(&uuid) {
Some(counter) => {
*counter += 1;
}
None => {
self.read_locked.insert(uuid, 1);
}
}
if let Some(i) = idx {
self.assert(&WaitType::SomeRead(uuid), i);
}
trace!("Lock record after acquisition: {}", self);
}
fn remove_read_lock(&mut self, uuid: U) {
match self.read_locked.remove(&uuid) {
Some(counter) => {
if counter > 1 {
self.read_locked.insert(uuid, counter - 1);
}
}
None => panic!("Must have acquired lock and incremented lock count"),
}
trace!("Lock record after removal: {}", self);
}
fn add_write_lock(&mut self, uuid: U, idx: Option<u64>) {
self.write_locked.insert(uuid);
if let Some(i) = idx {
self.assert(&WaitType::SomeWrite(uuid), i);
}
trace!("Lock record after acquisition: {}", self);
}
fn remove_write_lock(&mut self, uuid: &U) {
assert!(self.write_locked.remove(uuid));
trace!("Lock record after removal: {}", self);
}
fn add_read_all_lock(&mut self, idx: u64) {
self.all_read_locked += 1;
self.assert(&WaitType::AllRead, idx);
trace!("Lock record after acquisition: {}", self);
}
fn remove_read_all_lock(&mut self) {
self.all_read_locked = self
.all_read_locked
.checked_sub(1)
.expect("Cannot drop below 0");
trace!("Lock record after removal: {}", self);
}
fn add_write_all_lock(&mut self, idx: u64) {
self.all_write_locked = true;
self.assert(&WaitType::AllWrite, idx);
trace!("Lock record after acquisition: {}", self);
}
fn remove_write_all_lock(&mut self) {
assert!(self.all_write_locked);
self.all_write_locked = false;
trace!("Lock record after removal: {}", self);
}
fn add_waiter(
&mut self,
has_waited: &AtomicBool,
wait_type: WaitType<U>,
waker: Waker,
idx: u64,
) {
if self.waiting.iter().any(|w| w.idx == idx) {
return;
}
self.woken_or_new(Some(&wait_type), idx);
if has_waited.load(Ordering::SeqCst) {
self.waiting.push_front(Waiter {
ty: wait_type,
waker,
idx,
});
} else {
self.waiting.push_back(Waiter {
ty: wait_type,
waker,
idx,
});
has_waited.store(true, Ordering::SeqCst);
}
trace!("Lock record after sleep: {}", self);
}
fn should_wait(&self, ty: &WaitType<U>, idx: u64) -> bool {
if self.woken.contains_key(&idx) {
trace!(
"Task with index {}, wait type {:?} was woken and can acquire lock",
idx,
ty
);
false
} else {
let should_wait = !self.waiting.is_empty()
|| self.already_acquired(ty)
|| self.conflicts_with_woken(ty);
if should_wait {
trace!(
"Putting task with index {}, wait type {:?} to sleep",
idx,
ty
);
} else {
trace!(
"Task with index {}, wait type {:?} can acquire lock",
idx,
ty
);
}
should_wait
}
}
fn conflicts(already_woken: &WaitType<U>, ty: &WaitType<U>) -> bool {
match (already_woken, ty) {
(WaitType::SomeRead(_), WaitType::SomeRead(_) | WaitType::AllRead) => false,
(WaitType::SomeRead(uuid1), WaitType::SomeWrite(uuid2)) => uuid1 == uuid2,
(WaitType::SomeRead(_), _) => true,
(
WaitType::SomeWrite(uuid1),
WaitType::SomeRead(uuid2) | WaitType::SomeWrite(uuid2),
) => uuid1 == uuid2,
(WaitType::SomeWrite(_), _) => true,
(WaitType::AllRead, WaitType::SomeWrite(_) | WaitType::AllWrite) => true,
(WaitType::AllRead, _) => false,
(WaitType::AllWrite, _) => true,
}
}
fn conflicts_with_woken(&self, ty: &WaitType<U>) -> bool {
if self.woken.is_empty() {
false
} else {
self.woken.values().any(|woken| Self::conflicts(woken, ty))
}
}
fn already_acquired(&self, ty: &WaitType<U>) -> bool {
match ty {
WaitType::SomeRead(uuid) => self.write_locked.contains(uuid) || self.all_write_locked,
WaitType::SomeWrite(uuid) => {
self.read_locked.contains_key(uuid)
|| self.write_locked.contains(uuid)
|| self.all_read_locked > 0
|| self.all_write_locked
}
WaitType::AllRead => !self.write_locked.is_empty() || self.all_write_locked,
WaitType::AllWrite => {
!self.read_locked.is_empty()
|| !self.write_locked.is_empty()
|| self.all_read_locked > 0
|| self.all_write_locked
}
}
}
fn should_wake(&self) -> bool {
if let Some(w) = self.waiting.get(0) {
!self.conflicts_with_woken(&w.ty) && !self.already_acquired(&w.ty)
} else {
false
}
}
fn wake(&mut self) {
while self.should_wake() {
if let Some(w) = self.waiting.pop_front() {
self.woken.insert(w.idx, w.ty);
w.waker.wake();
}
}
}
fn cancel(&mut self, idx: u64) {
self.waiting = self
.waiting
.drain(..)
.filter(|waiter| waiter.idx != idx)
.collect::<VecDeque<_>>();
self.woken = self
.woken
.drain()
.filter(|(i, _)| i != &idx)
.collect::<HashMap<_, _>>();
}
}
impl<U, T> Display for LockRecord<U, T>
where
U: AsUuid,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LockRecord")
.field("all_read_locked", &self.all_read_locked)
.field("all_write_locked", &self.all_write_locked)
.field("read_locked", &self.read_locked)
.field("write_locked", &self.write_locked)
.field("waiting", &self.waiting)
.field("woken", &self.woken)
.field("next_idx", &self.next_idx)
.finish()
}
}
#[derive(Debug, PartialEq)]
enum WaitType<U> {
SomeRead(U),
SomeWrite(U),
AllRead,
AllWrite,
}
struct Waiter<U> {
ty: WaitType<U>,
waker: Waker,
idx: u64,
}
impl<U> Debug for Waiter<U>
where
U: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Waiter")
.field("ty", &self.ty)
.field("idx", &self.idx)
.finish()
}
}
#[derive(Debug)]
pub struct AllOrSomeLock<U, T> {
lock_record: Arc<Mutex<LockRecord<U, T>>>,
}
impl<U, T> AllOrSomeLock<U, T>
where
U: AsUuid,
{
pub fn new(inner: Table<U, T>) -> Self {
AllOrSomeLock {
lock_record: Arc::new(Mutex::new(LockRecord {
all_read_locked: 0,
all_write_locked: false,
read_locked: HashMap::new(),
write_locked: HashSet::new(),
inner: UnsafeCell::new(inner),
waiting: VecDeque::new(),
woken: HashMap::new(),
next_idx: 0,
})),
}
}
fn acquire_mutex(&self) -> MutexGuard<'_, LockRecord<U, T>> {
self.lock_record
.lock()
.expect("lock record mutex only locked internally")
}
fn next_idx(&self) -> u64 {
let mut lock_record = self.acquire_mutex();
let idx = lock_record.next_idx;
lock_record.next_idx = lock_record.next_idx.wrapping_add(1);
idx
}
}
impl<U, T> Clone for AllOrSomeLock<U, T> {
fn clone(&self) -> Self {
AllOrSomeLock {
lock_record: Arc::clone(&self.lock_record),
}
}
}
impl<U, T> AllOrSomeLock<U, T>
where
U: AsUuid + Unpin,
T: Unpin,
{
pub async fn read(&self, key: PoolIdentifier<U>) -> Option<SomeLockReadGuard<U, T>> {
trace!("Acquiring read lock on pool {:?}", key);
let idx = self.next_idx();
let guard = SomeRead(self.clone(), key, AtomicBool::new(false), idx).await;
if guard.is_some() {
trace!("Read lock acquired");
} else {
trace!("Pool not found");
}
guard
}
pub async fn read_all(&self) -> AllLockReadGuard<U, T> {
trace!("Acquiring read lock on all pools");
let idx = self.next_idx();
let guard = AllRead(self.clone(), AtomicBool::new(false), idx).await;
trace!("All read lock acquired");
guard
}
pub async fn write(&self, key: PoolIdentifier<U>) -> Option<SomeLockWriteGuard<U, T>> {
trace!("Acquiring write lock on pool {:?}", key);
let idx = self.next_idx();
let guard = SomeWrite(self.clone(), key, AtomicBool::new(false), idx).await;
if guard.is_some() {
trace!("Write lock acquired");
} else {
trace!("Pool not found");
}
guard
}
pub async fn write_all(&self) -> AllLockWriteGuard<U, T> {
trace!("Acquiring write lock on all pools");
let idx = self.next_idx();
let guard = AllWrite(self.clone(), AtomicBool::new(false), idx).await;
trace!("All write lock acquired");
guard
}
}
impl<U, T> Default for AllOrSomeLock<U, T>
where
U: AsUuid,
{
fn default() -> Self {
AllOrSomeLock::new(Table::default())
}
}
struct SomeRead<U: AsUuid, T>(AllOrSomeLock<U, T>, PoolIdentifier<U>, AtomicBool, u64);
impl<U, T> Future for SomeRead<U, T>
where
U: AsUuid + Unpin,
T: Unpin,
{
type Output = Option<SomeLockReadGuard<U, T>>;
fn poll(self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock_record = self.0.acquire_mutex();
let (uuid, name) = if let Some((uuid, name)) = lock_record.get_by_lock_key(&self.1) {
(uuid, name)
} else {
lock_record.woken_or_new(None, self.3);
lock_record.wake();
return Poll::Ready(None);
};
let wait_type = WaitType::SomeRead(uuid);
let poll = if lock_record.should_wait(&wait_type, self.3) {
lock_record.add_waiter(&self.2, wait_type, cxt.waker().clone(), self.3);
Poll::Pending
} else {
lock_record.add_read_lock(uuid, Some(self.3));
let (_, rf) = unsafe { lock_record.inner.get().as_ref() }
.expect("cannot be null")
.get_by_uuid(uuid)
.expect("Checked above");
Poll::Ready(Some(SomeLockReadGuard(
self.0.clone(),
uuid,
name,
rf as *const _,
)))
};
poll
}
}
impl<U, T> Drop for SomeRead<U, T>
where
U: AsUuid,
{
fn drop(&mut self) {
let mut lock_record = self.0.acquire_mutex();
lock_record.cancel(self.3);
}
}
pub struct SomeLockReadGuard<U: AsUuid, T>(AllOrSomeLock<U, T>, U, Name, *const T);
impl<U, T> SomeLockReadGuard<U, T>
where
U: AsUuid,
{
pub fn as_tuple(&self) -> (Name, U, &T) {
(
self.2.clone(),
self.1,
unsafe { self.3.as_ref() }.expect("Cannot create null pointer from Rust references"),
)
}
}
unsafe impl<U, T> Send for SomeLockReadGuard<U, T>
where
U: AsUuid + Send,
T: Send,
{
}
unsafe impl<U, T> Sync for SomeLockReadGuard<U, T>
where
U: AsUuid + Sync,
T: Sync,
{
}
impl<U, T> Deref for SomeLockReadGuard<U, T>
where
U: AsUuid,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.3.as_ref() }.expect("Cannot create null pointer through references in Rust")
}
}
impl<U, T> Drop for SomeLockReadGuard<U, T>
where
U: AsUuid,
{
fn drop(&mut self) {
trace!("Dropping read lock on pool with UUID {}", self.1);
let mut lock_record = self.0.acquire_mutex();
lock_record.remove_read_lock(self.1);
lock_record.wake();
trace!("Read lock on pool with UUID {} dropped", self.1);
}
}
struct SomeWrite<U: AsUuid, T>(AllOrSomeLock<U, T>, PoolIdentifier<U>, AtomicBool, u64);
impl<U, T> Future for SomeWrite<U, T>
where
U: AsUuid + Unpin,
T: Unpin,
{
type Output = Option<SomeLockWriteGuard<U, T>>;
fn poll(self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock_record = self.0.acquire_mutex();
let (uuid, name) = if let Some((uuid, name)) = lock_record.get_by_lock_key(&self.1) {
(uuid, name)
} else {
lock_record.woken_or_new(None, self.3);
lock_record.wake();
return Poll::Ready(None);
};
let wait_type = WaitType::SomeWrite(uuid);
let poll = if lock_record.should_wait(&wait_type, self.3) {
lock_record.add_waiter(&self.2, wait_type, cxt.waker().clone(), self.3);
Poll::Pending
} else {
lock_record.add_write_lock(uuid, Some(self.3));
let (_, rf) = unsafe { lock_record.inner.get().as_mut() }
.expect("cannot be null")
.get_mut_by_uuid(uuid)
.expect("Checked above");
Poll::Ready(Some(SomeLockWriteGuard(
self.0.clone(),
uuid,
name,
rf as *mut _,
)))
};
poll
}
}
impl<U, T> Drop for SomeWrite<U, T>
where
U: AsUuid,
{
fn drop(&mut self) {
let mut lock_record = self.0.acquire_mutex();
lock_record.cancel(self.3);
}
}
pub struct SomeLockWriteGuard<U: AsUuid, T>(AllOrSomeLock<U, T>, U, Name, *mut T);
impl<U, T> SomeLockWriteGuard<U, T>
where
U: AsUuid,
{
pub fn as_mut_tuple(&mut self) -> (Name, U, &mut T) {
(
self.2.clone(),
self.1,
unsafe { self.3.as_mut() }.expect("Cannot create null pointer from Rust references"),
)
}
}
unsafe impl<U, T> Send for SomeLockWriteGuard<U, T>
where
U: AsUuid + Send,
T: Send,
{
}
unsafe impl<U, T> Sync for SomeLockWriteGuard<U, T>
where
U: AsUuid + Sync,
T: Sync,
{
}
impl<U, T> Deref for SomeLockWriteGuard<U, T>
where
U: AsUuid,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.3.as_ref() }.expect("Cannot create null pointer through references in Rust")
}
}
impl<U, T> DerefMut for SomeLockWriteGuard<U, T>
where
U: AsUuid,
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.3.as_mut() }.expect("Cannot create null pointer through references in Rust")
}
}
impl<U, T> Drop for SomeLockWriteGuard<U, T>
where
U: AsUuid,
{
fn drop(&mut self) {
trace!("Dropping write lock on pool with UUID {}", self.1);
let mut lock_record = self.0.acquire_mutex();
lock_record.remove_write_lock(&self.1);
lock_record.wake();
trace!("Write lock on pool with UUID {} dropped", self.1);
}
}
struct AllRead<U: AsUuid, T>(AllOrSomeLock<U, T>, AtomicBool, u64);
impl<U, T> Future for AllRead<U, T>
where
U: AsUuid,
{
type Output = AllLockReadGuard<U, T>;
fn poll(self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock_record = self.0.acquire_mutex();
let wait_type = WaitType::AllRead;
let poll = if lock_record.should_wait(&wait_type, self.2) {
lock_record.add_waiter(&self.1, wait_type, cxt.waker().clone(), self.2);
Poll::Pending
} else {
lock_record.add_read_all_lock(self.2);
Poll::Ready(AllLockReadGuard(
self.0.clone(),
lock_record.inner.get() as *const _,
))
};
poll
}
}
impl<U, T> Drop for AllRead<U, T>
where
U: AsUuid,
{
fn drop(&mut self) {
let mut lock_record = self.0.acquire_mutex();
lock_record.cancel(self.2);
}
}
pub struct AllLockReadGuard<U: AsUuid, T>(AllOrSomeLock<U, T>, *const Table<U, T>);
impl<U, T> Into<Vec<SomeLockReadGuard<U, T>>> for AllLockReadGuard<U, T>
where
U: AsUuid,
{
#[allow(clippy::needless_collect)]
fn into(self) -> Vec<SomeLockReadGuard<U, T>> {
let mut lock_record = self.0.acquire_mutex();
assert!(lock_record.write_locked.is_empty());
assert!(!lock_record.all_write_locked);
let guards = unsafe { lock_record.inner.get().as_ref() }
.expect("cannot be null")
.iter()
.map(|(n, u, t)| {
(
*u,
SomeLockReadGuard(self.0.clone(), *u, n.clone(), t as *const _),
)
})
.collect::<Vec<_>>();
guards
.into_iter()
.map(|(u, guard)| {
lock_record.add_read_lock(u, None);
guard
})
.collect::<Vec<_>>()
}
}
unsafe impl<U, T> Send for AllLockReadGuard<U, T>
where
U: AsUuid + Send,
T: Send,
{
}
unsafe impl<U, T> Sync for AllLockReadGuard<U, T>
where
U: AsUuid + Sync,
T: Sync,
{
}
impl<U, T> Deref for AllLockReadGuard<U, T>
where
U: AsUuid,
{
type Target = Table<U, T>;
fn deref(&self) -> &Self::Target {
unsafe { self.1.as_ref() }.expect("Cannot create null pointer through references in Rust")
}
}
impl<U, T> Drop for AllLockReadGuard<U, T>
where
U: AsUuid,
{
fn drop(&mut self) {
trace!("Dropping all read lock");
let mut lock_record = self.0.acquire_mutex();
lock_record.remove_read_all_lock();
lock_record.wake();
trace!("All read lock dropped");
}
}
struct AllWrite<U: AsUuid, T>(AllOrSomeLock<U, T>, AtomicBool, u64);
impl<U, T> Future for AllWrite<U, T>
where
U: AsUuid,
{
type Output = AllLockWriteGuard<U, T>;
fn poll(self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Self::Output> {
let mut lock_record = self.0.acquire_mutex();
let wait_type = WaitType::AllWrite;
let poll = if lock_record.should_wait(&wait_type, self.2) {
lock_record.add_waiter(&self.1, wait_type, cxt.waker().clone(), self.2);
Poll::Pending
} else {
lock_record.add_write_all_lock(self.2);
Poll::Ready(AllLockWriteGuard(self.0.clone(), lock_record.inner.get()))
};
poll
}
}
impl<U, T> Drop for AllWrite<U, T>
where
U: AsUuid,
{
fn drop(&mut self) {
let mut lock_record = self.0.acquire_mutex();
lock_record.cancel(self.2);
}
}
pub struct AllLockWriteGuard<U: AsUuid, T>(AllOrSomeLock<U, T>, *mut Table<U, T>);
impl<U, T> Into<Vec<SomeLockWriteGuard<U, T>>> for AllLockWriteGuard<U, T>
where
U: AsUuid,
{
#[allow(clippy::needless_collect)]
fn into(self) -> Vec<SomeLockWriteGuard<U, T>> {
let mut lock_record = self.0.acquire_mutex();
assert!(lock_record.read_locked.is_empty());
assert!(lock_record.write_locked.is_empty());
assert_eq!(lock_record.all_read_locked, 0);
let guards = unsafe { lock_record.inner.get().as_mut() }
.expect("cannot be null")
.iter_mut()
.map(|(n, u, t)| {
(
*u,
SomeLockWriteGuard(self.0.clone(), *u, n.clone(), t as *mut _),
)
})
.collect::<Vec<_>>();
guards
.into_iter()
.map(|(u, guard)| {
lock_record.add_write_lock(u, None);
guard
})
.collect::<Vec<_>>()
}
}
unsafe impl<U, T> Send for AllLockWriteGuard<U, T>
where
U: AsUuid + Send,
T: Send,
{
}
unsafe impl<U, T> Sync for AllLockWriteGuard<U, T>
where
U: AsUuid + Sync,
T: Sync,
{
}
impl<U, T> Deref for AllLockWriteGuard<U, T>
where
U: AsUuid,
{
type Target = Table<U, T>;
fn deref(&self) -> &Self::Target {
unsafe { self.1.as_ref() }.expect("Cannot create null pointer through references in Rust")
}
}
impl<U, T> DerefMut for AllLockWriteGuard<U, T>
where
U: AsUuid,
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.1.as_mut() }.expect("Cannot create null pointer through references in Rust")
}
}
impl<U, T> Drop for AllLockWriteGuard<U, T>
where
U: AsUuid,
{
fn drop(&mut self) {
trace!("Dropping all write lock");
let mut lock_record = self.0.acquire_mutex();
lock_record.remove_write_all_lock();
lock_record.wake();
trace!("All write lock dropped");
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::poll;
use crate::engine::types::PoolUuid;
#[test]
fn test_cancelled_future() {
let lock = AllOrSomeLock::new(Table::<PoolUuid, bool>::default());
let _write_all = test_async!(lock.write_all());
let read_all = Box::pin(lock.read_all());
assert!(matches!(
test_async!(async { poll!(read_all) }),
Poll::Pending
));
let read_all = Box::pin(lock.read_all());
assert!(matches!(
test_async!(async { poll!(read_all) }),
Poll::Pending
));
let len = lock.lock_record.lock().unwrap().waiting.len();
assert_eq!(len, 0);
}
}