dioxus_fullstack/payloads/
sse.rs

1use crate::{ClientResponse, FromResponse, RequestError, ServerFnError};
2#[cfg(feature = "server")]
3use axum::{
4    response::sse::{Event, KeepAlive},
5    BoxError,
6};
7use futures::io::AsyncBufReadExt;
8use futures::Stream;
9use futures::{StreamExt, TryStreamExt};
10use http::{header::CONTENT_TYPE, HeaderValue, StatusCode};
11use serde::de::DeserializeOwned;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16/// A stream of Server-Sent Events (SSE) that can be used to receive events from the server.
17///
18/// This type implements `Stream` for asynchronous iteration over events.
19/// Events are automatically deserialized from JSON to the specified type `T`.
20#[allow(clippy::type_complexity)]
21pub struct ServerEvents<T> {
22    _marker: std::marker::PhantomData<fn() -> T>,
23
24    // The receiving end from the server
25    client: Option<Pin<Box<dyn Stream<Item = Result<ServerSentEvent, ServerFnError>>>>>,
26
27    #[cfg(feature = "server")]
28    keep_alive: Option<KeepAlive>,
29
30    // The actual SSE response to send to the client
31    #[cfg(feature = "server")]
32    sse: Option<axum::response::Sse<Pin<Box<dyn Stream<Item = Result<Event, BoxError>> + Send>>>>,
33}
34
35impl<T: DeserializeOwned> ServerEvents<T> {
36    /// Receives the next event from the stream, deserializing it to `T`.
37    ///
38    /// Returns `None` if the stream has ended.
39    pub async fn recv(&mut self) -> Option<Result<T, ServerFnError>> {
40        let event = self.next_event().await?;
41        match event {
42            Ok(event) => {
43                let data: Result<T, ServerFnError> =
44                    serde_json::from_str(&event.data).map_err(|err| {
45                        ServerFnError::Serialization(format!(
46                            "failed to deserialize event data: {}",
47                            err
48                        ))
49                    });
50                Some(data)
51            }
52            Err(err) => Some(Err(err)),
53        }
54    }
55}
56
57impl<T> ServerEvents<T> {
58    /// Receives the next raw event from the stream.
59    ///
60    /// Returns `None` if the stream has ended.
61    pub async fn next_event(&mut self) -> Option<Result<ServerSentEvent, ServerFnError>> {
62        self.client.as_mut()?.next().await
63    }
64}
65
66impl<T: DeserializeOwned> Stream for ServerEvents<T> {
67    type Item = Result<T, ServerFnError>;
68
69    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
70        let Some(client) = self.client.as_mut() else {
71            return Poll::Ready(None);
72        };
73
74        match client.as_mut().poll_next(cx) {
75            Poll::Ready(Some(Ok(event))) => {
76                let data = serde_json::from_str(&event.data).map_err(|err| {
77                    ServerFnError::Serialization(format!(
78                        "failed to deserialize event data: {}",
79                        err
80                    ))
81                });
82                Poll::Ready(Some(data))
83            }
84            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
85            Poll::Ready(None) => Poll::Ready(None),
86            Poll::Pending => Poll::Pending,
87        }
88    }
89}
90
91impl<T> FromResponse for ServerEvents<T> {
92    async fn from_response(res: ClientResponse) -> Result<Self, ServerFnError> {
93        let status = res.status();
94        if status != StatusCode::OK {
95            return Err(ServerFnError::Request(RequestError::Status(
96                format!("Expected status 200 OK, got {}", status),
97                status.as_u16(),
98            )));
99        }
100
101        let content_type = res.headers().get(CONTENT_TYPE);
102        if content_type != Some(&HeaderValue::from_static(mime::TEXT_EVENT_STREAM.as_ref())) {
103            return Err(ServerFnError::Request(RequestError::Request(format!(
104                "Expected content type 'text/event-stream', got {:?}",
105                content_type
106            ))));
107        }
108
109        let mut stream = res
110            .bytes_stream()
111            .map(|result| result.map_err(std::io::Error::other))
112            .into_async_read();
113
114        let mut line_buffer = String::new();
115        let mut event_buffer = EventBuffer::new();
116
117        let stream: Pin<Box<dyn Stream<Item = Result<ServerSentEvent, ServerFnError>>>> = Box::pin(
118            async_stream::try_stream! {
119                loop {
120                    line_buffer.clear();
121                    if stream.read_line(&mut line_buffer).await.map_err(|err| ServerFnError::StreamError(err.to_string()))? == 0 {
122                        break;
123                    }
124
125                    let line = if let Some(line) = line_buffer.strip_suffix('\n') {
126                        line
127                    } else {
128                        &line_buffer
129                    };
130
131                    // dispatch
132                    if line.is_empty() {
133                        if let Some(event) = event_buffer.produce_event() {
134                            yield event;
135                        }
136                        continue;
137                    }
138
139                    // Parse line to split field name and value, applying proper trimming.
140                    let (field, value) = line.split_once(':').unwrap_or((line, ""));
141                    let value = value.strip_prefix(' ').unwrap_or(value);
142
143                    // Handle fields - these are the in SSE speci.
144                    match field {
145                        "event" => event_buffer.set_event_type(value),
146                        "data" => event_buffer.push_data(value),
147                        "id" => event_buffer.set_id(value),
148                        "retry" => {
149                            if let Ok(millis) = value.parse() {
150                                event_buffer.set_retry(Duration::from_millis(millis));
151                            }
152                        }
153                        _ => {}
154                    }
155                }
156            },
157        );
158
159        Ok(Self {
160            _marker: std::marker::PhantomData,
161            client: Some(stream),
162
163            #[cfg(feature = "server")]
164            keep_alive: None,
165
166            #[cfg(feature = "server")]
167            sse: None,
168        })
169    }
170}
171
172/// Server-Sent Event representation.
173#[derive(Debug, Clone, Eq, PartialEq)]
174pub struct ServerSentEvent {
175    /// A string identifying the type of event described.
176    pub event_type: String,
177
178    /// The data field for the message.
179    pub data: String,
180
181    /// Last event ID value.
182    pub last_event_id: Option<String>,
183
184    /// Reconnection time.
185    pub retry: Option<Duration>,
186}
187
188/// Internal buffer used to accumulate lines of an SSE (Server-Sent Events) stream.
189struct EventBuffer {
190    event_type: String,
191    data: String,
192    last_event_id: Option<String>,
193    retry: Option<Duration>,
194}
195
196impl EventBuffer {
197    /// Creates fresh new [`EventBuffer`].
198    #[allow(clippy::new_without_default)]
199    fn new() -> Self {
200        Self {
201            event_type: String::new(),
202            data: String::new(),
203            last_event_id: None,
204            retry: None,
205        }
206    }
207
208    /// Produces a [`Event`], if current state allow it.
209    ///
210    /// Reset the internal state to process further data.
211    fn produce_event(&mut self) -> Option<ServerSentEvent> {
212        let event = if self.data.is_empty() {
213            None
214        } else {
215            Some(ServerSentEvent {
216                event_type: if self.event_type.is_empty() {
217                    "message".to_string()
218                } else {
219                    self.event_type.clone()
220                },
221                data: self.data.to_string(),
222                last_event_id: self.last_event_id.clone(),
223                retry: self.retry,
224            })
225        };
226
227        self.event_type.clear();
228        self.data.clear();
229
230        event
231    }
232
233    /// Set the [`Event`]'s type. Override previous value.
234    fn set_event_type(&mut self, event_type: &str) {
235        self.event_type.clear();
236        self.event_type.push_str(event_type);
237    }
238
239    /// Extends internal data with given data.
240    fn push_data(&mut self, data: &str) {
241        if !self.data.is_empty() {
242            self.data.push('\n');
243        }
244        self.data.push_str(data);
245    }
246
247    fn set_id(&mut self, id: &str) {
248        self.last_event_id = Some(id.to_string());
249    }
250
251    fn set_retry(&mut self, retry: Duration) {
252        self.retry = Some(retry);
253    }
254}
255
256#[cfg(feature = "server")]
257pub use server_impl::*;
258
259#[cfg(feature = "server")]
260mod server_impl {
261    use super::*;
262    use crate::spawn_platform;
263    use axum::response::sse::Sse;
264    use axum_core::response::IntoResponse;
265    use futures::Future;
266    use futures::SinkExt;
267    use futures::{Sink, TryStream};
268    use serde::Serialize;
269
270    impl<T: 'static> ServerEvents<T> {
271        /// Create a `ServerEvents` from a function that is given a sender to send events to the client.
272        ///
273        /// By default, we send a comment every 15 seconds to keep the connection alive.
274        pub fn new<F, R>(f: impl FnOnce(SseTx<T>) -> F + Send + 'static) -> Self
275        where
276            F: Future<Output = R> + 'static,
277            R: 'static + Send,
278        {
279            let (tx, mut rx) = futures_channel::mpsc::unbounded();
280
281            let tx = SseTx {
282                sender: tx,
283                _marker: std::marker::PhantomData,
284            };
285
286            // Spawn the user function in the background
287            spawn_platform(move || f(tx));
288
289            // Create the stream of events, mapping the incoming events to `Ok`
290            // If the user function ends, the stream will end and the connection will be closed
291            let stream = futures::stream::poll_fn(move |cx| match rx.poll_next_unpin(cx) {
292                std::task::Poll::Ready(Some(event)) => std::task::Poll::Ready(Some(
293                    Ok(event) as Result<axum::response::sse::Event, BoxError>
294                )),
295                std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
296                std::task::Poll::Pending => std::task::Poll::Pending,
297            });
298
299            let sse = Sse::new(stream.boxed());
300
301            Self {
302                _marker: std::marker::PhantomData,
303                client: None,
304                keep_alive: Some(KeepAlive::new().interval(Duration::from_secs(15))),
305                sse: Some(sse),
306            }
307        }
308
309        /// Create a `ServerEvents` from a `TryStream` of events.
310        pub fn from_stream<S>(stream: S) -> Self
311        where
312            S: TryStream<Ok = T, Error = BoxError> + Send + 'static,
313            T: Serialize,
314        {
315            let stream = stream.map_ok(|event| {
316                axum::response::sse::Event::default()
317                    .json_data(event)
318                    .expect("Failed to serialize SSE event")
319            });
320            let sse = axum::response::Sse::new(stream.boxed());
321            Self {
322                _marker: std::marker::PhantomData,
323                client: None,
324                keep_alive: Some(KeepAlive::new().interval(Duration::from_secs(15))),
325                sse: Some(sse),
326            }
327        }
328
329        /// Set the keep-alive configuration for the SSE connection.
330        ///
331        /// A `None` value will disable the default `KeepAlive` of 15 seconds.
332        pub fn with_keep_alive(mut self, keep_alive: Option<KeepAlive>) -> Self {
333            self.keep_alive = keep_alive;
334            self
335        }
336
337        /// Create a `ServerEvents` from an existing Axum `Sse` response.
338        #[allow(clippy::type_complexity)]
339        pub fn from_sse(
340            sse: Sse<Pin<Box<dyn Stream<Item = Result<Event, BoxError>> + Send>>>,
341        ) -> Self {
342            Self {
343                _marker: std::marker::PhantomData,
344                client: None,
345                keep_alive: None,
346                sse: Some(sse),
347            }
348        }
349    }
350
351    impl<T> IntoResponse for ServerEvents<T> {
352        fn into_response(self) -> axum_core::response::Response {
353            let sse = self
354                .sse
355                .expect("SSE should be initialized before using it as a response");
356
357            if let Some(keep_alive) = self.keep_alive {
358                sse.keep_alive(keep_alive).into_response()
359            } else {
360                sse.into_response()
361            }
362        }
363    }
364
365    /// A transmitter for sending events to the SSE stream.
366    pub struct SseTx<T> {
367        sender: futures_channel::mpsc::UnboundedSender<axum::response::sse::Event>,
368        _marker: std::marker::PhantomData<fn() -> T>,
369    }
370
371    impl<T: Serialize> SseTx<T> {
372        /// Sends an event to the SSE stream.
373        pub async fn send(&mut self, event: T) -> anyhow::Result<()> {
374            let event = axum::response::sse::Event::default().json_data(event)?;
375            self.sender.unbounded_send(event)?;
376            Ok(())
377        }
378    }
379
380    impl<T> std::ops::Deref for SseTx<T> {
381        type Target = futures_channel::mpsc::UnboundedSender<axum::response::sse::Event>;
382        fn deref(&self) -> &Self::Target {
383            &self.sender
384        }
385    }
386
387    impl<T> std::ops::DerefMut for SseTx<T> {
388        fn deref_mut(&mut self) -> &mut Self::Target {
389            &mut self.sender
390        }
391    }
392
393    impl<T: Serialize> Sink<T> for SseTx<T> {
394        type Error = anyhow::Error;
395
396        fn poll_ready(
397            mut self: Pin<&mut Self>,
398            _cx: &mut Context<'_>,
399        ) -> Poll<Result<(), Self::Error>> {
400            self.sender.poll_ready_unpin(_cx).map_err(|e| e.into())
401        }
402
403        fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
404            let event = axum::response::sse::Event::default().json_data(item)?;
405            self.sender.start_send(event).map_err(|e| e.into())
406        }
407
408        fn poll_flush(
409            mut self: Pin<&mut Self>,
410            _cx: &mut Context<'_>,
411        ) -> Poll<Result<(), Self::Error>> {
412            self.sender.poll_flush_unpin(_cx).map_err(|e| e.into())
413        }
414
415        fn poll_close(
416            mut self: Pin<&mut Self>,
417            _cx: &mut Context<'_>,
418        ) -> Poll<Result<(), Self::Error>> {
419            self.sender.poll_close_unpin(_cx).map_err(|e| e.into())
420        }
421    }
422}