use futures::stream::{self, StreamExt};
use tokio::sync::{broadcast, mpsc};
use crate::source::Source;
pub struct BroadcastHub<T: Clone + Send + 'static> {
sender: broadcast::Sender<T>,
}
impl<T: Clone + Send + 'static> BroadcastHub<T> {
pub fn new(buffer_size: usize) -> Self {
assert!(buffer_size >= 1, "buffer_size must be >= 1");
let (sender, _rx) = broadcast::channel(buffer_size);
Self { sender }
}
pub fn attach(&self, source: Source<T>) {
let tx = self.sender.clone();
tokio::spawn(async move {
let mut s = source.into_boxed();
while let Some(item) = s.next().await {
let _ = tx.send(item); }
});
}
pub fn consumer(&self) -> Source<T> {
let rx = self.sender.subscribe();
let stream = stream::unfold(rx, |mut rx| async move {
loop {
match rx.recv().await {
Ok(item) => return Some((item, rx)),
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => return None,
}
}
});
Source { inner: stream.boxed() }
}
pub fn consumer_count(&self) -> usize {
self.sender.receiver_count()
}
}
pub struct MergeHub<T: Send + 'static> {
sender: mpsc::UnboundedSender<T>,
receiver: parking_lot::Mutex<Option<mpsc::UnboundedReceiver<T>>>,
}
impl<T: Send + 'static> Default for MergeHub<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Send + 'static> MergeHub<T> {
pub fn new() -> Self {
let (tx, rx) = mpsc::unbounded_channel();
Self { sender: tx, receiver: parking_lot::Mutex::new(Some(rx)) }
}
pub fn attach(&self, source: Source<T>) {
let tx = self.sender.clone();
tokio::spawn(async move {
let mut s = source.into_boxed();
while let Some(item) = s.next().await {
if tx.send(item).is_err() {
return;
}
}
});
}
pub fn source(&self) -> Source<T> {
match self.receiver.lock().take() {
Some(rx) => Source::from_receiver(rx),
None => Source::empty(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sink::Sink;
use std::time::Duration;
#[tokio::test]
async fn broadcast_hub_fans_to_two_consumers() {
let hub = BroadcastHub::<i32>::new(16);
let c1 = hub.consumer();
let c2 = hub.consumer();
hub.attach(Source::from_iter(vec![1, 2, 3]));
drop(hub);
let (a, b) = tokio::join!(Sink::collect(c1), Sink::collect(c2));
assert_eq!(a, vec![1, 2, 3]);
assert_eq!(b, vec![1, 2, 3]);
}
#[tokio::test]
async fn broadcast_hub_late_consumer_misses_earlier_elements() {
let hub = BroadcastHub::<i32>::new(16);
let c_pre = hub.consumer();
hub.attach(Source::from_iter(vec![1, 2, 3]));
let pre = tokio::time::timeout(Duration::from_millis(200), async move {
let mut got = Vec::new();
let mut s = c_pre.into_boxed();
while got.len() < 3 {
match s.next().await {
Some(v) => got.push(v),
None => break,
}
}
got
})
.await
.unwrap_or_default();
assert_eq!(pre, vec![1, 2, 3]);
let c_late = hub.consumer();
let late =
tokio::time::timeout(Duration::from_millis(50), Sink::collect(c_late)).await.unwrap_or_default();
assert!(late.is_empty());
}
#[tokio::test]
async fn broadcast_hub_consumer_count_grows_with_subscribers() {
let hub = BroadcastHub::<i32>::new(4);
assert_eq!(hub.consumer_count(), 0);
let _c1 = hub.consumer();
let _c2 = hub.consumer();
assert_eq!(hub.consumer_count(), 2);
}
#[tokio::test]
async fn merge_hub_aggregates_multiple_producers() {
let hub = MergeHub::<i32>::new();
hub.attach(Source::from_iter(vec![1, 2, 3]));
hub.attach(Source::from_iter(vec![10, 20, 30]));
let merged = hub.source();
drop(hub);
let mut got = Sink::collect(merged).await;
got.sort();
assert_eq!(got, vec![1, 2, 3, 10, 20, 30]);
}
#[tokio::test]
async fn merge_hub_source_can_be_taken_only_once() {
let hub = MergeHub::<i32>::new();
hub.attach(Source::from_iter(vec![1]));
let _ = hub.source();
let s2 = hub.source();
let v = tokio::time::timeout(Duration::from_millis(50), Sink::collect(s2)).await.unwrap_or_default();
assert!(v.is_empty());
}
}