use bytes::Bytes;
use futures::Stream;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::{broadcast, RwLock};
use tracing::{debug, warn};
use crate::errors::Result;
pub struct OutputStream {
rx: broadcast::Receiver<Bytes>,
closed: Arc<AtomicBool>,
}
impl OutputStream {
fn new(rx: broadcast::Receiver<Bytes>, closed: Arc<AtomicBool>) -> Self {
Self { rx, closed }
}
}
impl Stream for OutputStream {
type Item = Bytes;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.closed.load(Ordering::SeqCst) {
return Poll::Ready(None);
}
match self.rx.try_recv() {
Ok(item) => Poll::Ready(Some(item)),
Err(broadcast::error::TryRecvError::Empty) => {
if self.closed.load(Ordering::SeqCst) {
return Poll::Ready(None);
}
cx.waker().wake_by_ref();
Poll::Pending
}
Err(broadcast::error::TryRecvError::Lagged(skipped)) => {
warn!(skipped, "Output stream lagged, messages were dropped");
cx.waker().wake_by_ref();
Poll::Pending
}
Err(broadcast::error::TryRecvError::Closed) => Poll::Ready(None),
}
}
}
pub struct ChannelMultiplexer {
output_tx: broadcast::Sender<Bytes>,
closed: Arc<AtomicBool>,
#[allow(dead_code)] sequence_counter: Arc<RwLock<i64>>,
}
impl ChannelMultiplexer {
pub fn new() -> Self {
let (output_tx, _) = broadcast::channel(1024);
Self {
output_tx,
closed: Arc::new(AtomicBool::new(false)),
sequence_counter: Arc::new(RwLock::new(0)),
}
}
pub fn output_stream(&self) -> OutputStream {
OutputStream::new(self.output_tx.subscribe(), Arc::clone(&self.closed))
}
pub fn close(&self) {
debug!("Closing channel multiplexer");
self.closed.store(true, Ordering::SeqCst);
}
pub async fn send_output(&self, data: Bytes) -> Result<()> {
if self.output_tx.send(data).is_err() {
debug!("No active output stream receivers");
}
Ok(())
}
#[allow(dead_code)] pub async fn next_sequence(&self) -> i64 {
let mut counter = self.sequence_counter.write().await;
*counter += 1;
*counter
}
}
impl Default for ChannelMultiplexer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[tokio::test]
async fn test_output_stream() {
let mux = ChannelMultiplexer::new();
let mut stream = mux.output_stream();
mux.send_output(Bytes::from("test1")).await.unwrap();
mux.send_output(Bytes::from("test2")).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let data1 = stream.next().await.unwrap();
assert_eq!(data1, Bytes::from("test1"));
let data2 = stream.next().await.unwrap();
assert_eq!(data2, Bytes::from("test2"));
}
#[tokio::test]
async fn test_multiple_output_streams() {
let mux = ChannelMultiplexer::new();
let mut stream1 = mux.output_stream();
let mut stream2 = mux.output_stream();
mux.send_output(Bytes::from("broadcast")).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let data1 = stream1.next().await.unwrap();
let data2 = stream2.next().await.unwrap();
assert_eq!(data1, Bytes::from("broadcast"));
assert_eq!(data2, Bytes::from("broadcast"));
}
#[tokio::test]
async fn test_sequence_numbers() {
let mux = ChannelMultiplexer::new();
let seq1 = mux.next_sequence().await;
let seq2 = mux.next_sequence().await;
let seq3 = mux.next_sequence().await;
assert_eq!(seq1, 1);
assert_eq!(seq2, 2);
assert_eq!(seq3, 3);
}
}