async_event_dispatch/
lib.rs1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3
4use deadqueue::{limited, unlimited};
5use tokio::sync::oneshot::Sender;
6use tokio::sync::{oneshot, RwLock};
7
8enum Queue<T> {
9 Bounded(limited::Queue<T>),
10 Unbounded(unlimited::Queue<T>),
11}
12
13enum DispatcherType {
14 Bounded(usize),
15 Unbounded,
16}
17
18enum Event<T> {
19 Item(T),
20 ItemWait(T, Sender<()>),
21 Close,
22}
23
24impl<T> Queue<T> {
25 pub fn push(&self, item: T) {
26 match self {
27 Queue::Bounded(queue) => {
28 let _ignored = queue.try_push(item);
29 }
30 Queue::Unbounded(queue) => {
31 queue.push(item);
32 }
33 }
34 }
35
36 pub async fn pop(&self) -> T {
37 match self {
38 Queue::Bounded(queue) => queue.pop().await,
39 Queue::Unbounded(queue) => queue.pop().await,
40 }
41 }
42}
43
44pub struct Subscriber<T> {
45 queue: Arc<Queue<Event<T>>>,
46 ref_count: Arc<AtomicUsize>,
47}
48
49impl<T> Subscriber<T> {
50 pub async fn next(&self) -> Option<T> {
51 match self.queue.pop().await {
52 Event::Item(item) => Some(item),
53 Event::ItemWait(item, responder) => {
54 let _ignored = responder.send(());
55 Some(item)
56 }
57 Event::Close => None,
58 }
59 }
60}
61
62impl<T> Clone for Subscriber<T> {
63 fn clone(&self) -> Self {
64 let _ignored = self.ref_count.fetch_add(1, Ordering::Relaxed);
65 Self {
66 queue: self.queue.clone(),
67 ref_count: self.ref_count.clone(),
68 }
69 }
70}
71
72impl<T> Drop for Subscriber<T> {
73 fn drop(&mut self) {
74 let _ignored = self.ref_count.fetch_sub(1, Ordering::Relaxed);
75 }
76}
77
78pub struct Dispatcher<T: Send + 'static> {
79 dispatcher_type: DispatcherType,
80 subscribers: Arc<RwLock<Vec<Subscriber<T>>>>,
81}
82
83impl<T: Clone + Send + 'static> Dispatcher<T> {
84 pub fn new() -> Self {
85 Self {
86 dispatcher_type: DispatcherType::Unbounded,
87 subscribers: Arc::new(RwLock::new(Vec::new())),
88 }
89 }
90
91 pub fn new_bounded(limit: usize) -> Self {
92 Self {
93 dispatcher_type: DispatcherType::Bounded(limit),
94 subscribers: Arc::new(RwLock::new(Vec::new())),
95 }
96 }
97
98 pub async fn cleanup(&self) {
99 let mut subscribers = self.subscribers.write().await;
100 subscribers
101 .retain(|subscriber| subscriber.ref_count.load(Ordering::Relaxed) > 1)
102 }
103
104 pub async fn dispatch(&self, event: T) {
105 let subscribers = self.subscribers.read().await;
106 let mut cleanup = false;
107 for subscriber in subscribers.iter() {
108 if subscriber.ref_count.load(Ordering::Relaxed) > 1 {
109 subscriber.queue.push(Event::Item(event.clone()));
110 } else {
111 cleanup = true;
112 }
113 }
114 if cleanup {
115 drop(subscribers);
116 self.cleanup().await;
117 }
118 }
119
120 pub async fn dispatch_wait(&self, event: T) {
121 let subscribers = self.subscribers.read().await;
122 let mut receivers = Vec::new();
123 let mut cleanup = false;
124 for subscriber in subscribers.iter() {
125 if subscriber.ref_count.load(Ordering::Relaxed) > 1 {
126 let (responder, receiver) = oneshot::channel();
127 receivers.push(receiver);
128 subscriber
129 .queue
130 .push(Event::ItemWait(event.clone(), responder));
131 } else {
132 cleanup = true;
133 }
134 }
135
136 if cleanup {
137 drop(subscribers);
138 self.cleanup().await;
139 }
140
141 for fut in receivers.into_iter() {
142 let _ignored = fut.await;
143 }
144 }
145
146 pub async fn subscriber_count(&self) -> usize {
147 self.subscribers.read().await.len()
148 }
149
150 pub async fn subscribe(&self) -> Subscriber<T> {
151 let subscriber = match self.dispatcher_type {
152 DispatcherType::Bounded(limit) => Subscriber {
153 queue: Arc::new(Queue::Bounded(limited::Queue::new(limit))),
154 ref_count: Arc::new(AtomicUsize::new(1)),
155 },
156 DispatcherType::Unbounded => Subscriber {
157 queue: Arc::new(Queue::Unbounded(unlimited::Queue::new())),
158 ref_count: Arc::new(AtomicUsize::new(1)),
159 },
160 };
161 let mut subscribers = self.subscribers.write().await;
162 subscribers.push(subscriber.clone());
163 subscriber
164 }
165}
166
167impl<T: Clone + Send + 'static> Default for Dispatcher<T> {
168 fn default() -> Self {
169 Dispatcher::new()
170 }
171}
172
173impl<T: Send + 'static> Drop for Dispatcher<T> {
174 fn drop(&mut self) {
175 if let Ok(rt) = tokio::runtime::Handle::try_current() {
176 let subscribers = self.subscribers.clone();
177 rt.spawn(async move {
178 for subscriber in subscribers.read().await.iter() {
179 subscriber.queue.push(Event::Close);
180 }
181 });
182 } else {
183 for subscriber in self.subscribers.blocking_read().iter() {
186 subscriber.queue.push(Event::Close);
187 }
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use std::sync::Arc;
195
196 use tokio::sync::Mutex;
197
198 use crate::Dispatcher;
199
200 #[tokio::test]
201 async fn it_works() {
202 let dispatcher = Dispatcher::<i32>::new();
203 let subscriber = dispatcher.subscribe().await;
204 let subscriber2 = dispatcher.subscribe().await;
205 let read = Arc::new(Mutex::new(false));
206 let handle = tokio::spawn({
207 let read = read.clone();
208 async move {
209 assert_eq!(subscriber.next().await, Some(42));
210 assert_eq!(subscriber.next().await, Some(69));
211 *read.lock().await = true;
212 }
213 });
214 dispatcher.dispatch(42).await;
215 dispatcher.dispatch(69).await;
216 assert!(handle.await.is_ok());
217 assert!(*read.lock().await);
218 assert_eq!(subscriber2.next().await, Some(42));
220 assert_eq!(subscriber2.next().await, Some(69));
221
222 drop(dispatcher);
223 assert_eq!(subscriber2.next().await, None);
224 }
225}