1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use crate::channel::{unbounded, UnboundedReceiver, UnboundedSender};
use async_lock::{Mutex, RwLock};
use async_trait::async_trait;
use std::fmt;

/// read only view of [`SubscribableRwLock`]
#[async_trait]
pub trait ReadView<T: Clone> {
    /// subscribe to state changes. Receive
    /// the updated state upon state change
    async fn subscribe(&self) -> UnboundedReceiver<T>;
    /// async clone the internal state and return it
    async fn cloned(&self) -> T;
}

/// read view with requirements on being threadsafe
pub trait ThreadedReadView<T: Clone + Sync + Send>:
    Send + Sync + ReadView<T> + std::fmt::Debug
{
}

/// A [`RwLock`] that can register subscribers to be notified upon state change.
#[derive(Default)]
pub struct SubscribableRwLock<T: Clone> {
    /// A list of subscribers to the rwlock
    subscribers: Mutex<Vec<UnboundedSender<T>>>,
    /// The lock holding the state
    rw_lock: RwLock<T>,
}

impl<T: Clone + Sync + Send + std::fmt::Debug> ThreadedReadView<T> for SubscribableRwLock<T> {}

#[async_trait]
impl<T: Clone + Send + Sync> ReadView<T> for SubscribableRwLock<T> {
    async fn subscribe(&self) -> UnboundedReceiver<T> {
        let (sender, receiver) = unbounded();
        self.subscribers.lock().await.push(sender);
        receiver
    }

    async fn cloned(&self) -> T {
        self.rw_lock.read().await.clone()
    }
}

impl<T: Clone> SubscribableRwLock<T> {
    /// create a new [`SubscribableRwLock`]
    pub fn new(t: T) -> Self {
        Self {
            subscribers: Mutex::new(Vec::new()),
            rw_lock: RwLock::new(t),
        }
    }

    /// subscribe to state changes. Receive
    /// the updated state upon state change
    pub async fn modify<F>(&self, cb: F)
    where
        F: FnOnce(&mut T),
    {
        let mut lock = self.rw_lock.write().await;
        cb(&mut *lock);
        let result = lock.clone();
        drop(lock);
        self.notify_change_subscribers(result).await;
    }

    /// send subscribers the updated state
    async fn notify_change_subscribers(&self, t: T) {
        let mut lock = self.subscribers.lock().await;
        let mut idx_to_remove = Vec::new();
        for (idx, sender) in lock.iter().enumerate() {
            if sender.send(t.clone()).await.is_err() {
                idx_to_remove.push(idx);
            }
        }
        for idx in idx_to_remove.into_iter().rev() {
            lock.remove(idx);
        }
    }
}

impl<T: Copy> SubscribableRwLock<T> {
    /// Return a copy of the current value of `T`
    pub async fn copied(&self) -> T {
        *self.rw_lock.read().await
    }
}

impl<T: fmt::Debug + Clone> fmt::Debug for SubscribableRwLock<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        /// Helper struct to be shown when the inner mutex is locked.
        struct Locked;
        impl fmt::Debug for Locked {
            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
                f.write_str("<locked>")
            }
        }

        match self.rw_lock.try_read() {
            None => f
                .debug_struct("SubscribableRwLock")
                .field("data", &Locked)
                .finish(),
            Some(guard) => f
                .debug_struct("SubscribableRwLock")
                .field("data", &&*guard)
                .finish(),
        }
    }
}