use axum::http::header;
use axum::response::sse::Sse;
use axum::response::{IntoResponse, Response};
use futures_core::Stream;
use std::pin::Pin;
pub(crate) fn format_sse_event(event_type: &str, data_key: &str, data: &str) -> String {
let oneline = data
.trim()
.lines()
.map(str::trim)
.collect::<Vec<_>>()
.join(" ");
format!("event: {event_type}\ndata: {data_key} {oneline}\n\n")
}
pub(crate) fn into_sse_response(body: String) -> Response {
(
[
(header::CONTENT_TYPE, "text/event-stream"),
(header::CACHE_CONTROL, "no-cache"),
],
body,
)
.into_response()
}
pub struct SseFragment {
html: String,
}
impl SseFragment {
pub fn new<T: askama::Template>(template: T) -> crate::error::Result<Self> {
let html = template
.render()
.map_err(|err| crate::error::Error::Internal(err.to_string()))?;
Ok(Self { html })
}
pub fn from_html(html: String) -> Self {
Self { html }
}
}
impl IntoResponse for SseFragment {
fn into_response(self) -> Response {
into_sse_response(format_sse_event(
"datastar-patch-elements",
"elements",
&self.html,
))
}
}
pub struct SseSignals {
json: String,
}
impl SseSignals {
pub fn new<T: serde::Serialize>(data: &T) -> crate::error::Result<Self> {
let json = serde_json::to_string(data)
.map_err(|err| crate::error::Error::Internal(err.to_string()))?;
Ok(Self { json })
}
}
impl IntoResponse for SseSignals {
fn into_response(self) -> Response {
into_sse_response(format_sse_event(
"datastar-patch-signals",
"signals",
&self.json,
))
}
}
type BoxedEventStream = Pin<
Box<dyn Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>> + Send>,
>;
pub struct SseStream {
inner: Sse<BoxedEventStream>,
}
impl SseStream {
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>
+ Send
+ 'static,
{
Self {
inner: Sse::new(Box::pin(stream)),
}
}
}
impl IntoResponse for SseStream {
fn into_response(self) -> Response {
self.inner.into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::to_bytes;
use axum::response::sse::Event;
use serde_json::json;
async fn response_body(resp: Response) -> String {
let bytes = to_bytes(resp.into_body(), 1024 * 64)
.await
.expect("read body");
String::from_utf8(bytes.to_vec()).expect("valid utf-8")
}
#[test]
fn format_sse_event_single_line() {
let result = format_sse_event("datastar-patch-elements", "elements", "<div>hello</div>");
assert_eq!(
result,
"event: datastar-patch-elements\ndata: elements <div>hello</div>\n\n"
);
}
#[test]
fn format_sse_event_multiline_collapses_to_single_line() {
let html = "<div>\n <p>hi</p>\n</div>";
let result = format_sse_event("datastar-patch-elements", "elements", html);
assert_eq!(
result,
"event: datastar-patch-elements\ndata: elements <div> <p>hi</p> </div>\n\n"
);
}
#[test]
fn format_sse_event_signals() {
let result = format_sse_event("datastar-patch-signals", "signals", r#"{"count":42}"#);
assert_eq!(
result,
"event: datastar-patch-signals\ndata: signals {\"count\":42}\n\n"
);
}
#[tokio::test]
async fn sse_fragment_from_html_has_correct_content_type() {
let fragment = SseFragment::from_html("<p>test</p>".to_owned());
let resp = fragment.into_response();
let content_type = resp
.headers()
.get(header::CONTENT_TYPE)
.expect("Content-Type header")
.to_str()
.expect("valid str");
assert!(
content_type.contains("text/event-stream"),
"expected text/event-stream, got: {content_type}"
);
}
#[tokio::test]
async fn sse_fragment_from_html_has_cache_control() {
let fragment = SseFragment::from_html("<p>test</p>".to_owned());
let resp = fragment.into_response();
let cache = resp
.headers()
.get(header::CACHE_CONTROL)
.expect("Cache-Control header")
.to_str()
.expect("valid str");
assert_eq!(cache, "no-cache");
}
#[tokio::test]
async fn sse_fragment_from_html_produces_correct_body() {
let fragment = SseFragment::from_html("<p>test</p>".to_owned());
let resp = fragment.into_response();
let body = response_body(resp).await;
assert_eq!(
body,
"event: datastar-patch-elements\ndata: elements <p>test</p>\n\n"
);
}
#[tokio::test]
async fn sse_fragment_multiline_html() {
let html = "<div>\n <span>inner</span>\n</div>".to_owned();
let fragment = SseFragment::from_html(html);
let resp = fragment.into_response();
let body = response_body(resp).await;
assert_eq!(
body,
"event: datastar-patch-elements\ndata: elements <div> <span>inner</span> </div>\n\n"
);
}
#[tokio::test]
async fn sse_signals_produces_valid_sse() {
let data = json!({"count": 42});
let signals = SseSignals::new(&data).expect("serialize");
let resp = signals.into_response();
let content_type = resp
.headers()
.get(header::CONTENT_TYPE)
.expect("Content-Type header")
.to_str()
.expect("valid str");
assert!(
content_type.contains("text/event-stream"),
"expected text/event-stream, got: {content_type}"
);
let body = response_body(resp).await;
assert!(body.starts_with("event: datastar-patch-signals\n"));
assert!(body.contains("data: signals "));
assert!(body.contains("\"count\":42"));
}
#[tokio::test]
async fn sse_signals_cache_control() {
let data = json!({"ok": true});
let signals = SseSignals::new(&data).expect("serialize");
let resp = signals.into_response();
let cache = resp
.headers()
.get(header::CACHE_CONTROL)
.expect("Cache-Control header")
.to_str()
.expect("valid str");
assert_eq!(cache, "no-cache");
}
#[tokio::test]
async fn sse_stream_has_event_stream_content_type() {
let stream =
SingleEventStream::new(Event::default().event("datastar-patch-elements").data("hi"));
let sse_stream = SseStream::new(stream);
let resp = sse_stream.into_response();
let content_type = resp
.headers()
.get(header::CONTENT_TYPE)
.expect("Content-Type header")
.to_str()
.expect("valid str");
assert!(
content_type.contains("text/event-stream"),
"expected text/event-stream, got: {content_type}"
);
}
struct SingleEventStream {
event: Option<Event>,
}
impl SingleEventStream {
fn new(event: Event) -> Self {
Self { event: Some(event) }
}
}
impl Stream for SingleEventStream {
type Item = Result<Event, std::convert::Infallible>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::task::Poll::Ready(self.event.take().map(Ok))
}
}
}