Skip to main content

broadcaster/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4//! broadcaster provides a wrapper for any Stream and Sink implementing the mpsc pattern to enable
5//! broadcasting items. This means that any item sent will be received by every receiver, not just
6//! the first to check (like most mpmc streams). As an example:
7//! ```rust
8//! use broadcaster::BroadcastChannel;
9//!
10//! # use futures_executor::block_on;
11//! use futures_util::StreamExt;
12//!
13//! # block_on(async {
14//! let mut chan = BroadcastChannel::new();
15//! chan.send(&5i32).await?;
16//! assert_eq!(chan.next().await, Some(5));
17//!
18//! let mut chan2 = chan.clone();
19//! chan2.send(&6i32).await?;
20//! assert_eq!(chan.next().await, Some(6));
21//! assert_eq!(chan2.next().await, Some(6));
22//! # Ok::<(), futures_channel::mpsc::SendError>(())
23//! # }).unwrap();
24//! ```
25
26use futures_core::{future::*, stream::*, task::Poll};
27use futures_sink::Sink;
28use futures_util::sink::SinkExt;
29use futures_util::stream::StreamExt;
30use futures_util::future::try_join_all;
31use slab::Slab;
32use std::fmt::{self, Debug};
33use std::sync::Arc;
34
35#[cfg(not(feature = "default-channels"))]
36use std::sync::RwLock;
37
38#[cfg(feature = "default-channels")]
39use parking_lot::RwLock;
40
41#[cfg(feature = "default-channels")]
42use futures_channel::mpsc::*;
43use futures_util::task::Context;
44use std::pin::Pin;
45
46/// A broadcast channel, wrapping any clonable Stream and Sink to have every message sent to every
47/// receiver.
48pub struct BroadcastChannel<
49    T,
50    #[cfg(feature = "default-channels")] S = UnboundedSender<T>,
51    #[cfg(feature = "default-channels")] R = UnboundedReceiver<T>,
52    #[cfg(not(feature = "default-channels"))] S,
53    #[cfg(not(feature = "default-channels"))] R,
54> where
55    T: Send + Clone + 'static,
56    S: Send + Sync + Unpin + Clone + Sink<T>,
57    R: Unpin + Stream<Item = T>,
58{
59    senders: Arc<RwLock<Slab<S>>>,
60    sender_key: usize,
61    receiver: R,
62    ctor: Arc<dyn Fn() -> (S, R) + Send + Sync>,
63}
64
65#[cfg(feature = "default-channels")]
66impl<T: Send + Clone> BroadcastChannel<T> {
67    /// Create a new unbounded channel. Requires the `default-channels` feature.
68    pub fn new() -> Self {
69        let (tx, rx) = unbounded();
70        let mut slab = Slab::new();
71        let sender_key = slab.insert(tx);
72        Self {
73            senders: Arc::new(RwLock::new(slab)),
74            sender_key,
75            receiver: rx,
76            ctor: Arc::new(unbounded),
77        }
78    }
79}
80
81#[cfg(feature = "default-channels")]
82impl<T: Send + Clone> BroadcastChannel<T, Sender<T>, Receiver<T>> {
83    /// Create a new bounded channel with a specific capacity. Requires the `default-channels` feature.
84    pub fn with_cap(cap: usize) -> Self {
85        let (tx, rx) = channel(cap);
86        let mut slab = Slab::new();
87        let sender_key = slab.insert(tx);
88        Self {
89            senders: Arc::new(RwLock::new(slab)),
90            sender_key,
91            receiver: rx,
92            ctor: Arc::new(move || channel(cap)),
93        }
94    }
95
96    /// Try sending a value on a bounded channel. Requires the `default-channels` feature.
97    pub fn try_send(&self, item: &T) -> Result<(), TrySendError<T>> {
98        #[cfg(feature = "parking-lot")]
99        let mut senders: Slab<Sender<T>> = Slab::clone(&*self.senders.read());
100
101        #[cfg(not(feature = "parking-lot"))]
102        let mut senders: Slab<Sender<T>> = Slab::clone(&*self.senders.read().unwrap());
103
104        senders
105            .iter_mut()
106            .map(|(_, s)| s.try_send(item.clone()))
107            .collect()
108    }
109}
110
111impl<T, S, R> BroadcastChannel<T, S, R>
112where
113    T: Send + Clone + 'static,
114    S: Send + Sync + Unpin + Clone + Sink<T>,
115    R: Unpin + Stream<Item = T>,
116{
117    /// Construct a new channel from any Sink and Stream. For proper functionality, cloning a
118    /// Sender will create a new sink that also sends data to Receiver.
119    pub fn with_ctor(ctor: Arc<dyn Fn() -> (S, R) + Send + Sync>) -> Self {
120        let (tx, rx) = ctor();
121        let mut slab = Slab::new();
122        let sender_key = slab.insert(tx);
123        Self {
124            senders: Arc::new(RwLock::new(slab)),
125            sender_key,
126            receiver: rx,
127            ctor,
128        }
129    }
130
131    /// Send an item to all receivers in the channel, including this one. This is because
132    /// futures-channel does not support comparing a sender and receiver. If this is not the
133    /// desired behavior, you must handle it yourself.
134    pub async fn send(&self, item: &T) -> Result<(), S::Error> {
135        let mut senders = self.senders();
136        try_join_all(senders.iter_mut().map(|(_, s)| s.send(item.clone()))).await?;
137        Ok(())
138    }
139
140    /// Receive a single value from the channel.
141    pub fn recv(&mut self) -> impl Future<Output = Option<T>> + '_ {
142        self.next()
143    }
144
145    /// Internal helper method to get a copy of the senders
146    fn senders(&self) -> Slab<S> {
147        // can't be split up because of how async/await works
148        #[cfg(feature = "parking-lot")]
149        let senders: Slab<S> = Slab::clone(&*self.senders.read());
150
151        #[cfg(not(feature = "parking-lot"))]
152        let senders: Slab<S> = Slab::clone(&*self.senders.read().unwrap());
153
154        senders
155    }
156}
157
158impl<T, S, R> Clone for BroadcastChannel<T, S, R>
159where
160    T: Send + Clone + 'static,
161    S: Send + Sync + Unpin + Clone + Sink<T>,
162    R: Unpin + Stream<Item = T>,
163{
164    fn clone(&self) -> Self {
165        let (tx, rx) = (self.ctor)();
166        #[cfg(feature = "parking-lot")]
167        let sender_key = self.senders.write().insert(tx);
168
169        #[cfg(not(feature = "parking-lot"))]
170        let sender_key = self.senders.write().unwrap().insert(tx);
171
172        Self {
173            senders: self.senders.clone(),
174            sender_key,
175            receiver: rx,
176            ctor: self.ctor.clone(),
177        }
178    }
179}
180
181impl<T, S, R> Drop for BroadcastChannel<T, S, R>
182where
183    T: Send + Clone + 'static,
184    S: Send + Sync + Unpin + Clone + Sink<T>,
185    R: Unpin + Stream<Item = T>,
186{
187    fn drop(&mut self) {
188        #[cfg(feature = "parking-lot")]
189        self.senders.write().remove(self.sender_key);
190
191        #[cfg(not(feature = "parking-lot"))]
192        self.senders.write().unwrap().remove(self.sender_key);
193    }
194}
195
196impl<T, S, R> Debug for BroadcastChannel<T, S, R>
197where
198    T: Send + Clone + 'static,
199    S: Send + Sync + Unpin + Clone + Debug + Sink<T>,
200    R: Unpin + Debug + Stream<Item = T>,
201{
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        f.debug_struct("BroadcastChannel")
204            .field("senders", &self.senders)
205            .field("sender_key", &self.sender_key)
206            .field("receiver", &self.receiver)
207            .finish()
208    }
209}
210
211impl<T, S, R> Stream for BroadcastChannel<T, S, R>
212where
213    T: Send + Clone + 'static,
214    S: Send + Sync + Unpin + Clone + Sink<T>,
215    R: Unpin + Stream<Item = T>,
216{
217    type Item = T;
218
219    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220        (&mut self.receiver).poll_next_unpin(cx)
221    }
222}
223
224impl<T, S, R> Sink<T> for &BroadcastChannel<T, S, R>
225where
226    T: Send + Clone + 'static,
227    S: Send + Sync + Unpin + Clone + Sink<T>,
228    R: Unpin + Stream<Item = T>,
229{
230    type Error = S::Error;
231
232    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
233        (*self)
234            .senders()
235            .iter_mut()
236            .map(|(_, sender)| Pin::new(sender).poll_ready(cx))
237            .find_map(|poll| match poll {
238                Poll::Ready(Err(_)) | Poll::Pending => Some(poll),
239                _ => None,
240            })
241            .or_else(|| Some(Poll::Ready(Ok(()))))
242            .unwrap()
243    }
244
245    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
246        (*self)
247            .senders()
248            .iter_mut()
249            .map(|(_, sender)| Pin::new(sender).start_send(item.clone()))
250            .collect::<Result<_, _>>()
251    }
252
253    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254        (*self)
255            .senders()
256            .iter_mut()
257            .map(|(_, sender)| Pin::new(sender).poll_flush(cx))
258            .find_map(|poll| match poll {
259                Poll::Ready(Err(_)) | Poll::Pending => Some(poll),
260                _ => None,
261            })
262            .or_else(|| Some(Poll::Ready(Ok(()))))
263            .unwrap()
264    }
265
266    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
267        (*self)
268            .senders()
269            .iter_mut()
270            .map(|(_, sender)| Pin::new(sender).poll_close(cx))
271            .find_map(|poll| match poll {
272                Poll::Ready(Err(_)) | Poll::Pending => Some(poll),
273                _ => None,
274            })
275            .or_else(|| Some(Poll::Ready(Ok(()))))
276            .unwrap()
277    }
278}
279
280impl<T, S, R> Sink<T> for BroadcastChannel<T, S, R>
281    where
282        T: Send + Clone + 'static,
283        S: Send + Sync + Unpin + Clone + Sink<T>,
284        R: Unpin + Stream<Item = T>,
285{
286    type Error = S::Error;
287
288    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
289        Sink::poll_ready(Pin::new(&mut &*self), cx)
290    }
291
292    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
293        Sink::start_send(Pin::new(&mut &*self), item)
294    }
295
296    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
297        Sink::poll_flush(Pin::new(&mut &*self), cx)
298    }
299
300    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
301        Sink::poll_close(Pin::new(&mut &*self), cx)
302    }
303}
304
305#[cfg(all(feature = "default-channels", test))]
306mod test {
307    use super::BroadcastChannel;
308    use futures_executor::block_on;
309    use futures_util::future::{FutureExt, ready};
310    use futures_core::future::Future;
311    use futures_util::{StreamExt, SinkExt};
312    use futures_channel::mpsc::SendError;
313
314    #[test]
315    fn send_next() {
316        let mut chan = BroadcastChannel::new();
317        block_on(chan.send(&5)).unwrap();
318        assert_eq!(block_on(chan.next()), Some(5));
319    }
320
321    #[test]
322    fn split() {
323        // test some of the extension methods from StreamExt and SinkExt
324        fn plus_1(num: usize) -> impl Future<Output = Result<usize, SendError>> {
325            ready(Ok(num + 1))
326        }
327
328        let chan = BroadcastChannel::new();
329        let chan_cloned = chan.clone();
330
331        let (sink, stream) = chan.split();
332        let mut sink = sink.with(plus_1);
333        block_on(sink.send(5)).unwrap();
334        block_on(chan_cloned.send(&10)).unwrap();
335
336        assert_eq!(block_on(stream.take(2).collect::<Vec<_>>()), vec![6, 10]);
337    }
338
339    #[test]
340    fn now_or_never() {
341        let fut = async {
342            let mut chan = BroadcastChannel::new();
343            chan.send(&5i32).await?;
344            assert_eq!(chan.next().await, Some(5));
345
346            let mut chan2 = chan.clone();
347            chan2.send(&6i32).await?;
348            assert_eq!(chan.next().await, Some(6));
349            assert_eq!(chan2.next().await, Some(6));
350            Ok::<(), futures_channel::mpsc::SendError>(())
351        };
352        fut.now_or_never().unwrap().unwrap();
353    }
354
355    #[test]
356    fn try_send() {
357        let fut = async {
358            let mut chan = BroadcastChannel::with_cap(2);
359            chan.try_send(&5i32)?;
360            assert_eq!(chan.next().await, Some(5));
361
362            let mut chan2 = chan.clone();
363            chan2.try_send(&6i32)?;
364            assert_eq!(chan.next().await, Some(6));
365            assert_eq!(chan2.next().await, Some(6));
366            Ok::<(), futures_channel::mpsc::TrySendError<i32>>(())
367        };
368        fut.now_or_never().unwrap().unwrap();
369    }
370
371    fn assert_impl_send<T: Send>() {}
372    fn assert_impl_sync<T: Sync>() {}
373    fn assert_val_impl_send<T: Send>(_val: &T) {}
374    fn assert_val_impl_sync<T: Sync>(_val: &T) {}
375
376    #[test]
377    fn recv_two() {
378        let fut = async {
379            let mut chan = BroadcastChannel::new();
380            chan.send(&5i32).await?;
381            assert_eq!(chan.next().await, Some(5));
382
383            let mut chan2 = chan.clone();
384            chan2.send(&6i32).await?;
385            assert_eq!(chan.next().await, Some(6));
386            assert_eq!(chan2.next().await, Some(6));
387            Ok::<(), futures_channel::mpsc::SendError>(())
388        };
389        assert_val_impl_send(&fut);
390        assert_val_impl_sync(&fut);
391        block_on(fut).unwrap();
392    }
393
394    #[test]
395    fn send_sync() {
396        assert_impl_send::<BroadcastChannel<i32>>();
397        assert_impl_sync::<BroadcastChannel<i32>>();
398    }
399}