use crate::response::IntoResponse;
use crate::BoxError;
use bytes::Bytes;
use futures_util::{
ready,
stream::{Stream, TryStream},
};
use http::Response;
use http_body::Body as HttpBody;
use pin_project_lite::pin_project;
use std::{
borrow::Cow,
fmt,
fmt::Write,
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use sync_wrapper::SyncWrapper;
use tokio::time::Sleep;
#[derive(Clone)]
pub struct Sse<S> {
stream: S,
keep_alive: Option<KeepAlive>,
}
impl<S> Sse<S> {
pub fn new(stream: S) -> Self
where
S: TryStream<Ok = Event> + Send + 'static,
S::Error: Into<BoxError>,
{
Sse {
stream,
keep_alive: None,
}
}
pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
}
}
impl<S> fmt::Debug for Sse<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sse")
.field("stream", &format_args!("{}", std::any::type_name::<S>()))
.field("keep_alive", &self.keep_alive)
.finish()
}
}
impl<S, E> IntoResponse for Sse<S>
where
S: Stream<Item = Result<Event, E>> + Send + 'static,
E: Into<BoxError>,
{
type Body = Body<S>;
type BodyError = E;
fn into_response(self) -> Response<Self::Body> {
let body = Body {
event_stream: SyncWrapper::new(self.stream),
keep_alive: self.keep_alive.map(KeepAliveStream::new),
};
Response::builder()
.header(http::header::CONTENT_TYPE, "text/event-stream")
.header(http::header::CACHE_CONTROL, "no-cache")
.body(body)
.unwrap()
}
}
pin_project! {
#[derive(Debug)]
pub struct Body<S> {
#[pin]
event_stream: SyncWrapper<S>,
#[pin]
keep_alive: Option<KeepAliveStream>,
}
}
impl<S, E> HttpBody for Body<S>
where
S: Stream<Item = Result<Event, E>>,
{
type Data = Bytes;
type Error = E;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let this = self.project();
match this.event_stream.get_pin_mut().poll_next(cx) {
Poll::Pending => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive
.poll_event(cx)
.map(|e| Some(Ok(Bytes::from(e.to_string()))))
} else {
Poll::Pending
}
}
Poll::Ready(Some(Ok(event))) => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive.reset();
}
Poll::Ready(Some(Ok(Bytes::from(event.to_string()))))
}
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
Poll::Ready(None) => Poll::Ready(None),
}
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}
#[derive(Default, Debug)]
pub struct Event {
id: Option<String>,
data: Option<DataType>,
event: Option<String>,
comment: Option<String>,
retry: Option<Duration>,
}
#[derive(Debug)]
enum DataType {
Text(String),
#[cfg(feature = "json")]
Json(String),
}
impl Event {
pub fn data<T>(mut self, data: T) -> Event
where
T: Into<String>,
{
self.data = Some(DataType::Text(data.into()));
self
}
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
pub fn json_data<T>(mut self, data: T) -> Result<Event, serde_json::Error>
where
T: serde::Serialize,
{
self.data = Some(DataType::Json(serde_json::to_string(&data)?));
Ok(self)
}
pub fn comment<T>(mut self, comment: T) -> Event
where
T: Into<String>,
{
self.comment = Some(comment.into());
self
}
pub fn event<T>(mut self, event: T) -> Event
where
T: Into<String>,
{
self.event = Some(event.into());
self
}
pub fn retry(mut self, duration: Duration) -> Event {
self.retry = Some(duration);
self
}
pub fn id<T>(mut self, id: T) -> Event
where
T: Into<String>,
{
self.id = Some(id.into());
self
}
}
impl fmt::Display for Event {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(comment) = &self.comment {
":".fmt(f)?;
comment.fmt(f)?;
f.write_char('\n')?;
}
if let Some(event) = &self.event {
"event:".fmt(f)?;
event.fmt(f)?;
f.write_char('\n')?;
}
match &self.data {
Some(DataType::Text(data)) => {
for line in data.split('\n') {
"data:".fmt(f)?;
line.fmt(f)?;
f.write_char('\n')?;
}
}
#[cfg(feature = "json")]
Some(DataType::Json(data)) => {
"data:".fmt(f)?;
data.fmt(f)?;
f.write_char('\n')?;
}
None => {}
}
if let Some(id) = &self.id {
"id:".fmt(f)?;
id.fmt(f)?;
f.write_char('\n')?;
}
if let Some(duration) = &self.retry {
"retry:".fmt(f)?;
let secs = duration.as_secs();
let millis = duration.subsec_millis();
if secs > 0 {
secs.fmt(f)?;
if millis < 10 {
f.write_str("00")?;
} else if millis < 100 {
f.write_char('0')?;
}
}
millis.fmt(f)?;
f.write_char('\n')?;
}
f.write_char('\n')?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct KeepAlive {
comment_text: Cow<'static, str>,
max_interval: Duration,
}
impl KeepAlive {
pub fn new() -> Self {
Self {
comment_text: Cow::Borrowed(""),
max_interval: Duration::from_secs(15),
}
}
pub fn interval(mut self, time: Duration) -> Self {
self.max_interval = time;
self
}
pub fn text<I>(mut self, text: I) -> Self
where
I: Into<Cow<'static, str>>,
{
self.comment_text = text.into();
self
}
}
impl Default for KeepAlive {
fn default() -> Self {
Self::new()
}
}
pin_project! {
#[derive(Debug)]
struct KeepAliveStream {
keep_alive: KeepAlive,
#[pin]
alive_timer: Sleep,
}
}
impl KeepAliveStream {
fn new(keep_alive: KeepAlive) -> Self {
Self {
alive_timer: tokio::time::sleep(keep_alive.max_interval),
keep_alive,
}
}
fn reset(self: Pin<&mut Self>) {
let this = self.project();
this.alive_timer
.reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
}
fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Event> {
let this = self.as_mut().project();
ready!(this.alive_timer.poll(cx));
let comment_str = this.keep_alive.comment_text.clone();
let event = Event::default().comment(comment_str);
self.reset();
Poll::Ready(event)
}
}