use async_stream::stream;
use futures::{
stream::{self, BoxStream},
Stream, StreamExt,
};
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use tracing::warn;
use tungstenite::Message;
pub trait ResponseStream {
fn handle(&self, input: broadcast::Receiver<Message>) -> BoxStream<'static, Message>;
}
pub struct StreamResponse {
stream_ctor: Arc<dyn Fn() -> BoxStream<'static, Message> + Send + Sync + 'static>,
}
impl StreamResponse {
pub fn new<F, S>(ctor: F) -> Self
where
F: Fn() -> S + Send + Sync + 'static,
S: Stream<Item = Message> + Send + Sync + 'static,
{
let stream_ctor = Arc::new(move || ctor().boxed());
Self { stream_ctor }
}
}
impl ResponseStream for StreamResponse {
fn handle(&self, _: broadcast::Receiver<Message>) -> BoxStream<'static, Message> {
(self.stream_ctor)()
}
}
pub fn pending() -> StreamResponse {
StreamResponse::new(stream::pending)
}
pub struct MapResponder {
map: Arc<dyn Fn(Message) -> Message + Send + Sync + 'static>,
}
impl MapResponder {
pub fn new<F: Fn(Message) -> Message + Send + Sync + 'static>(f: F) -> Self {
Self { map: Arc::new(f) }
}
}
impl ResponseStream for MapResponder {
fn handle(&self, input: broadcast::Receiver<Message>) -> BoxStream<'static, Message> {
let map_fn = Arc::clone(&self.map);
let input = BroadcastStream::new(input);
let stream = stream! {
for await value in input {
match value {
Ok(v) => yield map_fn(v),
Err(e) => {
warn!("Broadcast error: {}", e);
}
}
}
};
stream.boxed()
}
}
pub fn echo_response() -> MapResponder {
MapResponder::new(|msg| msg)
}