use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use flowscope::{FlowEvent, FlowExtractor};
use futures_core::Stream;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use crate::async_adapters::flow_stream::{FlowStream, NoReassembler};
use crate::traits::PacketSource;
pub struct FlowBroadcast<K> {
sender: broadcast::Sender<Arc<FlowEvent<K>>>,
task: tokio::task::JoinHandle<()>,
}
impl<K: Send + Sync + 'static> FlowBroadcast<K> {
pub fn subscribe(&self) -> FlowSubscriber<K> {
FlowSubscriber {
inner: BroadcastStream::new(self.sender.subscribe()),
}
}
pub fn receiver_count(&self) -> usize {
self.sender.receiver_count()
}
}
impl<K> Drop for FlowBroadcast<K> {
fn drop(&mut self) {
self.task.abort();
}
}
pub struct FlowSubscriber<K> {
inner: BroadcastStream<Arc<FlowEvent<K>>>,
}
#[derive(Debug, thiserror::Error)]
pub enum BroadcastRecvError {
#[error("subscriber lagged by {0} events")]
Lagged(u64),
}
impl<K: Send + Sync + 'static> Stream for FlowSubscriber<K> {
type Item = Result<Arc<FlowEvent<K>>, BroadcastRecvError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(ev))) => Poll::Ready(Some(Ok(ev))),
Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) => {
Poll::Ready(Some(Err(BroadcastRecvError::Lagged(n))))
}
}
}
}
impl<S, E> FlowStream<S, E, (), NoReassembler>
where
S: PacketSource + std::os::unix::io::AsRawFd + Send + Unpin + 'static,
E: FlowExtractor + Unpin + Send + 'static,
E::Key: Clone + Send + Sync + Unpin + 'static,
{
pub fn broadcast(self, buffer: usize) -> FlowBroadcast<E::Key> {
let (sender, _initial_rx) = broadcast::channel(buffer);
let publisher = sender.clone();
let task = tokio::spawn(async move {
let mut stream = self;
let mut stream = std::pin::Pin::new(&mut stream);
loop {
let next = std::future::poll_fn(|cx| stream.as_mut().poll_next(cx)).await;
match next {
None => break,
Some(Ok(event)) => {
let _ = publisher.send(Arc::new(event));
}
Some(Err(_)) => {
}
}
}
});
FlowBroadcast { sender, task }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "current_thread")]
async fn two_subscribers_see_all_events() {
use flowscope::{FlowSide, Timestamp};
let (sender, _initial_rx) = broadcast::channel::<Arc<FlowEvent<u32>>>(16);
let publisher = sender.clone();
let task = tokio::spawn(async move {
for i in 0..5u32 {
let _ = publisher.send(Arc::new(FlowEvent::Started {
key: i,
side: FlowSide::Initiator,
ts: Timestamp::default(),
l4: None,
}));
}
});
let bc: FlowBroadcast<u32> = FlowBroadcast { sender, task };
let mut sub_a = bc.subscribe();
let mut sub_b = bc.subscribe();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let mut count_a = 0;
let mut count_b = 0;
for _ in 0..5 {
if let Ok(item) = tokio::time::timeout(std::time::Duration::from_millis(50), async {
use std::pin::Pin;
std::future::poll_fn(|cx| Pin::new(&mut sub_a).poll_next(cx)).await
})
.await
{
if item.is_some() {
count_a += 1;
}
}
}
for _ in 0..5 {
if let Ok(item) = tokio::time::timeout(std::time::Duration::from_millis(50), async {
use std::pin::Pin;
std::future::poll_fn(|cx| Pin::new(&mut sub_b).poll_next(cx)).await
})
.await
{
if item.is_some() {
count_b += 1;
}
}
}
assert!(count_a > 0, "subscriber A saw no events");
assert!(count_b > 0, "subscriber B saw no events");
}
#[test]
fn receiver_count_zero_then_one() {
let (sender, _) = broadcast::channel::<Arc<FlowEvent<u32>>>(8);
let bc = FlowBroadcast {
sender: sender.clone(),
task: tokio::runtime::Builder::new_current_thread()
.build()
.unwrap()
.spawn(async {}),
};
assert_eq!(bc.receiver_count(), 0);
let _sub = bc.subscribe();
assert_eq!(bc.receiver_count(), 1);
}
}