sse_stream/
body.rs

1use std::{
2    pin::Pin,
3    task::{ready, Context, Poll},
4    time::Duration,
5};
6
7use crate::Sse;
8use bytes::Bytes;
9use futures_util::Stream;
10use http_body::{Body, Frame};
11use std::future::Future;
12pin_project_lite::pin_project! {
13    pub struct SseBody<S, T = NeverTimer> {
14        #[pin]
15        pub event_stream: S,
16        #[pin]
17        pub keep_alive: Option<KeepAliveStream<T>>,
18    }
19}
20
21impl<S, E> SseBody<S, NeverTimer>
22where
23    S: Stream<Item = Result<Sse, E>>,
24{
25    pub fn new(stream: S) -> Self {
26        Self {
27            event_stream: stream,
28            keep_alive: None,
29        }
30    }
31}
32
33impl<S, E, T> SseBody<S, T>
34where
35    S: Stream<Item = Result<Sse, E>>,
36    T: Timer,
37{
38    pub fn new_keep_alive(stream: S, keep_alive: KeepAlive) -> Self {
39        Self {
40            event_stream: stream,
41            keep_alive: Some(KeepAliveStream::new(keep_alive)),
42        }
43    }
44
45    pub fn with_keep_alive<T2: Timer>(self, keep_alive: KeepAlive) -> SseBody<S, T2> {
46        SseBody {
47            event_stream: self.event_stream,
48            keep_alive: Some(KeepAliveStream::new(keep_alive)),
49        }
50    }
51}
52
53impl<S, E, T> Body for SseBody<S, T>
54where
55    S: Stream<Item = Result<Sse, E>>,
56    T: Timer,
57{
58    type Data = Bytes;
59    type Error = E;
60
61    fn poll_frame(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
65        let this = self.project();
66
67        match this.event_stream.poll_next(cx) {
68            Poll::Pending => {
69                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
70                    keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
71                } else {
72                    Poll::Pending
73                }
74            }
75            Poll::Ready(Some(Ok(event))) => {
76                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
77                    keep_alive.reset();
78                }
79                Poll::Ready(Some(Ok(Frame::data(event.into()))))
80            }
81            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
82            Poll::Ready(None) => Poll::Ready(None),
83        }
84    }
85}
86
87/// Configure the interval between keep-alive messages, the content
88/// of each message, and the associated stream.
89#[derive(Debug, Clone)]
90#[must_use]
91pub struct KeepAlive {
92    event: Bytes,
93    max_interval: Duration,
94}
95
96impl KeepAlive {
97    /// Create a new `KeepAlive`.
98    pub fn new() -> Self {
99        Self {
100            event: Bytes::from_static(b":\n\n"),
101            max_interval: Duration::from_secs(15),
102        }
103    }
104
105    /// Customize the interval between keep-alive messages.
106    ///
107    /// Default is 15 seconds.
108    pub fn interval(mut self, time: Duration) -> Self {
109        self.max_interval = time;
110        self
111    }
112
113    /// Customize the event of the keep-alive message.
114    ///
115    /// Default is an empty comment.
116    ///
117    /// # Panics
118    ///
119    /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
120    /// comments.
121    pub fn event(mut self, event: Sse) -> Self {
122        self.event = event.into();
123        self
124    }
125
126    /// Customize the event of the keep-alive message with a comment
127    pub fn comment(mut self, comment: &str) -> Self {
128        self.event = format!(": {}\n\n", comment).into();
129        self
130    }
131}
132
133impl Default for KeepAlive {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139pub trait Timer: Future<Output = ()> {
140    fn reset(self: Pin<&mut Self>, instant: std::time::Instant);
141    fn from_duration(duration: Duration) -> Self;
142}
143
144pub struct NeverTimer;
145
146impl Future for NeverTimer {
147    type Output = ();
148
149    fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
150        Poll::Pending
151    }
152}
153
154impl Timer for NeverTimer {
155    fn from_duration(_: Duration) -> Self {
156        Self
157    }
158
159    fn reset(self: Pin<&mut Self>, _: std::time::Instant) {
160        // No-op
161    }
162}
163
164pin_project_lite::pin_project! {
165    #[derive(Debug)]
166    struct KeepAliveStream<S> {
167        keep_alive: KeepAlive,
168        #[pin]
169        alive_timer: S,
170    }
171}
172
173impl<S> KeepAliveStream<S>
174where
175    S: Timer,
176{
177    fn new(keep_alive: KeepAlive) -> Self {
178        Self {
179            alive_timer: S::from_duration(keep_alive.max_interval),
180            keep_alive,
181        }
182    }
183
184    fn reset(self: Pin<&mut Self>) {
185        let this = self.project();
186        this.alive_timer
187            .reset(std::time::Instant::now() + this.keep_alive.max_interval);
188    }
189
190    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
191        let this = self.as_mut().project();
192
193        ready!(this.alive_timer.poll(cx));
194
195        let event = this.keep_alive.event.clone();
196
197        self.reset();
198
199        Poll::Ready(event)
200    }
201}