Skip to main content

ferro_rs/http/
sse.rs

1//! Server-Sent Events (SSE) types: wire serialization, streaming body, and response factory.
2//!
3//! # Overview
4//!
5//! ```text
6//! SseStream::channel(16) → (Sender<SseEvent>, SseStream)
7//!     │
8//!     ├── Sender<SseEvent>   ── handler spawns a task that calls tx.send(event).await
9//!     └── SseStream ──► HttpResponse::sse(stream) ──► into_hyper() ──► FerroBody::Stream
10//! ```
11//!
12//! # Security note
13//!
14//! The `event` and `id` builder setters on [`SseEvent`] strip `\n`, `\r`, and `\0` characters
15//! to prevent SSE field injection: a caller-supplied value cannot inject an extra SSE field or
16//! event boundary (and a NUL cannot silently reset the browser's last-event-id). The `data`
17//! field is safe by construction — newlines in `data` produce
18//! repeated `data:` lines per the WHATWG spec, which is not an injection risk. The `retry`
19//! field is `u64` and cannot carry newlines.
20
21use bytes::Bytes;
22use hyper::body::{Body, Frame, SizeHint};
23use std::fmt;
24use std::pin::Pin;
25use std::task::{Context, Poll};
26use tokio::sync::mpsc;
27use tokio::time::{interval_at, Duration, Instant, Interval};
28
29// ──────────────────────────────────────────────────────────────────────────────
30// SseEvent
31// ──────────────────────────────────────────────────────────────────────────────
32
33/// A single Server-Sent Event, serializable to the WHATWG `text/event-stream` wire format.
34///
35/// Field ordering in the wire output follows the WHATWG spec recommendation:
36/// `event:`, `id:`, `retry:`, then one or more `data:` lines, terminated by a blank line.
37///
38/// # Example
39///
40/// ```rust,ignore
41/// let event = SseEvent::data("hello")
42///     .event("token")
43///     .id("42")
44///     .retry(3000);
45/// // Wire: "event: token\nid: 42\nretry: 3000\ndata: hello\n\n"
46/// ```
47#[derive(Debug, Clone)]
48pub struct SseEvent {
49    /// The event payload. Multi-line strings produce repeated `data:` lines.
50    pub data: String,
51    /// Optional named event type (`event:` field).
52    pub event: Option<String>,
53    /// Optional last-event ID (`id:` field).
54    pub id: Option<String>,
55    /// Optional client reconnection delay in milliseconds (`retry:` field).
56    pub retry: Option<u64>,
57}
58
59impl SseEvent {
60    /// Create an event with the given data string.
61    ///
62    /// This is the primary constructor; chain `.event()`, `.id()`, `.retry()` for additional fields.
63    pub fn data(data: impl Into<String>) -> Self {
64        Self {
65            data: data.into(),
66            event: None,
67            id: None,
68            retry: None,
69        }
70    }
71
72    /// Set the named event type.
73    ///
74    /// `event` is a single-line SSE field. Any `\n`, `\r`, or `\0` characters in the value
75    /// are stripped to prevent SSE field injection (and, for `id`, last-event-id resets).
76    /// `data` may contain newlines (rendered as multiple `data:` lines per the WHATWG
77    /// spec), so it is not stripped.
78    pub fn event(mut self, event: impl Into<String>) -> Self {
79        let s: String = event.into();
80        self.event = Some(s.replace(['\n', '\r', '\0'], ""));
81        self
82    }
83
84    /// Set the last-event ID.
85    ///
86    /// `id` is a single-line SSE field. Any `\n`, `\r`, or `\0` characters in the value are
87    /// stripped to prevent SSE field injection and to avoid a null byte silently resetting
88    /// the browser's last-event-id on reconnection.
89    pub fn id(mut self, id: impl Into<String>) -> Self {
90        let s: String = id.into();
91        self.id = Some(s.replace(['\n', '\r', '\0'], ""));
92        self
93    }
94
95    /// Set the client reconnection delay in milliseconds.
96    pub fn retry(mut self, ms: u64) -> Self {
97        self.retry = Some(ms);
98        self
99    }
100
101    /// Serialize to the SSE wire format.
102    ///
103    /// Equivalent to `format!("{event}")` via the [`Display`](fmt::Display) impl.
104    pub fn to_wire(&self) -> String {
105        self.to_string()
106    }
107}
108
109impl fmt::Display for SseEvent {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        if let Some(event) = &self.event {
112            writeln!(f, "event: {event}")?;
113        }
114        if let Some(id) = &self.id {
115            writeln!(f, "id: {id}")?;
116        }
117        if let Some(retry) = self.retry {
118            writeln!(f, "retry: {retry}")?;
119        }
120        // Multi-line data: each line gets its own `data:` prefix per WHATWG spec.
121        for line in self.data.lines() {
122            writeln!(f, "data: {line}")?;
123        }
124        // Empty data still emits one data: line (space after colon per spec).
125        if self.data.is_empty() {
126            writeln!(f, "data: ")?;
127        }
128        // Blank line terminates the event.
129        writeln!(f)
130    }
131}
132
133// ──────────────────────────────────────────────────────────────────────────────
134// SseStream
135// ──────────────────────────────────────────────────────────────────────────────
136
137/// Streaming HTTP body that serializes [`SseEvent`]s from an mpsc channel.
138///
139/// Implements `http_body::Body` so it can be carried as `FerroBody::Stream` through the
140/// framework's hyper serve loop. The stream ends when the [`mpsc::Sender`] is dropped.
141///
142/// A `:ping\n\n` keep-alive comment is emitted every 15 seconds while the channel is idle.
143/// Any real event resets the idle window.
144///
145/// # Bounded back-pressure
146///
147/// The internal channel is bounded (default 16 slots via [`SseStream::channel`]). If the
148/// client is too slow, `Sender::send().await` will apply back-pressure to the producer.
149///
150/// # Connection count limits
151///
152/// Each active `SseStream` holds a TCP connection open. Connection-count limits are the
153/// application's responsibility, not this primitive's.
154pub struct SseStream {
155    receiver: mpsc::Receiver<SseEvent>,
156    ping_interval: Interval,
157}
158
159impl SseStream {
160    /// Create a bounded channel for pushing events and the streaming body.
161    ///
162    /// Returns `(sender, stream)`. The handler holds the sender and pushes events from
163    /// a spawned task; the stream is wrapped in an `HttpResponse::sse` response.
164    ///
165    /// Uses [`interval_at`] with an initial delay equal to `interval_period` so that
166    /// the first ping is deferred — avoiding an immediate `:ping` frame on connection.
167    pub fn channel(buffer: usize) -> (mpsc::Sender<SseEvent>, Self) {
168        let (tx, rx) = mpsc::channel(buffer);
169        // interval_at defers the first tick, avoiding the immediate-first-tick pitfall.
170        let period = Duration::from_secs(15);
171        let ping = interval_at(Instant::now() + period, period);
172        (
173            tx,
174            SseStream {
175                receiver: rx,
176                ping_interval: ping,
177            },
178        )
179    }
180
181    /// Returns `true` if the internal channel has been closed (sender dropped).
182    pub fn is_closed(&self) -> bool {
183        self.receiver.is_closed()
184    }
185
186    /// Create a channel with a custom ping interval period.
187    ///
188    /// Intended for tests that need a short interval without waiting 15 seconds.
189    #[cfg(test)]
190    pub(crate) fn channel_with_interval(
191        buffer: usize,
192        interval_period: Duration,
193    ) -> (mpsc::Sender<SseEvent>, Self) {
194        let (tx, rx) = mpsc::channel(buffer);
195        let ping = interval_at(Instant::now() + interval_period, interval_period);
196        (
197            tx,
198            SseStream {
199                receiver: rx,
200                ping_interval: ping,
201            },
202        )
203    }
204}
205
206impl Body for SseStream {
207    type Data = Bytes;
208    type Error = std::convert::Infallible;
209
210    fn poll_frame(
211        mut self: Pin<&mut Self>,
212        cx: &mut Context<'_>,
213    ) -> Poll<Option<Result<Frame<Bytes>, Self::Error>>> {
214        // Both Receiver and Interval are Unpin — Pin::new is valid without pin-project.
215        match self.receiver.poll_recv(cx) {
216            Poll::Ready(Some(event)) => {
217                // Reset the idle window so the next keep-alive ping fires one full
218                // period AFTER this event, not at the original deadline (WR-01).
219                self.ping_interval.reset();
220                let bytes = Bytes::from(event.to_wire());
221                return Poll::Ready(Some(Ok(Frame::data(bytes))));
222            }
223            Poll::Ready(None) => {
224                // Sender dropped — signal end of stream.
225                return Poll::Ready(None);
226            }
227            Poll::Pending => {}
228        }
229
230        // No event ready — check keep-alive interval.
231        match Pin::new(&mut self.ping_interval).poll_tick(cx) {
232            Poll::Ready(_) => {
233                let ping = Bytes::from_static(b":ping\n\n");
234                Poll::Ready(Some(Ok(Frame::data(ping))))
235            }
236            Poll::Pending => Poll::Pending,
237        }
238    }
239
240    fn is_end_stream(&self) -> bool {
241        // Only terminated when the Sender is dropped; we cannot know ahead of time.
242        false
243    }
244
245    fn size_hint(&self) -> SizeHint {
246        SizeHint::default()
247    }
248}
249
250// ──────────────────────────────────────────────────────────────────────────────
251// Tests
252// ──────────────────────────────────────────────────────────────────────────────
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::http::response::HttpResponse;
258    use futures_util::task::noop_waker;
259
260    // ── SseEvent wire format ──────────────────────────────────────────────────
261
262    /// T-168-01: full event wire format
263    #[test]
264    fn sse_event_wire_format() {
265        let event = SseEvent::data("hello").event("msg").id("1").retry(3000);
266        let wire = event.to_wire();
267        assert_eq!(wire, "event: msg\nid: 1\nretry: 3000\ndata: hello\n\n");
268    }
269
270    /// T-168-02: multi-line data → repeated `data:` lines
271    #[test]
272    fn sse_event_multi_line_data() {
273        let event = SseEvent::data("line one\nline two");
274        let wire = event.to_wire();
275        assert_eq!(wire, "data: line one\ndata: line two\n\n");
276    }
277
278    /// Empty data emits exactly one `data: \n` line
279    #[test]
280    fn sse_event_empty_data() {
281        let event = SseEvent::data("");
282        let wire = event.to_wire();
283        assert_eq!(wire, "data: \n\n");
284    }
285
286    /// data-only event (no optional fields)
287    #[test]
288    fn sse_event_data_only() {
289        let wire = SseEvent::data("hello world").to_wire();
290        assert_eq!(wire, "data: hello world\n\n");
291    }
292
293    // ── SseStream poll_frame ──────────────────────────────────────────────────
294
295    /// T-168-03: poll_frame delivers event bytes from channel
296    #[tokio::test]
297    async fn sse_stream_poll_delivers_event() {
298        let (tx, mut stream) = SseStream::channel(4);
299        tx.send(SseEvent::data("first")).await.unwrap();
300
301        let waker = noop_waker();
302        let mut cx = Context::from_waker(&waker);
303
304        let frame = Pin::new(&mut stream).poll_frame(&mut cx);
305        match frame {
306            Poll::Ready(Some(Ok(f))) => {
307                let data = f.into_data().expect("expected data frame");
308                assert_eq!(data, Bytes::from("data: first\n\n"));
309            }
310            other => panic!("expected Poll::Ready(Some(Ok(frame))), got {other:?}"),
311        }
312
313        // Second poll — no more events queued.
314        let frame2 = Pin::new(&mut stream).poll_frame(&mut cx);
315        assert!(
316            matches!(frame2, Poll::Pending),
317            "expected Poll::Pending with no queued events, got {frame2:?}"
318        );
319    }
320
321    /// T-168-04: keep-alive ping is emitted when the interval fires with no pending events.
322    ///
323    /// Uses a 10 ms interval (via the test-only `channel_with_interval` constructor) and
324    /// a real sleep so we don't need the `test-util` tokio feature for `pause/advance`.
325    #[tokio::test]
326    async fn sse_stream_keep_alive_ping() {
327        let period = Duration::from_millis(10);
328        let (_tx, mut stream) = SseStream::channel_with_interval(4, period);
329
330        // Wait for the interval to fire.
331        tokio::time::sleep(period * 3).await;
332
333        // Drive poll_frame with a real waker via a one-shot future.
334        use http_body_util::BodyExt;
335        let frame = tokio::time::timeout(Duration::from_millis(200), stream.frame())
336            .await
337            .expect("timed out waiting for :ping frame")
338            .expect("stream ended unexpectedly")
339            .expect("poll_frame returned error");
340
341        let data = frame.into_data().expect("expected data frame");
342        assert_eq!(data, Bytes::from_static(b":ping\n\n"));
343    }
344
345    /// T-168-SEC: field injection — newline in `event`/`id` is stripped, never injected.
346    ///
347    /// An `event` or `id` value containing `\n` or `\r` must produce exactly one
348    /// `event:`/`id:` field line in the wire output. The newline is stripped so a
349    /// caller-supplied value cannot inject extra SSE fields.
350    #[test]
351    fn sse_field_injection_newline_stripped() {
352        // event with embedded newline
353        let wire = SseEvent::data("x").event("a\nb").to_wire();
354        let event_lines: Vec<&str> = wire.lines().filter(|l| l.starts_with("event:")).collect();
355        assert_eq!(
356            event_lines.len(),
357            1,
358            "expected exactly one event: line, got: {wire:?}"
359        );
360        assert_eq!(
361            event_lines[0], "event: ab",
362            "embedded newline should be stripped, not injected"
363        );
364
365        // id with embedded carriage-return
366        let wire2 = SseEvent::data("y").id("c\rd").to_wire();
367        let id_lines: Vec<&str> = wire2.lines().filter(|l| l.starts_with("id:")).collect();
368        assert_eq!(
369            id_lines.len(),
370            1,
371            "expected exactly one id: line, got: {wire2:?}"
372        );
373        assert_eq!(
374            id_lines[0], "id: cd",
375            "embedded carriage-return should be stripped, not injected"
376        );
377
378        // id with embedded NUL — a null byte would reset the browser's last-event-id (IN-01)
379        let wire3 = SseEvent::data("z").id("e\0f").event("g\0h").to_wire();
380        assert!(
381            wire3.contains("id: ef") && wire3.contains("event: gh"),
382            "embedded NUL should be stripped from id and event, got: {wire3:?}"
383        );
384    }
385
386    /// T-168-09: incremental delivery — event N frame before event N+1 is sent
387    #[tokio::test]
388    async fn sse_stream_incremental_delivery() {
389        let (tx, mut stream) = SseStream::channel(4);
390
391        let waker = noop_waker();
392        let mut cx = Context::from_waker(&waker);
393
394        // Before sending: Pending
395        let before = Pin::new(&mut stream).poll_frame(&mut cx);
396        assert!(
397            matches!(before, Poll::Pending),
398            "expected Poll::Pending before send"
399        );
400
401        // Send event N
402        tx.send(SseEvent::data("N")).await.unwrap();
403
404        // Now Ready
405        let after = Pin::new(&mut stream).poll_frame(&mut cx);
406        assert!(
407            matches!(after, Poll::Ready(Some(Ok(_)))),
408            "expected Poll::Ready after send"
409        );
410
411        // Still Pending — event N+1 not yet sent
412        let still_pending = Pin::new(&mut stream).poll_frame(&mut cx);
413        assert!(
414            matches!(still_pending, Poll::Pending),
415            "expected Poll::Pending before N+1 send"
416        );
417    }
418
419    // ── Factory headers + FerroBody::Stream variant (T-168-07, T-168-08) ─────
420
421    /// T-168-07: SSE response includes all 4 required headers.
422    #[tokio::test]
423    async fn sse_factory_headers() {
424        let (_, resp) = HttpResponse::sse_channel(16);
425        let headers = resp.headers();
426
427        let header_value =
428            |name: &str| -> Option<&str> { headers.get(name).and_then(|v| v.to_str().ok()) };
429
430        assert_eq!(
431            header_value("content-type"),
432            Some("text/event-stream"),
433            "Content-Type must be text/event-stream"
434        );
435        assert_eq!(
436            header_value("cache-control"),
437            Some("no-cache"),
438            "Cache-Control must be no-cache"
439        );
440        assert_eq!(
441            header_value("connection"),
442            Some("keep-alive"),
443            "Connection must be keep-alive"
444        );
445        assert_eq!(
446            header_value("x-accel-buffering"),
447            Some("no"),
448            "X-Accel-Buffering must be no"
449        );
450    }
451
452    /// T-168-08 (D-06 / SC#3 reinterpreted): SSE response body is FerroBody::Stream, not Full.
453    ///
454    /// A buffered response body (via `into_hyper()`) must return `is_streaming() == false`.
455    /// The SSE response (via `sse_channel()`) must return `is_streaming() == true`.
456    #[tokio::test]
457    async fn sse_response_is_stream_variant() {
458        let (_, sse_resp) = HttpResponse::sse_channel(16);
459        assert!(
460            sse_resp.body().is_streaming(),
461            "SSE response body must be FerroBody::Stream"
462        );
463
464        let buffered_resp = HttpResponse::text("hello").into_hyper();
465        assert!(
466            !buffered_resp.body().is_streaming(),
467            "buffered response body must NOT be FerroBody::Stream"
468        );
469    }
470}