1use crate::{
29 body::{Bytes, HttpBody},
30 BoxError,
31};
32use axum_core::{
33 body::Body,
34 response::{IntoResponse, Response},
35};
36use bytes::{BufMut, BytesMut};
37use futures_util::stream::{Stream, TryStream};
38use http_body::Frame;
39use pin_project_lite::pin_project;
40use std::{
41 fmt::{self, Write as _},
42 io::Write as _,
43 mem,
44 pin::Pin,
45 task::{ready, Context, Poll},
46 time::Duration,
47};
48use sync_wrapper::SyncWrapper;
49
50#[derive(Clone)]
52#[must_use]
53pub struct Sse<S> {
54 stream: S,
55}
56
57impl<S> Sse<S> {
58 pub fn new(stream: S) -> Self
63 where
64 S: TryStream<Ok = Event> + Send + 'static,
65 S::Error: Into<BoxError>,
66 {
67 Sse { stream }
68 }
69
70 #[cfg(feature = "tokio")]
74 pub fn keep_alive(self, keep_alive: KeepAlive) -> Sse<KeepAliveStream<S>> {
75 Sse {
76 stream: KeepAliveStream::new(keep_alive, self.stream),
77 }
78 }
79}
80
81impl<S> fmt::Debug for Sse<S> {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 f.debug_struct("Sse")
84 .field("stream", &format_args!("{}", std::any::type_name::<S>()))
85 .finish()
86 }
87}
88
89impl<S, E> IntoResponse for Sse<S>
90where
91 S: Stream<Item = Result<Event, E>> + Send + 'static,
92 E: Into<BoxError>,
93{
94 fn into_response(self) -> Response {
95 (
96 [
97 (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
98 (http::header::CACHE_CONTROL, "no-cache"),
99 ],
100 Body::new(SseBody {
101 event_stream: SyncWrapper::new(self.stream),
102 }),
103 )
104 .into_response()
105 }
106}
107
108pin_project! {
109 struct SseBody<S> {
110 #[pin]
111 event_stream: SyncWrapper<S>,
112 }
113}
114
115impl<S, E> HttpBody for SseBody<S>
116where
117 S: Stream<Item = Result<Event, E>>,
118{
119 type Data = Bytes;
120 type Error = E;
121
122 fn poll_frame(
123 self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
126 let this = self.project();
127
128 match ready!(this.event_stream.get_pin_mut().poll_next(cx)) {
129 Some(Ok(event)) => Poll::Ready(Some(Ok(Frame::data(event.finalize())))),
130 Some(Err(error)) => Poll::Ready(Some(Err(error))),
131 None => Poll::Ready(None),
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
146enum Buffer {
147 Active(BytesMut),
148 Finalized(Bytes),
149}
150
151impl Buffer {
152 fn as_mut(&mut self) -> &mut BytesMut {
157 match self {
158 Buffer::Active(bytes_mut) => bytes_mut,
159 Buffer::Finalized(bytes) => {
160 *self = Buffer::Active(BytesMut::from(mem::take(bytes)));
161 match self {
162 Buffer::Active(bytes_mut) => bytes_mut,
163 Buffer::Finalized(_) => unreachable!(),
164 }
165 }
166 }
167 }
168}
169
170#[derive(Debug, Clone)]
172#[must_use]
173pub struct Event {
174 buffer: Buffer,
175 flags: EventFlags,
176}
177
178#[derive(Debug)]
189#[must_use]
190pub struct EventDataWriter {
191 event: Event,
192
193 data_written: bool,
197}
198
199impl Event {
200 pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n"));
202
203 const fn finalized(bytes: Bytes) -> Self {
204 Self {
205 buffer: Buffer::Finalized(bytes),
206 flags: EventFlags::from_bits(0),
207 }
208 }
209
210 pub fn into_data_writer(self) -> EventDataWriter {
217 EventDataWriter {
218 event: self,
219 data_written: false,
220 }
221 }
222
223 pub fn data<T>(self, data: T) -> Self
237 where
238 T: AsRef<str>,
239 {
240 let mut writer = self.into_data_writer();
241 let _ = writer.write_str(data.as_ref());
242 writer.into_event()
243 }
244
245 #[cfg(feature = "json")]
255 pub fn json_data<T>(self, data: T) -> Result<Self, axum_core::Error>
256 where
257 T: serde_core::Serialize,
258 {
259 struct JsonWriter<'a>(&'a mut EventDataWriter);
260 impl std::io::Write for JsonWriter<'_> {
261 #[inline]
262 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
263 Ok(self.0.write_buf(buf))
264 }
265 fn flush(&mut self) -> std::io::Result<()> {
266 Ok(())
267 }
268 }
269
270 let mut writer = self.into_data_writer();
271
272 let json_writer = JsonWriter(&mut writer);
273 serde_json::to_writer(json_writer, &data).map_err(axum_core::Error::new)?;
274
275 Ok(writer.into_event())
276 }
277
278 pub fn comment<T>(mut self, comment: T) -> Event
289 where
290 T: AsRef<str>,
291 {
292 self.field("", comment.as_ref());
293 self
294 }
295
296 pub fn event<T>(mut self, event: T) -> Event
311 where
312 T: AsRef<str>,
313 {
314 if self.flags.contains(EventFlags::HAS_EVENT) {
315 panic!("Called `Event::event` multiple times");
316 }
317 self.flags.insert(EventFlags::HAS_EVENT);
318
319 self.field("event", event.as_ref());
320
321 self
322 }
323
324 pub fn retry(mut self, duration: Duration) -> Event {
334 if self.flags.contains(EventFlags::HAS_RETRY) {
335 panic!("Called `Event::retry` multiple times");
336 }
337 self.flags.insert(EventFlags::HAS_RETRY);
338
339 let buffer = self.buffer.as_mut();
340 buffer.extend_from_slice(b"retry: ");
341
342 let secs = duration.as_secs();
343 let millis = duration.subsec_millis();
344
345 if secs > 0 {
346 buffer.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
348
349 if millis < 10 {
351 buffer.extend_from_slice(b"00");
352 } else if millis < 100 {
353 buffer.extend_from_slice(b"0");
354 }
355 }
356
357 buffer.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
359
360 buffer.put_u8(b'\n');
361
362 self
363 }
364
365 pub fn id<T>(mut self, id: T) -> Event
378 where
379 T: AsRef<str>,
380 {
381 if self.flags.contains(EventFlags::HAS_ID) {
382 panic!("Called `Event::id` multiple times");
383 }
384 self.flags.insert(EventFlags::HAS_ID);
385
386 let id = id.as_ref().as_bytes();
387 assert_eq!(
388 memchr::memchr(b'\0', id),
389 None,
390 "Event ID cannot contain null characters",
391 );
392
393 self.field("id", id);
394 self
395 }
396
397 fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
398 let value = value.as_ref();
399 assert_eq!(
400 memchr::memchr2(b'\r', b'\n', value),
401 None,
402 "SSE field value cannot contain newlines or carriage returns",
403 );
404
405 let buffer = self.buffer.as_mut();
406 buffer.extend_from_slice(name.as_bytes());
407 buffer.put_u8(b':');
408 buffer.put_u8(b' ');
409 buffer.extend_from_slice(value);
410 buffer.put_u8(b'\n');
411 }
412
413 fn finalize(self) -> Bytes {
414 match self.buffer {
415 Buffer::Finalized(bytes) => bytes,
416 Buffer::Active(mut bytes_mut) => {
417 bytes_mut.put_u8(b'\n');
418 bytes_mut.freeze()
419 }
420 }
421 }
422}
423
424impl EventDataWriter {
425 pub fn into_event(self) -> Event {
430 let mut event = self.event;
431 if self.data_written {
432 let _ = event.buffer.as_mut().write_char('\n');
433 }
434 event
435 }
436}
437
438impl EventDataWriter {
439 fn write_buf(&mut self, buf: &[u8]) -> usize {
442 if buf.is_empty() {
443 return 0;
444 }
445
446 let buffer = self.event.buffer.as_mut();
447
448 if !std::mem::replace(&mut self.data_written, true) {
449 if self.event.flags.contains(EventFlags::HAS_DATA) {
450 panic!("Called `Event::data*` multiple times");
451 }
452
453 let _ = buffer.write_str("data: ");
454 self.event.flags.insert(EventFlags::HAS_DATA);
455 }
456
457 let mut writer = buffer.writer();
458
459 let mut last_split = 0;
460 for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
461 let _ = writer.write_all(&buf[last_split..=delimiter]);
462 let _ = writer.write_all(b"data: ");
463 last_split = delimiter + 1;
464 }
465 let _ = writer.write_all(&buf[last_split..]);
466
467 buf.len()
468 }
469}
470
471impl fmt::Write for EventDataWriter {
472 fn write_str(&mut self, s: &str) -> fmt::Result {
473 let _ = self.write_buf(s.as_bytes());
474 Ok(())
475 }
476}
477
478impl Default for Event {
479 fn default() -> Self {
480 Self {
481 buffer: Buffer::Active(BytesMut::new()),
482 flags: EventFlags::from_bits(0),
483 }
484 }
485}
486
487#[derive(Debug, Copy, Clone, PartialEq)]
488struct EventFlags(u8);
489
490impl EventFlags {
491 const HAS_DATA: Self = Self::from_bits(0b0001);
492 const HAS_EVENT: Self = Self::from_bits(0b0010);
493 const HAS_RETRY: Self = Self::from_bits(0b0100);
494 const HAS_ID: Self = Self::from_bits(0b1000);
495
496 const fn bits(&self) -> u8 {
497 self.0
498 }
499
500 const fn from_bits(bits: u8) -> Self {
501 Self(bits)
502 }
503
504 const fn contains(&self, other: Self) -> bool {
505 self.bits() & other.bits() == other.bits()
506 }
507
508 fn insert(&mut self, other: Self) {
509 *self = Self::from_bits(self.bits() | other.bits());
510 }
511}
512
513#[derive(Debug, Clone)]
516#[must_use]
517pub struct KeepAlive {
518 event: Event,
519 max_interval: Duration,
520}
521
522impl KeepAlive {
523 pub fn new() -> Self {
525 Self {
526 event: Event::DEFAULT_KEEP_ALIVE,
527 max_interval: Duration::from_secs(15),
528 }
529 }
530
531 pub fn interval(mut self, time: Duration) -> Self {
535 self.max_interval = time;
536 self
537 }
538
539 pub fn text<I>(self, text: I) -> Self
548 where
549 I: AsRef<str>,
550 {
551 self.event(Event::default().comment(text))
552 }
553
554 pub fn event(mut self, event: Event) -> Self {
563 self.event = Event::finalized(event.finalize());
564 self
565 }
566}
567
568impl Default for KeepAlive {
569 fn default() -> Self {
570 Self::new()
571 }
572}
573
574#[cfg(feature = "tokio")]
575pin_project! {
576 #[derive(Debug)]
578 pub struct KeepAliveStream<S> {
579 #[pin]
580 alive_timer: tokio::time::Sleep,
581 #[pin]
582 inner: S,
583 keep_alive: KeepAlive,
584 }
585}
586
587#[cfg(feature = "tokio")]
588impl<S> KeepAliveStream<S> {
589 fn new(keep_alive: KeepAlive, inner: S) -> Self {
590 Self {
591 alive_timer: tokio::time::sleep(keep_alive.max_interval),
592 inner,
593 keep_alive,
594 }
595 }
596
597 fn reset(self: Pin<&mut Self>) {
598 let this = self.project();
599 this.alive_timer
600 .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
601 }
602}
603
604#[cfg(feature = "tokio")]
605impl<S, E> Stream for KeepAliveStream<S>
606where
607 S: Stream<Item = Result<Event, E>>,
608{
609 type Item = Result<Event, E>;
610
611 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
612 use std::future::Future;
613
614 let mut this = self.as_mut().project();
615
616 match this.inner.as_mut().poll_next(cx) {
617 Poll::Ready(Some(Ok(event))) => {
618 self.reset();
619
620 Poll::Ready(Some(Ok(event)))
621 }
622 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
623 Poll::Ready(None) => Poll::Ready(None),
624 Poll::Pending => {
625 ready!(this.alive_timer.poll(cx));
626
627 let event = this.keep_alive.event.clone();
628
629 self.reset();
630
631 Poll::Ready(Some(Ok(event)))
632 }
633 }
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640 use crate::{routing::get, test_helpers::*, Router};
641 use futures_util::stream;
642 use serde_json::value::RawValue;
643 use std::{collections::HashMap, convert::Infallible};
644 use tokio_stream::StreamExt as _;
645
646 #[test]
647 fn leading_space_is_not_stripped() {
648 let no_leading_space = Event::default().data("\tfoobar");
649 assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
650
651 let leading_space = Event::default().data(" foobar");
652 assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
653 }
654
655 #[test]
656 fn write_data_writer_str() {
657 let mut writer = Event::default()
659 .into_data_writer()
660 .into_event()
661 .into_data_writer();
662 writer.write_str("").unwrap();
663 let mut writer = writer.into_event().into_data_writer();
664
665 writer.write_str("").unwrap();
666 writer.write_str("moon ").unwrap();
667 writer.write_str("star\nsun").unwrap();
668 writer.write_str("").unwrap();
669 writer.write_str("set").unwrap();
670 writer.write_str("").unwrap();
671 writer.write_str(" bye\r").unwrap();
672
673 let event = writer.into_event();
674
675 assert_eq!(
676 &*event.finalize(),
677 b"data: moon star\ndata: sunset bye\rdata: \n\n"
678 );
679 }
680
681 #[test]
682 fn valid_json_raw_value_chars_handled() {
683 let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}";
684 let json_raw_value_event = Event::default()
685 .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
686 .unwrap();
687 assert_eq!(
688 &*json_raw_value_event.finalize(),
689 b"data: {\rdata: \"foo\": \ndata: \rdata: \rdata: \"bar\\n\"\ndata: }\n\n"
690 );
691 }
692
693 #[crate::test]
694 async fn basic() {
695 let app = Router::new().route(
696 "/",
697 get(|| async {
698 let stream = stream::iter(vec![
699 Event::default().data("one").comment("this is a comment"),
700 Event::default()
701 .json_data(serde_json::json!({ "foo": "bar" }))
702 .unwrap(),
703 Event::default()
704 .event("three")
705 .retry(Duration::from_secs(30))
706 .id("unique-id"),
707 ])
708 .map(Ok::<_, Infallible>);
709 Sse::new(stream)
710 }),
711 );
712
713 let client = TestClient::new(app);
714 let mut stream = client.get("/").await;
715
716 assert_eq!(stream.headers()["content-type"], "text/event-stream");
717 assert_eq!(stream.headers()["cache-control"], "no-cache");
718
719 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
720 assert_eq!(event_fields.get("data").unwrap(), "one");
721 assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
722
723 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
724 assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
725 assert!(!event_fields.contains_key("comment"));
726
727 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
728 assert_eq!(event_fields.get("event").unwrap(), "three");
729 assert_eq!(event_fields.get("retry").unwrap(), "30000");
730 assert_eq!(event_fields.get("id").unwrap(), "unique-id");
731 assert!(!event_fields.contains_key("comment"));
732
733 assert!(stream.chunk_text().await.is_none());
734 }
735
736 #[tokio::test(start_paused = true)]
737 async fn keep_alive() {
738 const DELAY: Duration = Duration::from_secs(5);
739
740 let app = Router::new().route(
741 "/",
742 get(|| async {
743 let stream = stream::repeat_with(|| Event::default().data("msg"))
744 .map(Ok::<_, Infallible>)
745 .throttle(DELAY);
746
747 Sse::new(stream).keep_alive(
748 KeepAlive::new()
749 .interval(Duration::from_secs(1))
750 .text("keep-alive-text"),
751 )
752 }),
753 );
754
755 let client = TestClient::new(app);
756 let mut stream = client.get("/").await;
757
758 for _ in 0..5 {
759 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
761 assert_eq!(event_fields.get("data").unwrap(), "msg");
762
763 for _ in 0..4 {
765 tokio::time::sleep(Duration::from_secs(1)).await;
766 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
767 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
768 }
769 }
770 }
771
772 #[tokio::test(start_paused = true)]
773 async fn keep_alive_ends_when_the_stream_ends() {
774 const DELAY: Duration = Duration::from_secs(5);
775
776 let app = Router::new().route(
777 "/",
778 get(|| async {
779 let stream = stream::repeat_with(|| Event::default().data("msg"))
780 .map(Ok::<_, Infallible>)
781 .throttle(DELAY)
782 .take(2);
783
784 Sse::new(stream).keep_alive(
785 KeepAlive::new()
786 .interval(Duration::from_secs(1))
787 .text("keep-alive-text"),
788 )
789 }),
790 );
791
792 let client = TestClient::new(app);
793 let mut stream = client.get("/").await;
794
795 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
797 assert_eq!(event_fields.get("data").unwrap(), "msg");
798
799 for _ in 0..4 {
801 tokio::time::sleep(Duration::from_secs(1)).await;
802 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
803 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
804 }
805
806 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
808 assert_eq!(event_fields.get("data").unwrap(), "msg");
809
810 assert!(stream.chunk_text().await.is_none());
812 }
813
814 fn parse_event(payload: &str) -> HashMap<String, String> {
815 let mut fields = HashMap::new();
816
817 let mut lines = payload.lines().peekable();
818 while let Some(line) = lines.next() {
819 if line.is_empty() {
820 assert!(lines.next().is_none());
821 break;
822 }
823
824 let (mut key, value) = line.split_once(':').unwrap();
825 let value = value.trim();
826 if key.is_empty() {
827 key = "comment";
828 }
829 fields.insert(key.to_owned(), value.to_owned());
830 }
831
832 fields
833 }
834}