use std::convert::Infallible;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use bytes::Bytes;
use bytes::BytesMut;
use futures_util::Stream;
use futures_util::StreamExt;
use http::StatusCode;
use http::header;
use http_body_util::StreamBody;
use pin_project_lite::pin_project;
use tako_rs_core::body::TakoBody;
use tako_rs_core::responder::Responder;
use tako_rs_core::types::Response;
const PREFIX: &[u8] = b"data: ";
const SUFFIX: &[u8] = b"\n\n";
const PS_LEN: usize = PREFIX.len() + SUFFIX.len();
const KEEPALIVE_FRAME: &[u8] = b":keepalive\n\n";
#[derive(Debug, Clone, Default)]
pub struct SseEvent {
pub data: Option<String>,
pub event: Option<String>,
pub id: Option<String>,
pub retry_ms: Option<u64>,
pub comment: Option<String>,
}
impl SseEvent {
pub fn data(d: impl Into<String>) -> Self {
Self {
data: Some(d.into()),
..Default::default()
}
}
pub fn comment(c: impl Into<String>) -> Self {
Self {
comment: Some(c.into()),
..Default::default()
}
}
pub fn retry(d: Duration) -> Self {
Self {
retry_ms: Some(d.as_millis() as u64),
..Default::default()
}
}
pub fn event(mut self, e: impl Into<String>) -> Self {
self.event = Some(e.into());
self
}
pub fn id(mut self, i: impl Into<String>) -> Self {
self.id = Some(i.into());
self
}
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(64);
if let Some(c) = self.comment.as_deref() {
for line in c.split('\n') {
buf.extend_from_slice(b": ");
buf.extend_from_slice(strip_cr(line).as_bytes());
buf.extend_from_slice(b"\n");
}
}
if let Some(e) = self.event.as_deref() {
buf.extend_from_slice(b"event: ");
buf.extend_from_slice(sanitize_single_line(e).as_bytes());
buf.extend_from_slice(b"\n");
}
if let Some(i) = self.id.as_deref() {
buf.extend_from_slice(b"id: ");
buf.extend_from_slice(sanitize_single_line(i).as_bytes());
buf.extend_from_slice(b"\n");
}
if let Some(r) = self.retry_ms {
buf.extend_from_slice(b"retry: ");
buf.extend_from_slice(r.to_string().as_bytes());
buf.extend_from_slice(b"\n");
}
if let Some(d) = self.data.as_deref() {
for line in d.split('\n') {
buf.extend_from_slice(b"data: ");
buf.extend_from_slice(strip_cr(line).as_bytes());
buf.extend_from_slice(b"\n");
}
}
buf.extend_from_slice(b"\n");
buf.freeze()
}
}
fn sanitize_single_line(s: &str) -> String {
s.replace(['\n', '\r'], " ")
}
fn strip_cr(s: &str) -> String {
s.replace('\r', "")
}
#[doc(alias = "sse")]
#[doc(alias = "eventsource")]
pub struct Sse<S> {
pub(crate) stream: S,
pub(crate) keepalive: Option<Duration>,
}
impl<S> Sse<S>
where
S: Stream<Item = Bytes> + Send + 'static,
{
pub fn new(stream: S) -> Self {
Self {
stream,
keepalive: None,
}
}
}
impl<S> Sse<S> {
pub fn keep_alive(mut self, period: Duration) -> Self {
self.keepalive = Some(period);
self
}
}
impl<S> Responder for Sse<S>
where
S: Stream<Item = Bytes> + Send + 'static,
{
fn into_response(self) -> Response {
let mapped = self.stream.map(|msg| {
let mut buf = BytesMut::with_capacity(PS_LEN + msg.len());
buf.extend_from_slice(PREFIX);
buf.extend_from_slice(&msg);
buf.extend_from_slice(SUFFIX);
Ok::<_, Infallible>(http_body::Frame::data(Bytes::from(buf)))
});
let body = if let Some(period) = self.keepalive {
let stream = KeepAliveStream::new(mapped, period, Bytes::from_static(KEEPALIVE_FRAME));
TakoBody::new(StreamBody::new(stream))
} else {
TakoBody::new(StreamBody::new(mapped))
};
build_sse_response(body)
}
}
pub struct SseEvents<S> {
stream: S,
keepalive: Option<Duration>,
}
impl<S> Sse<S> {
pub fn events(stream: S) -> SseEvents<S>
where
S: Stream<Item = SseEvent> + Send + 'static,
{
SseEvents {
stream,
keepalive: None,
}
}
}
impl<S> SseEvents<S> {
pub fn keep_alive(mut self, period: Duration) -> Self {
self.keepalive = Some(period);
self
}
}
impl<S> Responder for SseEvents<S>
where
S: Stream<Item = SseEvent> + Send + 'static,
{
fn into_response(self) -> Response {
let mapped = self
.stream
.map(|ev| Ok::<_, Infallible>(http_body::Frame::data(ev.encode())));
let body = if let Some(period) = self.keepalive {
let stream = KeepAliveStream::new(mapped, period, Bytes::from_static(KEEPALIVE_FRAME));
TakoBody::new(StreamBody::new(stream))
} else {
TakoBody::new(StreamBody::new(mapped))
};
build_sse_response(body)
}
}
fn build_sse_response(body: TakoBody) -> Response {
http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache, no-store, must-revalidate")
.header(header::CONNECTION, "keep-alive")
.header("X-Accel-Buffering", "no")
.body(body)
.expect("valid SSE response")
}
pin_project! {
struct KeepAliveStream<S> {
#[pin]
inner: S,
#[pin]
sleep: tokio::time::Sleep,
period: Duration,
keepalive_frame: Bytes,
inner_done: bool,
}
}
impl<S> KeepAliveStream<S> {
fn new(inner: S, period: Duration, keepalive_frame: Bytes) -> Self {
Self {
inner,
sleep: tokio::time::sleep(period),
period,
keepalive_frame,
inner_done: false,
}
}
}
impl<S> Stream for KeepAliveStream<S>
where
S: Stream<Item = Result<http_body::Frame<Bytes>, Infallible>>,
{
type Item = Result<http_body::Frame<Bytes>, Infallible>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if !*this.inner_done {
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(item)) => {
let deadline = tokio::time::Instant::now() + *this.period;
this.sleep.as_mut().reset(deadline);
return Poll::Ready(Some(item));
}
Poll::Ready(None) => {
*this.inner_done = true;
}
Poll::Pending => {}
}
}
if *this.inner_done {
return Poll::Ready(None);
}
if this.sleep.as_mut().poll(cx).is_ready() {
let frame = http_body::Frame::data(this.keepalive_frame.clone());
let deadline = tokio::time::Instant::now() + *this.period;
this.sleep.as_mut().reset(deadline);
return Poll::Ready(Some(Ok(frame)));
}
Poll::Pending
}
}
pub fn last_event_id(headers: &http::HeaderMap) -> Option<String> {
headers
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
}
pub fn last_event_id_bytes(headers: &http::HeaderMap) -> Option<Vec<u8>> {
let bytes = headers.get("last-event-id")?.as_bytes();
let start = bytes.iter().position(|b| !b.is_ascii_whitespace())?;
let end = bytes
.iter()
.rposition(|b| !b.is_ascii_whitespace())
.map_or(start, |i| i + 1);
Some(bytes[start..end].to_vec())
}
#[cfg(test)]
mod tests {
use super::SseEvent;
#[test]
fn event_and_id_strip_crlf() {
let frame = SseEvent::data("payload")
.event("legit\nid: hostile")
.id("a\r\nb")
.encode();
let s = std::str::from_utf8(&frame).unwrap();
assert!(
s.contains("event: legit id: hostile\n"),
"expected sanitized event line, got: {s:?}"
);
assert!(
s.contains("id: a b\n"),
"expected sanitized id line, got: {s:?}"
);
assert!(!s.contains('\r'));
}
}