Skip to main content

picoserve/response/
sse.rs

1//! Server-Sent Events. See [server_sent_events](https://github.com/sammhicks/picoserve/blob/main/examples/server_sent_events/src/main.rs) for usage example.
2
3use crate::io::{Read, Write, WriteExt};
4
5use super::StatusCode;
6
7/// Types which can be used as the data of an event.
8pub trait EventData {
9    /// Write event data to the socket.
10    async fn write_to<W: Write>(self, writer: &mut W) -> Result<(), W::Error>;
11}
12
13impl EventData for core::fmt::Arguments<'_> {
14    async fn write_to<W: Write>(self, writer: &mut W) -> Result<(), W::Error> {
15        writer.write_fmt(self).await
16    }
17}
18
19impl EventData for &str {
20    async fn write_to<W: Write>(self, writer: &mut W) -> Result<(), W::Error> {
21        writer.write_all(self.as_bytes()).await
22    }
23}
24
25#[cfg(feature = "json")]
26impl<T: serde::Serialize> EventData for super::json::Json<T> {
27    async fn write_to<W: Write>(self, writer: &mut W) -> Result<(), W::Error> {
28        self.do_write_to(writer).await
29    }
30}
31
32struct EventWriterState {
33    is_currently_writing_event: core::cell::Cell<bool>,
34    is_running: core::cell::Cell<bool>,
35}
36
37impl EventWriterState {
38    fn new() -> Self {
39        Self {
40            is_currently_writing_event: false.into(),
41            is_running: true.into(),
42        }
43    }
44}
45
46/// Writing events to an [`EventWriter`] will send the events to the client.
47pub struct EventWriter<'a, W: Write> {
48    writer: W,
49    event_writer_state: &'a EventWriterState,
50}
51
52impl<W: Write> EventWriter<'_, W> {
53    async fn do_write<F: core::future::Future>(
54        event_writer_state: &EventWriterState,
55        write_task: F,
56    ) -> F::Output {
57        event_writer_state.is_currently_writing_event.set(true);
58
59        let result = write_task.await;
60
61        event_writer_state.is_currently_writing_event.set(false);
62
63        // If the connection was shutting down, block writing suspend the task to allow `write_events_until_shutdown` to terminate.
64        if !event_writer_state.is_running.get() {
65            return core::future::pending().await;
66        };
67
68        result
69    }
70
71    /// Send an event with an empty name, keeping the connection alive.
72    pub async fn write_keepalive(&mut self) -> Result<(), W::Error> {
73        Self::do_write(self.event_writer_state, async {
74            self.writer.write_all(b":\n\n").await?;
75
76            self.writer.flush().await
77        })
78        .await
79    }
80
81    /// Send an event with a given name and data.
82    pub async fn write_event<T: EventData>(
83        &mut self,
84        event: &str,
85        data: T,
86    ) -> Result<(), W::Error> {
87        pub struct DataWriter<W: Write> {
88            writer: W,
89        }
90
91        impl<W: Write> crate::io::ErrorType for DataWriter<W> {
92            type Error = W::Error;
93        }
94
95        impl<W: Write> Write for DataWriter<W> {
96            async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
97                for line in buf.split_inclusive(|&b| b == b'\n') {
98                    self.writer.write_all(b"data:").await?;
99                    self.writer.write_all(line).await?;
100                }
101
102                self.writer.write_all(b"\n").await?;
103
104                Ok(buf.len())
105            }
106
107            async fn flush(&mut self) -> Result<(), Self::Error> {
108                self.writer.flush().await
109            }
110        }
111
112        Self::do_write(self.event_writer_state, async {
113            self.writer.write_all(b"event:").await?;
114            self.writer.write_all(event.as_bytes()).await?;
115            self.writer.write_all(b"\n").await?;
116
117            data.write_to(&mut DataWriter {
118                writer: &mut self.writer,
119            })
120            .await?;
121
122            self.writer.write_all(b"\n").await?;
123
124            self.writer.flush().await
125        })
126        .await
127    }
128}
129
130async fn write_events_until_shutdown<E, F: core::future::Future<Output = Result<(), E>>>(
131    event_writer_state: &EventWriterState,
132    shutdown_signal: impl core::future::Future<Output = ()> + Unpin,
133    mut write_events: core::pin::Pin<&mut F>,
134) -> Result<(), E> {
135    let shutdown_task = async {
136        shutdown_signal.await;
137        event_writer_state.is_running.set(false);
138
139        core::future::pending().await
140    };
141
142    let write_events_task = core::future::poll_fn(|cx| {
143        use core::task::Poll;
144
145        if event_writer_state.is_running.get() {
146            return write_events.as_mut().poll(cx);
147        }
148
149        if !event_writer_state.is_currently_writing_event.get() {
150            return Poll::Ready(Ok(()));
151        }
152
153        if let Poll::Ready(result) = write_events.as_mut().poll(cx) {
154            return Poll::Ready(result);
155        }
156
157        if !event_writer_state.is_currently_writing_event.get() {
158            return Poll::Ready(Ok(()));
159        }
160
161        Poll::Pending
162    });
163
164    crate::futures::select(shutdown_task, write_events_task).await
165}
166
167/// Implement this trait to generate events to send to the client.
168pub trait EventSource {
169    /// Produce a stream of events and write them to `writer`
170    async fn write_events<W: Write>(self, writer: EventWriter<W>) -> Result<(), W::Error>;
171}
172
173/// A stream of Events sent by the server. Return an instance of this from the handler function.
174pub struct EventStream<S: EventSource>(pub S);
175
176impl<S: EventSource> EventStream<S> {
177    /// Convert SSE stream into a [`Response`](super::Response) with a status code of "OK"
178    pub fn into_response(self) -> super::Response<impl super::HeadersIter, impl super::Body> {
179        super::Response {
180            status_code: StatusCode::OK,
181            headers: [
182                ("Cache-Control", "no-cache"),
183                ("Content-Type", "text/event-stream"),
184            ],
185            body: self,
186        }
187    }
188}
189
190impl<S: EventSource> super::Body for EventStream<S> {
191    async fn write_response_body<R: Read, W: Write<Error = R::Error>>(
192        self,
193        connection: super::Connection<'_, R>,
194        mut writer: W,
195    ) -> Result<(), W::Error> {
196        writer.flush().await?;
197
198        let shutdown_signal = connection.shutdown_signal.clone();
199
200        let event_writer_state = &EventWriterState::new();
201
202        let write_events = core::pin::pin!(connection.run_until_disconnection(
203            (),
204            self.0.write_events(EventWriter {
205                writer,
206                event_writer_state
207            })
208        ));
209
210        write_events_until_shutdown(event_writer_state, shutdown_signal, write_events).await
211    }
212}
213
214impl<S: EventSource> super::IntoResponse for EventStream<S> {
215    async fn write_to<R: Read, W: super::ResponseWriter<Error = R::Error>>(
216        self,
217        connection: super::Connection<'_, R>,
218        response_writer: W,
219    ) -> Result<crate::ResponseSent, W::Error> {
220        response_writer
221            .write_response(connection, self.into_response())
222            .await
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[derive(Clone)]
231    struct TestEventSource {
232        event: &'static str,
233        data: &'static str,
234        write_count: usize,
235    }
236
237    impl TestEventSource {
238        fn with_write_count(mut self, write_count: usize) -> Self {
239            self.write_count = write_count;
240            self
241        }
242    }
243
244    impl EventSource for TestEventSource {
245        async fn write_events<W: Write>(
246            self,
247            mut writer: EventWriter<'_, W>,
248        ) -> Result<(), W::Error> {
249            for _ in 0..self.write_count {
250                writer.write_event(self.event, self.data).await?;
251            }
252
253            Ok(())
254        }
255    }
256
257    struct CountWriteSize(usize);
258
259    impl crate::io::ErrorType for CountWriteSize {
260        type Error = core::convert::Infallible;
261    }
262
263    impl Write for CountWriteSize {
264        async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
265            let write_size = buf.len();
266
267            self.0 += write_size;
268
269            Ok(write_size)
270        }
271
272        async fn flush(&mut self) -> Result<(), Self::Error> {
273            Ok(())
274        }
275    }
276
277    struct ThrottledWriter {
278        write_size: usize,
279    }
280
281    impl crate::io::ErrorType for ThrottledWriter {
282        type Error = core::convert::Infallible;
283    }
284
285    impl Write for ThrottledWriter {
286        async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
287            if buf.is_empty() {
288                Ok(0)
289            } else {
290                self.write_size += 1;
291
292                tokio::task::yield_now().await;
293
294                Ok(1)
295            }
296        }
297
298        async fn flush(&mut self) -> Result<(), Self::Error> {
299            Ok(())
300        }
301    }
302
303    #[tokio::test]
304    #[ntest::timeout(1000)]
305    async fn wait_event_to_finish_writing() {
306        use futures_util::FutureExt;
307
308        let (shutdown_signal_tx, shutdown_signal_rx) = tokio::sync::oneshot::channel::<()>();
309
310        let event_writer_state = &EventWriterState::new();
311
312        let source = TestEventSource {
313            event: "test",
314            data: "test",
315            write_count: 1,
316        };
317
318        let write_size = {
319            let mut count_write_size = CountWriteSize(0);
320
321            let _ = source
322                .clone()
323                .write_events(EventWriter {
324                    writer: &mut count_write_size,
325                    event_writer_state,
326                })
327                .await;
328
329            count_write_size.0
330        };
331
332        assert!(!event_writer_state.is_currently_writing_event.get());
333        assert!(event_writer_state.is_running.get());
334
335        let mut throttle_writer = ThrottledWriter { write_size: 0 };
336
337        let write_events = async {
338            source
339                .with_write_count(3)
340                .write_events(EventWriter {
341                    writer: &mut throttle_writer,
342                    event_writer_state,
343                })
344                .await
345        };
346
347        {
348            let task_shutdown_signal = core::pin::pin!(async {
349                let _ = shutdown_signal_rx.await;
350            });
351
352            let task_write_events = core::pin::pin!(write_events);
353
354            let mut task = core::pin::pin!(write_events_until_shutdown(
355                event_writer_state,
356                task_shutdown_signal,
357                task_write_events,
358            ));
359
360            for _ in 0..3 {
361                assert_eq!(task.as_mut().now_or_never(), None);
362            }
363
364            let _ = shutdown_signal_tx.send(());
365
366            let _ = task.await;
367        }
368
369        assert_eq!(throttle_writer.write_size, write_size);
370    }
371}