1use std::{
101 collections::HashMap,
102 sync::{Arc, RwLock},
103 thread,
104};
105
106use crossbeam::channel::{bounded, unbounded, Receiver, Sender, TryRecvError};
107
108struct BusInner<T: Clone> {
109 senders: HashMap<usize, Sender<T>>,
110 next_id: usize,
111}
112
113impl<T: Clone> BusInner<T> {
114 pub fn add_rx(&mut self) -> Receiver<T> {
115 let (sender, receiver) = unbounded::<T>();
116 self.senders.insert(self.next_id, sender);
117 self.next_id += 1;
118 receiver
119 }
120
121 pub fn broadcast(&self, event: T) -> Vec<usize> {
122 let mut disconnected = Vec::with_capacity(0);
123
124 if let Some(((last_id, last_sender), the_rest)) = self.get_sorted_senders().split_last() {
125 for (id, sender) in the_rest.iter() {
126 if sender.send(event.clone()).is_err() {
127 disconnected.push(**id);
128 }
129 }
130
131 if last_sender.send(event).is_err() {
132 disconnected.push(**last_id);
133 };
134 }
135
136 disconnected
137 }
138
139 pub fn remove_senders(&mut self, ids: &[usize]) {
140 for id in ids {
141 self.senders.remove(&id);
142 }
143 }
144
145 fn get_sorted_senders(&self) -> Vec<(&usize, &Sender<T>)> {
146 let mut senders = self.senders.iter().collect::<Vec<(&usize, &Sender<T>)>>();
147 senders.sort_by_key(|(id, _)| **id);
148 senders
149 }
150}
151
152impl<T: Clone> Default for BusInner<T> {
153 fn default() -> Self {
154 BusInner {
155 senders: Default::default(),
156 next_id: 0,
157 }
158 }
159}
160
161#[derive(Clone)]
162pub struct Bus<T: Clone> {
163 inner: Arc<RwLock<BusInner<T>>>,
164}
165
166impl<T: Clone> Bus<T> {
167 pub fn new() -> Self {
169 Bus {
170 inner: Default::default(),
171 }
172 }
173
174 pub fn add_rx(&self) -> Receiver<T> {
176 self.inner.write().expect("Lock was poisoned").add_rx()
177 }
178
179 pub fn broadcast(&self, event: T) {
181 let disconnected = {
182 self.inner
183 .read()
184 .expect("Lock was poisoned")
185 .broadcast(event)
186 };
187
188 if !disconnected.is_empty() {
189 self.inner
190 .write()
191 .expect("Lock was poisoned")
192 .remove_senders(&disconnected);
193 }
194 }
195}
196
197impl<T: Clone> Default for Bus<T> {
198 fn default() -> Self {
199 Bus::new()
200 }
201}
202
203type BoxedFn<T> = Box<dyn FnMut(T) + Send>;
204
205struct DropSignal {
206 tx_signal: Sender<()>,
207}
208
209impl DropSignal {
210 pub fn new(tx_signal: Sender<()>) -> Arc<Self> {
211 Arc::new(DropSignal { tx_signal })
212 }
213}
214
215impl Drop for DropSignal {
216 fn drop(&mut self) {
217 let _ = self.tx_signal.send(());
218 }
219}
220
221#[derive(Clone)]
222pub struct Subscription {
223 terminate: Arc<DropSignal>,
224}
225
226impl Subscription {
227 pub fn new(terminate: Sender<()>) -> Self {
228 Subscription {
229 terminate: DropSignal::new(terminate),
230 }
231 }
232}
233
234pub trait SubscribeToReader<T: Send + 'static> {
235 #[must_use]
236 fn subscribe_on_thread(&self, callback: BoxedFn<T>) -> Subscription;
237 fn subscribe(&self, callback: BoxedFn<T>);
238}
239
240impl<T: Send + 'static> SubscribeToReader<T> for Receiver<T> {
241 #[must_use]
242 fn subscribe_on_thread(&self, mut callback: BoxedFn<T>) -> Subscription {
243 let (terminate_tx, terminate_rx) = bounded::<()>(0);
244 let receiver = self.clone();
245
246 thread::Builder::new()
247 .name("Receiver subscription thread".to_string())
248 .spawn(move || loop {
249 for event in receiver.try_iter() {
250 callback(event);
251 }
252
253 match terminate_rx.try_recv() {
254 Err(TryRecvError::Empty) => {}
255 _ => return,
256 }
257 })
258 .expect("Could not start Receiver subscription thread");
259
260 Subscription::new(terminate_tx)
261 }
262
263 fn subscribe(&self, mut callback: BoxedFn<T>) {
264 for event in self.iter() {
265 callback(event);
266 }
267 }
268}
269
270impl<T: Clone + Send + 'static> SubscribeToReader<T> for Bus<T> {
271 #[must_use]
272 fn subscribe_on_thread(&self, callback: BoxedFn<T>) -> Subscription {
273 self.add_rx().subscribe_on_thread(callback)
274 }
275
276 fn subscribe(&self, callback: BoxedFn<T>) {
277 self.add_rx().subscribe(callback)
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crossbeam::channel::RecvTimeoutError;
285 use std::time::Duration;
286
287 #[derive(Clone, PartialEq, Debug)]
288 struct Something;
289
290 #[derive(Clone, PartialEq, Debug)]
291 enum Event {
292 Start,
293 Stop(Vec<Something>),
294 }
295
296 #[test]
297 fn subscribe_on_thread() {
298 let dispatcher = Bus::<Event>::new();
299
300 let _sub_unused = dispatcher.subscribe_on_thread(Box::new(move |_event| {
302 }));
304
305 let __sub_unused = dispatcher.subscribe_on_thread(Box::new(move |_event| {
306 }));
308
309 let (tx_test, rx_test) = unbounded::<Event>();
310
311 {
312 let _sub = dispatcher.subscribe_on_thread(Box::new(move |event| {
313 tx_test.send(event).unwrap();
314 }));
315
316 dispatcher.broadcast(Event::Start);
317 dispatcher.broadcast(Event::Stop(vec![Something {}]));
318
319 match rx_test.recv_timeout(Duration::from_millis(100)) {
320 Err(_) => panic!("Event not received"),
321 Ok(e) => assert_eq!(e, Event::Start),
322 }
323
324 match rx_test.recv_timeout(Duration::from_millis(100)) {
325 Err(_) => panic!("Event not received"),
326 Ok(e) => assert_eq!(e, Event::Stop(vec![Something {}])),
327 }
328
329 }
331
332 dispatcher.broadcast(Event::Start);
333
334 match rx_test.recv_timeout(Duration::from_millis(100)) {
335 Err(RecvTimeoutError::Disconnected) => {}
336 _ => panic!("Subscription has been dropped so we should not get any events"),
337 }
338 }
339
340 #[test]
341 fn clone_subscription_without_dropping() {
342 let dispatcher = Bus::<Event>::new();
343
344 let (tx_test, rx_test) = unbounded::<Event>();
345
346 {
347 let sub = dispatcher.subscribe_on_thread(Box::new(move |event| {
348 tx_test.send(event).unwrap();
349 }));
350
351 {
352 #[allow(clippy::redundant_clone)]
353 let _sub_clone = sub.clone();
354 }
356
357 dispatcher.broadcast(Event::Start);
358
359 match rx_test.recv_timeout(Duration::from_millis(100)) {
360 Err(_) => panic!("Event not received"),
361 Ok(e) => assert_eq!(e, Event::Start),
362 }
363
364 }
366
367 dispatcher.broadcast(Event::Start);
368
369 match rx_test.recv_timeout(Duration::from_millis(100)) {
370 Err(RecvTimeoutError::Disconnected) => {}
371 _ => panic!("Subscription has been dropped so we should not get any events"),
372 }
373 }
374}