1use futures::stream::{self, StreamExt};
14use tokio::sync::{broadcast, mpsc};
15
16use crate::source::Source;
17
18pub struct BroadcastHub<T: Clone + Send + 'static> {
22 sender: broadcast::Sender<T>,
23}
24
25impl<T: Clone + Send + 'static> BroadcastHub<T> {
26 pub fn new(buffer_size: usize) -> Self {
27 assert!(buffer_size >= 1, "buffer_size must be >= 1");
28 let (sender, _rx) = broadcast::channel(buffer_size);
29 Self { sender }
30 }
31
32 pub fn attach(&self, source: Source<T>) {
35 let tx = self.sender.clone();
36 tokio::spawn(async move {
37 let mut s = source.into_boxed();
38 while let Some(item) = s.next().await {
39 let _ = tx.send(item); }
41 });
42 }
43
44 pub fn consumer(&self) -> Source<T> {
48 let rx = self.sender.subscribe();
49 let stream = stream::unfold(rx, |mut rx| async move {
50 loop {
51 match rx.recv().await {
52 Ok(item) => return Some((item, rx)),
53 Err(broadcast::error::RecvError::Lagged(_)) => continue,
54 Err(broadcast::error::RecvError::Closed) => return None,
55 }
56 }
57 });
58 Source { inner: stream.boxed() }
59 }
60
61 pub fn consumer_count(&self) -> usize {
63 self.sender.receiver_count()
64 }
65}
66
67pub struct MergeHub<T: Send + 'static> {
71 sender: mpsc::UnboundedSender<T>,
72 receiver: parking_lot::Mutex<Option<mpsc::UnboundedReceiver<T>>>,
74}
75
76impl<T: Send + 'static> Default for MergeHub<T> {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl<T: Send + 'static> MergeHub<T> {
83 pub fn new() -> Self {
84 let (tx, rx) = mpsc::unbounded_channel();
85 Self { sender: tx, receiver: parking_lot::Mutex::new(Some(rx)) }
86 }
87
88 pub fn attach(&self, source: Source<T>) {
90 let tx = self.sender.clone();
91 tokio::spawn(async move {
92 let mut s = source.into_boxed();
93 while let Some(item) = s.next().await {
94 if tx.send(item).is_err() {
95 return;
96 }
97 }
98 });
99 }
100
101 pub fn source(&self) -> Source<T> {
104 match self.receiver.lock().take() {
105 Some(rx) => Source::from_receiver(rx),
106 None => Source::empty(),
107 }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::sink::Sink;
115 use std::time::Duration;
116
117 #[tokio::test]
118 async fn broadcast_hub_fans_to_two_consumers() {
119 let hub = BroadcastHub::<i32>::new(16);
120 let c1 = hub.consumer();
121 let c2 = hub.consumer();
122
123 hub.attach(Source::from_iter(vec![1, 2, 3]));
125
126 drop(hub);
129
130 let (a, b) = tokio::join!(Sink::collect(c1), Sink::collect(c2));
132 assert_eq!(a, vec![1, 2, 3]);
133 assert_eq!(b, vec![1, 2, 3]);
134 }
135
136 #[tokio::test]
137 async fn broadcast_hub_late_consumer_misses_earlier_elements() {
138 let hub = BroadcastHub::<i32>::new(16);
139 let c_pre = hub.consumer();
142 hub.attach(Source::from_iter(vec![1, 2, 3]));
143 let pre = tokio::time::timeout(Duration::from_millis(200), async move {
147 let mut got = Vec::new();
148 let mut s = c_pre.into_boxed();
149 while got.len() < 3 {
150 match s.next().await {
151 Some(v) => got.push(v),
152 None => break,
153 }
154 }
155 got
156 })
157 .await
158 .unwrap_or_default();
159 assert_eq!(pre, vec![1, 2, 3]);
160
161 let c_late = hub.consumer();
164 let late =
165 tokio::time::timeout(Duration::from_millis(50), Sink::collect(c_late)).await.unwrap_or_default();
166 assert!(late.is_empty());
167 }
168
169 #[tokio::test]
170 async fn broadcast_hub_consumer_count_grows_with_subscribers() {
171 let hub = BroadcastHub::<i32>::new(4);
172 assert_eq!(hub.consumer_count(), 0);
173 let _c1 = hub.consumer();
174 let _c2 = hub.consumer();
175 assert_eq!(hub.consumer_count(), 2);
176 }
177
178 #[tokio::test]
179 async fn merge_hub_aggregates_multiple_producers() {
180 let hub = MergeHub::<i32>::new();
181 hub.attach(Source::from_iter(vec![1, 2, 3]));
182 hub.attach(Source::from_iter(vec![10, 20, 30]));
183 let merged = hub.source();
184 drop(hub);
187
188 let mut got = Sink::collect(merged).await;
189 got.sort();
190 assert_eq!(got, vec![1, 2, 3, 10, 20, 30]);
191 }
192
193 #[tokio::test]
194 async fn merge_hub_source_can_be_taken_only_once() {
195 let hub = MergeHub::<i32>::new();
196 hub.attach(Source::from_iter(vec![1]));
197 let _ = hub.source();
198 let s2 = hub.source();
199 let v = tokio::time::timeout(Duration::from_millis(50), Sink::collect(s2)).await.unwrap_or_default();
200 assert!(v.is_empty());
201 }
202}