doxa_docs/sse.rs
1//! Runtime support for Server-Sent Event streams typed by an event enum.
2//!
3//! A handler returning [`SseStream<E, S>`] produces an SSE response
4//! whose event frames are named after the variant of `E` carried by
5//! each stream item. The event name and JSON payload are derived from
6//! the [`SseEventMeta`] trait, which the
7//! [`SseEvent`](doxa_macros::SseEvent) derive implements
8//! alongside [`utoipa::ToSchema`] for the same enum.
9//!
10//! The typed enum is the single source of truth for both the wire
11//! format (variant → `event:` line + JSON `data:`) and the OpenAPI
12//! description (`oneOf` of tagged variant objects under
13//! `text/event-stream`). Handlers never construct
14//! [`axum::response::sse::Event`] values directly — [`SseStream`]
15//! owns that conversion so variants cannot drift out of sync with the
16//! rendered documentation.
17//!
18//! # Example
19//!
20//! ```no_run
21//! use doxa::{SseEventMeta, SseStream};
22//! use futures_core::Stream;
23//! use std::convert::Infallible;
24//!
25//! // Normally derived with `#[derive(doxa::SseEvent,
26//! // serde::Serialize)]`; shown here as a hand-written impl for
27//! // clarity.
28//! #[derive(serde::Serialize)]
29//! #[serde(tag = "event", content = "data", rename_all = "snake_case")]
30//! enum BuildEvent {
31//! Started { id: u64 },
32//! Progress { done: u64, total: u64 },
33//! }
34//!
35//! impl SseEventMeta for BuildEvent {
36//! fn event_name(&self) -> &'static str {
37//! match self {
38//! Self::Started { .. } => "started",
39//! Self::Progress { .. } => "progress",
40//! }
41//! }
42//!
43//! fn all_event_names() -> &'static [&'static str] {
44//! &["started", "progress"]
45//! }
46//! }
47//!
48//! async fn stream_handler(
49//! ) -> SseStream<BuildEvent, impl Stream<Item = Result<BuildEvent, Infallible>>> {
50//! let events = futures::stream::iter(vec![
51//! Ok(BuildEvent::Started { id: 1 }),
52//! Ok(BuildEvent::Progress { done: 1, total: 10 }),
53//! ]);
54//! SseStream::new(events)
55//! }
56//! ```
57//!
58//! [`axum::response::sse::Event`]: https://docs.rs/axum/latest/axum/response/sse/struct.Event.html
59
60use std::convert::Infallible;
61use std::marker::PhantomData;
62
63use axum::response::sse::{Event, KeepAlive, Sse};
64use axum::response::{IntoResponse, Response};
65use futures_core::Stream;
66
67/// Vendor-extension key marking a response's `text/event-stream`
68/// content entry as an SSE stream for the builder's post-process to
69/// recognize.
70///
71/// Emitted by the [`mark_sse_response`] helper (which the method-shortcut
72/// macros call when they infer an `SseStream<E, …>` return type) and
73/// stripped by the builder in both OpenAPI 3.1 and 3.2 output modes so
74/// it never leaks to downstream consumers.
75pub(crate) const SSE_STREAM_MARKER_KEY: &str = "x-sse-stream";
76
77/// Tag `op`'s `200` response's `text/event-stream` content entry with
78/// an `x-sse-stream: true` vendor extension. Invoked from the
79/// [`crate::DocResponseBody`] impl for [`SseStream`] after it inserts
80/// the response; the builder's spec-version post-process reads the
81/// marker to decide whether to rewrite `schema` → `itemSchema`.
82///
83/// Idempotent: repeated calls with the same operation produce a
84/// single marker entry.
85pub(crate) fn mark_sse_response(op: &mut utoipa::openapi::path::Operation) {
86 use utoipa::openapi::RefOr;
87
88 let Some(resp) = op.responses.responses.get_mut("200") else {
89 return;
90 };
91 let RefOr::T(resp) = resp else {
92 return;
93 };
94 let Some(content) = resp.content.get_mut("text/event-stream") else {
95 return;
96 };
97
98 // Round-trip through `serde_json::Value` because utoipa's
99 // `Extensions` type does not expose its inner map by reference.
100 let existing = content
101 .extensions
102 .as_ref()
103 .and_then(|e| serde_json::to_value(e).ok());
104 let already = matches!(
105 existing.as_ref().and_then(|v| v.get(SSE_STREAM_MARKER_KEY)),
106 Some(serde_json::Value::Bool(true))
107 );
108 if already {
109 return;
110 }
111
112 let ext = utoipa::openapi::extensions::ExtensionsBuilder::new()
113 .add(SSE_STREAM_MARKER_KEY, serde_json::Value::Bool(true))
114 .build();
115 match content.extensions.as_mut() {
116 Some(existing) => existing.merge(ext),
117 None => content.extensions = Some(ext),
118 }
119}
120
121/// Per-variant metadata for a Server-Sent Event enum.
122///
123/// Implemented automatically by the
124/// [`SseEvent`](doxa_macros::SseEvent) derive; hand-written
125/// impls are supported but rare. The trait exposes two pieces of
126/// information:
127///
128/// - [`Self::event_name`] — the event name emitted on the `event:` line of the
129/// SSE frame for a specific value. Defaults to the snake-case form of the
130/// variant name; overridable per variant via `#[sse(name = "…")]`.
131/// - [`Self::all_event_names`] — the full set of event names the enum can
132/// produce, in declaration order. Surfaced for documentation/testing; not
133/// used on the hot path.
134pub trait SseEventMeta {
135 /// Return the SSE event name for the current variant.
136 fn event_name(&self) -> &'static str;
137
138 /// Return every event name this enum can produce, in variant
139 /// declaration order.
140 fn all_event_names() -> &'static [&'static str];
141}
142
143/// A typed SSE response stream.
144///
145/// Wraps a [`Stream`] of `Result<E, Err>` and produces an
146/// [`axum::response::sse::Sse`] response on `IntoResponse`. Each
147/// stream item is serialized to JSON and framed with the event name
148/// returned by [`SseEventMeta::event_name`].
149///
150/// The [`SseStream`] newtype is what the
151/// [`#[derive(SseEvent)]`](doxa_macros::SseEvent) integration
152/// reads at documentation-generation time to attach the
153/// `text/event-stream` response and its schema to the operation —
154/// handlers that return `SseStream<E, _>` get the right OpenAPI
155/// description for free.
156///
157/// Keep-alive comments are enabled by default so intermediaries do
158/// not close idle connections; swap via [`Self::with_keep_alive`].
159pub struct SseStream<E, S> {
160 stream: S,
161 keep_alive: Option<KeepAlive>,
162 _event: PhantomData<fn() -> E>,
163}
164
165impl<E, S> SseStream<E, S> {
166 /// Wrap a stream of events. Enables the default keep-alive.
167 pub fn new(stream: S) -> Self {
168 Self {
169 stream,
170 keep_alive: Some(KeepAlive::default()),
171 _event: PhantomData,
172 }
173 }
174
175 /// Replace the keep-alive configuration. Pass [`None`] to
176 /// disable keep-alive frames entirely.
177 pub fn with_keep_alive(mut self, keep_alive: Option<KeepAlive>) -> Self {
178 self.keep_alive = keep_alive;
179 self
180 }
181}
182
183impl<E, S, Err> IntoResponse for SseStream<E, S>
184where
185 E: SseEventMeta + serde::Serialize + Send + 'static,
186 S: Stream<Item = Result<E, Err>> + Send + Unpin + 'static,
187 Err: std::error::Error + Send + Sync + 'static,
188{
189 fn into_response(self) -> Response {
190 let mapped = EventMapStream {
191 inner: self.stream,
192 _event: PhantomData::<fn() -> E>,
193 };
194 // Apply a keep-alive unconditionally so the method chain stays
195 // monomorphic (otherwise the `Sse<…>` generic argument diverges
196 // between the two branches). Callers that pass
197 // `with_keep_alive(None)` get a keep-alive with a one-day
198 // interval — effectively disabled for any realistic request,
199 // but it keeps the return type stable and avoids `Instant`
200 // overflow that an actual `u64::MAX` interval would trigger.
201 let ka = self.keep_alive.unwrap_or_else(|| {
202 KeepAlive::new().interval(std::time::Duration::from_secs(60 * 60 * 24))
203 });
204 Sse::new(mapped).keep_alive(ka).into_response()
205 }
206}
207
208// A thin adapter that maps each `Result<E, Err>` into the
209// `Result<Event, Infallible>` shape axum's `Sse` expects. JSON
210// serialization failures are logged and the frame is replaced with
211// an `error` event — keeping the stream alive is more useful than
212// dropping a whole subscription because one payload was malformed.
213struct EventMapStream<E, S> {
214 inner: S,
215 _event: PhantomData<fn() -> E>,
216}
217
218impl<E, S, Err> Stream for EventMapStream<E, S>
219where
220 E: SseEventMeta + serde::Serialize,
221 S: Stream<Item = Result<E, Err>> + Unpin,
222 Err: std::error::Error,
223{
224 type Item = Result<Event, Infallible>;
225
226 fn poll_next(
227 mut self: std::pin::Pin<&mut Self>,
228 cx: &mut std::task::Context<'_>,
229 ) -> std::task::Poll<Option<Self::Item>> {
230 use std::task::Poll;
231 match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
232 Poll::Ready(Some(Ok(ev))) => Poll::Ready(Some(Ok(event_for(&ev)))),
233 Poll::Ready(Some(Err(err))) => {
234 tracing::warn!(error = %err, "sse upstream stream item failed");
235 let frame = Event::default().event("error").data(err.to_string());
236 Poll::Ready(Some(Ok(frame)))
237 }
238 Poll::Ready(None) => Poll::Ready(None),
239 Poll::Pending => Poll::Pending,
240 }
241 }
242}
243
244impl<E, S> Unpin for EventMapStream<E, S> where S: Unpin {}
245
246/// Build the [`Event`] for a single typed value. Logs and falls back
247/// to an `error` frame if the JSON payload cannot be serialized.
248fn event_for<E>(value: &E) -> Event
249where
250 E: SseEventMeta + serde::Serialize,
251{
252 let name = value.event_name();
253 match Event::default().event(name).json_data(value) {
254 Ok(ev) => ev,
255 Err(err) => {
256 tracing::error!(
257 error = %err,
258 event_name = name,
259 "sse json_data serialization failed; emitting error frame",
260 );
261 Event::default().event("error").data(err.to_string())
262 }
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use axum::body::Body;
270 use axum::http::StatusCode;
271 use axum::response::IntoResponse;
272 use futures::stream;
273 use http_body_util::BodyExt;
274
275 #[derive(serde::Serialize)]
276 #[serde(tag = "event", content = "data", rename_all = "snake_case")]
277 enum Ev {
278 Started { pipeline: String },
279 Done,
280 }
281
282 impl SseEventMeta for Ev {
283 fn event_name(&self) -> &'static str {
284 match self {
285 Self::Started { .. } => "started",
286 Self::Done => "done",
287 }
288 }
289 fn all_event_names() -> &'static [&'static str] {
290 &["started", "done"]
291 }
292 }
293
294 #[tokio::test]
295 async fn into_response_sets_text_event_stream_content_type() {
296 let s = SseStream::<Ev, _>::new(stream::iter(Vec::<Result<Ev, Infallible>>::new()));
297 let resp: Response = s.into_response();
298 assert_eq!(resp.status(), StatusCode::OK);
299 let ct = resp
300 .headers()
301 .get(axum::http::header::CONTENT_TYPE)
302 .unwrap();
303 assert!(ct.to_str().unwrap().starts_with("text/event-stream"));
304 }
305
306 #[tokio::test]
307 async fn emits_named_event_frame_with_json_data_for_each_item() {
308 let items: Vec<Result<Ev, Infallible>> = vec![
309 Ok(Ev::Started {
310 pipeline: "p1".into(),
311 }),
312 Ok(Ev::Done),
313 ];
314 let s = SseStream::<Ev, _>::new(stream::iter(items)).with_keep_alive(None);
315
316 let resp: Response = s.into_response();
317 // Drain the response body and inspect the framed output.
318 let body: Body = resp.into_body();
319 let bytes = body.collect().await.unwrap().to_bytes();
320 let text = std::str::from_utf8(&bytes).unwrap();
321
322 assert!(text.contains("event: started"));
323 assert!(text.contains(r#""event":"started""#));
324 assert!(text.contains(r#""pipeline":"p1""#));
325 assert!(text.contains("event: done"));
326 assert!(text.contains(r#""event":"done""#));
327 }
328}