#![forbid(unsafe_code)]
use std::time::Duration;
use bytes::Bytes;
use futures_util::stream::{self, StreamExt};
use http::StatusCode;
use tokio::sync::mpsc;
use tokio::time;
use oxihttp_core::{Body, OxiHttpError};
#[derive(Debug, Clone)]
pub struct SseEvent {
pub id: Option<String>,
pub event: Option<String>,
pub data: String,
pub retry: Option<u64>,
}
impl SseEvent {
pub fn data(data: impl Into<String>) -> Self {
Self {
id: None,
event: None,
data: data.into(),
retry: None,
}
}
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
pub fn with_event(mut self, event: impl Into<String>) -> Self {
self.event = Some(event.into());
self
}
pub fn with_retry(mut self, ms: u64) -> Self {
self.retry = Some(ms);
self
}
pub fn encode(&self) -> Bytes {
let mut buf = String::new();
if let Some(id) = &self.id {
buf.push_str("id: ");
buf.push_str(id);
buf.push('\n');
}
if let Some(event) = &self.event {
buf.push_str("event: ");
buf.push_str(event);
buf.push('\n');
}
if let Some(retry) = self.retry {
buf.push_str("retry: ");
buf.push_str(&retry.to_string());
buf.push('\n');
}
for line in self.data.lines() {
buf.push_str("data: ");
buf.push_str(line);
buf.push('\n');
}
if self.data.is_empty() {
buf.push_str("data: \n");
}
buf.push('\n'); Bytes::from(buf)
}
}
pub struct SseSender {
tx: mpsc::Sender<SseEvent>,
}
impl SseSender {
pub async fn send(&self, event: SseEvent) -> Result<(), OxiHttpError> {
self.tx
.send(event)
.await
.map_err(|_| OxiHttpError::Body("SSE channel closed".into()))
}
pub fn try_send(&self, event: SseEvent) -> Result<(), OxiHttpError> {
self.tx
.try_send(event)
.map_err(|_| OxiHttpError::Body("SSE channel full or closed".into()))
}
}
pub struct SseResponse {
rx: mpsc::Receiver<SseEvent>,
heartbeat: Option<Duration>,
}
impl SseResponse {
pub fn channel(buffer: usize) -> (SseSender, Self) {
let (tx, rx) = mpsc::channel(buffer);
(
SseSender { tx },
Self {
rx,
heartbeat: None,
},
)
}
pub fn with_heartbeat(mut self, interval: Duration) -> Self {
self.heartbeat = Some(interval);
self
}
pub fn into_response(self) -> http::Response<Body> {
let rx = self.rx;
let heartbeat = self.heartbeat;
let event_stream = stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(|event| {
let bytes = event.encode();
(Ok::<Bytes, OxiHttpError>(bytes), rx)
})
});
let body = if let Some(interval) = heartbeat {
let heartbeat_stream = stream::unfold((), move |()| async move {
time::sleep(interval).await;
Some((Ok::<Bytes, OxiHttpError>(Bytes::from_static(b":\n\n")), ()))
});
let (merged_tx, merged_rx) = mpsc::channel::<Bytes>(32);
tokio::spawn(async move {
let mut event_stream = Box::pin(event_stream);
let mut hb_stream = Box::pin(heartbeat_stream);
loop {
tokio::select! {
result = event_stream.next() => match result {
Some(Ok(bytes)) => {
if merged_tx.send(bytes).await.is_err() {
break;
}
}
Some(Err(_)) | None => break,
},
Some(Ok(hb)) = hb_stream.next() => {
if merged_tx.send(hb).await.is_err() {
break;
}
}
}
}
});
let merged_stream = stream::unfold(merged_rx, |mut rx| async move {
rx.recv().await.map(|b| (Ok::<Bytes, OxiHttpError>(b), rx))
});
Body::stream(Box::pin(merged_stream))
} else {
Body::stream(Box::pin(event_stream))
};
http::Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.header("X-Accel-Buffering", "no")
.body(body)
.unwrap_or_else(|_| http::Response::new(Body::Empty))
}
}