async_event_dispatch/
lib.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3
4use deadqueue::{limited, unlimited};
5use tokio::sync::oneshot::Sender;
6use tokio::sync::{oneshot, RwLock};
7
8enum Queue<T> {
9  Bounded(limited::Queue<T>),
10  Unbounded(unlimited::Queue<T>),
11}
12
13enum DispatcherType {
14  Bounded(usize),
15  Unbounded,
16}
17
18enum Event<T> {
19  Item(T),
20  ItemWait(T, Sender<()>),
21  Close,
22}
23
24impl<T> Queue<T> {
25  pub fn push(&self, item: T) {
26    match self {
27      Queue::Bounded(queue) => {
28        let _ignored = queue.try_push(item);
29      }
30      Queue::Unbounded(queue) => {
31        queue.push(item);
32      }
33    }
34  }
35
36  pub async fn pop(&self) -> T {
37    match self {
38      Queue::Bounded(queue) => queue.pop().await,
39      Queue::Unbounded(queue) => queue.pop().await,
40    }
41  }
42}
43
44pub struct Subscriber<T> {
45  queue: Arc<Queue<Event<T>>>,
46  ref_count: Arc<AtomicUsize>,
47}
48
49impl<T> Subscriber<T> {
50  pub async fn next(&self) -> Option<T> {
51    match self.queue.pop().await {
52      Event::Item(item) => Some(item),
53      Event::ItemWait(item, responder) => {
54        let _ignored = responder.send(());
55        Some(item)
56      }
57      Event::Close => None,
58    }
59  }
60}
61
62impl<T> Clone for Subscriber<T> {
63  fn clone(&self) -> Self {
64    let _ignored = self.ref_count.fetch_add(1, Ordering::Relaxed);
65    Self {
66      queue: self.queue.clone(),
67      ref_count: self.ref_count.clone(),
68    }
69  }
70}
71
72impl<T> Drop for Subscriber<T> {
73  fn drop(&mut self) {
74    let _ignored = self.ref_count.fetch_sub(1, Ordering::Relaxed);
75  }
76}
77
78pub struct Dispatcher<T: Send + 'static> {
79  dispatcher_type: DispatcherType,
80  subscribers: Arc<RwLock<Vec<Subscriber<T>>>>,
81}
82
83impl<T: Clone + Send + 'static> Dispatcher<T> {
84  pub fn new() -> Self {
85    Self {
86      dispatcher_type: DispatcherType::Unbounded,
87      subscribers: Arc::new(RwLock::new(Vec::new())),
88    }
89  }
90
91  pub fn new_bounded(limit: usize) -> Self {
92    Self {
93      dispatcher_type: DispatcherType::Bounded(limit),
94      subscribers: Arc::new(RwLock::new(Vec::new())),
95    }
96  }
97
98  pub async fn cleanup(&self) {
99    let mut subscribers = self.subscribers.write().await;
100    subscribers
101      .retain(|subscriber| subscriber.ref_count.load(Ordering::Relaxed) > 1)
102  }
103
104  pub async fn dispatch(&self, event: T) {
105    let subscribers = self.subscribers.read().await;
106    let mut cleanup = false;
107    for subscriber in subscribers.iter() {
108      if subscriber.ref_count.load(Ordering::Relaxed) > 1 {
109        subscriber.queue.push(Event::Item(event.clone()));
110      } else {
111        cleanup = true;
112      }
113    }
114    if cleanup {
115      drop(subscribers);
116      self.cleanup().await;
117    }
118  }
119
120  pub async fn dispatch_wait(&self, event: T) {
121    let subscribers = self.subscribers.read().await;
122    let mut receivers = Vec::new();
123    let mut cleanup = false;
124    for subscriber in subscribers.iter() {
125      if subscriber.ref_count.load(Ordering::Relaxed) > 1 {
126        let (responder, receiver) = oneshot::channel();
127        receivers.push(receiver);
128        subscriber
129          .queue
130          .push(Event::ItemWait(event.clone(), responder));
131      } else {
132        cleanup = true;
133      }
134    }
135
136    if cleanup {
137      drop(subscribers);
138      self.cleanup().await;
139    }
140
141    for fut in receivers.into_iter() {
142      let _ignored = fut.await;
143    }
144  }
145
146  pub async fn subscriber_count(&self) -> usize {
147    self.subscribers.read().await.len()
148  }
149
150  pub async fn subscribe(&self) -> Subscriber<T> {
151    let subscriber = match self.dispatcher_type {
152      DispatcherType::Bounded(limit) => Subscriber {
153        queue: Arc::new(Queue::Bounded(limited::Queue::new(limit))),
154        ref_count: Arc::new(AtomicUsize::new(1)),
155      },
156      DispatcherType::Unbounded => Subscriber {
157        queue: Arc::new(Queue::Unbounded(unlimited::Queue::new())),
158        ref_count: Arc::new(AtomicUsize::new(1)),
159      },
160    };
161    let mut subscribers = self.subscribers.write().await;
162    subscribers.push(subscriber.clone());
163    subscriber
164  }
165}
166
167impl<T: Clone + Send + 'static> Default for Dispatcher<T> {
168  fn default() -> Self {
169    Dispatcher::new()
170  }
171}
172
173impl<T: Send + 'static> Drop for Dispatcher<T> {
174  fn drop(&mut self) {
175    if let Ok(rt) = tokio::runtime::Handle::try_current() {
176      let subscribers = self.subscribers.clone();
177      rt.spawn(async move {
178        for subscriber in subscribers.read().await.iter() {
179          subscriber.queue.push(Event::Close);
180        }
181      });
182    } else {
183      // blocking_read() is deadlock-safe as we do not hand out
184      // any guards and there aren't any references anymore
185      for subscriber in self.subscribers.blocking_read().iter() {
186        subscriber.queue.push(Event::Close);
187      }
188    }
189  }
190}
191
192#[cfg(test)]
193mod tests {
194  use std::sync::Arc;
195
196  use tokio::sync::Mutex;
197
198  use crate::Dispatcher;
199
200  #[tokio::test]
201  async fn it_works() {
202    let dispatcher = Dispatcher::<i32>::new();
203    let subscriber = dispatcher.subscribe().await;
204    let subscriber2 = dispatcher.subscribe().await;
205    let read = Arc::new(Mutex::new(false));
206    let handle = tokio::spawn({
207      let read = read.clone();
208      async move {
209        assert_eq!(subscriber.next().await, Some(42));
210        assert_eq!(subscriber.next().await, Some(69));
211        *read.lock().await = true;
212      }
213    });
214    dispatcher.dispatch(42).await;
215    dispatcher.dispatch(69).await;
216    assert!(handle.await.is_ok());
217    assert!(*read.lock().await);
218    // no slow receiver issues
219    assert_eq!(subscriber2.next().await, Some(42));
220    assert_eq!(subscriber2.next().await, Some(69));
221
222    drop(dispatcher);
223    assert_eq!(subscriber2.next().await, None);
224  }
225}