#![deny(missing_docs)]
use std::cell::RefCell;
use std::ops::Deref;
use std::sync::{Arc, RwLock, RwLockReadGuard};
use std::thread::LocalKey;
pub struct AppWorldWrapper<W: AppWorld> {
world: Arc<RwLock<W>>,
}
pub trait AppWorld: Sized {
type Message;
fn msg(&mut self, message: Self::Message);
}
impl<W: AppWorld + 'static> AppWorldWrapper<W> {
pub fn new(world: W) -> Self {
let world = Arc::new(RwLock::new(world));
Self { world }
}
pub fn msg(&self, msg: W::Message) {
self.world.write().unwrap().msg(msg)
}
}
impl<W: AppWorld + 'static> AppWorldWrapper<W> {
thread_local!(
static HAS_READ: RefCell<bool> = RefCell::new(false);
);
pub fn read(&self) -> WorldReadGuard<'_, W> {
Self::HAS_READ.with(|has_read| {
let mut has_read = has_read.borrow_mut();
if *has_read {
panic!("Thread already holds read guard")
}
*has_read = true
});
WorldReadGuard {
guard: self.world.read().unwrap(),
read_tracker: &Self::HAS_READ,
}
}
#[cfg(feature = "test-utils")]
pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, W> {
self.world.write().unwrap()
}
}
impl<W: AppWorld> Clone for AppWorldWrapper<W> {
fn clone(&self) -> Self {
AppWorldWrapper {
world: self.world.clone(),
}
}
}
pub struct WorldReadGuard<'a, W> {
guard: RwLockReadGuard<'a, W>,
read_tracker: &'static LocalKey<RefCell<bool>>,
}
impl<'a, W> Deref for WorldReadGuard<'a, W> {
type Target = RwLockReadGuard<'a, W>;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<'a, W> Drop for WorldReadGuard<'a, W> {
fn drop(&mut self) {
self.read_tracker.with(|has_reads| {
*has_reads.borrow_mut() = false;
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
#[should_panic = "Second read attempt panicked"]
fn deadlock_prevention_same_thread_double_read_another_thread_write() {
let world = AppWorldWrapper::new(TestWorld { was_mutated: false });
let world_clone1 = world.clone();
let world_clone2 = world.clone();
let handle = thread::spawn(move || {
let guard_1 = world.read();
assert_eq!(guard_1.was_mutated, false);
let handle = thread::spawn(move || {
world_clone1.msg(());
});
thread::sleep(Duration::from_millis(50));
let guard_3 = world.read();
assert_eq!(guard_3.was_mutated, true);
handle.join().unwrap();
});
let join = handle.join();
assert_eq!(world_clone2.read().was_mutated, true);
join.expect("Second read attempt panicked");
}
#[test]
fn two_non_colliding_reads() {
let world = AppWorldWrapper::new(TestWorld::default());
{
let _guard = world.read();
}
let _guard = world.read();
}
#[derive(Default)]
struct TestWorld {
was_mutated: bool,
}
impl AppWorld for TestWorld {
type Message = ();
fn msg(&mut self, _message: Self::Message) {
self.was_mutated = true;
}
}
}