puzz_sse/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::borrow::Cow;
4use std::fmt::{self, Write};
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use futures_core::{ready, Stream};
11use pin_project_lite::pin_project;
12use puzz_core::body::{Body, BodyExt, Bytes};
13use puzz_core::http::header;
14use puzz_core::response::IntoResponse;
15use puzz_core::{BoxError, Response};
16
17pub struct Sse<S> {
18    stream: S,
19    keep_alive: Option<KeepAlive>,
20}
21
22impl<S> Sse<S> {
23    pub fn new(stream: S) -> Self {
24        Self {
25            stream,
26            keep_alive: None,
27        }
28    }
29
30    pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
31        self.keep_alive = Some(keep_alive);
32        self
33    }
34}
35
36impl<S> fmt::Debug for Sse<S> {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        f.debug_struct("Sse")
39            .field("stream", &format_args!("{}", std::any::type_name::<S>()))
40            .field("keep_alive", &self.keep_alive)
41            .finish()
42    }
43}
44
45impl<S, E> IntoResponse for Sse<S>
46where
47    S: Stream<Item = Result<Event, E>> + Send + 'static,
48    E: Into<BoxError>,
49{
50    fn into_response(self) -> Response {
51        let body = SseBody {
52            event_stream: self.stream,
53            keep_alive: self.keep_alive.map(KeepAliveStream::new),
54        };
55
56        Response::builder()
57            .header(header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref())
58            .header(header::CACHE_CONTROL, "no-cache")
59            .body(body.boxed())
60            .unwrap()
61    }
62}
63
64pin_project! {
65    struct SseBody<S> {
66        #[pin]
67        event_stream: S,
68        #[pin]
69        keep_alive: Option<KeepAliveStream>,
70    }
71}
72
73impl<S, E> Body for SseBody<S>
74where
75    S: Stream<Item = Result<Event, E>>,
76{
77    type Error = E;
78
79    fn poll_next(
80        self: Pin<&mut Self>,
81        cx: &mut Context<'_>,
82    ) -> Poll<Option<Result<Bytes, Self::Error>>> {
83        let this = self.project();
84
85        match this.event_stream.poll_next(cx) {
86            Poll::Pending => {
87                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
88                    keep_alive
89                        .poll_event(cx)
90                        .map(|e| Some(Ok(Bytes::from(e.to_string()))))
91                } else {
92                    Poll::Pending
93                }
94            }
95            Poll::Ready(Some(Ok(event))) => {
96                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
97                    keep_alive.reset();
98                }
99                Poll::Ready(Some(Ok(Bytes::from(event.to_string()))))
100            }
101            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
102            Poll::Ready(None) => Poll::Ready(None),
103        }
104    }
105}
106
107#[derive(Debug, Default)]
108pub struct Event {
109    id: Option<String>,
110    data: Option<DataType>,
111    event: Option<String>,
112    comment: Option<String>,
113    retry: Option<Duration>,
114}
115
116#[derive(Debug)]
117enum DataType {
118    Text(String),
119
120    Json(String),
121}
122
123impl Event {
124    pub fn data<T>(mut self, data: T) -> Event
125    where
126        T: Into<String>,
127    {
128        let data = data.into();
129        assert_eq!(
130            memchr::memchr(b'\r', data.as_bytes()),
131            None,
132            "SSE data cannot contain carriage returns",
133        );
134        self.data = Some(DataType::Text(data));
135        self
136    }
137
138    pub fn json_data<T>(mut self, data: T) -> serde_json::Result<Event>
139    where
140        T: serde::Serialize,
141    {
142        self.data = Some(DataType::Json(serde_json::to_string(&data)?));
143        Ok(self)
144    }
145
146    pub fn comment<T>(mut self, comment: T) -> Event
147    where
148        T: Into<String>,
149    {
150        let comment = comment.into();
151        assert_eq!(
152            memchr::memchr2(b'\r', b'\n', comment.as_bytes()),
153            None,
154            "SSE comment cannot contain newlines or carriage returns"
155        );
156        self.comment = Some(comment);
157        self
158    }
159
160    pub fn event<T>(mut self, event: T) -> Event
161    where
162        T: Into<String>,
163    {
164        let event = event.into();
165        assert_eq!(
166            memchr::memchr2(b'\r', b'\n', event.as_bytes()),
167            None,
168            "SSE event name cannot contain newlines or carriage returns"
169        );
170        self.event = Some(event);
171        self
172    }
173
174    pub fn retry(mut self, duration: Duration) -> Event {
175        self.retry = Some(duration);
176        self
177    }
178
179    pub fn id<T>(mut self, id: T) -> Event
180    where
181        T: Into<String>,
182    {
183        let id = id.into();
184        assert_eq!(
185            memchr::memchr3(b'\r', b'\n', b'\0', id.as_bytes()),
186            None,
187            "Event ID cannot contain newlines, carriage returns or null characters",
188        );
189        self.id = Some(id);
190        self
191    }
192}
193
194impl fmt::Display for Event {
195    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
196        if let Some(comment) = &self.comment {
197            ":".fmt(f)?;
198            comment.fmt(f)?;
199            f.write_char('\n')?;
200        }
201
202        if let Some(event) = &self.event {
203            "event: ".fmt(f)?;
204            event.fmt(f)?;
205            f.write_char('\n')?;
206        }
207
208        match &self.data {
209            Some(DataType::Text(data)) => {
210                for line in data.split('\n') {
211                    "data: ".fmt(f)?;
212                    line.fmt(f)?;
213                    f.write_char('\n')?;
214                }
215            }
216
217            Some(DataType::Json(data)) => {
218                "data:".fmt(f)?;
219                data.fmt(f)?;
220                f.write_char('\n')?;
221            }
222            None => {}
223        }
224
225        if let Some(id) = &self.id {
226            "id: ".fmt(f)?;
227            id.fmt(f)?;
228            f.write_char('\n')?;
229        }
230
231        if let Some(duration) = &self.retry {
232            "retry:".fmt(f)?;
233
234            let secs = duration.as_secs();
235            let millis = duration.subsec_millis();
236
237            if secs > 0 {
238                // format seconds
239                secs.fmt(f)?;
240
241                // pad milliseconds
242                if millis < 10 {
243                    f.write_str("00")?;
244                } else if millis < 100 {
245                    f.write_char('0')?;
246                }
247            }
248
249            // format milliseconds
250            millis.fmt(f)?;
251
252            f.write_char('\n')?;
253        }
254
255        f.write_char('\n')?;
256
257        Ok(())
258    }
259}
260
261#[derive(Debug, Clone)]
262pub struct KeepAlive {
263    comment_text: Cow<'static, str>,
264    max_interval: Duration,
265}
266
267impl KeepAlive {
268    pub fn new() -> Self {
269        Self {
270            comment_text: Cow::Borrowed(""),
271            max_interval: Duration::from_secs(15),
272        }
273    }
274
275    pub fn interval(mut self, time: Duration) -> Self {
276        self.max_interval = time;
277        self
278    }
279
280    pub fn text<I>(mut self, text: I) -> Self
281    where
282        I: Into<Cow<'static, str>>,
283    {
284        self.comment_text = text.into();
285        self
286    }
287}
288
289impl Default for KeepAlive {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295pin_project! {
296    #[derive(Debug)]
297    pub(crate) struct KeepAliveStream {
298        keep_alive: KeepAlive,
299        #[pin]
300        alive_timer: tokio::time::Sleep,
301    }
302}
303
304impl KeepAliveStream {
305    pub(crate) fn new(keep_alive: KeepAlive) -> Self {
306        Self {
307            alive_timer: tokio::time::sleep(keep_alive.max_interval),
308            keep_alive,
309        }
310    }
311
312    pub(crate) fn reset(self: Pin<&mut Self>) {
313        let this = self.project();
314        this.alive_timer
315            .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
316    }
317
318    pub(crate) fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Event> {
319        let this = self.as_mut().project();
320
321        ready!(this.alive_timer.poll(cx));
322
323        let comment_str = this.keep_alive.comment_text.clone();
324        let event = Event::default().comment(comment_str);
325
326        self.reset();
327
328        Poll::Ready(event)
329    }
330}