async_compatibility_layer/async_primitives/
broadcast.rs

1#![allow(clippy::must_use_candidate, clippy::module_name_repetitions)]
2use crate::art::async_block_on;
3use crate::channel::{SendError, UnboundedReceiver, UnboundedRecvError, UnboundedSender};
4use async_lock::RwLock;
5
6use std::{
7    collections::HashMap,
8    fmt::Debug,
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc,
12    },
13};
14
15/// Internals for a broadcast queue sender
16struct BroadcastSenderInner<T> {
17    /// Atomic int used for assigning ids
18    count: AtomicUsize,
19    /// Map of IDs to channels
20    outputs: RwLock<HashMap<usize, UnboundedSender<T>>>,
21}
22
23/// Public interface for a broadcast queue sender
24#[derive(Clone)]
25pub struct BroadcastSender<T> {
26    /// Underlying shared implementation details
27    inner: Arc<BroadcastSenderInner<T>>,
28}
29
30impl<T> Debug for BroadcastSender<T> {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("BroadcastSender")
33            .field("inner", &"inner")
34            .finish()
35    }
36}
37
38impl<T> BroadcastSender<T>
39where
40    T: Clone,
41{
42    /// Asynchronously sends a value to all connected receivers
43    ///
44    /// # Errors
45    ///
46    /// Will return `Err` if one of the downstream receivers was disconnected without being properly
47    /// dropped.
48    pub async fn send_async(&self, item: T) -> Result<(), SendError<T>> {
49        let map = self.inner.outputs.read().await;
50        for sender in map.values() {
51            sender.send(item.clone()).await?;
52        }
53        Ok(())
54    }
55
56    /// Asynchronously creates a new handle
57    pub async fn handle_async(&self) -> BroadcastReceiver<T> {
58        let id = self.inner.count.fetch_add(1, Ordering::SeqCst);
59        let (send, recv) = crate::channel::unbounded();
60        let mut map = self.inner.outputs.write().await;
61        map.insert(id, send);
62        BroadcastReceiver {
63            id,
64            output: recv,
65            handle: self.clone(),
66        }
67    }
68
69    /// Synchronously creates a new handle
70    pub fn handle_sync(&self) -> BroadcastReceiver<T> {
71        async_block_on(self.handle_async())
72    }
73}
74
75/// Broadcast queue receiver
76pub struct BroadcastReceiver<T> {
77    /// ID for this receiver
78    id: usize,
79    /// Queue output
80    output: UnboundedReceiver<T>,
81    /// Handle to the sender internals
82    handle: BroadcastSender<T>,
83}
84
85impl<T> BroadcastReceiver<T>
86where
87    T: Clone,
88{
89    /// Asynchronously receives a value
90    ///
91    /// # Errors
92    ///
93    /// Will return `Err` if the upstream sender has been disconnected.
94    pub async fn recv_async(&mut self) -> Result<T, UnboundedRecvError> {
95        self.output.recv().await
96    }
97
98    /// Returns a value, if one is available
99    pub fn try_recv(&mut self) -> Option<T> {
100        self.output.try_recv().ok()
101    }
102
103    /// Asynchronously clones this handle
104    pub async fn clone_async(&self) -> Self {
105        self.handle.handle_async().await
106    }
107}
108
109impl<T> Drop for BroadcastReceiver<T> {
110    /// Remove self from sender's map
111    fn drop(&mut self) {
112        let mut map = async_block_on(self.handle.inner.outputs.write());
113        map.remove(&self.id);
114    }
115}
116
117impl<T> Clone for BroadcastReceiver<T>
118where
119    T: Clone,
120{
121    fn clone(&self) -> Self {
122        async_block_on(self.clone_async())
123    }
124}
125
126/// Creates a sender, receiver pair
127pub fn channel<T: Clone>() -> (BroadcastSender<T>, BroadcastReceiver<T>) {
128    let inner = BroadcastSenderInner {
129        count: AtomicUsize::from(0),
130        outputs: RwLock::new(HashMap::new()),
131    };
132    let input = BroadcastSender {
133        inner: Arc::new(inner),
134    };
135    let output = input.handle_sync();
136    (input, output)
137}