par_stream/
broadcast.rs

1use crate::{
2    common::*, config::BufSize, index_stream::IndexStreamExt as _, rt, stream::StreamExt as _,
3    utils,
4};
5use tokio::sync::{oneshot, watch};
6
7/// The build type returned from [broadcast()](crate::par_stream::ParStreamExt::broadcast).
8///
9/// It is used to register new broadcast receivers. Each receiver consumes copies
10/// of items of the stream. The builder is finished by `guard.build()` so that
11/// registered receivers can start consuming data. Otherwise, the receivers
12/// take empty input.
13#[derive(Debug)]
14pub struct BroadcastBuilder<T> {
15    pub(super) buf_size: Option<usize>,
16    pub(super) ready_rx: watch::Receiver<()>,
17    pub(super) senders_tx: Option<oneshot::Sender<Vec<flume::Sender<(usize, T)>>>>,
18    pub(super) senders: Option<Vec<flume::Sender<(usize, T)>>>,
19}
20
21impl<T> BroadcastBuilder<T>
22where
23    T: 'static + Send + Clone,
24{
25    pub fn new<B, St>(stream: St, buf_size: B, send_all: bool) -> BroadcastBuilder<T>
26    where
27        St: 'static + Send + Stream<Item = T>,
28        B: Into<BufSize>,
29    {
30        let (senders_tx, senders_rx) = oneshot::channel();
31        let (ready_tx, ready_rx) = watch::channel(());
32
33        rt::spawn(async move {
34            // wait for receiver list to be ready
35            let senders: Vec<flume::Sender<(usize, T)>> = match senders_rx.await {
36                Ok(senders) => senders,
37                Err(_) => return,
38            };
39
40            // tell subscribers to be ready
41            if ready_tx.send(()).is_err() {
42                return;
43            }
44
45            let num_senders = senders.len();
46
47            match num_senders {
48                0 => {
49                    // fall through for zero senders
50                }
51                1 => {
52                    // fast path for single sender
53                    let sender = senders.into_iter().next().unwrap();
54                    let _ = stream.enumerate().map(Ok).forward(sender.into_sink()).await;
55                }
56                _ => {
57                    // merge senders into a sink
58                    let sink =
59                        futures::sink::unfold(senders, |senders, item: (usize, T)| async move {
60                            // let each sender sends a copy of the item
61                            let futures: stream::FuturesUnordered<_> = senders
62                                .into_iter()
63                                .map(|tx| {
64                                    let item = item.clone();
65
66                                    async move {
67                                        let result = tx.send_async(item).await;
68
69                                        // if sending is successful, return the sender back
70                                        result.map(move |()| tx)
71                                    }
72                                })
73                                .collect();
74
75                            // collect senders back
76                            let senders: Vec<_> = futures
77                                .filter_map(|tx| future::ready(tx.ok()))
78                                .collect()
79                                .await;
80
81                            // finish sink if
82                            // case 1: send_all == true, no senders fail
83                            // case 2: send_all == false, there are successful sender(s)
84                            let n_remaining_senders = senders.len();
85
86                            if (!send_all && n_remaining_senders > 0)
87                                || (send_all && (n_remaining_senders == num_senders))
88                            {
89                                Ok(senders)
90                            } else {
91                                Err(flume::SendError(()))
92                            }
93                        });
94
95                    let _ = stream.enumerate().map(Ok).forward(sink).await;
96                }
97            }
98        });
99
100        BroadcastBuilder {
101            buf_size: buf_size.into().get(),
102            ready_rx,
103            senders_tx: Some(senders_tx),
104            senders: Some(vec![]),
105        }
106    }
107
108    /// Creates a new receiver.
109    pub fn register(&mut self) -> BroadcastStream<T> {
110        let Self {
111            buf_size,
112            ref ready_rx,
113            ref mut senders,
114            ..
115        } = *self;
116        let senders = senders.as_mut().unwrap();
117        let mut ready_rx = ready_rx.clone();
118
119        let (tx, rx) = utils::channel(buf_size);
120        senders.push(tx);
121
122        let stream = rx
123            .into_stream()
124            .reorder_enumerated()
125            .wait_until(async move { ready_rx.changed().await.is_ok() })
126            .boxed();
127
128        BroadcastStream { stream }
129    }
130
131    /// Finish the builder to start broadcasting.
132    pub fn build(mut self) {
133        let senders_tx = self.senders_tx.take().unwrap();
134        let senders = self.senders.take().unwrap();
135        senders_tx.send(senders).unwrap();
136    }
137}
138
139/// The receiver that consumes broadcasted messages from the stream.
140#[pin_project]
141pub struct BroadcastStream<T> {
142    #[pin]
143    pub(super) stream: BoxStream<'static, T>,
144}
145
146impl<T> Stream for BroadcastStream<T> {
147    type Item = T;
148
149    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150        self.project().stream.poll_next(cx)
151    }
152}
153
154// tests
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::{par_stream::ParStreamExt as _, utils::async_test};
160    use itertools::izip;
161
162    async_test! {
163        async fn broadcast_test() {
164            let mut builder = stream::iter(0..).broadcast(2, true);
165            let rx1 = builder.register();
166            let rx2 = builder.register();
167            builder.build();
168
169            let (ret1, ret2): (Vec<_>, Vec<_>) =
170                join!(rx1.take(100).collect(), rx2.take(100).collect());
171
172            izip!(ret1, 0..100).for_each(|(lhs, rhs)| {
173                assert_eq!(lhs, rhs);
174            });
175            izip!(ret2, 0..100).for_each(|(lhs, rhs)| {
176                assert_eq!(lhs, rhs);
177            });
178        }
179
180        async fn broadcast_and_drop_receiver_test() {
181            {
182                let mut builder = stream::iter(0..).broadcast(2, false);
183                let rx1 = builder.register();
184                let rx2 = builder.register();
185                builder.build();
186
187                drop(rx2);
188
189                let vec: Vec<_> = rx1.take(100).collect().await;
190                izip!(vec, 0..100).for_each(|(lhs, rhs)| {
191                    assert_eq!(lhs, rhs);
192                });
193            }
194
195            {
196                let mut builder = stream::iter(0..).broadcast(2, true);
197                let mut rx1 = builder.register();
198                let rx2 = builder.register();
199                builder.build();
200
201                drop(rx2);
202                assert_eq!(rx1.next().await, Some(0));
203                assert!(rx1.next().await.is_none());
204            }
205        }
206    }
207}