future_mediator/
lib.rs

1use std::{
2    fmt::Debug,
3    future::Future,
4    ops::{Deref, DerefMut},
5    sync::Arc,
6    task::Context,
7};
8
9use futures::lock::{Mutex, OwnedMutexLockFuture};
10
11use std::{
12    collections::HashMap,
13    hash::Hash,
14    task::{Poll, Waker},
15};
16
17use futures::FutureExt;
18
19/// Shared raw data between futures.
20pub struct Shared<T, E> {
21    value: T,
22    wakers: HashMap<E, Waker>,
23}
24
25impl<T, E> Shared<T, E> {
26    fn new(value: T) -> Self {
27        Self {
28            value: value.into(),
29            wakers: Default::default(),
30        }
31    }
32
33    fn register_event_listener(&mut self, event: E, waker: Waker)
34    where
35        E: Eq + Hash,
36    {
37        self.wakers.insert(event, waker);
38    }
39
40    /// Emit once `event` on
41    pub fn notify(&mut self, event: E)
42    where
43        E: Eq + Hash + Debug,
44    {
45        if let Some(waker) = self.wakers.remove(&event) {
46            log::trace!("notify event={:?}, wakeup=true", event);
47            waker.wake();
48        } else {
49            log::trace!("notify event={:?}, wakeup=false", event);
50        }
51    }
52
53    /// Emit all `events` on
54    pub fn notify_all<Events: AsRef<[E]>>(&mut self, events: Events)
55    where
56        E: Eq + Hash + Debug + Clone,
57    {
58        for event in events.as_ref() {
59            self.notify(event.clone());
60        }
61    }
62
63    /// Get shared value immutable reference.
64    pub fn value(&self) -> &T {
65        &self.value
66    }
67
68    /// Get shared value mutable reference.
69    pub fn value_mut(&mut self) -> &mut T {
70        &mut self.value
71    }
72}
73
74impl<T, E> Deref for Shared<T, E> {
75    type Target = T;
76    fn deref(&self) -> &Self::Target {
77        &self.value
78    }
79}
80
81impl<T, E> DerefMut for Shared<T, E> {
82    fn deref_mut(&mut self) -> &mut Self::Target {
83        &mut self.value
84    }
85}
86
87/// A mediator is a central hub for communication between futures.
88#[derive(Debug)]
89pub struct Mediator<T, E> {
90    raw: Arc<Mutex<Shared<T, E>>>,
91}
92
93impl<T, E> Clone for Mediator<T, E> {
94    fn clone(&self) -> Self {
95        Self {
96            raw: self.raw.clone(),
97        }
98    }
99}
100
101impl<T, E> Mediator<T, E> {
102    /// Create new mediator with shared value.
103    pub fn new(value: T) -> Self {
104        Self {
105            raw: Arc::new(Mutex::new(Shared::new(value))),
106        }
107    }
108
109    /// Acquire the lock and access immutable shared data.
110    pub async fn with<F, R>(&self, f: F) -> R
111    where
112        F: FnOnce(&T) -> R,
113    {
114        let raw = self.raw.lock().await;
115
116        f(&raw.value)
117    }
118
119    /// Acquire the lock and access mutable shared data.
120    pub async fn with_mut<F, R>(&self, f: F) -> R
121    where
122        F: FnOnce(&mut T) -> R,
123    {
124        let mut raw = self.raw.lock().await;
125
126        f(&mut raw.value)
127    }
128
129    /// Attempt to acquire the shared data lock immediately.
130    ///
131    /// If the lock is currently held, this will return `None`.
132    pub fn try_lock(&self) -> Option<futures::lock::MutexGuard<'_, Shared<T, E>>> {
133        self.raw.try_lock()
134    }
135
136    /// Emit one event on.
137    pub async fn notify(&self, event: E)
138    where
139        E: Eq + Hash + Debug,
140    {
141        let mut raw = self.raw.lock().await;
142
143        raw.notify(event);
144    }
145
146    /// Emit all events
147    pub async fn notify_all<Events: AsRef<[E]>>(&self, events: Events)
148    where
149        E: Eq + Hash + Clone + Debug,
150    {
151        let mut raw = self.raw.lock().await;
152
153        for event in events.as_ref() {
154            raw.notify(event.clone());
155        }
156    }
157
158    /// Create a new event handle future with poll function `f` and run once immediately
159    ///
160    /// If `f` returns [`Pending`](Poll::Pending), the system will move the
161    /// handle into the event waiting map, and the future returns `Pending` status.
162    ///
163    /// You can call [`notify`](Mediator::notify) to wake up this poll function and run once again.
164    pub fn on_fn<F, R>(&self, event: E, f: F) -> OnEvent<T, E, F>
165    where
166        F: FnMut(&mut Shared<T, E>, &mut Context<'_>) -> Poll<R> + Unpin,
167        T: Unpin + 'static,
168        E: Unpin + Eq + Hash + Debug,
169        R: Unpin,
170    {
171        OnEvent {
172            f: Some(f),
173            raw: self.raw.clone(),
174            lock_future: None,
175            event,
176        }
177    }
178}
179
180/// Future create by [`on`](Mediator::on)
181pub struct OnEvent<T, E, F>
182where
183    E: Debug,
184{
185    f: Option<F>,
186    raw: Arc<Mutex<Shared<T, E>>>,
187    lock_future: Option<OwnedMutexLockFuture<Shared<T, E>>>,
188    event: E,
189}
190
191impl<T, E, F, R> Future for OnEvent<T, E, F>
192where
193    F: FnMut(&mut Shared<T, E>, &mut Context<'_>) -> Poll<R> + Unpin,
194    T: Unpin,
195    E: Unpin + Eq + Hash + Copy,
196    R: Unpin,
197    E: Debug,
198{
199    type Output = R;
200
201    fn poll(
202        mut self: std::pin::Pin<&mut Self>,
203        cx: &mut std::task::Context<'_>,
204    ) -> Poll<Self::Output> {
205        let mut lock_future = if let Some(lock_future) = self.lock_future.take() {
206            lock_future
207        } else {
208            self.raw.clone().lock_owned()
209        };
210
211        let mut raw = match lock_future.poll_unpin(cx) {
212            Poll::Ready(raw) => raw,
213            _ => {
214                self.lock_future = Some(lock_future);
215
216                return Poll::Pending;
217            }
218        };
219
220        let mut f = self.f.take().unwrap();
221
222        match f(&mut raw, cx) {
223            Poll::Pending => {
224                self.f = Some(f);
225
226                raw.register_event_listener(self.event, cx.waker().clone());
227
228                return Poll::Pending;
229            }
230            poll => {
231                return poll;
232            }
233        }
234    }
235}
236
237/// Register event handle with async fn
238#[macro_export]
239macro_rules! on {
240    ($mediator: expr, $event: expr, $fut: expr) => {
241        $mediator.on_fn(Event::A, |mediator_cx, cx| {
242            use $crate::FutureExt;
243            Box::pin($fut(mediator_cx)).poll_unpin(cx)
244        })
245    };
246}
247
248#[cfg(test)]
249mod tests {
250    use std::task::Poll;
251
252    use futures::executor::ThreadPool;
253
254    use futures::task::SpawnExt;
255
256    use crate::{Mediator, Shared};
257
258    #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
259    enum Event {
260        A,
261        B,
262    }
263
264    #[futures_test::test]
265    async fn test_mediator() {
266        let mediator: Mediator<i32, Event> = Mediator::new(1);
267
268        let thread_pool = ThreadPool::builder().pool_size(10).create().unwrap();
269
270        thread_pool
271            .spawn(mediator.on_fn(Event::B, |mediator_cx, _| {
272                if *mediator_cx.value() == 1 {
273                    *mediator_cx.value_mut() = 2;
274                    mediator_cx.notify(Event::A);
275
276                    return Poll::Ready(());
277                }
278
279                return Poll::Pending;
280            }))
281            .unwrap();
282
283        mediator
284            .on_fn(Event::A, |mediator_cx, _| {
285                if *mediator_cx.value() == 1 {
286                    return Poll::Pending;
287                }
288
289                return Poll::Ready(());
290            })
291            .await;
292    }
293
294    #[futures_test::test]
295    async fn test_mediator_async() {
296        let mediator: Mediator<i32, Event> = Mediator::new(1);
297
298        let thread_pool = ThreadPool::builder().pool_size(10).create().unwrap();
299
300        async fn assign_2(cx: &mut Shared<i32, Event>) {
301            *cx.value_mut() = 2;
302        }
303
304        thread_pool
305            .spawn_with_handle(on!(mediator, Event::A, assign_2))
306            .unwrap()
307            .await;
308
309        assert_eq!(mediator.with(|value| *value).await, 2);
310    }
311}