under/sse.rs
1//! Async SSE.
2//!
3//! This adds some wrappers around using the `async-sse` crate with this
4//! HTTP library, making it easier to handle SSE connections. It is gated
5//! behind the `sse` feature flag for those who do not want to use it.
6
7use crate::{HttpEntity, Request, Response};
8pub use async_sse::Sender;
9use futures::StreamExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use tokio_util::compat::FuturesAsyncReadCompatExt;
14
15/// Creates an endpoint that can handle SSE connections. This directly
16/// upgrades the HTTP request to SSE unconditionally, before calling the
17/// handler function with the current request and the SSE sender.
18///
19/// # Examples
20/// ```rust,no_run
21/// # use under::*;
22/// use under::sse::Sender;
23///
24/// async fn handle(req: Request, mut sender: Sender) -> Result<(), anyhow::Error> {
25/// sender.send(None, "hello, world!", None).await?;
26/// Ok(())
27/// }
28///
29/// let mut http = under::http();
30/// http.at("/sse").get(under::sse::endpoint(handle));
31/// ```
32pub fn endpoint<F, Fut>(handle: F) -> SseEndpoint<F>
33where
34 F: Fn(Request, Sender) -> Fut + Send + Sync + 'static,
35 Fut: Future<Output = crate::Result<()>> + Send + 'static,
36{
37 SseEndpoint::new(handle)
38}
39
40/// Upgrades a request to SSE. This allows you to check beforehand if a request
41/// should be upgraded to SSE, instead of [`endpoint`], which directly upgrades
42/// the connection.
43///
44/// # Examples
45/// ```rust,no_run
46/// # use under::*;
47/// use under::sse::Sender;
48///
49/// async fn sse(request: Request, mut sender: Sender) -> Result<(), anyhow::Error> {
50/// sender.send(None, "hello, world!", None).await?;
51/// Ok(())
52/// }
53///
54/// fn should_upgrade_to_sse(request: &Request) -> bool {
55/// # return true;
56/// // ...
57/// }
58///
59/// async fn handle(request: Request) -> Result<Response, anyhow::Error> {
60/// if should_upgrade_to_sse(&request) {
61/// under::sse::upgrade(request, sse)
62/// } else {
63/// Ok(Response::empty_404())
64/// }
65/// }
66///
67/// let mut http = under::http();
68/// http.at("/sse").get(handle);
69/// ```
70#[allow(clippy::missing_errors_doc)]
71pub fn upgrade<F, Fut>(request: Request, handle: F) -> Result<Response, anyhow::Error>
72where
73 F: FnOnce(Request, Sender) -> Fut + Send + Sync + 'static,
74 Fut: Future<Output = crate::Result<()>> + Send + 'static,
75{
76 Ok(handle_sse(request, handle))
77}
78
79/// Performs a heartbeat on an SSE connection. This allows the server to
80/// ensure that a client is still connected. This is expected, generally, to
81/// be used in conjunction with either [`endpoint`] or [`upgrade`]. The steam
82/// passed in should be cancellable, and will be cancelled if it does not
83/// resolve within the heartbeat timeout (1s by default). This is mostly
84/// expected to be used in a loop.
85///
86/// # Errors
87/// This will return an error if the heartbeat fails to send. This implies
88/// an issue with the underlying connection.
89///
90/// # Examples
91/// ```rust,no_run
92/// # use under::*;
93/// use under::sse::{Sender, stream_heartbeat};
94///
95/// # fn some_stream() -> impl futures::Stream<Item = u64> {
96/// # futures::stream::iter(vec![1, 2, 3])
97/// # }
98///
99/// async fn sse(request: Request, mut sender: Sender) -> Result<(), anyhow::Error> {
100/// let mut stream = some_stream();
101/// while let Some(event) = stream_heartbeat(&mut sender, &mut stream).await? {
102/// sender.send(None, &format!("{}", event), None).await?;
103/// }
104/// Ok(())
105/// }
106///
107/// let mut http = under::http();
108/// http.at("/sse").get(under::sse::endpoint(sse));
109/// ```
110pub async fn stream_heartbeat<I, S: futures::Stream<Item = I> + Unpin>(
111 sender: &mut Sender,
112 stream: &mut S,
113) -> Result<Option<I>, anyhow::Error> {
114 loop {
115 let time = tokio::time::timeout(tokio::time::Duration::from_secs(1), stream.next()).await;
116
117 match time {
118 Ok(t) => {
119 return Ok(t);
120 }
121 Err(_) => {
122 sender.send("_hb", "", None).await?;
123 }
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
129/// An instance of an SSE endpoint.
130///
131/// This is created by [`endpoint`], and implements the [`crate::Endpoint`]
132/// trait.
133pub struct SseEndpoint<F>(Arc<F>);
134
135impl<F> SseEndpoint<F> {
136 fn new(f: F) -> Self {
137 SseEndpoint(Arc::new(f))
138 }
139}
140
141#[async_trait]
142impl<F, Fut> crate::Endpoint for SseEndpoint<F>
143where
144 F: Fn(Request, Sender) -> Fut + Send + Sync + 'static,
145 Fut: Future<Output = crate::Result<()>> + Send + 'static,
146{
147 async fn apply(self: Pin<&Self>, request: Request) -> Result<Response, anyhow::Error> {
148 let h = self.0.clone();
149 // we need this for lifetime extension. If we pass in `h` directly,
150 // `h` would be bound to the lifetime of this function.
151 #[allow(clippy::redundant_closure)]
152 Ok(handle_sse(request, move |r, s| h(r, s)))
153 }
154
155 fn describe(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.debug_tuple("SseEndpoint")
157 .field(&std::any::type_name::<F>())
158 .finish()
159 }
160}
161
162fn handle_sse<F, Fut>(request: Request, handle: F) -> crate::Response
163where
164 F: FnOnce(Request, Sender) -> Fut + Send + Sync + 'static,
165 Fut: Future<Output = crate::Result<()>> + Send + 'static,
166{
167 let (sender, encoder) = async_sse::encode();
168
169 let stream = tokio_util::io::ReaderStream::new(encoder.compat());
170 let response = Response::empty_200()
171 .with_header("Cache-Control", "no-cache")
172 .expect("Cache-Control is a valid header")
173 .with_header("Content-Type", "text/event-stream")
174 .expect("Content-Type is a valid header")
175 .with_body(hyper::Body::wrap_stream(stream));
176
177 tokio::task::spawn(async move {
178 handle(request, sender).await.ok();
179 });
180
181 response
182}