Skip to main content

ranvier_http/
sse.rs

1use crate::response::{HttpResponse, IntoResponse};
2use bytes::Bytes;
3use futures_util::stream::Stream;
4
5use ranvier_core::event::EventSource;
6use serde::de::{self, Deserializer};
7use serde::ser::Serializer;
8use serde::{Deserialize, Serialize};
9use std::convert::Infallible;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct SseEvent {
16    pub(crate) data: Option<String>,
17    pub(crate) id: Option<String>,
18    pub(crate) event: Option<String>,
19    pub(crate) retry: Option<Duration>,
20    pub(crate) comment: Option<String>,
21}
22
23impl SseEvent {
24    pub fn default() -> Self {
25        Self {
26            data: None,
27            id: None,
28            event: None,
29            retry: None,
30            comment: None,
31        }
32    }
33
34    pub fn data(mut self, data: impl Into<String>) -> Self {
35        self.data = Some(data.into());
36        self
37    }
38
39    pub fn id(mut self, id: impl Into<String>) -> Self {
40        self.id = Some(id.into());
41        self
42    }
43
44    pub fn event(mut self, event: impl Into<String>) -> Self {
45        self.event = Some(event.into());
46        self
47    }
48
49    pub fn retry(mut self, duration: Duration) -> Self {
50        self.retry = Some(duration);
51        self
52    }
53
54    pub fn comment(mut self, comment: impl Into<String>) -> Self {
55        self.comment = Some(comment.into());
56        self
57    }
58
59    fn serialize(&self) -> String {
60        let mut out = String::new();
61        if let Some(comment) = &self.comment {
62            for line in comment.lines() {
63                out.push_str(&format!(": {}\n", line));
64            }
65        }
66        if let Some(event) = &self.event {
67            out.push_str(&format!("event: {}\n", event));
68        }
69        if let Some(id) = &self.id {
70            out.push_str(&format!("id: {}\n", id));
71        }
72        if let Some(retry) = &self.retry {
73            out.push_str(&format!("retry: {}\n", retry.as_millis()));
74        }
75        if let Some(data) = &self.data {
76            for line in data.lines() {
77                out.push_str(&format!("data: {}\n", line));
78            }
79        }
80        out.push('\n');
81        out
82    }
83}
84
85pub struct Sse<S> {
86    stream: S,
87}
88
89// Stub Serialize/Deserialize implementations for Sse.
90// Sse wraps a live stream so real (de)serialization is not meaningful;
91// these impls exist to satisfy the Axon<In, Out, E> type-parameter bounds
92// that require Serialize + DeserializeOwned on every intermediate type.
93impl<S> Serialize for Sse<S> {
94    fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
95        serializer.serialize_str("<<Sse stream>>")
96    }
97}
98
99impl<'de, S> Deserialize<'de> for Sse<S> {
100    fn deserialize<D: Deserializer<'de>>(_deserializer: D) -> Result<Self, D::Error> {
101        Err(de::Error::custom(
102            "Sse<S> cannot be deserialized; it wraps a live stream",
103        ))
104    }
105}
106
107impl<S, E> Sse<S>
108where
109    S: Stream<Item = Result<SseEvent, E>> + Send + 'static,
110    E: Into<Box<dyn std::error::Error + Send + Sync>>,
111{
112    pub fn new(stream: S) -> Self {
113        Self { stream }
114    }
115}
116
117pub struct FrameStream<S, E> {
118    inner: S,
119    _marker: std::marker::PhantomData<fn() -> E>,
120}
121
122impl<S, E> Stream for FrameStream<S, E>
123where
124    S: Stream<Item = Result<SseEvent, E>> + Unpin,
125    E: Into<Box<dyn std::error::Error + Send + Sync>>,
126{
127    type Item = Result<http_body::Frame<Bytes>, E>;
128
129    fn poll_next(
130        mut self: Pin<&mut Self>,
131        cx: &mut Context<'_>,
132    ) -> Poll<Option<<Self as Stream>::Item>> {
133        match Pin::new(&mut self.inner).poll_next(cx) {
134            Poll::Ready(Some(Ok(event))) => {
135                let serialized = event.serialize();
136                let frame = http_body::Frame::data(Bytes::from(serialized));
137                Poll::Ready(Some(Ok(frame)))
138            }
139            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
140            Poll::Ready(None) => Poll::Ready(None),
141            Poll::Pending => Poll::Pending,
142        }
143    }
144}
145
146impl<S, E> IntoResponse for Sse<S>
147where
148    S: Stream<Item = Result<SseEvent, E>> + Send + Sync + Unpin + 'static,
149    E: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
150{
151    fn into_response(self) -> HttpResponse {
152        let frame_stream = FrameStream {
153            inner: self.stream,
154            _marker: std::marker::PhantomData,
155        };
156
157        let mut frame_stream = Box::pin(frame_stream);
158        let infallible_stream = async_stream::stream! {
159            while let Some(res) = futures_util::StreamExt::next(&mut frame_stream).await {
160                match res {
161                    Ok(frame) => yield Ok::<_, std::convert::Infallible>(frame),
162                    Err(e) => {
163                        let err: Box<dyn std::error::Error + Send + Sync> = e.into();
164                        tracing::error!("SSE stream terminated with error: {:?}", err);
165                        break;
166                    }
167                }
168            }
169        };
170
171        let body = http_body_util::StreamBody::new(infallible_stream);
172
173        http::Response::builder()
174            .status(http::StatusCode::OK)
175            .header(http::header::CONTENT_TYPE, "text/event-stream")
176            .header(http::header::CACHE_CONTROL, "no-cache")
177            .header(http::header::CONNECTION, "keep-alive")
178            .body(http_body_util::BodyExt::boxed(body))
179            .expect("Valid builder")
180    }
181}
182
183pub fn from_event_source<E, S, F>(
184    mut source: S,
185    mut mapper: F,
186) -> impl Stream<Item = Result<SseEvent, Infallible>> + Send + Sync
187where
188    S: EventSource<E> + Send + 'static,
189    E: Send + 'static,
190    F: FnMut(E) -> SseEvent + Send + 'static,
191{
192    let (tx, mut rx) = tokio::sync::mpsc::channel(16);
193    tokio::spawn(async move {
194        while let Some(event) = source.next_event().await {
195            if tx.send(mapper(event)).await.is_err() {
196                break;
197            }
198        }
199    });
200
201    let stream = async_stream::stream! {
202        while let Some(event) = rx.recv().await {
203            yield Ok(event);
204        }
205    };
206    Box::pin(stream)
207}