use std::time::Duration;
use bytes::Bytes;
use futures::Stream;
use crate::body::{body_from_stream, BoxBody, BoxBodyError};
use crate::response::IntoResponse;
#[derive(Debug, Clone, Default)]
pub struct SseEvent {
pub id: Option<String>,
pub event: Option<String>,
pub data: String,
pub retry: Option<u64>,
comment: Option<String>,
}
impl SseEvent {
pub fn data(data: impl Into<String>) -> Self {
SseEvent {
data: data.into(),
..Default::default()
}
}
pub fn comment(text: impl Into<String>) -> Self {
SseEvent {
comment: Some(text.into()),
..Default::default()
}
}
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
pub fn with_event(mut self, event: impl Into<String>) -> Self {
self.event = Some(event.into());
self
}
pub fn with_retry(mut self, ms: u64) -> Self {
self.retry = Some(ms);
self
}
pub fn to_wire(&self) -> String {
if let Some(ref c) = self.comment {
let mut out = String::new();
for line in c.split('\n') {
out.push(':');
out.push_str(&line.replace('\r', ""));
out.push('\n');
}
out.push('\n');
return out;
}
let mut out = String::new();
if let Some(ref id) = self.id {
out.push_str("id: ");
out.push_str(&strip_cr_lf(id));
out.push('\n');
}
if let Some(ref ev) = self.event {
out.push_str("event: ");
out.push_str(&strip_cr_lf(ev));
out.push('\n');
}
if let Some(ms) = self.retry {
out.push_str("retry: ");
out.push_str(&ms.to_string());
out.push('\n');
}
for line in self.data.split('\n') {
out.push_str("data: ");
out.push_str(&line.replace('\r', ""));
out.push('\n');
}
out.push('\n');
out
}
}
fn strip_cr_lf(s: &str) -> String {
s.chars().filter(|&c| c != '\n' && c != '\r').collect()
}
pub struct SseResponse<S> {
stream: S,
}
impl<S> SseResponse<S>
where
S: Stream<Item = SseEvent> + Send + 'static,
{
pub fn new(stream: S) -> Self {
SseResponse { stream }
}
}
impl<S> IntoResponse for SseResponse<S>
where
S: Stream<Item = SseEvent> + Send + 'static,
{
fn into_response(self) -> http::Response<BoxBody> {
use futures::StreamExt;
let framed = self.stream.map(|ev| {
let wire = ev.to_wire();
Ok::<_, BoxBodyError>(http_body::Frame::data(Bytes::from(wire)))
});
let body = body_from_stream(framed);
let mut res = http::Response::new(body);
let h = res.headers_mut();
h.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/event-stream"),
);
h.insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_static("no-cache"),
);
h.insert(
http::header::CONNECTION,
http::HeaderValue::from_static("keep-alive"),
);
h.insert(
http::HeaderName::from_static("x-accel-buffering"),
http::HeaderValue::from_static("no"),
);
res
}
}
pub fn keep_alive<S>(stream: S, interval: Duration) -> impl Stream<Item = SseEvent>
where
S: Stream<Item = SseEvent> + Send + 'static,
{
let pings = futures::stream::unfold((), move |_| async move {
tokio::time::sleep(interval).await;
Some((SseEvent::comment("keepalive"), ()))
});
futures::stream::select(stream, pings)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn data_event_wire_format() {
let ev = SseEvent::data("hello");
assert_eq!(ev.to_wire(), "data: hello\n\n");
}
#[test]
fn multi_line_data_splits_into_multiple_data_lines() {
let ev = SseEvent::data("line one\nline two\nline three");
assert_eq!(
ev.to_wire(),
"data: line one\ndata: line two\ndata: line three\n\n"
);
}
#[test]
fn full_event_wire_format() {
let ev = SseEvent::data("payload")
.with_id("42")
.with_event("update")
.with_retry(3000);
assert_eq!(
ev.to_wire(),
"id: 42\nevent: update\nretry: 3000\ndata: payload\n\n"
);
}
#[test]
fn comment_event_wire_format() {
let ev = SseEvent::comment("keepalive");
assert_eq!(ev.to_wire(), ":keepalive\n\n");
}
#[test]
fn comment_with_newline_is_split() {
let ev = SseEvent::comment("line1\nline2");
assert_eq!(ev.to_wire(), ":line1\n:line2\n\n");
}
#[test]
fn cr_lf_is_stripped_from_scalar_fields() {
let ev = SseEvent::data("ok").with_id("1\n2\r3").with_event("a\nb");
let wire = ev.to_wire();
assert!(wire.contains("id: 123\n"));
assert!(wire.contains("event: ab\n"));
}
#[test]
fn carriage_return_in_data_is_dropped() {
let ev = SseEvent::data("a\rb\nc\rd");
assert_eq!(ev.to_wire(), "data: ab\ndata: cd\n\n");
}
#[tokio::test]
async fn sse_response_sets_required_headers() {
use futures::stream;
let s = stream::iter(vec![SseEvent::data("hi")]);
let res = SseResponse::new(s).into_response();
assert_eq!(res.status(), http::StatusCode::OK);
assert_eq!(
res.headers().get(http::header::CONTENT_TYPE).unwrap(),
"text/event-stream"
);
assert_eq!(
res.headers().get(http::header::CACHE_CONTROL).unwrap(),
"no-cache"
);
assert_eq!(
res.headers().get(http::header::CONNECTION).unwrap(),
"keep-alive"
);
assert_eq!(res.headers().get("x-accel-buffering").unwrap(), "no");
}
#[tokio::test]
async fn sse_response_streams_event_bytes() {
use futures::stream;
use http_body_util::BodyExt;
let s = stream::iter(vec![
SseEvent::data("first"),
SseEvent::data("second").with_event("update"),
]);
let res = SseResponse::new(s).into_response();
let collected = res.into_body().collect().await.unwrap().to_bytes();
let text = std::str::from_utf8(&collected).unwrap();
assert_eq!(text, "data: first\n\nevent: update\ndata: second\n\n");
}
#[tokio::test]
async fn keep_alive_interleaves_pings() {
use futures::StreamExt;
use std::time::Duration;
let pending = futures::stream::pending::<SseEvent>();
let source = futures::stream::iter(vec![SseEvent::data("real")]).chain(pending);
let mut combined = Box::pin(keep_alive(source, Duration::from_millis(20)));
let first = tokio::time::timeout(Duration::from_millis(200), combined.next())
.await
.unwrap()
.unwrap();
let second = tokio::time::timeout(Duration::from_millis(200), combined.next())
.await
.unwrap()
.unwrap();
let saw_ping = first.comment.is_some() || second.comment.is_some();
assert!(saw_ping, "expected at least one keep-alive ping");
}
}