Skip to main content

doxa_docs/
sse.rs

1//! Runtime support for Server-Sent Event streams typed by an event enum.
2//!
3//! A handler returning [`SseStream<E, S>`] produces an SSE response
4//! whose event frames are named after the variant of `E` carried by
5//! each stream item. The event name and JSON payload are derived from
6//! the [`SseEventMeta`] trait, which the
7//! [`SseEvent`](doxa_macros::SseEvent) derive implements
8//! alongside [`utoipa::ToSchema`] for the same enum.
9//!
10//! The typed enum is the single source of truth for both the wire
11//! format (variant → `event:` line + JSON `data:`) and the OpenAPI
12//! description (`oneOf` of tagged variant objects under
13//! `text/event-stream`). Handlers never construct
14//! [`axum::response::sse::Event`] values directly — [`SseStream`]
15//! owns that conversion so variants cannot drift out of sync with the
16//! rendered documentation.
17//!
18//! # Example
19//!
20//! ```no_run
21//! use doxa::{SseEventMeta, SseStream};
22//! use futures_core::Stream;
23//! use std::convert::Infallible;
24//!
25//! // Normally derived with `#[derive(doxa::SseEvent,
26//! // serde::Serialize)]`; shown here as a hand-written impl for
27//! // clarity.
28//! #[derive(serde::Serialize)]
29//! #[serde(tag = "event", content = "data", rename_all = "snake_case")]
30//! enum BuildEvent {
31//!     Started { id: u64 },
32//!     Progress { done: u64, total: u64 },
33//! }
34//!
35//! impl SseEventMeta for BuildEvent {
36//!     fn event_name(&self) -> &'static str {
37//!         match self {
38//!             Self::Started { .. } => "started",
39//!             Self::Progress { .. } => "progress",
40//!         }
41//!     }
42//!
43//!     fn all_event_names() -> &'static [&'static str] {
44//!         &["started", "progress"]
45//!     }
46//! }
47//!
48//! async fn stream_handler(
49//! ) -> SseStream<BuildEvent, impl Stream<Item = Result<BuildEvent, Infallible>>> {
50//!     let events = futures::stream::iter(vec![
51//!         Ok(BuildEvent::Started { id: 1 }),
52//!         Ok(BuildEvent::Progress { done: 1, total: 10 }),
53//!     ]);
54//!     SseStream::new(events)
55//! }
56//! ```
57//!
58//! [`axum::response::sse::Event`]: https://docs.rs/axum/latest/axum/response/sse/struct.Event.html
59
60use std::convert::Infallible;
61use std::marker::PhantomData;
62
63use axum::response::sse::{Event, KeepAlive, Sse};
64use axum::response::{IntoResponse, Response};
65use futures_core::Stream;
66
67/// Vendor-extension key marking a response's `text/event-stream`
68/// content entry as an SSE stream for the builder's post-process to
69/// recognize.
70///
71/// Emitted by the [`mark_sse_response`] helper (which the method-shortcut
72/// macros call when they infer an `SseStream<E, …>` return type) and
73/// stripped by the builder in both OpenAPI 3.1 and 3.2 output modes so
74/// it never leaks to downstream consumers.
75pub(crate) const SSE_STREAM_MARKER_KEY: &str = "x-sse-stream";
76
77/// Tag `op`'s `200` response's `text/event-stream` content entry with
78/// an `x-sse-stream: true` vendor extension. Invoked from the
79/// [`crate::DocResponseBody`] impl for [`SseStream`] after it inserts
80/// the response; the builder's spec-version post-process reads the
81/// marker to decide whether to rewrite `schema` → `itemSchema`.
82///
83/// Idempotent: repeated calls with the same operation produce a
84/// single marker entry.
85pub(crate) fn mark_sse_response(op: &mut utoipa::openapi::path::Operation) {
86    use utoipa::openapi::RefOr;
87
88    let Some(resp) = op.responses.responses.get_mut("200") else {
89        return;
90    };
91    let RefOr::T(resp) = resp else {
92        return;
93    };
94    let Some(content) = resp.content.get_mut("text/event-stream") else {
95        return;
96    };
97
98    // Round-trip through `serde_json::Value` because utoipa's
99    // `Extensions` type does not expose its inner map by reference.
100    let existing = content
101        .extensions
102        .as_ref()
103        .and_then(|e| serde_json::to_value(e).ok());
104    let already = matches!(
105        existing.as_ref().and_then(|v| v.get(SSE_STREAM_MARKER_KEY)),
106        Some(serde_json::Value::Bool(true))
107    );
108    if already {
109        return;
110    }
111
112    let ext = utoipa::openapi::extensions::ExtensionsBuilder::new()
113        .add(SSE_STREAM_MARKER_KEY, serde_json::Value::Bool(true))
114        .build();
115    match content.extensions.as_mut() {
116        Some(existing) => existing.merge(ext),
117        None => content.extensions = Some(ext),
118    }
119}
120
121/// Per-variant metadata for a Server-Sent Event enum.
122///
123/// Implemented automatically by the
124/// [`SseEvent`](doxa_macros::SseEvent) derive; hand-written
125/// impls are supported but rare. The trait exposes two pieces of
126/// information:
127///
128/// - [`Self::event_name`] — the event name emitted on the `event:` line of the
129///   SSE frame for a specific value. Defaults to the snake-case form of the
130///   variant name; overridable per variant via `#[sse(name = "…")]`.
131/// - [`Self::all_event_names`] — the full set of event names the enum can
132///   produce, in declaration order. Surfaced for documentation/testing; not
133///   used on the hot path.
134pub trait SseEventMeta {
135    /// Return the SSE event name for the current variant.
136    fn event_name(&self) -> &'static str;
137
138    /// Return every event name this enum can produce, in variant
139    /// declaration order.
140    fn all_event_names() -> &'static [&'static str];
141}
142
143/// A typed SSE response stream.
144///
145/// Wraps a [`Stream`] of `Result<E, Err>` and produces an
146/// [`axum::response::sse::Sse`] response on `IntoResponse`. Each
147/// stream item is serialized to JSON and framed with the event name
148/// returned by [`SseEventMeta::event_name`].
149///
150/// The [`SseStream`] newtype is what the
151/// [`#[derive(SseEvent)]`](doxa_macros::SseEvent) integration
152/// reads at documentation-generation time to attach the
153/// `text/event-stream` response and its schema to the operation —
154/// handlers that return `SseStream<E, _>` get the right OpenAPI
155/// description for free.
156///
157/// Keep-alive comments are enabled by default so intermediaries do
158/// not close idle connections; swap via [`Self::with_keep_alive`].
159pub struct SseStream<E, S> {
160    stream: S,
161    keep_alive: Option<KeepAlive>,
162    _event: PhantomData<fn() -> E>,
163}
164
165impl<E, S> SseStream<E, S> {
166    /// Wrap a stream of events. Enables the default keep-alive.
167    pub fn new(stream: S) -> Self {
168        Self {
169            stream,
170            keep_alive: Some(KeepAlive::default()),
171            _event: PhantomData,
172        }
173    }
174
175    /// Replace the keep-alive configuration. Pass [`None`] to
176    /// disable keep-alive frames entirely.
177    pub fn with_keep_alive(mut self, keep_alive: Option<KeepAlive>) -> Self {
178        self.keep_alive = keep_alive;
179        self
180    }
181}
182
183impl<E, S, Err> IntoResponse for SseStream<E, S>
184where
185    E: SseEventMeta + serde::Serialize + Send + 'static,
186    S: Stream<Item = Result<E, Err>> + Send + Unpin + 'static,
187    Err: std::error::Error + Send + Sync + 'static,
188{
189    fn into_response(self) -> Response {
190        let mapped = EventMapStream {
191            inner: self.stream,
192            _event: PhantomData::<fn() -> E>,
193        };
194        // Apply a keep-alive unconditionally so the method chain stays
195        // monomorphic (otherwise the `Sse<…>` generic argument diverges
196        // between the two branches). Callers that pass
197        // `with_keep_alive(None)` get a keep-alive with a one-day
198        // interval — effectively disabled for any realistic request,
199        // but it keeps the return type stable and avoids `Instant`
200        // overflow that an actual `u64::MAX` interval would trigger.
201        let ka = self.keep_alive.unwrap_or_else(|| {
202            KeepAlive::new().interval(std::time::Duration::from_secs(60 * 60 * 24))
203        });
204        Sse::new(mapped).keep_alive(ka).into_response()
205    }
206}
207
208// A thin adapter that maps each `Result<E, Err>` into the
209// `Result<Event, Infallible>` shape axum's `Sse` expects. JSON
210// serialization failures are logged and the frame is replaced with
211// an `error` event — keeping the stream alive is more useful than
212// dropping a whole subscription because one payload was malformed.
213struct EventMapStream<E, S> {
214    inner: S,
215    _event: PhantomData<fn() -> E>,
216}
217
218impl<E, S, Err> Stream for EventMapStream<E, S>
219where
220    E: SseEventMeta + serde::Serialize,
221    S: Stream<Item = Result<E, Err>> + Unpin,
222    Err: std::error::Error,
223{
224    type Item = Result<Event, Infallible>;
225
226    fn poll_next(
227        mut self: std::pin::Pin<&mut Self>,
228        cx: &mut std::task::Context<'_>,
229    ) -> std::task::Poll<Option<Self::Item>> {
230        use std::task::Poll;
231        match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
232            Poll::Ready(Some(Ok(ev))) => Poll::Ready(Some(Ok(event_for(&ev)))),
233            Poll::Ready(Some(Err(err))) => {
234                tracing::warn!(error = %err, "sse upstream stream item failed");
235                let frame = Event::default().event("error").data(err.to_string());
236                Poll::Ready(Some(Ok(frame)))
237            }
238            Poll::Ready(None) => Poll::Ready(None),
239            Poll::Pending => Poll::Pending,
240        }
241    }
242}
243
244impl<E, S> Unpin for EventMapStream<E, S> where S: Unpin {}
245
246/// Build the [`Event`] for a single typed value. Logs and falls back
247/// to an `error` frame if the JSON payload cannot be serialized.
248fn event_for<E>(value: &E) -> Event
249where
250    E: SseEventMeta + serde::Serialize,
251{
252    let name = value.event_name();
253    match Event::default().event(name).json_data(value) {
254        Ok(ev) => ev,
255        Err(err) => {
256            tracing::error!(
257                error = %err,
258                event_name = name,
259                "sse json_data serialization failed; emitting error frame",
260            );
261            Event::default().event("error").data(err.to_string())
262        }
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use axum::body::Body;
270    use axum::http::StatusCode;
271    use axum::response::IntoResponse;
272    use futures::stream;
273    use http_body_util::BodyExt;
274
275    #[derive(serde::Serialize)]
276    #[serde(tag = "event", content = "data", rename_all = "snake_case")]
277    enum Ev {
278        Started { pipeline: String },
279        Done,
280    }
281
282    impl SseEventMeta for Ev {
283        fn event_name(&self) -> &'static str {
284            match self {
285                Self::Started { .. } => "started",
286                Self::Done => "done",
287            }
288        }
289        fn all_event_names() -> &'static [&'static str] {
290            &["started", "done"]
291        }
292    }
293
294    #[tokio::test]
295    async fn into_response_sets_text_event_stream_content_type() {
296        let s = SseStream::<Ev, _>::new(stream::iter(Vec::<Result<Ev, Infallible>>::new()));
297        let resp: Response = s.into_response();
298        assert_eq!(resp.status(), StatusCode::OK);
299        let ct = resp
300            .headers()
301            .get(axum::http::header::CONTENT_TYPE)
302            .unwrap();
303        assert!(ct.to_str().unwrap().starts_with("text/event-stream"));
304    }
305
306    #[tokio::test]
307    async fn emits_named_event_frame_with_json_data_for_each_item() {
308        let items: Vec<Result<Ev, Infallible>> = vec![
309            Ok(Ev::Started {
310                pipeline: "p1".into(),
311            }),
312            Ok(Ev::Done),
313        ];
314        let s = SseStream::<Ev, _>::new(stream::iter(items)).with_keep_alive(None);
315
316        let resp: Response = s.into_response();
317        // Drain the response body and inspect the framed output.
318        let body: Body = resp.into_body();
319        let bytes = body.collect().await.unwrap().to_bytes();
320        let text = std::str::from_utf8(&bytes).unwrap();
321
322        assert!(text.contains("event: started"));
323        assert!(text.contains(r#""event":"started""#));
324        assert!(text.contains(r#""pipeline":"p1""#));
325        assert!(text.contains("event: done"));
326        assert!(text.contains(r#""event":"done""#));
327    }
328}