1use std::{
2 fmt::Debug,
3 future::Future,
4 ops::{Deref, DerefMut},
5 sync::Arc,
6 task::Context,
7};
8
9use futures::lock::{Mutex, OwnedMutexLockFuture};
10
11use std::{
12 collections::HashMap,
13 hash::Hash,
14 task::{Poll, Waker},
15};
16
17use futures::FutureExt;
18
19pub struct Shared<T, E> {
21 value: T,
22 wakers: HashMap<E, Waker>,
23}
24
25impl<T, E> Shared<T, E> {
26 fn new(value: T) -> Self {
27 Self {
28 value: value.into(),
29 wakers: Default::default(),
30 }
31 }
32
33 fn register_event_listener(&mut self, event: E, waker: Waker)
34 where
35 E: Eq + Hash,
36 {
37 self.wakers.insert(event, waker);
38 }
39
40 pub fn notify(&mut self, event: E)
42 where
43 E: Eq + Hash + Debug,
44 {
45 if let Some(waker) = self.wakers.remove(&event) {
46 log::trace!("notify event={:?}, wakeup=true", event);
47 waker.wake();
48 } else {
49 log::trace!("notify event={:?}, wakeup=false", event);
50 }
51 }
52
53 pub fn notify_all<Events: AsRef<[E]>>(&mut self, events: Events)
55 where
56 E: Eq + Hash + Debug + Clone,
57 {
58 for event in events.as_ref() {
59 self.notify(event.clone());
60 }
61 }
62
63 pub fn value(&self) -> &T {
65 &self.value
66 }
67
68 pub fn value_mut(&mut self) -> &mut T {
70 &mut self.value
71 }
72}
73
74impl<T, E> Deref for Shared<T, E> {
75 type Target = T;
76 fn deref(&self) -> &Self::Target {
77 &self.value
78 }
79}
80
81impl<T, E> DerefMut for Shared<T, E> {
82 fn deref_mut(&mut self) -> &mut Self::Target {
83 &mut self.value
84 }
85}
86
87#[derive(Debug)]
89pub struct Mediator<T, E> {
90 raw: Arc<Mutex<Shared<T, E>>>,
91}
92
93impl<T, E> Clone for Mediator<T, E> {
94 fn clone(&self) -> Self {
95 Self {
96 raw: self.raw.clone(),
97 }
98 }
99}
100
101impl<T, E> Mediator<T, E> {
102 pub fn new(value: T) -> Self {
104 Self {
105 raw: Arc::new(Mutex::new(Shared::new(value))),
106 }
107 }
108
109 pub async fn with<F, R>(&self, f: F) -> R
111 where
112 F: FnOnce(&T) -> R,
113 {
114 let raw = self.raw.lock().await;
115
116 f(&raw.value)
117 }
118
119 pub async fn with_mut<F, R>(&self, f: F) -> R
121 where
122 F: FnOnce(&mut T) -> R,
123 {
124 let mut raw = self.raw.lock().await;
125
126 f(&mut raw.value)
127 }
128
129 pub fn try_lock(&self) -> Option<futures::lock::MutexGuard<'_, Shared<T, E>>> {
133 self.raw.try_lock()
134 }
135
136 pub async fn notify(&self, event: E)
138 where
139 E: Eq + Hash + Debug,
140 {
141 let mut raw = self.raw.lock().await;
142
143 raw.notify(event);
144 }
145
146 pub async fn notify_all<Events: AsRef<[E]>>(&self, events: Events)
148 where
149 E: Eq + Hash + Clone + Debug,
150 {
151 let mut raw = self.raw.lock().await;
152
153 for event in events.as_ref() {
154 raw.notify(event.clone());
155 }
156 }
157
158 pub fn on_fn<F, R>(&self, event: E, f: F) -> OnEvent<T, E, F>
165 where
166 F: FnMut(&mut Shared<T, E>, &mut Context<'_>) -> Poll<R> + Unpin,
167 T: Unpin + 'static,
168 E: Unpin + Eq + Hash + Debug,
169 R: Unpin,
170 {
171 OnEvent {
172 f: Some(f),
173 raw: self.raw.clone(),
174 lock_future: None,
175 event,
176 }
177 }
178}
179
180pub struct OnEvent<T, E, F>
182where
183 E: Debug,
184{
185 f: Option<F>,
186 raw: Arc<Mutex<Shared<T, E>>>,
187 lock_future: Option<OwnedMutexLockFuture<Shared<T, E>>>,
188 event: E,
189}
190
191impl<T, E, F, R> Future for OnEvent<T, E, F>
192where
193 F: FnMut(&mut Shared<T, E>, &mut Context<'_>) -> Poll<R> + Unpin,
194 T: Unpin,
195 E: Unpin + Eq + Hash + Copy,
196 R: Unpin,
197 E: Debug,
198{
199 type Output = R;
200
201 fn poll(
202 mut self: std::pin::Pin<&mut Self>,
203 cx: &mut std::task::Context<'_>,
204 ) -> Poll<Self::Output> {
205 let mut lock_future = if let Some(lock_future) = self.lock_future.take() {
206 lock_future
207 } else {
208 self.raw.clone().lock_owned()
209 };
210
211 let mut raw = match lock_future.poll_unpin(cx) {
212 Poll::Ready(raw) => raw,
213 _ => {
214 self.lock_future = Some(lock_future);
215
216 return Poll::Pending;
217 }
218 };
219
220 let mut f = self.f.take().unwrap();
221
222 match f(&mut raw, cx) {
223 Poll::Pending => {
224 self.f = Some(f);
225
226 raw.register_event_listener(self.event, cx.waker().clone());
227
228 return Poll::Pending;
229 }
230 poll => {
231 return poll;
232 }
233 }
234 }
235}
236
237#[macro_export]
239macro_rules! on {
240 ($mediator: expr, $event: expr, $fut: expr) => {
241 $mediator.on_fn(Event::A, |mediator_cx, cx| {
242 use $crate::FutureExt;
243 Box::pin($fut(mediator_cx)).poll_unpin(cx)
244 })
245 };
246}
247
248#[cfg(test)]
249mod tests {
250 use std::task::Poll;
251
252 use futures::executor::ThreadPool;
253
254 use futures::task::SpawnExt;
255
256 use crate::{Mediator, Shared};
257
258 #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
259 enum Event {
260 A,
261 B,
262 }
263
264 #[futures_test::test]
265 async fn test_mediator() {
266 let mediator: Mediator<i32, Event> = Mediator::new(1);
267
268 let thread_pool = ThreadPool::builder().pool_size(10).create().unwrap();
269
270 thread_pool
271 .spawn(mediator.on_fn(Event::B, |mediator_cx, _| {
272 if *mediator_cx.value() == 1 {
273 *mediator_cx.value_mut() = 2;
274 mediator_cx.notify(Event::A);
275
276 return Poll::Ready(());
277 }
278
279 return Poll::Pending;
280 }))
281 .unwrap();
282
283 mediator
284 .on_fn(Event::A, |mediator_cx, _| {
285 if *mediator_cx.value() == 1 {
286 return Poll::Pending;
287 }
288
289 return Poll::Ready(());
290 })
291 .await;
292 }
293
294 #[futures_test::test]
295 async fn test_mediator_async() {
296 let mediator: Mediator<i32, Event> = Mediator::new(1);
297
298 let thread_pool = ThreadPool::builder().pool_size(10).create().unwrap();
299
300 async fn assign_2(cx: &mut Shared<i32, Event>) {
301 *cx.value_mut() = 2;
302 }
303
304 thread_pool
305 .spawn_with_handle(on!(mediator, Event::A, assign_2))
306 .unwrap()
307 .await;
308
309 assert_eq!(mediator.with(|value| *value).await, 2);
310 }
311}