dioxus_fullstack/payloads/
sse.rs1use crate::{ClientResponse, FromResponse, RequestError, ServerFnError};
2#[cfg(feature = "server")]
3use axum::{
4 response::sse::{Event, KeepAlive},
5 BoxError,
6};
7use futures::io::AsyncBufReadExt;
8use futures::Stream;
9use futures::{StreamExt, TryStreamExt};
10use http::{header::CONTENT_TYPE, HeaderValue, StatusCode};
11use serde::de::DeserializeOwned;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16#[allow(clippy::type_complexity)]
21pub struct ServerEvents<T> {
22 _marker: std::marker::PhantomData<fn() -> T>,
23
24 client: Option<Pin<Box<dyn Stream<Item = Result<ServerSentEvent, ServerFnError>>>>>,
26
27 #[cfg(feature = "server")]
28 keep_alive: Option<KeepAlive>,
29
30 #[cfg(feature = "server")]
32 sse: Option<axum::response::Sse<Pin<Box<dyn Stream<Item = Result<Event, BoxError>> + Send>>>>,
33}
34
35impl<T: DeserializeOwned> ServerEvents<T> {
36 pub async fn recv(&mut self) -> Option<Result<T, ServerFnError>> {
40 let event = self.next_event().await?;
41 match event {
42 Ok(event) => {
43 let data: Result<T, ServerFnError> =
44 serde_json::from_str(&event.data).map_err(|err| {
45 ServerFnError::Serialization(format!(
46 "failed to deserialize event data: {}",
47 err
48 ))
49 });
50 Some(data)
51 }
52 Err(err) => Some(Err(err)),
53 }
54 }
55}
56
57impl<T> ServerEvents<T> {
58 pub async fn next_event(&mut self) -> Option<Result<ServerSentEvent, ServerFnError>> {
62 self.client.as_mut()?.next().await
63 }
64}
65
66impl<T: DeserializeOwned> Stream for ServerEvents<T> {
67 type Item = Result<T, ServerFnError>;
68
69 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
70 let Some(client) = self.client.as_mut() else {
71 return Poll::Ready(None);
72 };
73
74 match client.as_mut().poll_next(cx) {
75 Poll::Ready(Some(Ok(event))) => {
76 let data = serde_json::from_str(&event.data).map_err(|err| {
77 ServerFnError::Serialization(format!(
78 "failed to deserialize event data: {}",
79 err
80 ))
81 });
82 Poll::Ready(Some(data))
83 }
84 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
85 Poll::Ready(None) => Poll::Ready(None),
86 Poll::Pending => Poll::Pending,
87 }
88 }
89}
90
91impl<T> FromResponse for ServerEvents<T> {
92 async fn from_response(res: ClientResponse) -> Result<Self, ServerFnError> {
93 let status = res.status();
94 if status != StatusCode::OK {
95 return Err(ServerFnError::Request(RequestError::Status(
96 format!("Expected status 200 OK, got {}", status),
97 status.as_u16(),
98 )));
99 }
100
101 let content_type = res.headers().get(CONTENT_TYPE);
102 if content_type != Some(&HeaderValue::from_static(mime::TEXT_EVENT_STREAM.as_ref())) {
103 return Err(ServerFnError::Request(RequestError::Request(format!(
104 "Expected content type 'text/event-stream', got {:?}",
105 content_type
106 ))));
107 }
108
109 let mut stream = res
110 .bytes_stream()
111 .map(|result| result.map_err(std::io::Error::other))
112 .into_async_read();
113
114 let mut line_buffer = String::new();
115 let mut event_buffer = EventBuffer::new();
116
117 let stream: Pin<Box<dyn Stream<Item = Result<ServerSentEvent, ServerFnError>>>> = Box::pin(
118 async_stream::try_stream! {
119 loop {
120 line_buffer.clear();
121 if stream.read_line(&mut line_buffer).await.map_err(|err| ServerFnError::StreamError(err.to_string()))? == 0 {
122 break;
123 }
124
125 let line = if let Some(line) = line_buffer.strip_suffix('\n') {
126 line
127 } else {
128 &line_buffer
129 };
130
131 if line.is_empty() {
133 if let Some(event) = event_buffer.produce_event() {
134 yield event;
135 }
136 continue;
137 }
138
139 let (field, value) = line.split_once(':').unwrap_or((line, ""));
141 let value = value.strip_prefix(' ').unwrap_or(value);
142
143 match field {
145 "event" => event_buffer.set_event_type(value),
146 "data" => event_buffer.push_data(value),
147 "id" => event_buffer.set_id(value),
148 "retry" => {
149 if let Ok(millis) = value.parse() {
150 event_buffer.set_retry(Duration::from_millis(millis));
151 }
152 }
153 _ => {}
154 }
155 }
156 },
157 );
158
159 Ok(Self {
160 _marker: std::marker::PhantomData,
161 client: Some(stream),
162
163 #[cfg(feature = "server")]
164 keep_alive: None,
165
166 #[cfg(feature = "server")]
167 sse: None,
168 })
169 }
170}
171
172#[derive(Debug, Clone, Eq, PartialEq)]
174pub struct ServerSentEvent {
175 pub event_type: String,
177
178 pub data: String,
180
181 pub last_event_id: Option<String>,
183
184 pub retry: Option<Duration>,
186}
187
188struct EventBuffer {
190 event_type: String,
191 data: String,
192 last_event_id: Option<String>,
193 retry: Option<Duration>,
194}
195
196impl EventBuffer {
197 #[allow(clippy::new_without_default)]
199 fn new() -> Self {
200 Self {
201 event_type: String::new(),
202 data: String::new(),
203 last_event_id: None,
204 retry: None,
205 }
206 }
207
208 fn produce_event(&mut self) -> Option<ServerSentEvent> {
212 let event = if self.data.is_empty() {
213 None
214 } else {
215 Some(ServerSentEvent {
216 event_type: if self.event_type.is_empty() {
217 "message".to_string()
218 } else {
219 self.event_type.clone()
220 },
221 data: self.data.to_string(),
222 last_event_id: self.last_event_id.clone(),
223 retry: self.retry,
224 })
225 };
226
227 self.event_type.clear();
228 self.data.clear();
229
230 event
231 }
232
233 fn set_event_type(&mut self, event_type: &str) {
235 self.event_type.clear();
236 self.event_type.push_str(event_type);
237 }
238
239 fn push_data(&mut self, data: &str) {
241 if !self.data.is_empty() {
242 self.data.push('\n');
243 }
244 self.data.push_str(data);
245 }
246
247 fn set_id(&mut self, id: &str) {
248 self.last_event_id = Some(id.to_string());
249 }
250
251 fn set_retry(&mut self, retry: Duration) {
252 self.retry = Some(retry);
253 }
254}
255
256#[cfg(feature = "server")]
257pub use server_impl::*;
258
259#[cfg(feature = "server")]
260mod server_impl {
261 use super::*;
262 use crate::spawn_platform;
263 use axum::response::sse::Sse;
264 use axum_core::response::IntoResponse;
265 use futures::Future;
266 use futures::SinkExt;
267 use futures::{Sink, TryStream};
268 use serde::Serialize;
269
270 impl<T: 'static> ServerEvents<T> {
271 pub fn new<F, R>(f: impl FnOnce(SseTx<T>) -> F + Send + 'static) -> Self
275 where
276 F: Future<Output = R> + 'static,
277 R: 'static + Send,
278 {
279 let (tx, mut rx) = futures_channel::mpsc::unbounded();
280
281 let tx = SseTx {
282 sender: tx,
283 _marker: std::marker::PhantomData,
284 };
285
286 spawn_platform(move || f(tx));
288
289 let stream = futures::stream::poll_fn(move |cx| match rx.poll_next_unpin(cx) {
292 std::task::Poll::Ready(Some(event)) => std::task::Poll::Ready(Some(
293 Ok(event) as Result<axum::response::sse::Event, BoxError>
294 )),
295 std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
296 std::task::Poll::Pending => std::task::Poll::Pending,
297 });
298
299 let sse = Sse::new(stream.boxed());
300
301 Self {
302 _marker: std::marker::PhantomData,
303 client: None,
304 keep_alive: Some(KeepAlive::new().interval(Duration::from_secs(15))),
305 sse: Some(sse),
306 }
307 }
308
309 pub fn from_stream<S>(stream: S) -> Self
311 where
312 S: TryStream<Ok = T, Error = BoxError> + Send + 'static,
313 T: Serialize,
314 {
315 let stream = stream.map_ok(|event| {
316 axum::response::sse::Event::default()
317 .json_data(event)
318 .expect("Failed to serialize SSE event")
319 });
320 let sse = axum::response::Sse::new(stream.boxed());
321 Self {
322 _marker: std::marker::PhantomData,
323 client: None,
324 keep_alive: Some(KeepAlive::new().interval(Duration::from_secs(15))),
325 sse: Some(sse),
326 }
327 }
328
329 pub fn with_keep_alive(mut self, keep_alive: Option<KeepAlive>) -> Self {
333 self.keep_alive = keep_alive;
334 self
335 }
336
337 #[allow(clippy::type_complexity)]
339 pub fn from_sse(
340 sse: Sse<Pin<Box<dyn Stream<Item = Result<Event, BoxError>> + Send>>>,
341 ) -> Self {
342 Self {
343 _marker: std::marker::PhantomData,
344 client: None,
345 keep_alive: None,
346 sse: Some(sse),
347 }
348 }
349 }
350
351 impl<T> IntoResponse for ServerEvents<T> {
352 fn into_response(self) -> axum_core::response::Response {
353 let sse = self
354 .sse
355 .expect("SSE should be initialized before using it as a response");
356
357 if let Some(keep_alive) = self.keep_alive {
358 sse.keep_alive(keep_alive).into_response()
359 } else {
360 sse.into_response()
361 }
362 }
363 }
364
365 pub struct SseTx<T> {
367 sender: futures_channel::mpsc::UnboundedSender<axum::response::sse::Event>,
368 _marker: std::marker::PhantomData<fn() -> T>,
369 }
370
371 impl<T: Serialize> SseTx<T> {
372 pub async fn send(&mut self, event: T) -> anyhow::Result<()> {
374 let event = axum::response::sse::Event::default().json_data(event)?;
375 self.sender.unbounded_send(event)?;
376 Ok(())
377 }
378 }
379
380 impl<T> std::ops::Deref for SseTx<T> {
381 type Target = futures_channel::mpsc::UnboundedSender<axum::response::sse::Event>;
382 fn deref(&self) -> &Self::Target {
383 &self.sender
384 }
385 }
386
387 impl<T> std::ops::DerefMut for SseTx<T> {
388 fn deref_mut(&mut self) -> &mut Self::Target {
389 &mut self.sender
390 }
391 }
392
393 impl<T: Serialize> Sink<T> for SseTx<T> {
394 type Error = anyhow::Error;
395
396 fn poll_ready(
397 mut self: Pin<&mut Self>,
398 _cx: &mut Context<'_>,
399 ) -> Poll<Result<(), Self::Error>> {
400 self.sender.poll_ready_unpin(_cx).map_err(|e| e.into())
401 }
402
403 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
404 let event = axum::response::sse::Event::default().json_data(item)?;
405 self.sender.start_send(event).map_err(|e| e.into())
406 }
407
408 fn poll_flush(
409 mut self: Pin<&mut Self>,
410 _cx: &mut Context<'_>,
411 ) -> Poll<Result<(), Self::Error>> {
412 self.sender.poll_flush_unpin(_cx).map_err(|e| e.into())
413 }
414
415 fn poll_close(
416 mut self: Pin<&mut Self>,
417 _cx: &mut Context<'_>,
418 ) -> Poll<Result<(), Self::Error>> {
419 self.sender.poll_close_unpin(_cx).map_err(|e| e.into())
420 }
421 }
422}