axum/response/
sse.rs

1//! Server-Sent Events (SSE) responses.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     Router,
8//!     routing::get,
9//!     response::sse::{Event, KeepAlive, Sse},
10//! };
11//! use std::{time::Duration, convert::Infallible};
12//! use tokio_stream::StreamExt as _ ;
13//! use futures_util::stream::{self, Stream};
14//!
15//! let app = Router::new().route("/sse", get(sse_handler));
16//!
17//! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
18//!     // A `Stream` that repeats an event every second
19//!     let stream = stream::repeat_with(|| Event::default().data("hi!"))
20//!         .map(Ok)
21//!         .throttle(Duration::from_secs(1));
22//!
23//!     Sse::new(stream).keep_alive(KeepAlive::default())
24//! }
25//! # let _: Router = app;
26//! ```
27
28use crate::{
29    body::{Bytes, HttpBody},
30    BoxError,
31};
32use axum_core::{
33    body::Body,
34    response::{IntoResponse, Response},
35};
36use bytes::{BufMut, BytesMut};
37use futures_util::stream::{Stream, TryStream};
38use http_body::Frame;
39use pin_project_lite::pin_project;
40use std::{
41    fmt::{self, Write as _},
42    io::Write as _,
43    mem,
44    pin::Pin,
45    task::{ready, Context, Poll},
46    time::Duration,
47};
48use sync_wrapper::SyncWrapper;
49
50/// An SSE response
51#[derive(Clone)]
52#[must_use]
53pub struct Sse<S> {
54    stream: S,
55}
56
57impl<S> Sse<S> {
58    /// Create a new [`Sse`] response that will respond with the given stream of
59    /// [`Event`]s.
60    ///
61    /// See the [module docs](self) for more details.
62    pub fn new(stream: S) -> Self
63    where
64        S: TryStream<Ok = Event> + Send + 'static,
65        S::Error: Into<BoxError>,
66    {
67        Sse { stream }
68    }
69
70    /// Configure the interval between keep-alive messages.
71    ///
72    /// Defaults to no keep-alive messages.
73    #[cfg(feature = "tokio")]
74    pub fn keep_alive(self, keep_alive: KeepAlive) -> Sse<KeepAliveStream<S>> {
75        Sse {
76            stream: KeepAliveStream::new(keep_alive, self.stream),
77        }
78    }
79}
80
81impl<S> fmt::Debug for Sse<S> {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        f.debug_struct("Sse")
84            .field("stream", &format_args!("{}", std::any::type_name::<S>()))
85            .finish()
86    }
87}
88
89impl<S, E> IntoResponse for Sse<S>
90where
91    S: Stream<Item = Result<Event, E>> + Send + 'static,
92    E: Into<BoxError>,
93{
94    fn into_response(self) -> Response {
95        (
96            [
97                (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
98                (http::header::CACHE_CONTROL, "no-cache"),
99            ],
100            Body::new(SseBody {
101                event_stream: SyncWrapper::new(self.stream),
102            }),
103        )
104            .into_response()
105    }
106}
107
108pin_project! {
109    struct SseBody<S> {
110        #[pin]
111        event_stream: SyncWrapper<S>,
112    }
113}
114
115impl<S, E> HttpBody for SseBody<S>
116where
117    S: Stream<Item = Result<Event, E>>,
118{
119    type Data = Bytes;
120    type Error = E;
121
122    fn poll_frame(
123        self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
126        let this = self.project();
127
128        match ready!(this.event_stream.get_pin_mut().poll_next(cx)) {
129            Some(Ok(event)) => Poll::Ready(Some(Ok(Frame::data(event.finalize())))),
130            Some(Err(error)) => Poll::Ready(Some(Err(error))),
131            None => Poll::Ready(None),
132        }
133    }
134}
135
136/// The state of an event's buffer.
137///
138/// This type allows creating events in a `const` context
139/// by using a finalized buffer.
140///
141/// While the buffer is active, more bytes can be written to it.
142/// Once finalized, it's immutable and cheap to clone.
143/// The buffer is active during the event building, but eventually
144/// becomes finalized to send http body frames as [`Bytes`].
145#[derive(Debug, Clone)]
146enum Buffer {
147    Active(BytesMut),
148    Finalized(Bytes),
149}
150
151impl Buffer {
152    /// Returns a mutable reference to the internal buffer.
153    ///
154    /// If the buffer was finalized, this method creates
155    /// a new active buffer with the previous contents.
156    fn as_mut(&mut self) -> &mut BytesMut {
157        match self {
158            Buffer::Active(bytes_mut) => bytes_mut,
159            Buffer::Finalized(bytes) => {
160                *self = Buffer::Active(BytesMut::from(mem::take(bytes)));
161                match self {
162                    Buffer::Active(bytes_mut) => bytes_mut,
163                    Buffer::Finalized(_) => unreachable!(),
164                }
165            }
166        }
167    }
168}
169
170/// Server-sent event
171#[derive(Debug, Clone)]
172#[must_use]
173pub struct Event {
174    buffer: Buffer,
175    flags: EventFlags,
176}
177
178/// Expose [`Event`] as a [`std::fmt::Write`]
179/// such that any form of data can be written as data safely.
180///
181/// This also ensures that newline characters `\r` and `\n`
182/// correctly trigger a split with a new `data: ` prefix.
183///
184/// # Panics
185///
186/// Panics if any `data` has already been written prior to the first write
187/// of this [`EventDataWriter`] instance.
188#[derive(Debug)]
189#[must_use]
190pub struct EventDataWriter {
191    event: Event,
192
193    // Indicates if _this_ EventDataWriter has written data,
194    // this does not say anything about whether or not `event` contains
195    // data or not.
196    data_written: bool,
197}
198
199impl Event {
200    /// Default keep-alive event
201    pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n"));
202
203    const fn finalized(bytes: Bytes) -> Self {
204        Self {
205            buffer: Buffer::Finalized(bytes),
206            flags: EventFlags::from_bits(0),
207        }
208    }
209
210    /// Use this [`Event`] as a [`EventDataWriter`] to write custom data.
211    ///
212    /// - [`Self::data`] can be used as a shortcut to write `str` data
213    /// - [`Self::json_data`] can be used as a shortcut to write `json` data
214    ///
215    /// Turn it into an [`Event`] again using [`EventDataWriter::into_event`].
216    pub fn into_data_writer(self) -> EventDataWriter {
217        EventDataWriter {
218            event: self,
219            data_written: false,
220        }
221    }
222
223    /// Set the event's data data field(s) (`data: <content>`)
224    ///
225    /// Newlines in `data` will automatically be broken across `data: ` fields.
226    ///
227    /// This corresponds to [`MessageEvent`'s data field].
228    ///
229    /// Note that events with an empty data field will be ignored by the browser.
230    ///
231    /// # Panics
232    ///
233    /// Panics if any `data` has already been written before.
234    ///
235    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
236    pub fn data<T>(self, data: T) -> Self
237    where
238        T: AsRef<str>,
239    {
240        let mut writer = self.into_data_writer();
241        let _ = writer.write_str(data.as_ref());
242        writer.into_event()
243    }
244
245    /// Set the event's data field to a value serialized as unformatted JSON (`data: <content>`).
246    ///
247    /// This corresponds to [`MessageEvent`'s data field].
248    ///
249    /// # Panics
250    ///
251    /// Panics if any `data` has already been written before.
252    ///
253    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
254    #[cfg(feature = "json")]
255    pub fn json_data<T>(self, data: T) -> Result<Self, axum_core::Error>
256    where
257        T: serde_core::Serialize,
258    {
259        struct JsonWriter<'a>(&'a mut EventDataWriter);
260        impl std::io::Write for JsonWriter<'_> {
261            #[inline]
262            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
263                Ok(self.0.write_buf(buf))
264            }
265            fn flush(&mut self) -> std::io::Result<()> {
266                Ok(())
267            }
268        }
269
270        let mut writer = self.into_data_writer();
271
272        let json_writer = JsonWriter(&mut writer);
273        serde_json::to_writer(json_writer, &data).map_err(axum_core::Error::new)?;
274
275        Ok(writer.into_event())
276    }
277
278    /// Set the event's comment field (`:<comment-text>`).
279    ///
280    /// This field will be ignored by most SSE clients.
281    ///
282    /// Unlike other functions, this function can be called multiple times to add many comments.
283    ///
284    /// # Panics
285    ///
286    /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
287    /// comments.
288    pub fn comment<T>(mut self, comment: T) -> Event
289    where
290        T: AsRef<str>,
291    {
292        self.field("", comment.as_ref());
293        self
294    }
295
296    /// Set the event's name field (`event:<event-name>`).
297    ///
298    /// This corresponds to the `type` parameter given when calling `addEventListener` on an
299    /// [`EventSource`]. For example, `.event("update")` should correspond to
300    /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a
301    /// [`message` event] instead.
302    ///
303    /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource
304    /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event
305    ///
306    /// # Panics
307    ///
308    /// - Panics if `event` contains any newlines or carriage returns.
309    /// - Panics if this function has already been called on this event.
310    pub fn event<T>(mut self, event: T) -> Event
311    where
312        T: AsRef<str>,
313    {
314        if self.flags.contains(EventFlags::HAS_EVENT) {
315            panic!("Called `Event::event` multiple times");
316        }
317        self.flags.insert(EventFlags::HAS_EVENT);
318
319        self.field("event", event.as_ref());
320
321        self
322    }
323
324    /// Set the event's retry timeout field (`retry: <timeout>`).
325    ///
326    /// This sets how long clients will wait before reconnecting if they are disconnected from the
327    /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
328    /// wish, such as if they implement exponential backoff.
329    ///
330    /// # Panics
331    ///
332    /// Panics if this function has already been called on this event.
333    pub fn retry(mut self, duration: Duration) -> Event {
334        if self.flags.contains(EventFlags::HAS_RETRY) {
335            panic!("Called `Event::retry` multiple times");
336        }
337        self.flags.insert(EventFlags::HAS_RETRY);
338
339        let buffer = self.buffer.as_mut();
340        buffer.extend_from_slice(b"retry: ");
341
342        let secs = duration.as_secs();
343        let millis = duration.subsec_millis();
344
345        if secs > 0 {
346            // format seconds
347            buffer.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
348
349            // pad milliseconds
350            if millis < 10 {
351                buffer.extend_from_slice(b"00");
352            } else if millis < 100 {
353                buffer.extend_from_slice(b"0");
354            }
355        }
356
357        // format milliseconds
358        buffer.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
359
360        buffer.put_u8(b'\n');
361
362        self
363    }
364
365    /// Set the event's identifier field (`id:<identifier>`).
366    ///
367    /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself,
368    /// the browser will set that field to the last known message ID, starting with the empty
369    /// string.
370    ///
371    /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId
372    ///
373    /// # Panics
374    ///
375    /// - Panics if `id` contains any newlines, carriage returns or null characters.
376    /// - Panics if this function has already been called on this event.
377    pub fn id<T>(mut self, id: T) -> Event
378    where
379        T: AsRef<str>,
380    {
381        if self.flags.contains(EventFlags::HAS_ID) {
382            panic!("Called `Event::id` multiple times");
383        }
384        self.flags.insert(EventFlags::HAS_ID);
385
386        let id = id.as_ref().as_bytes();
387        assert_eq!(
388            memchr::memchr(b'\0', id),
389            None,
390            "Event ID cannot contain null characters",
391        );
392
393        self.field("id", id);
394        self
395    }
396
397    fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
398        let value = value.as_ref();
399        assert_eq!(
400            memchr::memchr2(b'\r', b'\n', value),
401            None,
402            "SSE field value cannot contain newlines or carriage returns",
403        );
404
405        let buffer = self.buffer.as_mut();
406        buffer.extend_from_slice(name.as_bytes());
407        buffer.put_u8(b':');
408        buffer.put_u8(b' ');
409        buffer.extend_from_slice(value);
410        buffer.put_u8(b'\n');
411    }
412
413    fn finalize(self) -> Bytes {
414        match self.buffer {
415            Buffer::Finalized(bytes) => bytes,
416            Buffer::Active(mut bytes_mut) => {
417                bytes_mut.put_u8(b'\n');
418                bytes_mut.freeze()
419            }
420        }
421    }
422}
423
424impl EventDataWriter {
425    /// Consume the [`EventDataWriter`] and return the [`Event`] once again.
426    ///
427    /// In case any data was written by this instance
428    /// it will also write the trailing `\n` character.
429    pub fn into_event(self) -> Event {
430        let mut event = self.event;
431        if self.data_written {
432            let _ = event.buffer.as_mut().write_char('\n');
433        }
434        event
435    }
436}
437
438impl EventDataWriter {
439    // Assumption: underlying writer never returns an error:
440    // <https://docs.rs/bytes/latest/src/bytes/buf/writer.rs.html#79-82>
441    fn write_buf(&mut self, buf: &[u8]) -> usize {
442        if buf.is_empty() {
443            return 0;
444        }
445
446        let buffer = self.event.buffer.as_mut();
447
448        if !std::mem::replace(&mut self.data_written, true) {
449            if self.event.flags.contains(EventFlags::HAS_DATA) {
450                panic!("Called `Event::data*` multiple times");
451            }
452
453            let _ = buffer.write_str("data: ");
454            self.event.flags.insert(EventFlags::HAS_DATA);
455        }
456
457        let mut writer = buffer.writer();
458
459        let mut last_split = 0;
460        for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
461            let _ = writer.write_all(&buf[last_split..=delimiter]);
462            let _ = writer.write_all(b"data: ");
463            last_split = delimiter + 1;
464        }
465        let _ = writer.write_all(&buf[last_split..]);
466
467        buf.len()
468    }
469}
470
471impl fmt::Write for EventDataWriter {
472    fn write_str(&mut self, s: &str) -> fmt::Result {
473        let _ = self.write_buf(s.as_bytes());
474        Ok(())
475    }
476}
477
478impl Default for Event {
479    fn default() -> Self {
480        Self {
481            buffer: Buffer::Active(BytesMut::new()),
482            flags: EventFlags::from_bits(0),
483        }
484    }
485}
486
487#[derive(Debug, Copy, Clone, PartialEq)]
488struct EventFlags(u8);
489
490impl EventFlags {
491    const HAS_DATA: Self = Self::from_bits(0b0001);
492    const HAS_EVENT: Self = Self::from_bits(0b0010);
493    const HAS_RETRY: Self = Self::from_bits(0b0100);
494    const HAS_ID: Self = Self::from_bits(0b1000);
495
496    const fn bits(&self) -> u8 {
497        self.0
498    }
499
500    const fn from_bits(bits: u8) -> Self {
501        Self(bits)
502    }
503
504    const fn contains(&self, other: Self) -> bool {
505        self.bits() & other.bits() == other.bits()
506    }
507
508    fn insert(&mut self, other: Self) {
509        *self = Self::from_bits(self.bits() | other.bits());
510    }
511}
512
513/// Configure the interval between keep-alive messages, the content
514/// of each message, and the associated stream.
515#[derive(Debug, Clone)]
516#[must_use]
517pub struct KeepAlive {
518    event: Event,
519    max_interval: Duration,
520}
521
522impl KeepAlive {
523    /// Create a new `KeepAlive`.
524    pub fn new() -> Self {
525        Self {
526            event: Event::DEFAULT_KEEP_ALIVE,
527            max_interval: Duration::from_secs(15),
528        }
529    }
530
531    /// Customize the interval between keep-alive messages.
532    ///
533    /// Default is 15 seconds.
534    pub fn interval(mut self, time: Duration) -> Self {
535        self.max_interval = time;
536        self
537    }
538
539    /// Customize the text of the keep-alive message.
540    ///
541    /// Default is an empty comment.
542    ///
543    /// # Panics
544    ///
545    /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
546    /// comments.
547    pub fn text<I>(self, text: I) -> Self
548    where
549        I: AsRef<str>,
550    {
551        self.event(Event::default().comment(text))
552    }
553
554    /// Customize the event of the keep-alive message.
555    ///
556    /// Default is an empty comment.
557    ///
558    /// # Panics
559    ///
560    /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
561    /// comments.
562    pub fn event(mut self, event: Event) -> Self {
563        self.event = Event::finalized(event.finalize());
564        self
565    }
566}
567
568impl Default for KeepAlive {
569    fn default() -> Self {
570        Self::new()
571    }
572}
573
574#[cfg(feature = "tokio")]
575pin_project! {
576    /// A wrapper around a stream that produces keep-alive events
577    #[derive(Debug)]
578    pub struct KeepAliveStream<S> {
579        #[pin]
580        alive_timer: tokio::time::Sleep,
581        #[pin]
582        inner: S,
583        keep_alive: KeepAlive,
584    }
585}
586
587#[cfg(feature = "tokio")]
588impl<S> KeepAliveStream<S> {
589    fn new(keep_alive: KeepAlive, inner: S) -> Self {
590        Self {
591            alive_timer: tokio::time::sleep(keep_alive.max_interval),
592            inner,
593            keep_alive,
594        }
595    }
596
597    fn reset(self: Pin<&mut Self>) {
598        let this = self.project();
599        this.alive_timer
600            .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
601    }
602}
603
604#[cfg(feature = "tokio")]
605impl<S, E> Stream for KeepAliveStream<S>
606where
607    S: Stream<Item = Result<Event, E>>,
608{
609    type Item = Result<Event, E>;
610
611    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
612        use std::future::Future;
613
614        let mut this = self.as_mut().project();
615
616        match this.inner.as_mut().poll_next(cx) {
617            Poll::Ready(Some(Ok(event))) => {
618                self.reset();
619
620                Poll::Ready(Some(Ok(event)))
621            }
622            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
623            Poll::Ready(None) => Poll::Ready(None),
624            Poll::Pending => {
625                ready!(this.alive_timer.poll(cx));
626
627                let event = this.keep_alive.event.clone();
628
629                self.reset();
630
631                Poll::Ready(Some(Ok(event)))
632            }
633        }
634    }
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640    use crate::{routing::get, test_helpers::*, Router};
641    use futures_util::stream;
642    use serde_json::value::RawValue;
643    use std::{collections::HashMap, convert::Infallible};
644    use tokio_stream::StreamExt as _;
645
646    #[test]
647    fn leading_space_is_not_stripped() {
648        let no_leading_space = Event::default().data("\tfoobar");
649        assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
650
651        let leading_space = Event::default().data(" foobar");
652        assert_eq!(&*leading_space.finalize(), b"data:  foobar\n\n");
653    }
654
655    #[test]
656    fn write_data_writer_str() {
657        // also confirm that nop writers do nothing :)
658        let mut writer = Event::default()
659            .into_data_writer()
660            .into_event()
661            .into_data_writer();
662        writer.write_str("").unwrap();
663        let mut writer = writer.into_event().into_data_writer();
664
665        writer.write_str("").unwrap();
666        writer.write_str("moon ").unwrap();
667        writer.write_str("star\nsun").unwrap();
668        writer.write_str("").unwrap();
669        writer.write_str("set").unwrap();
670        writer.write_str("").unwrap();
671        writer.write_str(" bye\r").unwrap();
672
673        let event = writer.into_event();
674
675        assert_eq!(
676            &*event.finalize(),
677            b"data: moon star\ndata: sunset bye\rdata: \n\n"
678        );
679    }
680
681    #[test]
682    fn valid_json_raw_value_chars_handled() {
683        let json_string = "{\r\"foo\":  \n\r\r   \"bar\\n\"\n}";
684        let json_raw_value_event = Event::default()
685            .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
686            .unwrap();
687        assert_eq!(
688            &*json_raw_value_event.finalize(),
689            b"data: {\rdata: \"foo\":  \ndata: \rdata: \rdata:    \"bar\\n\"\ndata: }\n\n"
690        );
691    }
692
693    #[crate::test]
694    async fn basic() {
695        let app = Router::new().route(
696            "/",
697            get(|| async {
698                let stream = stream::iter(vec![
699                    Event::default().data("one").comment("this is a comment"),
700                    Event::default()
701                        .json_data(serde_json::json!({ "foo": "bar" }))
702                        .unwrap(),
703                    Event::default()
704                        .event("three")
705                        .retry(Duration::from_secs(30))
706                        .id("unique-id"),
707                ])
708                .map(Ok::<_, Infallible>);
709                Sse::new(stream)
710            }),
711        );
712
713        let client = TestClient::new(app);
714        let mut stream = client.get("/").await;
715
716        assert_eq!(stream.headers()["content-type"], "text/event-stream");
717        assert_eq!(stream.headers()["cache-control"], "no-cache");
718
719        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
720        assert_eq!(event_fields.get("data").unwrap(), "one");
721        assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
722
723        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
724        assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
725        assert!(!event_fields.contains_key("comment"));
726
727        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
728        assert_eq!(event_fields.get("event").unwrap(), "three");
729        assert_eq!(event_fields.get("retry").unwrap(), "30000");
730        assert_eq!(event_fields.get("id").unwrap(), "unique-id");
731        assert!(!event_fields.contains_key("comment"));
732
733        assert!(stream.chunk_text().await.is_none());
734    }
735
736    #[tokio::test(start_paused = true)]
737    async fn keep_alive() {
738        const DELAY: Duration = Duration::from_secs(5);
739
740        let app = Router::new().route(
741            "/",
742            get(|| async {
743                let stream = stream::repeat_with(|| Event::default().data("msg"))
744                    .map(Ok::<_, Infallible>)
745                    .throttle(DELAY);
746
747                Sse::new(stream).keep_alive(
748                    KeepAlive::new()
749                        .interval(Duration::from_secs(1))
750                        .text("keep-alive-text"),
751                )
752            }),
753        );
754
755        let client = TestClient::new(app);
756        let mut stream = client.get("/").await;
757
758        for _ in 0..5 {
759            // first message should be an event
760            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
761            assert_eq!(event_fields.get("data").unwrap(), "msg");
762
763            // then 4 seconds of keep-alive messages
764            for _ in 0..4 {
765                tokio::time::sleep(Duration::from_secs(1)).await;
766                let event_fields = parse_event(&stream.chunk_text().await.unwrap());
767                assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
768            }
769        }
770    }
771
772    #[tokio::test(start_paused = true)]
773    async fn keep_alive_ends_when_the_stream_ends() {
774        const DELAY: Duration = Duration::from_secs(5);
775
776        let app = Router::new().route(
777            "/",
778            get(|| async {
779                let stream = stream::repeat_with(|| Event::default().data("msg"))
780                    .map(Ok::<_, Infallible>)
781                    .throttle(DELAY)
782                    .take(2);
783
784                Sse::new(stream).keep_alive(
785                    KeepAlive::new()
786                        .interval(Duration::from_secs(1))
787                        .text("keep-alive-text"),
788                )
789            }),
790        );
791
792        let client = TestClient::new(app);
793        let mut stream = client.get("/").await;
794
795        // first message should be an event
796        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
797        assert_eq!(event_fields.get("data").unwrap(), "msg");
798
799        // then 4 seconds of keep-alive messages
800        for _ in 0..4 {
801            tokio::time::sleep(Duration::from_secs(1)).await;
802            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
803            assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
804        }
805
806        // then the last event
807        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
808        assert_eq!(event_fields.get("data").unwrap(), "msg");
809
810        // then no more events or keep-alive messages
811        assert!(stream.chunk_text().await.is_none());
812    }
813
814    fn parse_event(payload: &str) -> HashMap<String, String> {
815        let mut fields = HashMap::new();
816
817        let mut lines = payload.lines().peekable();
818        while let Some(line) = lines.next() {
819            if line.is_empty() {
820                assert!(lines.next().is_none());
821                break;
822            }
823
824            let (mut key, value) = line.split_once(':').unwrap();
825            let value = value.trim();
826            if key.is_empty() {
827                key = "comment";
828            }
829            fields.insert(key.to_owned(), value.to_owned());
830        }
831
832        fields
833    }
834}