use parking_lot::{Condvar, MappedMutexGuard, Mutex, MutexGuard};
use std::sync::Arc;
#[must_use = "Shared data will be unlocked on drop!"]
pub struct SharedDataLockedUpgradable<T> {
shared_data: SharedData<T>,
}
impl<T> SharedDataLockedUpgradable<T> {
pub fn upgrade(&mut self) -> MappedMutexGuard<'_, T> {
MutexGuard::map(self.shared_data.inner.lock(), |i| &mut i.shared_data)
}
}
impl<T> Drop for SharedDataLockedUpgradable<T> {
fn drop(&mut self) {
let mut inner = self.shared_data.inner.lock();
inner.locked = false;
self.shared_data.cond_var.notify_all();
}
}
#[must_use = "Shared data will be unlocked on drop!"]
pub struct SharedDataLocked<'a, T> {
inner: MutexGuard<'a, SharedDataInner<T>>,
shared_data: Option<SharedData<T>>,
}
impl<'a, T> SharedDataLocked<'a, T> {
pub fn release_mutex(mut self) -> SharedDataLockedUpgradable<T> {
SharedDataLockedUpgradable {
shared_data: self.shared_data.take().expect("`shared_data` is only taken on drop; qed"),
}
}
}
impl<'a, T> Drop for SharedDataLocked<'a, T> {
fn drop(&mut self) {
if let Some(shared_data) = self.shared_data.take() {
self.inner.locked = false;
shared_data.cond_var.notify_all();
}
}
}
impl<'a, T> std::ops::Deref for SharedDataLocked<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner.shared_data
}
}
impl<'a, T> std::ops::DerefMut for SharedDataLocked<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner.shared_data
}
}
struct SharedDataInner<T> {
shared_data: T,
locked: bool,
}
pub struct SharedData<T> {
inner: Arc<Mutex<SharedDataInner<T>>>,
cond_var: Arc<Condvar>,
}
impl<T> Clone for SharedData<T> {
fn clone(&self) -> Self {
Self { inner: self.inner.clone(), cond_var: self.cond_var.clone() }
}
}
impl<T> SharedData<T> {
pub fn new(shared_data: T) -> Self {
Self {
inner: Arc::new(Mutex::new(SharedDataInner { shared_data, locked: false })),
cond_var: Default::default(),
}
}
pub fn shared_data(&self) -> MappedMutexGuard<'_, T> {
let mut guard = self.inner.lock();
while guard.locked {
self.cond_var.wait(&mut guard);
}
debug_assert!(!guard.locked);
MutexGuard::map(guard, |i| &mut i.shared_data)
}
pub fn shared_data_locked(&self) -> SharedDataLocked<'_, T> {
let mut guard = self.inner.lock();
while guard.locked {
self.cond_var.wait(&mut guard);
}
debug_assert!(!guard.locked);
guard.locked = true;
SharedDataLocked { inner: guard, shared_data: Some(self.clone()) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shared_data_locking_works() {
const THREADS: u32 = 100;
let shared_data = SharedData::new(0u32);
let lock = shared_data.shared_data_locked();
for i in 0..THREADS {
let data = shared_data.clone();
std::thread::spawn(move || {
if i % 2 == 1 {
*data.shared_data() += 1;
} else {
let mut lock = data.shared_data_locked().release_mutex();
std::thread::sleep(std::time::Duration::from_millis(10));
*lock.upgrade() += 1;
}
});
}
let lock = lock.release_mutex();
std::thread::sleep(std::time::Duration::from_millis(100));
drop(lock);
while *shared_data.shared_data() < THREADS {
std::thread::sleep(std::time::Duration::from_millis(100));
}
}
}