1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use bytes::Bytes;
use futures_util::{stream::BoxStream, Stream, StreamExt};
use tokio::time::Duration;

use super::Event;
use crate::{Body, IntoResponse, Response};

/// An SSE response.
///
/// # Example
///
/// ```
/// use futures_util::stream;
/// use poem::{
///     handler,
///     http::StatusCode,
///     web::sse::{Event, SSE},
///     Endpoint, Request,
/// };
///
/// #[handler]
/// fn index() -> SSE {
///     SSE::new(stream::iter(vec![
///         Event::message("a"),
///         Event::message("b"),
///         Event::message("c"),
///     ]))
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let mut resp = index.call(Request::default()).await;
/// assert_eq!(resp.status(), StatusCode::OK);
/// assert_eq!(
///     resp.take_body().into_string().await.unwrap(),
///     "data: a\n\ndata: b\n\ndata: c\n\n"
/// );
/// # });
/// ```
pub struct SSE {
    stream: BoxStream<'static, Event>,
    keep_alive: Option<Duration>,
}

impl SSE {
    /// Create an SSE response using an event stream.
    pub fn new(stream: impl Stream<Item = Event> + Send + 'static) -> Self {
        Self {
            stream: stream.boxed(),
            keep_alive: None,
        }
    }

    /// Set the keep alive interval.
    #[must_use]
    pub fn keep_alive(self, duration: Duration) -> Self {
        Self {
            keep_alive: Some(duration),
            ..self
        }
    }
}

impl IntoResponse for SSE {
    fn into_response(self) -> Response {
        let mut stream = self
            .stream
            .map(|event| Ok::<_, std::io::Error>(Bytes::from(event.to_string())))
            .boxed();
        if let Some(duration) = self.keep_alive {
            let comment = Bytes::from_static(b":\n\n");
            stream = futures_util::stream::select(
                stream,
                tokio_stream::wrappers::IntervalStream::new(tokio::time::interval_at(
                    tokio::time::Instant::now() + duration,
                    duration,
                ))
                .map(move |_| Ok(comment.clone())),
            )
            .boxed();
        }

        Response::builder()
            .content_type("text/event-stream")
            .body(Body::from_async_read(tokio_util::io::StreamReader::new(
                stream,
            )))
    }
}