under/
sse.rs

1//! Async SSE.
2//!
3//! This adds some wrappers around using the `async-sse` crate with this
4//! HTTP library, making it easier to handle SSE connections.  It is gated
5//! behind the `sse` feature flag for those who do not want to use it.
6
7use crate::{HttpEntity, Request, Response};
8pub use async_sse::Sender;
9use futures::StreamExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use tokio_util::compat::FuturesAsyncReadCompatExt;
14
15/// Creates an endpoint that can handle SSE connections.  This directly
16/// upgrades the HTTP request to SSE unconditionally, before calling the
17/// handler function with the current request and the SSE sender.
18///
19/// # Examples
20/// ```rust,no_run
21/// # use under::*;
22/// use under::sse::Sender;
23///
24/// async fn handle(req: Request, mut sender: Sender) -> Result<(), anyhow::Error> {
25///     sender.send(None, "hello, world!", None).await?;
26///     Ok(())
27/// }
28///
29/// let mut http = under::http();
30/// http.at("/sse").get(under::sse::endpoint(handle));
31/// ```
32pub fn endpoint<F, Fut>(handle: F) -> SseEndpoint<F>
33where
34    F: Fn(Request, Sender) -> Fut + Send + Sync + 'static,
35    Fut: Future<Output = crate::Result<()>> + Send + 'static,
36{
37    SseEndpoint::new(handle)
38}
39
40/// Upgrades a request to SSE.  This allows you to check beforehand if a request
41/// should be upgraded to SSE, instead of [`endpoint`], which directly upgrades
42/// the connection.
43///
44/// # Examples
45/// ```rust,no_run
46/// # use under::*;
47/// use under::sse::Sender;
48///
49/// async fn sse(request: Request, mut sender: Sender) -> Result<(), anyhow::Error> {
50///     sender.send(None, "hello, world!", None).await?;
51///     Ok(())
52/// }
53///
54/// fn should_upgrade_to_sse(request: &Request) -> bool {
55/// #    return true;
56///     // ...
57/// }
58///
59/// async fn handle(request: Request) -> Result<Response, anyhow::Error> {
60///     if should_upgrade_to_sse(&request) {
61///         under::sse::upgrade(request, sse)
62///     } else {
63///        Ok(Response::empty_404())
64///     }
65/// }
66///
67/// let mut http = under::http();
68/// http.at("/sse").get(handle);
69/// ```
70#[allow(clippy::missing_errors_doc)]
71pub fn upgrade<F, Fut>(request: Request, handle: F) -> Result<Response, anyhow::Error>
72where
73    F: FnOnce(Request, Sender) -> Fut + Send + Sync + 'static,
74    Fut: Future<Output = crate::Result<()>> + Send + 'static,
75{
76    Ok(handle_sse(request, handle))
77}
78
79/// Performs a heartbeat on an SSE connection.  This allows the server to
80/// ensure that a client is still connected.  This is expected, generally, to
81/// be used in conjunction with either [`endpoint`] or [`upgrade`].  The steam
82/// passed in should be cancellable, and will be cancelled if it does not
83/// resolve within the heartbeat timeout (1s by default).  This is mostly
84/// expected to be used in a loop.
85///
86/// # Errors
87/// This will return an error if the heartbeat fails to send.  This implies
88/// an issue with the underlying connection.
89///
90/// # Examples
91/// ```rust,no_run
92/// # use under::*;
93/// use under::sse::{Sender, stream_heartbeat};
94///
95/// # fn some_stream() -> impl futures::Stream<Item = u64> {
96/// #     futures::stream::iter(vec![1, 2, 3])
97/// # }
98///
99/// async fn sse(request: Request, mut sender: Sender) -> Result<(), anyhow::Error> {
100///     let mut stream = some_stream();
101///     while let Some(event) = stream_heartbeat(&mut sender, &mut stream).await? {
102///         sender.send(None, &format!("{}", event), None).await?;
103///     }
104///     Ok(())
105/// }
106///
107/// let mut http = under::http();
108/// http.at("/sse").get(under::sse::endpoint(sse));
109/// ```
110pub async fn stream_heartbeat<I, S: futures::Stream<Item = I> + Unpin>(
111    sender: &mut Sender,
112    stream: &mut S,
113) -> Result<Option<I>, anyhow::Error> {
114    loop {
115        let time = tokio::time::timeout(tokio::time::Duration::from_secs(1), stream.next()).await;
116
117        match time {
118            Ok(t) => {
119                return Ok(t);
120            }
121            Err(_) => {
122                sender.send("_hb", "", None).await?;
123            }
124        }
125    }
126}
127
128#[derive(Debug, Clone)]
129/// An instance of an SSE endpoint.
130///
131/// This is created by [`endpoint`], and implements the [`crate::Endpoint`]
132/// trait.
133pub struct SseEndpoint<F>(Arc<F>);
134
135impl<F> SseEndpoint<F> {
136    fn new(f: F) -> Self {
137        SseEndpoint(Arc::new(f))
138    }
139}
140
141#[async_trait]
142impl<F, Fut> crate::Endpoint for SseEndpoint<F>
143where
144    F: Fn(Request, Sender) -> Fut + Send + Sync + 'static,
145    Fut: Future<Output = crate::Result<()>> + Send + 'static,
146{
147    async fn apply(self: Pin<&Self>, request: Request) -> Result<Response, anyhow::Error> {
148        let h = self.0.clone();
149        // we need this for lifetime extension.  If we pass in `h` directly,
150        // `h` would be bound to the lifetime of this function.
151        #[allow(clippy::redundant_closure)]
152        Ok(handle_sse(request, move |r, s| h(r, s)))
153    }
154
155    fn describe(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        f.debug_tuple("SseEndpoint")
157            .field(&std::any::type_name::<F>())
158            .finish()
159    }
160}
161
162fn handle_sse<F, Fut>(request: Request, handle: F) -> crate::Response
163where
164    F: FnOnce(Request, Sender) -> Fut + Send + Sync + 'static,
165    Fut: Future<Output = crate::Result<()>> + Send + 'static,
166{
167    let (sender, encoder) = async_sse::encode();
168
169    let stream = tokio_util::io::ReaderStream::new(encoder.compat());
170    let response = Response::empty_200()
171        .with_header("Cache-Control", "no-cache")
172        .expect("Cache-Control is a valid header")
173        .with_header("Content-Type", "text/event-stream")
174        .expect("Content-Type is a valid header")
175        .with_body(hyper::Body::wrap_stream(stream));
176
177    tokio::task::spawn(async move {
178        handle(request, sender).await.ok();
179    });
180
181    response
182}