use bytes::Bytes;
use futures_util::Stream;
use http::{header, StatusCode};
use pin_project_lite::pin_project;
use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, SchemaRef};
use std::collections::BTreeMap;
use std::fmt::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use crate::response::{IntoResponse, Response};
#[derive(Debug, Clone, Default)]
pub struct SseEvent {
pub data: String,
pub event: Option<String>,
pub id: Option<String>,
pub retry: Option<u64>,
comment: Option<String>,
}
impl SseEvent {
pub fn new(data: impl Into<String>) -> Self {
Self {
data: data.into(),
event: None,
id: None,
retry: None,
comment: None,
}
}
pub fn comment(text: impl Into<String>) -> Self {
Self {
data: String::new(),
event: None,
id: None,
retry: None,
comment: Some(text.into()),
}
}
pub fn event(mut self, event: impl Into<String>) -> Self {
self.event = Some(event.into());
self
}
pub fn id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
pub fn retry(mut self, retry: u64) -> Self {
self.retry = Some(retry);
self
}
pub fn json_data<T: serde::Serialize>(data: &T) -> Result<Self, serde_json::Error> {
Ok(Self::new(serde_json::to_string(data)?))
}
pub fn to_sse_string(&self) -> String {
let mut output = String::new();
if let Some(ref comment) = self.comment {
writeln!(output, ": {}", comment).unwrap();
output.push('\n');
return output;
}
if let Some(ref event) = self.event {
writeln!(output, "event: {}", event).unwrap();
}
if let Some(ref id) = self.id {
writeln!(output, "id: {}", id).unwrap();
}
if let Some(retry) = self.retry {
writeln!(output, "retry: {}", retry).unwrap();
}
for line in self.data.lines() {
writeln!(output, "data: {}", line).unwrap();
}
if self.data.is_empty() && self.comment.is_none() {
writeln!(output, "data:").unwrap();
}
output.push('\n');
output
}
pub fn to_bytes(&self) -> Bytes {
Bytes::from(self.to_sse_string())
}
}
#[derive(Debug, Clone)]
pub struct KeepAlive {
interval: Duration,
text: String,
}
impl Default for KeepAlive {
fn default() -> Self {
Self {
interval: Duration::from_secs(15),
text: "keep-alive".to_string(),
}
}
}
impl KeepAlive {
pub fn new() -> Self {
Self::default()
}
pub fn interval(mut self, interval: Duration) -> Self {
self.interval = interval;
self
}
pub fn text(mut self, text: impl Into<String>) -> Self {
self.text = text.into();
self
}
pub fn get_interval(&self) -> Duration {
self.interval
}
pub fn event(&self) -> SseEvent {
SseEvent::comment(&self.text)
}
}
pub struct Sse<S> {
stream: S,
keep_alive: Option<KeepAlive>,
}
impl<S> Sse<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
keep_alive: None,
}
}
pub fn keep_alive(mut self, config: KeepAlive) -> Self {
self.keep_alive = Some(config);
self
}
pub fn get_keep_alive(&self) -> Option<&KeepAlive> {
self.keep_alive.as_ref()
}
}
pin_project! {
pub struct SseStream<S> {
#[pin]
inner: S,
keep_alive: Option<KeepAlive>,
#[pin]
keep_alive_timer: Option<tokio::time::Interval>,
}
}
impl<S, E> Stream for SseStream<S>
where
S: Stream<Item = Result<SseEvent, E>>,
{
type Item = Result<Bytes, E>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.inner.poll_next(cx) {
Poll::Ready(Some(Ok(event))) => {
return Poll::Ready(Some(Ok(event.to_bytes())));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {}
}
if let Some(mut timer) = this.keep_alive_timer.as_pin_mut() {
if timer.poll_tick(cx).is_ready() {
if let Some(keep_alive) = this.keep_alive {
let event = keep_alive.event();
return Poll::Ready(Some(Ok(event.to_bytes())));
}
}
}
Poll::Pending
}
}
impl<S, E> IntoResponse for Sse<S>
where
S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
E: std::error::Error + Send + Sync + 'static,
{
fn into_response(self) -> Response {
let timer = self.keep_alive.as_ref().map(|k| {
let mut interval = tokio::time::interval(k.interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval
});
let stream = SseStream {
inner: self.stream,
keep_alive: self.keep_alive,
keep_alive_timer: timer,
};
use futures_util::StreamExt;
let stream =
stream.map(|res| res.map_err(|e| crate::error::ApiError::internal(e.to_string())));
let body = crate::response::Body::from_stream(stream);
http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.header(header::CONNECTION, "keep-alive")
.header("X-Accel-Buffering", "no") .body(body)
.unwrap()
}
}
impl<S> ResponseModifier for Sse<S> {
fn update_response(op: &mut Operation) {
let mut content = BTreeMap::new();
content.insert(
"text/event-stream".to_string(),
MediaType {
schema: Some(SchemaRef::Inline(serde_json::json!({
"type": "string",
"description": "Server-Sent Events stream. Events follow the SSE format: 'event: <type>\\ndata: <json>\\n\\n'",
}))),
example: Some(serde_json::json!("event: message\ndata: {\"id\": 1, \"text\": \"Hello\"}\n\n")),
},
);
let response = ResponseSpec {
description: "Server-Sent Events stream for real-time updates".to_string(),
content,
headers: BTreeMap::new(),
};
op.responses.insert("200".to_string(), response);
}
}
pub async fn collect_sse_events<S, E>(stream: S) -> Result<Bytes, E>
where
S: Stream<Item = Result<SseEvent, E>> + Send,
{
use futures_util::StreamExt;
let mut buffer = Vec::new();
futures_util::pin_mut!(stream);
while let Some(result) = stream.next().await {
let event = result?;
buffer.extend_from_slice(&event.to_bytes());
}
Ok(Bytes::from(buffer))
}
pub fn sse_response<I>(events: I) -> Response
where
I: IntoIterator<Item = SseEvent>,
{
let mut buffer = String::new();
for event in events {
buffer.push_str(&event.to_sse_string());
}
http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.header(header::CONNECTION, "keep-alive")
.header("X-Accel-Buffering", "no")
.body(crate::response::Body::from(buffer))
.unwrap()
}
pub fn sse_from_iter<I, E>(
events: I,
) -> Sse<futures_util::stream::Iter<std::vec::IntoIter<Result<SseEvent, E>>>>
where
I: IntoIterator<Item = Result<SseEvent, E>>,
{
use futures_util::stream;
let vec: Vec<_> = events.into_iter().collect();
Sse::new(stream::iter(vec))
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_sse_event_basic() {
let event = SseEvent::new("Hello, World!");
let output = event.to_sse_string();
assert_eq!(output, "data: Hello, World!\n\n");
}
#[test]
fn test_sse_event_with_event_type() {
let event = SseEvent::new("Hello").event("greeting");
let output = event.to_sse_string();
assert!(output.contains("event: greeting\n"));
assert!(output.contains("data: Hello\n"));
}
#[test]
fn test_sse_event_with_id() {
let event = SseEvent::new("Hello").id("123");
let output = event.to_sse_string();
assert!(output.contains("id: 123\n"));
assert!(output.contains("data: Hello\n"));
}
#[test]
fn test_sse_event_with_retry() {
let event = SseEvent::new("Hello").retry(5000);
let output = event.to_sse_string();
assert!(output.contains("retry: 5000\n"));
assert!(output.contains("data: Hello\n"));
}
#[test]
fn test_sse_event_multiline_data() {
let event = SseEvent::new("Line 1\nLine 2\nLine 3");
let output = event.to_sse_string();
assert!(output.contains("data: Line 1\n"));
assert!(output.contains("data: Line 2\n"));
assert!(output.contains("data: Line 3\n"));
}
#[test]
fn test_sse_event_full() {
let event = SseEvent::new("Hello").event("message").id("1").retry(3000);
let output = event.to_sse_string();
assert!(output.contains("event: message\n"));
assert!(output.contains("id: 1\n"));
assert!(output.contains("retry: 3000\n"));
assert!(output.contains("data: Hello\n"));
assert!(output.ends_with("\n\n"));
}
#[test]
fn test_sse_response_headers() {
use futures_util::stream;
let events: Vec<Result<SseEvent, std::convert::Infallible>> =
vec![Ok(SseEvent::new("test"))];
let sse = Sse::new(stream::iter(events));
let response = sse.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"text/event-stream"
);
assert_eq!(
response.headers().get(header::CACHE_CONTROL).unwrap(),
"no-cache"
);
assert_eq!(
response.headers().get(header::CONNECTION).unwrap(),
"keep-alive"
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_sse_response_format(
data in "[a-zA-Z0-9 ]{1,50}",
event_type in proptest::option::of("[a-zA-Z][a-zA-Z0-9_]{0,20}"),
event_id in proptest::option::of("[a-zA-Z0-9]{1,10}"),
retry_time in proptest::option::of(1000u64..60000u64),
) {
use futures_util::stream;
let mut event = SseEvent::new(data.clone());
if let Some(ref et) = event_type {
event = event.event(et.clone());
}
if let Some(ref id) = event_id {
event = event.id(id.clone());
}
if let Some(retry) = retry_time {
event = event.retry(retry);
}
let sse_string = event.to_sse_string();
prop_assert!(
sse_string.ends_with("\n\n"),
"SSE event must end with double newline, got: {:?}",
sse_string
);
prop_assert!(
sse_string.contains(&format!("data: {}", data)),
"SSE event must contain data field with 'data: ' prefix"
);
if let Some(ref et) = event_type {
prop_assert!(
sse_string.contains(&format!("event: {}", et)),
"SSE event must contain event type with 'event: ' prefix"
);
}
if let Some(ref id) = event_id {
prop_assert!(
sse_string.contains(&format!("id: {}", id)),
"SSE event must contain ID with 'id: ' prefix"
);
}
if let Some(retry) = retry_time {
prop_assert!(
sse_string.contains(&format!("retry: {}", retry)),
"SSE event must contain retry with 'retry: ' prefix"
);
}
let events: Vec<Result<SseEvent, std::convert::Infallible>> = vec![Ok(event)];
let sse = Sse::new(stream::iter(events));
let response = sse.into_response();
prop_assert_eq!(
response.headers().get(header::CONTENT_TYPE).map(|v| v.to_str().unwrap()),
Some("text/event-stream"),
"SSE response must have Content-Type: text/event-stream"
);
prop_assert_eq!(
response.headers().get(header::CACHE_CONTROL).map(|v| v.to_str().unwrap()),
Some("no-cache"),
"SSE response must have Cache-Control: no-cache"
);
prop_assert_eq!(
response.headers().get(header::CONNECTION).map(|v| v.to_str().unwrap()),
Some("keep-alive"),
"SSE response must have Connection: keep-alive"
);
}
#[test]
fn prop_sse_multiline_data_format(
lines in proptest::collection::vec("[a-zA-Z0-9 ]{1,30}", 1..5),
) {
let data = lines.join("\n");
let event = SseEvent::new(data.clone());
let sse_string = event.to_sse_string();
for line in lines.iter() {
prop_assert!(
sse_string.contains(&format!("data: {}", line)),
"Each line of multiline data must be prefixed with 'data: '"
);
}
prop_assert!(
sse_string.ends_with("\n\n"),
"SSE event must end with double newline"
);
}
}
}