use std::convert::Infallible;
use std::marker::PhantomData;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use futures_core::Stream;
pub(crate) const SSE_STREAM_MARKER_KEY: &str = "x-sse-stream";
pub(crate) fn mark_sse_response(op: &mut utoipa::openapi::path::Operation) {
use utoipa::openapi::RefOr;
let Some(resp) = op.responses.responses.get_mut("200") else {
return;
};
let RefOr::T(resp) = resp else {
return;
};
let Some(content) = resp.content.get_mut("text/event-stream") else {
return;
};
let existing = content
.extensions
.as_ref()
.and_then(|e| serde_json::to_value(e).ok());
let already = matches!(
existing.as_ref().and_then(|v| v.get(SSE_STREAM_MARKER_KEY)),
Some(serde_json::Value::Bool(true))
);
if already {
return;
}
let ext = utoipa::openapi::extensions::ExtensionsBuilder::new()
.add(SSE_STREAM_MARKER_KEY, serde_json::Value::Bool(true))
.build();
match content.extensions.as_mut() {
Some(existing) => existing.merge(ext),
None => content.extensions = Some(ext),
}
}
pub trait SseEventMeta {
fn event_name(&self) -> &'static str;
fn all_event_names() -> &'static [&'static str];
}
pub struct SseStream<E, S> {
stream: S,
keep_alive: Option<KeepAlive>,
_event: PhantomData<fn() -> E>,
}
impl<E, S> SseStream<E, S> {
pub fn new(stream: S) -> Self {
Self {
stream,
keep_alive: Some(KeepAlive::default()),
_event: PhantomData,
}
}
pub fn with_keep_alive(mut self, keep_alive: Option<KeepAlive>) -> Self {
self.keep_alive = keep_alive;
self
}
}
impl<E, S, Err> IntoResponse for SseStream<E, S>
where
E: SseEventMeta + serde::Serialize + Send + 'static,
S: Stream<Item = Result<E, Err>> + Send + Unpin + 'static,
Err: std::error::Error + Send + Sync + 'static,
{
fn into_response(self) -> Response {
let mapped = EventMapStream {
inner: self.stream,
_event: PhantomData::<fn() -> E>,
};
let ka = self.keep_alive.unwrap_or_else(|| {
KeepAlive::new().interval(std::time::Duration::from_secs(60 * 60 * 24))
});
Sse::new(mapped).keep_alive(ka).into_response()
}
}
struct EventMapStream<E, S> {
inner: S,
_event: PhantomData<fn() -> E>,
}
impl<E, S, Err> Stream for EventMapStream<E, S>
where
E: SseEventMeta + serde::Serialize,
S: Stream<Item = Result<E, Err>> + Unpin,
Err: std::error::Error,
{
type Item = Result<Event, Infallible>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::task::Poll;
match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(ev))) => Poll::Ready(Some(Ok(event_for(&ev)))),
Poll::Ready(Some(Err(err))) => {
tracing::warn!(error = %err, "sse upstream stream item failed");
let frame = Event::default().event("error").data(err.to_string());
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl<E, S> Unpin for EventMapStream<E, S> where S: Unpin {}
fn event_for<E>(value: &E) -> Event
where
E: SseEventMeta + serde::Serialize,
{
let name = value.event_name();
match Event::default().event(name).json_data(value) {
Ok(ev) => ev,
Err(err) => {
tracing::error!(
error = %err,
event_name = name,
"sse json_data serialization failed; emitting error frame",
);
Event::default().event("error").data(err.to_string())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use futures::stream;
use http_body_util::BodyExt;
#[derive(serde::Serialize)]
#[serde(tag = "event", content = "data", rename_all = "snake_case")]
enum Ev {
Started { pipeline: String },
Done,
}
impl SseEventMeta for Ev {
fn event_name(&self) -> &'static str {
match self {
Self::Started { .. } => "started",
Self::Done => "done",
}
}
fn all_event_names() -> &'static [&'static str] {
&["started", "done"]
}
}
#[tokio::test]
async fn into_response_sets_text_event_stream_content_type() {
let s = SseStream::<Ev, _>::new(stream::iter(Vec::<Result<Ev, Infallible>>::new()));
let resp: Response = s.into_response();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp
.headers()
.get(axum::http::header::CONTENT_TYPE)
.unwrap();
assert!(ct.to_str().unwrap().starts_with("text/event-stream"));
}
#[tokio::test]
async fn emits_named_event_frame_with_json_data_for_each_item() {
let items: Vec<Result<Ev, Infallible>> = vec![
Ok(Ev::Started {
pipeline: "p1".into(),
}),
Ok(Ev::Done),
];
let s = SseStream::<Ev, _>::new(stream::iter(items)).with_keep_alive(None);
let resp: Response = s.into_response();
let body: Body = resp.into_body();
let bytes = body.collect().await.unwrap().to_bytes();
let text = std::str::from_utf8(&bytes).unwrap();
assert!(text.contains("event: started"));
assert!(text.contains(r#""event":"started""#));
assert!(text.contains(r#""pipeline":"p1""#));
assert!(text.contains("event: done"));
assert!(text.contains(r#""event":"done""#));
}
}