use bytes::Bytes;
use hyper::body::{Body, Frame, SizeHint};
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio::time::{interval_at, Duration, Instant, Interval};
#[derive(Debug, Clone)]
pub struct SseEvent {
pub data: String,
pub event: Option<String>,
pub id: Option<String>,
pub retry: Option<u64>,
}
impl SseEvent {
pub fn data(data: impl Into<String>) -> Self {
Self {
data: data.into(),
event: None,
id: None,
retry: None,
}
}
pub fn event(mut self, event: impl Into<String>) -> Self {
let s: String = event.into();
self.event = Some(s.replace(['\n', '\r', '\0'], ""));
self
}
pub fn id(mut self, id: impl Into<String>) -> Self {
let s: String = id.into();
self.id = Some(s.replace(['\n', '\r', '\0'], ""));
self
}
pub fn retry(mut self, ms: u64) -> Self {
self.retry = Some(ms);
self
}
pub fn to_wire(&self) -> String {
self.to_string()
}
}
impl fmt::Display for SseEvent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(event) = &self.event {
writeln!(f, "event: {event}")?;
}
if let Some(id) = &self.id {
writeln!(f, "id: {id}")?;
}
if let Some(retry) = self.retry {
writeln!(f, "retry: {retry}")?;
}
for line in self.data.lines() {
writeln!(f, "data: {line}")?;
}
if self.data.is_empty() {
writeln!(f, "data: ")?;
}
writeln!(f)
}
}
pub struct SseStream {
receiver: mpsc::Receiver<SseEvent>,
ping_interval: Interval,
}
impl SseStream {
pub fn channel(buffer: usize) -> (mpsc::Sender<SseEvent>, Self) {
let (tx, rx) = mpsc::channel(buffer);
let period = Duration::from_secs(15);
let ping = interval_at(Instant::now() + period, period);
(
tx,
SseStream {
receiver: rx,
ping_interval: ping,
},
)
}
pub fn is_closed(&self) -> bool {
self.receiver.is_closed()
}
#[cfg(test)]
pub(crate) fn channel_with_interval(
buffer: usize,
interval_period: Duration,
) -> (mpsc::Sender<SseEvent>, Self) {
let (tx, rx) = mpsc::channel(buffer);
let ping = interval_at(Instant::now() + interval_period, interval_period);
(
tx,
SseStream {
receiver: rx,
ping_interval: ping,
},
)
}
}
impl Body for SseStream {
type Data = Bytes;
type Error = std::convert::Infallible;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Bytes>, Self::Error>>> {
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(event)) => {
self.ping_interval.reset();
let bytes = Bytes::from(event.to_wire());
return Poll::Ready(Some(Ok(Frame::data(bytes))));
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {}
}
match Pin::new(&mut self.ping_interval).poll_tick(cx) {
Poll::Ready(_) => {
let ping = Bytes::from_static(b":ping\n\n");
Poll::Ready(Some(Ok(Frame::data(ping))))
}
Poll::Pending => Poll::Pending,
}
}
fn is_end_stream(&self) -> bool {
false
}
fn size_hint(&self) -> SizeHint {
SizeHint::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::response::HttpResponse;
use futures_util::task::noop_waker;
#[test]
fn sse_event_wire_format() {
let event = SseEvent::data("hello").event("msg").id("1").retry(3000);
let wire = event.to_wire();
assert_eq!(wire, "event: msg\nid: 1\nretry: 3000\ndata: hello\n\n");
}
#[test]
fn sse_event_multi_line_data() {
let event = SseEvent::data("line one\nline two");
let wire = event.to_wire();
assert_eq!(wire, "data: line one\ndata: line two\n\n");
}
#[test]
fn sse_event_empty_data() {
let event = SseEvent::data("");
let wire = event.to_wire();
assert_eq!(wire, "data: \n\n");
}
#[test]
fn sse_event_data_only() {
let wire = SseEvent::data("hello world").to_wire();
assert_eq!(wire, "data: hello world\n\n");
}
#[tokio::test]
async fn sse_stream_poll_delivers_event() {
let (tx, mut stream) = SseStream::channel(4);
tx.send(SseEvent::data("first")).await.unwrap();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let frame = Pin::new(&mut stream).poll_frame(&mut cx);
match frame {
Poll::Ready(Some(Ok(f))) => {
let data = f.into_data().expect("expected data frame");
assert_eq!(data, Bytes::from("data: first\n\n"));
}
other => panic!("expected Poll::Ready(Some(Ok(frame))), got {other:?}"),
}
let frame2 = Pin::new(&mut stream).poll_frame(&mut cx);
assert!(
matches!(frame2, Poll::Pending),
"expected Poll::Pending with no queued events, got {frame2:?}"
);
}
#[tokio::test]
async fn sse_stream_keep_alive_ping() {
let period = Duration::from_millis(10);
let (_tx, mut stream) = SseStream::channel_with_interval(4, period);
tokio::time::sleep(period * 3).await;
use http_body_util::BodyExt;
let frame = tokio::time::timeout(Duration::from_millis(200), stream.frame())
.await
.expect("timed out waiting for :ping frame")
.expect("stream ended unexpectedly")
.expect("poll_frame returned error");
let data = frame.into_data().expect("expected data frame");
assert_eq!(data, Bytes::from_static(b":ping\n\n"));
}
#[test]
fn sse_field_injection_newline_stripped() {
let wire = SseEvent::data("x").event("a\nb").to_wire();
let event_lines: Vec<&str> = wire.lines().filter(|l| l.starts_with("event:")).collect();
assert_eq!(
event_lines.len(),
1,
"expected exactly one event: line, got: {wire:?}"
);
assert_eq!(
event_lines[0], "event: ab",
"embedded newline should be stripped, not injected"
);
let wire2 = SseEvent::data("y").id("c\rd").to_wire();
let id_lines: Vec<&str> = wire2.lines().filter(|l| l.starts_with("id:")).collect();
assert_eq!(
id_lines.len(),
1,
"expected exactly one id: line, got: {wire2:?}"
);
assert_eq!(
id_lines[0], "id: cd",
"embedded carriage-return should be stripped, not injected"
);
let wire3 = SseEvent::data("z").id("e\0f").event("g\0h").to_wire();
assert!(
wire3.contains("id: ef") && wire3.contains("event: gh"),
"embedded NUL should be stripped from id and event, got: {wire3:?}"
);
}
#[tokio::test]
async fn sse_stream_incremental_delivery() {
let (tx, mut stream) = SseStream::channel(4);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let before = Pin::new(&mut stream).poll_frame(&mut cx);
assert!(
matches!(before, Poll::Pending),
"expected Poll::Pending before send"
);
tx.send(SseEvent::data("N")).await.unwrap();
let after = Pin::new(&mut stream).poll_frame(&mut cx);
assert!(
matches!(after, Poll::Ready(Some(Ok(_)))),
"expected Poll::Ready after send"
);
let still_pending = Pin::new(&mut stream).poll_frame(&mut cx);
assert!(
matches!(still_pending, Poll::Pending),
"expected Poll::Pending before N+1 send"
);
}
#[tokio::test]
async fn sse_factory_headers() {
let (_, resp) = HttpResponse::sse_channel(16);
let headers = resp.headers();
let header_value =
|name: &str| -> Option<&str> { headers.get(name).and_then(|v| v.to_str().ok()) };
assert_eq!(
header_value("content-type"),
Some("text/event-stream"),
"Content-Type must be text/event-stream"
);
assert_eq!(
header_value("cache-control"),
Some("no-cache"),
"Cache-Control must be no-cache"
);
assert_eq!(
header_value("connection"),
Some("keep-alive"),
"Connection must be keep-alive"
);
assert_eq!(
header_value("x-accel-buffering"),
Some("no"),
"X-Accel-Buffering must be no"
);
}
#[tokio::test]
async fn sse_response_is_stream_variant() {
let (_, sse_resp) = HttpResponse::sse_channel(16);
assert!(
sse_resp.body().is_streaming(),
"SSE response body must be FerroBody::Stream"
);
let buffered_resp = HttpResponse::text("hello").into_hyper();
assert!(
!buffered_resp.body().is_streaming(),
"buffered response body must NOT be FerroBody::Stream"
);
}
}