use std::{
ops::{Deref, DerefMut},
sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockError},
};
use once_cell::sync::Lazy;
use crate::Sender;
use super::SubscriberFn;
pub struct SharedState<Data> {
data: Lazy<RwLock<Data>>,
subscribers: Lazy<RwLock<Vec<SubscriberFn<Data>>>>,
}
impl<Data: std::fmt::Debug> std::fmt::Debug for SharedState<Data> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedState")
.field("data", &self.data)
.field("subscribers", &self.subscribers.try_read().map(|s| s.len()))
.finish()
}
}
impl<Data> Default for SharedState<Data>
where
Data: Default,
{
fn default() -> Self {
Self::new()
}
}
impl<Data> SharedState<Data>
where
Data: Default,
{
#[must_use]
pub const fn new() -> Self {
Self {
data: Lazy::new(RwLock::default),
subscribers: Lazy::new(RwLock::default),
}
}
pub fn subscribe<Msg, F>(&self, sender: &Sender<Msg>, f: F)
where
F: Fn(&Data) -> Msg + 'static + Send + Sync,
Msg: Send + 'static,
{
let sender = sender.clone();
self.subscribers
.write()
.unwrap()
.push(Box::new(move |data: &Data| {
let msg = f(data);
sender.send(msg).is_ok()
}));
}
pub fn subscribe_optional<Msg, F>(&self, sender: &Sender<Msg>, f: F)
where
F: Fn(&Data) -> Option<Msg> + 'static + Send + Sync,
Msg: Send + 'static,
{
let sender = sender.clone();
self.subscribers
.write()
.unwrap()
.push(Box::new(move |data: &Data| {
if let Some(msg) = f(data) {
sender.send(msg).is_ok()
} else {
true
}
}));
}
pub fn read(&self) -> SharedStateReadGuard<'_, Data> {
SharedStateReadGuard {
inner: self.data.read().unwrap(),
}
}
pub fn try_read(
&self,
) -> Result<SharedStateReadGuard<'_, Data>, TryLockError<RwLockReadGuard<'_, Data>>> {
Ok(SharedStateReadGuard {
inner: self.data.try_read()?,
})
}
pub fn write(&self) -> SharedStateWriteGuard<'_, Data> {
let subscribers = self.subscribers.write().unwrap();
let data = self.data.write().unwrap();
SharedStateWriteGuard { data, subscribers }
}
pub fn try_write(
&self,
) -> Result<SharedStateWriteGuard<'_, Data>, TryLockError<RwLockWriteGuard<'_, Data>>> {
let data = self.data.try_write()?;
let subscribers = self.subscribers.write().unwrap();
Ok(SharedStateWriteGuard { data, subscribers })
}
pub fn get_mut(&mut self) -> &mut Data {
self.data.get_mut().unwrap()
}
pub fn read_inner(&self) -> RwLockReadGuard<'_, Data> {
self.data.read().unwrap()
}
pub fn write_inner(&self) -> RwLockWriteGuard<'_, Data> {
self.data.write().unwrap()
}
}
#[derive(Debug)]
pub struct SharedStateReadGuard<'a, Data> {
inner: RwLockReadGuard<'a, Data>,
}
impl<Data> Deref for SharedStateReadGuard<'_, Data> {
type Target = Data;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub struct SharedStateWriteGuard<'a, Data> {
data: RwLockWriteGuard<'a, Data>,
subscribers: RwLockWriteGuard<'a, Vec<SubscriberFn<Data>>>,
}
impl<Data: std::fmt::Debug> std::fmt::Debug for SharedStateWriteGuard<'_, Data> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedStateWriteGuard")
.field("data", &self.data)
.field("subscribers", &self.subscribers.len())
.finish()
}
}
impl<Data> Deref for SharedStateWriteGuard<'_, Data> {
type Target = Data;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<Data> DerefMut for SharedStateWriteGuard<'_, Data> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl<Data> Drop for SharedStateWriteGuard<'_, Data> {
fn drop(&mut self) {
let data = &*self.data;
self.subscribers.retain(|subscriber| subscriber(data));
}
}
#[cfg(test)]
mod test {
use super::SharedState;
static STATE: SharedState<u8> = SharedState::new();
#[test]
fn shared_state() {
assert_eq!(*STATE.read(), 0);
{
let mut data = STATE.write();
*data += 1;
}
assert_eq!(*STATE.read(), 1);
let (sender, receiver) = crate::channel();
STATE.subscribe(&sender, |data| *data);
{
let mut data = STATE.write();
*data += 1;
}
assert_eq!(receiver.recv_sync().unwrap(), 2);
assert_eq!(*STATE.read(), 2);
}
}