1use crate::{
2 common::*, config::BufSize, index_stream::IndexStreamExt as _, rt, stream::StreamExt as _,
3 utils,
4};
5use tokio::sync::{oneshot, watch};
6
7#[derive(Debug)]
14pub struct BroadcastBuilder<T> {
15 pub(super) buf_size: Option<usize>,
16 pub(super) ready_rx: watch::Receiver<()>,
17 pub(super) senders_tx: Option<oneshot::Sender<Vec<flume::Sender<(usize, T)>>>>,
18 pub(super) senders: Option<Vec<flume::Sender<(usize, T)>>>,
19}
20
21impl<T> BroadcastBuilder<T>
22where
23 T: 'static + Send + Clone,
24{
25 pub fn new<B, St>(stream: St, buf_size: B, send_all: bool) -> BroadcastBuilder<T>
26 where
27 St: 'static + Send + Stream<Item = T>,
28 B: Into<BufSize>,
29 {
30 let (senders_tx, senders_rx) = oneshot::channel();
31 let (ready_tx, ready_rx) = watch::channel(());
32
33 rt::spawn(async move {
34 let senders: Vec<flume::Sender<(usize, T)>> = match senders_rx.await {
36 Ok(senders) => senders,
37 Err(_) => return,
38 };
39
40 if ready_tx.send(()).is_err() {
42 return;
43 }
44
45 let num_senders = senders.len();
46
47 match num_senders {
48 0 => {
49 }
51 1 => {
52 let sender = senders.into_iter().next().unwrap();
54 let _ = stream.enumerate().map(Ok).forward(sender.into_sink()).await;
55 }
56 _ => {
57 let sink =
59 futures::sink::unfold(senders, |senders, item: (usize, T)| async move {
60 let futures: stream::FuturesUnordered<_> = senders
62 .into_iter()
63 .map(|tx| {
64 let item = item.clone();
65
66 async move {
67 let result = tx.send_async(item).await;
68
69 result.map(move |()| tx)
71 }
72 })
73 .collect();
74
75 let senders: Vec<_> = futures
77 .filter_map(|tx| future::ready(tx.ok()))
78 .collect()
79 .await;
80
81 let n_remaining_senders = senders.len();
85
86 if (!send_all && n_remaining_senders > 0)
87 || (send_all && (n_remaining_senders == num_senders))
88 {
89 Ok(senders)
90 } else {
91 Err(flume::SendError(()))
92 }
93 });
94
95 let _ = stream.enumerate().map(Ok).forward(sink).await;
96 }
97 }
98 });
99
100 BroadcastBuilder {
101 buf_size: buf_size.into().get(),
102 ready_rx,
103 senders_tx: Some(senders_tx),
104 senders: Some(vec![]),
105 }
106 }
107
108 pub fn register(&mut self) -> BroadcastStream<T> {
110 let Self {
111 buf_size,
112 ref ready_rx,
113 ref mut senders,
114 ..
115 } = *self;
116 let senders = senders.as_mut().unwrap();
117 let mut ready_rx = ready_rx.clone();
118
119 let (tx, rx) = utils::channel(buf_size);
120 senders.push(tx);
121
122 let stream = rx
123 .into_stream()
124 .reorder_enumerated()
125 .wait_until(async move { ready_rx.changed().await.is_ok() })
126 .boxed();
127
128 BroadcastStream { stream }
129 }
130
131 pub fn build(mut self) {
133 let senders_tx = self.senders_tx.take().unwrap();
134 let senders = self.senders.take().unwrap();
135 senders_tx.send(senders).unwrap();
136 }
137}
138
139#[pin_project]
141pub struct BroadcastStream<T> {
142 #[pin]
143 pub(super) stream: BoxStream<'static, T>,
144}
145
146impl<T> Stream for BroadcastStream<T> {
147 type Item = T;
148
149 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150 self.project().stream.poll_next(cx)
151 }
152}
153
154#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::{par_stream::ParStreamExt as _, utils::async_test};
160 use itertools::izip;
161
162 async_test! {
163 async fn broadcast_test() {
164 let mut builder = stream::iter(0..).broadcast(2, true);
165 let rx1 = builder.register();
166 let rx2 = builder.register();
167 builder.build();
168
169 let (ret1, ret2): (Vec<_>, Vec<_>) =
170 join!(rx1.take(100).collect(), rx2.take(100).collect());
171
172 izip!(ret1, 0..100).for_each(|(lhs, rhs)| {
173 assert_eq!(lhs, rhs);
174 });
175 izip!(ret2, 0..100).for_each(|(lhs, rhs)| {
176 assert_eq!(lhs, rhs);
177 });
178 }
179
180 async fn broadcast_and_drop_receiver_test() {
181 {
182 let mut builder = stream::iter(0..).broadcast(2, false);
183 let rx1 = builder.register();
184 let rx2 = builder.register();
185 builder.build();
186
187 drop(rx2);
188
189 let vec: Vec<_> = rx1.take(100).collect().await;
190 izip!(vec, 0..100).for_each(|(lhs, rhs)| {
191 assert_eq!(lhs, rhs);
192 });
193 }
194
195 {
196 let mut builder = stream::iter(0..).broadcast(2, true);
197 let mut rx1 = builder.register();
198 let rx2 = builder.register();
199 builder.build();
200
201 drop(rx2);
202 assert_eq!(rx1.next().await, Some(0));
203 assert!(rx1.next().await.is_none());
204 }
205 }
206 }
207}