use crate::{HttpEntity, Request, Response};
pub use async_sse::Sender;
use futures::StreamExt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio_util::compat::FuturesAsyncReadCompatExt;
pub fn endpoint<F, Fut>(handle: F) -> SseEndpoint<F>
where
F: Fn(Request, Sender) -> Fut + Send + Sync + 'static,
Fut: Future<Output = crate::Result<()>> + Send + 'static,
{
SseEndpoint::new(handle)
}
#[allow(clippy::missing_errors_doc)]
pub fn upgrade<F, Fut>(request: Request, handle: F) -> Result<Response, anyhow::Error>
where
F: FnOnce(Request, Sender) -> Fut + Send + Sync + 'static,
Fut: Future<Output = crate::Result<()>> + Send + 'static,
{
Ok(handle_sse(request, handle))
}
pub async fn stream_heartbeat<I, S: futures::Stream<Item = I> + Unpin>(
sender: &mut Sender,
stream: &mut S,
) -> Result<Option<I>, anyhow::Error> {
loop {
let time = tokio::time::timeout(tokio::time::Duration::from_secs(1), stream.next()).await;
match time {
Ok(t) => {
return Ok(t);
}
Err(_) => {
sender.send("_hb", "", None).await?;
}
}
}
}
#[derive(Debug, Clone)]
pub struct SseEndpoint<F>(Arc<F>);
impl<F> SseEndpoint<F> {
fn new(f: F) -> Self {
SseEndpoint(Arc::new(f))
}
}
#[async_trait]
impl<F, Fut> crate::Endpoint for SseEndpoint<F>
where
F: Fn(Request, Sender) -> Fut + Send + Sync + 'static,
Fut: Future<Output = crate::Result<()>> + Send + 'static,
{
async fn apply(self: Pin<&Self>, request: Request) -> Result<Response, anyhow::Error> {
let h = self.0.clone();
#[allow(clippy::redundant_closure)]
Ok(handle_sse(request, move |r, s| h(r, s)))
}
fn describe(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SseEndpoint")
.field(&std::any::type_name::<F>())
.finish()
}
}
fn handle_sse<F, Fut>(request: Request, handle: F) -> crate::Response
where
F: FnOnce(Request, Sender) -> Fut + Send + Sync + 'static,
Fut: Future<Output = crate::Result<()>> + Send + 'static,
{
let (sender, encoder) = async_sse::encode();
let stream = tokio_util::io::ReaderStream::new(encoder.compat());
let response = Response::empty_200()
.with_header("Cache-Control", "no-cache")
.expect("Cache-Control is a valid header")
.with_header("Content-Type", "text/event-stream")
.expect("Content-Type is a valid header")
.with_body(hyper::Body::wrap_stream(stream));
tokio::task::spawn(async move {
handle(request, sender).await.ok();
});
response
}